src/HOL/Tools/Predicate_Compile/code_prolog.ML
changeset 38075 3d5e2b7d1374
parent 38073 64062d56ad3c
child 38076 2a9c14d9d2ef
--- a/src/HOL/Tools/Predicate_Compile/code_prolog.ML	Thu Jul 29 17:27:54 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/code_prolog.ML	Thu Jul 29 17:27:55 2010 +0200
@@ -13,7 +13,7 @@
 
   val generate : Proof.context -> string list -> logic_program
   val write_program : logic_program -> string
-  val run : logic_program -> string -> string list -> unit
+  val run : logic_program -> string -> string list -> term list
 
 end;
 
@@ -29,6 +29,10 @@
 
 datatype term = Var of string * typ | Cons of string | AppF of string * term list;
 
+fun string_of_prol_term (Var (s, T)) = "Var " ^ s
+  | string_of_prol_term (Cons s) = "Cons " ^ s
+  | string_of_prol_term (AppF (f, args)) = f ^ "(" ^ commas (map string_of_prol_term args) ^ ")" 
+
 datatype prem = Conj of prem list | NotRel of string * term list | Rel of string * term list | Eq of term * term | NotEq of term * term;
 
 fun dest_Rel (Rel (c, ts)) = (c, ts)
@@ -145,6 +149,36 @@
   ":- style_check(-singleton).\n\n" ^
   "main :- catch(eval, E, (print_message(error, E), fail)), halt.\n" ^
   "main :- halt(1).\n"
+
+(* parsing prolog solution *)
+
+val scan_atom =
+  Scan.repeat (Scan.one (fn s => Symbol.is_ascii_lower s orelse Symbol.is_ascii_quasi s))
+
+val scan_var =
+  Scan.repeat (Scan.one
+    (fn s => Symbol.is_ascii_upper s orelse Symbol.is_ascii_digit s orelse Symbol.is_ascii_quasi s))
+
+fun dest_Char (Symbol.Char s) = s
+
+val string_of = concat o map (dest_Char o Symbol.decode)
+
+val scan_term =
+  scan_atom >> (Cons o string_of) || scan_var >> (Var o rpair dummyT o string_of)
+
+val parse_term = fst o Scan.finite Symbol.stopper
+            (Scan.error (!! (fn _ => raise Fail "parsing prolog output failed"))
+                            scan_term)
+  o explode
+  
+fun parse_solution sol =
+  let
+    fun dest_eq s = (tracing s; tracing(commas (space_explode "=" s)); case space_explode "=" s of
+        (l :: r :: []) => parse_term (unprefix " " r)
+      | _ => raise Fail "unexpected equation in prolog output")
+  in
+    map dest_eq (fst (split_last (space_explode "\n" sol)))
+  end 
   
 (* calling external interpreter and getting results *)
 
@@ -155,8 +189,61 @@
     val prolog_file = File.tmp_path (Path.basic "prolog_file")
     val _ = File.write prolog_file prog
     val (solution, _) = bash_output ("/usr/bin/swipl -f " ^ File.shell_path prolog_file)
+    val ts = parse_solution solution
+    val _ = tracing (commas (map string_of_prol_term ts)) 
   in
-    tracing ("Prolog returns result:\n" ^ solution)
+    ts
   end
 
+(* values command *)
+
+fun mk_term (Var (s, T)) = Free (s, T)
+  | mk_term (Cons s) = Const (s, dummyT)
+  | mk_term (AppF (f, args)) = list_comb (Const (f, dummyT), map mk_term args)
+  
+fun values ctxt soln t_compr =
+  let
+    val split = case t_compr of (Const (@{const_name Collect}, _) $ t) => t
+      | _ => error ("Not a set comprehension: " ^ Syntax.string_of_term ctxt t_compr);
+    val (body, Ts, fp) = HOLogic.strip_psplits split;
+    val output_names = Name.variant_list (Term.add_free_names body [])
+      (map (fn i => "x" ^ string_of_int i) (1 upto length Ts))
+    val output_frees = map2 (curry Free) output_names (rev Ts)
+    val body = subst_bounds (output_frees, body)
+    val (pred as Const (name, T), all_args) =
+      case strip_comb body of
+        (Const (name, T), all_args) => (Const (name, T), all_args)
+      | (head, _) => error ("Not a constant: " ^ Syntax.string_of_term ctxt head)
+    val vnames =
+      case try (map (fst o dest_Free)) all_args of
+        SOME vs => vs
+      | NONE => error ("Not only free variables in " ^ commas (map (Syntax.string_of_term ctxt) all_args))
+    val ts = run (generate ctxt [name]) (translate_const name) (map first_upper vnames)
+  in
+    HOLogic.mk_tuple (map mk_term ts)
+  end
+
+fun values_cmd print_modes soln raw_t state =
+  let
+    val ctxt = Toplevel.context_of state
+    val t = Syntax.read_term ctxt raw_t
+    val t' = values ctxt soln t
+    val ty' = Term.type_of t'
+    val ctxt' = Variable.auto_fixes t' ctxt
+    val p = Print_Mode.with_modes print_modes (fn () =>
+      Pretty.block [Pretty.quote (Syntax.pretty_term ctxt' t'), Pretty.fbrk,
+        Pretty.str "::", Pretty.brk 1, Pretty.quote (Syntax.pretty_typ ctxt' ty')]) ();
+  in Pretty.writeln p end;
+
+
+(* renewing the values command for Prolog queries *)
+
+val opt_print_modes =
+  Scan.optional (Parse.$$$ "(" |-- Parse.!!! (Scan.repeat1 Parse.xname --| Parse.$$$ ")")) [];
+
+val _ = Outer_Syntax.improper_command "values" "enumerate and print comprehensions" Keyword.diag
+  (opt_print_modes -- Scan.optional Parse.nat ~1 -- Parse.term
+   >> (fn ((print_modes, soln), t) => Toplevel.keep
+        (values_cmd print_modes soln t)));
+
 end;