1use crate::*;
2use std::io::{self, BufRead, BufReader, IsTerminal, Read, Write};
3
4use clap::Parser;
5use env_logger::Env;
6use std::path::PathBuf;
7
8#[derive(Debug, Parser)]
9#[command(version = env!("FULL_VERSION"), about = env!("CARGO_PKG_DESCRIPTION"))]
10struct Args {
11 #[clap(short = 'F', long)]
13 fact_directory: Option<PathBuf>,
14 #[clap(long)]
16 naive: bool,
17 #[clap(long, default_value_t = RunMode::Normal)]
19 mode: RunMode,
20 inputs: Vec<PathBuf>,
22 #[clap(long)]
24 to_json: bool,
25 #[clap(long)]
27 to_dot: bool,
28 #[clap(long)]
30 to_svg: bool,
31 #[clap(long)]
33 serialize_split_primitive_outputs: bool,
34 #[clap(long, default_value = "40")]
36 max_functions: usize,
37 #[clap(long, default_value = "40")]
39 max_calls_per_function: usize,
40 #[clap(long, default_value = "0")]
42 serialize_n_inline_leaves: usize,
43 #[clap(short = 'j', long, default_value = "1")]
44 threads: usize,
47 #[arg(value_enum)]
48 #[clap(long, default_value_t = ReportLevel::TimeOnly)]
49 report_level: ReportLevel,
50 #[clap(long)]
51 save_report: Option<PathBuf>,
52 #[clap(long = "strict-mode")]
54 strict_mode: bool,
55}
56
57#[allow(clippy::disallowed_macros)]
62pub fn cli(mut egraph: EGraph) {
63 env_logger::Builder::from_env(Env::default().default_filter_or("warn"))
64 .format_timestamp(None)
65 .format_target(false)
66 .parse_default_env()
67 .init();
68
69 let args = Args::parse();
70 rayon::ThreadPoolBuilder::new()
71 .num_threads(args.threads)
72 .build_global()
73 .unwrap();
74 log::debug!(
75 "Initialized thread pool with {} threads",
76 rayon::current_num_threads()
77 );
78 egraph.fact_directory.clone_from(&args.fact_directory);
79 egraph.seminaive = !args.naive;
80 egraph.set_report_level(args.report_level);
81 if args.strict_mode {
82 egraph.set_strict_mode(true);
83 }
84 if args.inputs.is_empty() {
85 match egraph.repl(args.mode) {
86 Ok(()) => std::process::exit(0),
87 Err(err) => {
88 log::error!("{err}");
89 std::process::exit(1)
90 }
91 }
92 } else {
93 for input in &args.inputs {
94 let program = std::fs::read_to_string(input).unwrap_or_else(|_| {
95 let arg = input.to_string_lossy();
96 panic!("Failed to read file {arg}")
97 });
98
99 match run_commands(
100 &mut egraph,
101 Some(input.to_str().unwrap().into()),
102 &program,
103 io::stdout(),
104 args.mode,
105 ) {
106 Ok(None) => {}
107 _ => std::process::exit(1),
108 }
109
110 if args.to_json || args.to_dot || args.to_svg {
111 let serialized_output = egraph.serialize(SerializeConfig {
112 max_functions: Some(args.max_functions),
113 max_calls_per_function: Some(args.max_calls_per_function),
114 ..SerializeConfig::default()
115 });
116 if !serialized_output.is_complete() {
117 log::warn!("{}", serialized_output.omitted_description());
118 }
119 let mut serialized = serialized_output.egraph;
120 if args.serialize_split_primitive_outputs {
121 serialized.split_classes(|id, _| egraph.from_node_id(id).is_primitive())
122 }
123 for _ in 0..args.serialize_n_inline_leaves {
124 serialized.inline_leaves();
125 }
126
127 let serialize_filename = if args.serialize_split_primitive_outputs {
129 input.with_file_name(format!(
130 "{}-split",
131 input.file_stem().unwrap().to_str().unwrap()
132 ))
133 } else {
134 input.clone()
135 };
136 if args.to_dot {
137 let dot_path = serialize_filename.with_extension("dot");
138 serialized
139 .to_dot_file(dot_path.clone())
140 .unwrap_or_else(|_| panic!("Failed to write dot file to {dot_path:?}"));
141 }
142 if args.to_svg {
143 let svg_path = serialize_filename.with_extension("svg");
144 serialized.to_svg_file(svg_path.clone()).unwrap_or_else( |_|
145 panic!("Failed to write svg file to {svg_path:?}. Make sure you have the `dot` executable installed")
146 );
147 }
148 if args.to_json {
149 let json_path = serialize_filename.with_extension("json");
150 serialized
151 .to_json_file(json_path.clone())
152 .unwrap_or_else(|_| panic!("Failed to write json file to {json_path:?}"));
153 }
154 }
155 }
156 }
157
158 if let Some(report_path) = args.save_report {
159 let report = egraph.get_overall_run_report();
160 serde_json::to_writer(
161 std::fs::File::create(&report_path)
162 .unwrap_or_else(|_| panic!("Failed to create report file at {report_path:?}")),
163 &report,
164 )
165 .expect("Failed to serialize report");
166 log::info!("Saved report to {report_path:?}");
167 }
168
169 std::mem::forget(egraph)
171}
172
173impl EGraph {
174 pub fn repl(&mut self, mode: RunMode) -> io::Result<()> {
176 self.repl_with(io::stdin(), io::stdout(), mode, io::stdin().is_terminal())
177 }
178
179 pub fn repl_with<R, W>(
181 &mut self,
182 input: R,
183 mut output: W,
184 mode: RunMode,
185 is_terminal: bool,
186 ) -> io::Result<()>
187 where
188 R: Read,
189 W: Write,
190 {
191 if is_terminal {
193 output.write_all(welcome_prompt().as_bytes())?;
194 output.write_all(b"\n> ")?;
195 output.flush()?;
196 }
197 let mut cmd_buffer = String::new();
198
199 for line in BufReader::new(input).lines() {
200 let line_str = line?;
201 cmd_buffer.push_str(&line_str);
202 cmd_buffer.push('\n');
203 if should_eval(&cmd_buffer) {
205 run_commands(self, None, &cmd_buffer, &mut output, mode)?;
206 cmd_buffer = String::new();
207 if is_terminal {
208 output.write_all(b"> ")?;
209 output.flush()?;
210 }
211 }
212 }
213
214 if !cmd_buffer.is_empty() {
215 run_commands(self, None, &cmd_buffer, &mut output, mode)?;
216 }
217
218 Ok(())
219 }
220}
221
222fn welcome_prompt() -> String {
223 format!("Welcome to Egglog REPL! (build: {})", env!("FULL_VERSION"))
224}
225
226fn should_eval(curr_cmd: &str) -> bool {
227 all_sexps(SexpParser::new(None, curr_cmd)).is_ok()
228}
229
230fn run_commands<W>(
231 egraph: &mut EGraph,
232 filename: Option<String>,
233 command: &str,
234 mut output: W,
235 mode: RunMode,
236) -> io::Result<Option<Error>>
237where
238 W: Write,
239{
240 if mode == RunMode::ShowDesugaredEgglog {
241 return Ok(match egraph.resugar_program(filename, command) {
242 Ok(desugared) => {
243 for line in desugared {
244 writeln!(output, "{line}")?;
245 }
246 None
247 }
248 Err(err) => {
249 log::error!("{err}");
250 Some(err)
251 }
252 });
253 };
254
255 Ok(match egraph.parse_and_run_program(filename, command) {
256 Ok(msgs) => {
257 if mode != RunMode::NoMessages {
258 for msg in msgs {
259 write!(output, "{msg}")?;
260 }
261 }
262 if mode == RunMode::Interactive {
263 writeln!(output, "(done)")?;
264 }
265 None
266 }
267 Err(err) => {
268 log::error!("{err}");
269 if mode == RunMode::Interactive {
270 writeln!(output, "(error)")?;
271 }
272 Some(err)
273 }
274 })
275}
276
277#[derive(Debug, Clone, PartialEq, Eq, Hash, Copy)]
278pub enum RunMode {
279 Normal,
280 ShowDesugaredEgglog,
281 Interactive,
282 NoMessages,
283}
284
285impl Display for RunMode {
286 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
287 match self {
288 RunMode::Normal => write!(f, "normal"),
289 RunMode::ShowDesugaredEgglog => write!(f, "resugar"),
290 RunMode::Interactive => write!(f, "interactive"),
291 RunMode::NoMessages => write!(f, "no-messages"),
292 }
293 }
294}
295
296impl FromStr for RunMode {
297 type Err = String;
298
299 fn from_str(s: &str) -> Result<Self, Self::Err> {
300 match s {
301 "normal" => Ok(RunMode::Normal),
302 "resugar" => Ok(RunMode::ShowDesugaredEgglog),
303 "interactive" => Ok(RunMode::Interactive),
304 "no-messages" => Ok(RunMode::NoMessages),
305 _ => Err(format!("Unknown run mode: {s}")),
306 }
307 }
308}
309
310#[cfg(test)]
311mod tests {
312 use super::*;
313
314 #[test]
315 fn test_should_eval() {
316 #[rustfmt::skip]
317 let test_cases = vec![
318 vec![
319 "(extract",
320 "\"1",
321 ")",
322 "(",
323 ")))",
324 "\"",
325 ";; )",
326 ")"
327 ],
328 vec![
329 "(extract 1) (extract",
330 "2) (",
331 "extract 3) (extract 4) ;;;; ("
332 ],
333 vec![
334 "(extract \"\\\")\")"
335 ]];
336 for test in test_cases {
337 let mut cmd_buffer = String::new();
338 for (i, line) in test.iter().enumerate() {
339 cmd_buffer.push_str(line);
340 cmd_buffer.push('\n');
341 assert_eq!(should_eval(&cmd_buffer), i == test.len() - 1);
342 }
343 }
344 }
345
346 #[test]
347 fn test_repl() {
348 let mut egraph = EGraph::default();
349
350 let input = "(extract 1)";
351 let mut output = Vec::new();
352 egraph
353 .repl_with(input.as_bytes(), &mut output, RunMode::Normal, false)
354 .unwrap();
355 assert_eq!(String::from_utf8(output).unwrap(), "1\n");
356
357 let input = "\n\n\n";
358 let mut output = Vec::new();
359 egraph
360 .repl_with(input.as_bytes(), &mut output, RunMode::Normal, false)
361 .unwrap();
362 assert_eq!(String::from_utf8(output).unwrap(), "");
363
364 let input = "(extract 1)";
365 let mut output = Vec::new();
366 egraph
367 .repl_with(input.as_bytes(), &mut output, RunMode::Interactive, false)
368 .unwrap();
369 assert_eq!(String::from_utf8(output).unwrap(), "1\n(done)\n");
370
371 let input = "xyz";
372 let mut output: Vec<u8> = Vec::new();
373 egraph
374 .repl_with(input.as_bytes(), &mut output, RunMode::Interactive, false)
375 .unwrap();
376 assert_eq!(String::from_utf8(output).unwrap(), "(error)\n");
377
378 let input = "(extract 1)";
379 let mut output = Vec::new();
380 egraph
381 .repl_with(
382 input.as_bytes(),
383 &mut output,
384 RunMode::ShowDesugaredEgglog,
385 false,
386 )
387 .unwrap();
388 assert_eq!(String::from_utf8(output).unwrap(), "(extract 1 0)\n");
389
390 let input = "(extract 1)";
391 let mut output = Vec::new();
392 egraph
393 .repl_with(input.as_bytes(), &mut output, RunMode::NoMessages, false)
394 .unwrap();
395 assert_eq!(String::from_utf8(output).unwrap(), "");
396
397 let input = "(extract 1)";
398 let mut output = Vec::new();
399 egraph
400 .repl_with(input.as_bytes(), &mut output, RunMode::Normal, true)
401 .unwrap();
402 assert_eq!(
403 String::from_utf8(output).unwrap(),
404 format!("{}\n> 1\n> ", welcome_prompt())
405 );
406 }
407}