egglog/
cli.rs

1use crate::*;
2use std::io::{self, BufRead, BufReader, IsTerminal, Read, Write};
3
4use clap::Parser;
5use env_logger::Env;
6use std::path::PathBuf;
7
8#[derive(Debug, Parser)]
9#[command(version = env!("FULL_VERSION"), about = env!("CARGO_PKG_DESCRIPTION"))]
10struct Args {
11    /// Directory for files when using `input` and `output` commands
12    #[clap(short = 'F', long)]
13    fact_directory: Option<PathBuf>,
14    /// Turns off the seminaive optimization
15    #[clap(long)]
16    naive: bool,
17    /// Prints extra information, which can be useful for debugging
18    #[clap(long, default_value_t = RunMode::Normal)]
19    mode: RunMode,
20    /// The file names for the egglog files to run
21    inputs: Vec<PathBuf>,
22    /// Serializes the egraph for each egglog file as JSON
23    #[clap(long)]
24    to_json: bool,
25    /// Serializes the egraph for each egglog file as a dot file
26    #[clap(long)]
27    to_dot: bool,
28    /// Serializes the egraph for each egglog file as an SVG
29    #[clap(long)]
30    to_svg: bool,
31    /// Splits the serialized egraph into primitives and non-primitives
32    #[clap(long)]
33    serialize_split_primitive_outputs: bool,
34    /// Maximum number of function nodes to render in dot/svg output
35    #[clap(long, default_value = "40")]
36    max_functions: usize,
37    /// Maximum number of calls per function to render in dot/svg output
38    #[clap(long, default_value = "40")]
39    max_calls_per_function: usize,
40    /// Number of times to inline leaves
41    #[clap(long, default_value = "0")]
42    serialize_n_inline_leaves: usize,
43    #[clap(short = 'j', long, default_value = "1")]
44    /// Number of threads to use for parallel execution. Passing `0` will use the maximum
45    /// inferred parallelism available on the current system.
46    threads: usize,
47    #[arg(value_enum)]
48    #[clap(long, default_value_t = ReportLevel::TimeOnly)]
49    report_level: ReportLevel,
50    #[clap(long)]
51    save_report: Option<PathBuf>,
52    /// Treat missing `$` prefixes on globals as errors instead of warnings
53    #[clap(long = "strict-mode")]
54    strict_mode: bool,
55}
56
57/// Start a command-line interface for the E-graph.
58///
59/// This is what vanilla egglog uses, and custom egglog builds (i.e., "egglog batteries included")
60/// should also call this function.
61#[allow(clippy::disallowed_macros)]
62pub fn cli(mut egraph: EGraph) {
63    env_logger::Builder::from_env(Env::default().default_filter_or("warn"))
64        .format_timestamp(None)
65        .format_target(false)
66        .parse_default_env()
67        .init();
68
69    let args = Args::parse();
70    rayon::ThreadPoolBuilder::new()
71        .num_threads(args.threads)
72        .build_global()
73        .unwrap();
74    log::debug!(
75        "Initialized thread pool with {} threads",
76        rayon::current_num_threads()
77    );
78    egraph.fact_directory.clone_from(&args.fact_directory);
79    egraph.seminaive = !args.naive;
80    egraph.set_report_level(args.report_level);
81    if args.strict_mode {
82        egraph.set_strict_mode(true);
83    }
84    if args.inputs.is_empty() {
85        match egraph.repl(args.mode) {
86            Ok(()) => std::process::exit(0),
87            Err(err) => {
88                log::error!("{err}");
89                std::process::exit(1)
90            }
91        }
92    } else {
93        for input in &args.inputs {
94            let program = std::fs::read_to_string(input).unwrap_or_else(|_| {
95                let arg = input.to_string_lossy();
96                panic!("Failed to read file {arg}")
97            });
98
99            match run_commands(
100                &mut egraph,
101                Some(input.to_str().unwrap().into()),
102                &program,
103                io::stdout(),
104                args.mode,
105            ) {
106                Ok(None) => {}
107                _ => std::process::exit(1),
108            }
109
110            if args.to_json || args.to_dot || args.to_svg {
111                let serialized_output = egraph.serialize(SerializeConfig {
112                    max_functions: Some(args.max_functions),
113                    max_calls_per_function: Some(args.max_calls_per_function),
114                    ..SerializeConfig::default()
115                });
116                if !serialized_output.is_complete() {
117                    log::warn!("{}", serialized_output.omitted_description());
118                }
119                let mut serialized = serialized_output.egraph;
120                if args.serialize_split_primitive_outputs {
121                    serialized.split_classes(|id, _| egraph.from_node_id(id).is_primitive())
122                }
123                for _ in 0..args.serialize_n_inline_leaves {
124                    serialized.inline_leaves();
125                }
126
127                // if we are splitting primitive outputs, add `-split` to the end of the file name
128                let serialize_filename = if args.serialize_split_primitive_outputs {
129                    input.with_file_name(format!(
130                        "{}-split",
131                        input.file_stem().unwrap().to_str().unwrap()
132                    ))
133                } else {
134                    input.clone()
135                };
136                if args.to_dot {
137                    let dot_path = serialize_filename.with_extension("dot");
138                    serialized
139                        .to_dot_file(dot_path.clone())
140                        .unwrap_or_else(|_| panic!("Failed to write dot file to {dot_path:?}"));
141                }
142                if args.to_svg {
143                    let svg_path = serialize_filename.with_extension("svg");
144                    serialized.to_svg_file(svg_path.clone()).unwrap_or_else( |_|
145                        panic!("Failed to write svg file to {svg_path:?}. Make sure you have the `dot` executable installed")
146                    );
147                }
148                if args.to_json {
149                    let json_path = serialize_filename.with_extension("json");
150                    serialized
151                        .to_json_file(json_path.clone())
152                        .unwrap_or_else(|_| panic!("Failed to write json file to {json_path:?}"));
153                }
154            }
155        }
156    }
157
158    if let Some(report_path) = args.save_report {
159        let report = egraph.get_overall_run_report();
160        serde_json::to_writer(
161            std::fs::File::create(&report_path)
162                .unwrap_or_else(|_| panic!("Failed to create report file at {report_path:?}")),
163            &report,
164        )
165        .expect("Failed to serialize report");
166        log::info!("Saved report to {report_path:?}");
167    }
168
169    // no need to drop the egraph if we are going to exit
170    std::mem::forget(egraph)
171}
172
173impl EGraph {
174    /// Start a Read-Eval-Print Loop with standard I/O.
175    pub fn repl(&mut self, mode: RunMode) -> io::Result<()> {
176        self.repl_with(io::stdin(), io::stdout(), mode, io::stdin().is_terminal())
177    }
178
179    /// Start a Read-Eval-Print Loop with the given input and output channel.
180    pub fn repl_with<R, W>(
181        &mut self,
182        input: R,
183        mut output: W,
184        mode: RunMode,
185        is_terminal: bool,
186    ) -> io::Result<()>
187    where
188        R: Read,
189        W: Write,
190    {
191        // https://doc.rust-lang.org/beta/std/io/trait.IsTerminal.html#examples
192        if is_terminal {
193            output.write_all(welcome_prompt().as_bytes())?;
194            output.write_all(b"\n> ")?;
195            output.flush()?;
196        }
197        let mut cmd_buffer = String::new();
198
199        for line in BufReader::new(input).lines() {
200            let line_str = line?;
201            cmd_buffer.push_str(&line_str);
202            cmd_buffer.push('\n');
203            // handles multi-line commands
204            if should_eval(&cmd_buffer) {
205                run_commands(self, None, &cmd_buffer, &mut output, mode)?;
206                cmd_buffer = String::new();
207                if is_terminal {
208                    output.write_all(b"> ")?;
209                    output.flush()?;
210                }
211            }
212        }
213
214        if !cmd_buffer.is_empty() {
215            run_commands(self, None, &cmd_buffer, &mut output, mode)?;
216        }
217
218        Ok(())
219    }
220}
221
222fn welcome_prompt() -> String {
223    format!("Welcome to Egglog REPL! (build: {})", env!("FULL_VERSION"))
224}
225
226fn should_eval(curr_cmd: &str) -> bool {
227    all_sexps(SexpParser::new(None, curr_cmd)).is_ok()
228}
229
230fn run_commands<W>(
231    egraph: &mut EGraph,
232    filename: Option<String>,
233    command: &str,
234    mut output: W,
235    mode: RunMode,
236) -> io::Result<Option<Error>>
237where
238    W: Write,
239{
240    if mode == RunMode::ShowDesugaredEgglog {
241        return Ok(match egraph.resugar_program(filename, command) {
242            Ok(desugared) => {
243                for line in desugared {
244                    writeln!(output, "{line}")?;
245                }
246                None
247            }
248            Err(err) => {
249                log::error!("{err}");
250                Some(err)
251            }
252        });
253    };
254
255    Ok(match egraph.parse_and_run_program(filename, command) {
256        Ok(msgs) => {
257            if mode != RunMode::NoMessages {
258                for msg in msgs {
259                    write!(output, "{msg}")?;
260                }
261            }
262            if mode == RunMode::Interactive {
263                writeln!(output, "(done)")?;
264            }
265            None
266        }
267        Err(err) => {
268            log::error!("{err}");
269            if mode == RunMode::Interactive {
270                writeln!(output, "(error)")?;
271            }
272            Some(err)
273        }
274    })
275}
276
277#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
278pub enum RunMode {
279    Normal,
280    ShowDesugaredEgglog,
281    Interactive,
282    NoMessages,
283}
284
285impl Display for RunMode {
286    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
287        match self {
288            RunMode::Normal => write!(f, "normal"),
289            RunMode::ShowDesugaredEgglog => write!(f, "resugar"),
290            RunMode::Interactive => write!(f, "interactive"),
291            RunMode::NoMessages => write!(f, "no-messages"),
292        }
293    }
294}
295
296impl FromStr for RunMode {
297    type Err = String;
298
299    fn from_str(s: &str) -> Result<Self, Self::Err> {
300        match s {
301            "normal" => Ok(RunMode::Normal),
302            "resugar" => Ok(RunMode::ShowDesugaredEgglog),
303            "interactive" => Ok(RunMode::Interactive),
304            "no-messages" => Ok(RunMode::NoMessages),
305            _ => Err(format!("Unknown run mode: {s}")),
306        }
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313
314    #[test]
315    fn test_should_eval() {
316        #[rustfmt::skip]
317        let test_cases = vec![
318            vec![
319                "(extract",
320                "\"1",
321                ")",
322                "(",
323                ")))",
324                "\"",
325                ";; )",
326                ")"
327            ],
328            vec![
329                "(extract 1) (extract",
330                "2) (",
331                "extract 3) (extract 4) ;;;; ("
332            ],
333            vec![
334                "(extract \"\\\")\")"
335            ]];
336        for test in test_cases {
337            let mut cmd_buffer = String::new();
338            for (i, line) in test.iter().enumerate() {
339                cmd_buffer.push_str(line);
340                cmd_buffer.push('\n');
341                assert_eq!(should_eval(&cmd_buffer), i == test.len() - 1);
342            }
343        }
344    }
345
346    #[test]
347    fn test_repl() {
348        let mut egraph = EGraph::default();
349
350        let input = "(extract 1)";
351        let mut output = Vec::new();
352        egraph
353            .repl_with(input.as_bytes(), &mut output, RunMode::Normal, false)
354            .unwrap();
355        assert_eq!(String::from_utf8(output).unwrap(), "1\n");
356
357        let input = "\n\n\n";
358        let mut output = Vec::new();
359        egraph
360            .repl_with(input.as_bytes(), &mut output, RunMode::Normal, false)
361            .unwrap();
362        assert_eq!(String::from_utf8(output).unwrap(), "");
363
364        let input = "(extract 1)";
365        let mut output = Vec::new();
366        egraph
367            .repl_with(input.as_bytes(), &mut output, RunMode::Interactive, false)
368            .unwrap();
369        assert_eq!(String::from_utf8(output).unwrap(), "1\n(done)\n");
370
371        let input = "xyz";
372        let mut output: Vec<u8> = Vec::new();
373        egraph
374            .repl_with(input.as_bytes(), &mut output, RunMode::Interactive, false)
375            .unwrap();
376        assert_eq!(String::from_utf8(output).unwrap(), "(error)\n");
377
378        let input = "(extract 1)";
379        let mut output = Vec::new();
380        egraph
381            .repl_with(
382                input.as_bytes(),
383                &mut output,
384                RunMode::ShowDesugaredEgglog,
385                false,
386            )
387            .unwrap();
388        assert_eq!(String::from_utf8(output).unwrap(), "(extract 1 0)\n");
389
390        let input = "(extract 1)";
391        let mut output = Vec::new();
392        egraph
393            .repl_with(input.as_bytes(), &mut output, RunMode::NoMessages, false)
394            .unwrap();
395        assert_eq!(String::from_utf8(output).unwrap(), "");
396
397        let input = "(extract 1)";
398        let mut output = Vec::new();
399        egraph
400            .repl_with(input.as_bytes(), &mut output, RunMode::Normal, true)
401            .unwrap();
402        assert_eq!(
403            String::from_utf8(output).unwrap(),
404            format!("{}\n> 1\n> ", welcome_prompt())
405        );
406    }
407}