src/HOL/Tools/SMT/lethe_proof.ML
changeset 78177 ea7a3cc64df5
parent 76183 8089593a364a
--- a/src/HOL/Tools/SMT/lethe_proof.ML	Sat Jun 17 17:41:02 2023 +0200
+++ b/src/HOL/Tools/SMT/lethe_proof.ML	Mon Jun 19 22:28:09 2023 +0200
@@ -32,6 +32,17 @@
      term -> (string * typ) list -> term Symtab.table -> (string * term) list ->
      (string * typ) list * term list * term list * lethe_replay_node list -> lethe_replay_node
 
+datatype raw_lethe_node = Raw_Lethe_Node of {
+  id: string,
+  rule: string,
+  args: SMTLIB.tree,
+  prems: string list,
+  concl: SMTLIB.tree,
+  declarations: (string * SMTLIB.tree) list,
+  subproof: raw_lethe_node list}
+ val parse_raw_proof_steps: string option -> SMTLIB.tree list -> SMTLIB_Proof.name_bindings -> int ->
+ raw_lethe_node list * SMTLIB.tree list * SMTLIB_Proof.name_bindings
+
   (*proof parser*)
   val parse: typ Symtab.table -> term Symtab.table -> string list ->
     Proof.context -> lethe_step list * Proof.context
@@ -60,6 +71,7 @@
   val theory_resolution2_rule: string
   val equiv_pos2_rule: string
   val and_pos_rule: string
+  val hole: string
   val th_resolution_rule: string
 
   val is_skolemization: string -> bool
@@ -145,6 +157,7 @@
 val equiv_pos2_rule = "equiv_pos2"
 val th_resolution_rule = "th_resolution"
 val and_pos_rule = "and_pos"
+val hole = "hole"
 
 val is_lethe_def = String.isSuffix lethe_def
 val skolemization_steps = ["sko_forall", "sko_ex"]
@@ -202,6 +215,7 @@
                 (case node_of (SMTLIB.Sym y) cx of
                   ((_, []), _) => [([x], typ)]
                 | _ => [([x, y], typ)])
+             | (SMTLIB.S (SMTLIB.Sym "=" :: SMTLIB.S [SMTLIB.Sym x, typ] :: SMTLIB.Sym y :: []), _) => [([x, y], SOME typ)]
              | (SMTLIB.S (SMTLIB.Sym "=" :: SMTLIB.Sym x :: _), typ) => [([x], typ)]
              |  t => raise (Fail ("match error " ^ @{make_string} t)))
     |> flat
@@ -220,7 +234,7 @@
   fold (fn (SMTLIB.S [SMTLIB.Sym "=", _, SMTLIB.Sym y]) => curry (op ::) y) bds []
 
 (*FIXME there is probably a way to use the information given by onepoint*)
-fun bound_vars_by_rule _ "bind" (bds) = extract_symbols bds
+fun bound_vars_by_rule _ "bind" bds = extract_symbols bds
   | bound_vars_by_rule cx "onepoint" bds = extract_qnt_symbols cx bds
   | bound_vars_by_rule _ "sko_forall" bds = extract_symbols_map bds
   | bound_vars_by_rule _ "sko_ex" bds = extract_symbols_map bds
@@ -243,17 +257,19 @@
 
 end
 
-datatype step_kind = ASSUME | ANCHOR | NO_STEP | NORMAL_STEP | SKOLEM
+datatype step_kind = ASSUME | ASSERT | ANCHOR | NO_STEP | NORMAL_STEP | SKOLEM
 
-fun parse_raw_proof_steps (limit : string option) (ls : SMTLIB.tree list) (cx : name_bindings) :
+fun parse_raw_proof_steps (limit : string option) (ls : SMTLIB.tree list) (cx : name_bindings) (assms_nbr : int):
      (raw_lethe_node list * SMTLIB.tree list * name_bindings) =
   let
     fun rotate_pair (a, (b, c)) = ((a, b), c)
     fun step_kind [] = (NO_STEP, SMTLIB.S [], [])
       | step_kind ((p as SMTLIB.S (SMTLIB.Sym "anchor" :: _)) :: l) = (ANCHOR, p, l)
       | step_kind ((p as SMTLIB.S (SMTLIB.Sym "assume" :: _)) :: l) = (ASSUME, p, l)
+      | step_kind ((p as SMTLIB.S (SMTLIB.Sym "assert" :: _)) :: l) = (ASSERT, p, l)
       | step_kind ((p as SMTLIB.S (SMTLIB.Sym "step" :: _)) :: l) = (NORMAL_STEP, p, l)
       | step_kind ((p as SMTLIB.S (SMTLIB.Sym "define-fun" :: _)) :: l) = (SKOLEM, p, l)
+      | step_kind ((p as SMTLIB.S (SMTLIB.Sym "declare-fun" :: _)) :: l) = (SKOLEM, p, l)
       | step_kind p = raise (Fail ("step_kind unrec: " ^ @{make_string} p))
     fun parse_skolem (SMTLIB.S [SMTLIB.Sym "define-fun", SMTLIB.Sym id,  _, typ,
            SMTLIB.S (SMTLIB.Sym "!" :: t :: [SMTLIB.Key _, SMTLIB.Sym name])]) cx =
@@ -266,11 +282,19 @@
               (SMTLIB.S [SMTLIB.Sym "=", SMTLIB.Sym id, l]) [], cx)
          end
       | parse_skolem (SMTLIB.S [SMTLIB.Sym "define-fun", SMTLIB.Sym id,  _, typ, SMTLIB.S l]) cx =
-         let val (l, cx) = (fst oo SMTLIB_Proof.extract_and_update_name_bindings) (SMTLIB.S l ) cx
+         let val (l, cx) = (fst oo SMTLIB_Proof.extract_and_update_name_bindings) (SMTLIB.S l) cx
          in
            (mk_raw_node (id ^ lethe_def) lethe_def (SMTLIB.S [SMTLIB.Sym id, typ, l]) [] []
               (SMTLIB.S [SMTLIB.Sym "=", SMTLIB.Sym id, l]) [], cx)
          end
+      | parse_skolem (SMTLIB.S [SMTLIB.Sym "declare-fun", SMTLIB.Sym id, typ, def]) cx =
+         (*replace the name binding by the constant instead of the full term in order to reduce
+           the size of the generated terms and therefore the reconstruction time*)
+         let val (l, cx) = (fst oo SMTLIB_Proof.extract_and_update_name_bindings) def cx
+         in
+           (mk_raw_node (id ^ lethe_def) lethe_def (SMTLIB.S [SMTLIB.Sym id, typ, l]) [] []
+              (SMTLIB.S [SMTLIB.Sym "=", SMTLIB.Sym id, def]) [], cx)
+         end
       | parse_skolem t _ = raise Fail ("unrecognized Lethe proof " ^ \<^make_string> t)
     fun get_id_cx (SMTLIB.S ((SMTLIB.Sym _) :: (SMTLIB.Sym id) :: l), cx) = (id, (l, cx))
       | get_id_cx t = raise Fail ("unrecognized Lethe proof " ^ \<^make_string> t)
@@ -318,28 +342,41 @@
               val (s, (_, cx)) =  (p, cx)
                 |> parse_normal_step
                 |>>  (to_raw_node [])
-              val (rp, rl, cx) = parse_raw_proof_steps limit l cx
+              val (rp, rl, cx) = parse_raw_proof_steps limit l cx assms_nbr
           in (s :: rp, rl, cx) end
       | (ASSUME, p, l) =>
           let
             val (id, t :: []) = p
               |> get_id
+
             val ((t, cx), _) = SMTLIB_Proof.extract_and_update_name_bindings t cx
             val s = mk_raw_node id input_rule (SMTLIB.S []) [] [] t []
-            val (rp, rl, cx) = parse_raw_proof_steps limit l cx
+            (*Recursive call to parse rest of the steps.*)
+            val (rp, rl, cx) = parse_raw_proof_steps limit l cx (assms_nbr + 1)
+          in (s :: rp, rl, cx) end
+      | (ASSERT, p, l) => 
+          let
+            val (id, term) = (case p of
+                SMTLIB.S [SMTLIB.Sym "assert", SMTLIB.S [SMTLIB.Sym "!", term, SMTLIB.Key "named", SMTLIB.Sym id]] => (id, term)
+              | SMTLIB.S [SMTLIB.Sym "assert", term] => (Int.toString assms_nbr, term))
+
+            val ((t, cx), _) = SMTLIB_Proof.extract_and_update_name_bindings term cx
+            val s = mk_raw_node id input_rule (SMTLIB.S []) [] [] t []
+            (*Recursive call to parse rest of the steps.*)
+            val (rp, rl, cx) = parse_raw_proof_steps limit l cx (assms_nbr+1)
           in (s :: rp, rl, cx) end
       | (ANCHOR, p, l) =>
           let
             val (anchor_id, (anchor_args, (_, cx))) = (p, cx) |> (parse_anchor_step ##> parse_args)
-            val (subproof, discharge_step :: remaining_proof, cx) = parse_raw_proof_steps (SOME anchor_id) l cx
+            val (subproof, discharge_step :: remaining_proof, cx) = parse_raw_proof_steps (SOME anchor_id) l cx assms_nbr
             val (curss, (_, cx)) = parse_normal_step (discharge_step, cx)
             val s = to_raw_node subproof (fst curss, anchor_args)
-            val (rp, rl, cx) = parse_raw_proof_steps limit remaining_proof cx
+            val (rp, rl, cx) = parse_raw_proof_steps limit remaining_proof cx assms_nbr
           in (s :: rp, rl, cx) end
       | (SKOLEM, p, l) =>
           let
             val (s, cx) = parse_skolem p cx
-            val (rp, rl, cx) = parse_raw_proof_steps limit l cx
+            val (rp, rl, cx) = parse_raw_proof_steps limit l cx (assms_nbr)
           in (s :: rp, rl, cx) end
   end
 
@@ -352,6 +389,7 @@
 
 fun args_of_rule "bind" t = t
   | args_of_rule "la_generic" t = t
+  | args_of_rule "all_simplify" t = t
   | args_of_rule _ _ = []
 
 fun insts_of_forall_inst "forall_inst" t = map (fn SMTLIB.S [_, SMTLIB.Sym x, a] => (x, a)) t
@@ -389,6 +427,13 @@
   let
     fun expand_assms cs =
       map (fn t => case AList.lookup (op =) cs t of NONE => t | SOME a => a)
+    fun match_typing_arguments (SMTLIB.S [SMTLIB.Sym var, typ as SMTLIB.Sym _] :: SMTLIB.S [SMTLIB.Sym "=", SMTLIB.Sym x1, x2] :: xs) =
+       if var = x1 then (*CVC5*)
+         SMTLIB.S [SMTLIB.Sym "=", SMTLIB.S [SMTLIB.Sym x1, typ], x2] :: match_typing_arguments xs
+       else
+         SMTLIB.S [SMTLIB.Sym var, typ] :: match_typing_arguments (SMTLIB.S [SMTLIB.Sym "=", SMTLIB.Sym x1, x2] :: xs)
+     | match_typing_arguments (a :: xs) = a :: match_typing_arguments xs
+     | match_typing_arguments [] = []
     fun expand_lonely_arguments (args as SMTLIB.S [SMTLIB.Sym "=", _, _]) = [args]
       | expand_lonely_arguments (x as SMTLIB.S [SMTLIB.Sym var, _]) = [SMTLIB.S [SMTLIB.Sym "=", x, SMTLIB.Sym var]]
 
@@ -399,7 +444,7 @@
           |> map
               (fn SMTLIB.S [SMTLIB.Key "=", x, y] => SMTLIB.S [SMTLIB.Sym "=", x, y]
               | x => x)
-          |> (rule = "bind" orelse rule = "onepoint") ? flat o (map expand_lonely_arguments)
+          |> (rule = "bind" orelse rule = "onepoint") ? flat o (map expand_lonely_arguments) o match_typing_arguments
           |> `(if rule = lethe_def then single o extract_skolem else K [])
           ||> SMTLIB.S
 
@@ -428,7 +473,7 @@
     |> single
  | extract_types_of_args (SMTLIB.S t) =
   let
-    fun extract_types_of_arg (SMTLIB.S [eq, SMTLIB.S [var, typ], t]) = (SMTLIB.S [eq, var, t], SOME typ)
+    fun extract_types_of_arg (SMTLIB.S [eq as SMTLIB.Sym "=", SMTLIB.S [var, typ], t]) = (SMTLIB.S [eq, var, t], SOME typ)
       | extract_types_of_arg t = (t, NONE)
   in
     t
@@ -439,6 +484,8 @@
   (if is_skolemization rule then map (fn id => id ^ lethe_def) (skolems_introduced_by_rule args) else []) @
   flat (map collect_skolem_defs subproof)
 
+val desymbolize = Name.desymbolize (SOME false) o perhaps (try (unprefix "?"))
+
 (*The postprocessing does:
   1. translate the terms to Isabelle syntax, taking care of free variables
   2. remove the ambiguity in the proof terms:
@@ -453,8 +500,6 @@
   let
     fun postprocess (Raw_Lethe_Node {id, rule, args, prems, declarations, concl, subproof}) (cx, rew) =
     let
-      val _ = (SMT_Config.verit_msg ctxt) (fn () => @{print} ("id =", id, "concl =", concl))
-
       val args = extract_types_of_args args
       val globally_bound_vars = declared_csts cx rule args
       val cx = fold (update_binding o (fn (s, typ) => (s, Term (Free (s, type_of cx typ)))))
@@ -486,13 +531,9 @@
       (* postprocess conclusion *)
       val concl = SMTLIB_Isar.unskolemize_names ctxt (subproof_rewriter concl)
 
-      val _ = (SMT_Config.verit_msg ctxt) (fn () => \<^print> ("id =", id, "concl =", concl))
-      val _ = (SMT_Config.verit_msg ctxt) (fn () => \<^print> ("id =", id, "cx' =", cx',
-        "bound_vars =", bound_vars))
-
-      val bound_tvars =
-          map (fn (s, SOME typ) => (s, type_of cx typ))
-            (shadowing_vars @ new_lhs_vars)
+      fun give_proper_type (s, SOME typ) = (s, type_of cx typ)
+       | give_proper_type (s, NONE) = raise (Fail ("could not find type of var " ^ @{make_string} s ^ " in step " ^ id ^ " in " ^  @{make_string} concl))
+      val bound_tvars = map give_proper_type (shadowing_vars @ new_lhs_vars)
       val subproof_cx =
          add_bound_variables_to_ctxt cx (shadowing_vars @ new_lhs_vars) cx
 
@@ -531,15 +572,18 @@
 
       (* postprocess arguments *)
       val rule_args = args_of_rule rule stripped_args
+
       val (termified_args, _) = fold_map term_of rule_args subproof_cx
       val normalized_args = map unsk_and_rewrite termified_args
+
       val rule_args = map subproof_rewriter normalized_args
 
       val raw_insts = insts_of_forall_inst rule stripped_args
       fun termify_term (x, t) cx = let val (t, cx) = term_of t cx in ((x, t), cx) end
       val (termified_args, _) = fold_map termify_term raw_insts subproof_cx
+
       val insts = Symtab.empty
-        |> fold (fn (x, t) => fn insts => Symtab.update_new (x, t) insts) termified_args
+        |> fold (fn (x, t) => fn insts => Symtab.update_new (desymbolize x, t) insts) termified_args
         |> Symtab.map (K unsk_and_rewrite)
 
       (* declarations *)
@@ -742,14 +786,16 @@
   fun import_proof_and_post_process typs funs lines ctxt =
     let
       val compress = SMT_Config.compress_verit_proofs ctxt
+
       val smtlib_lines_without_qm =
         lines
+        |> filter_out (fn x => x = "")
         |> map single
         |> map SMTLIB.parse
         |> map remove_all_qm2
         |> map remove_pattern
       val (raw_steps, _, _) =
-        parse_raw_proof_steps NONE smtlib_lines_without_qm SMTLIB_Proof.empty_name_binding
+        parse_raw_proof_steps NONE smtlib_lines_without_qm SMTLIB_Proof.empty_name_binding 0
 
       fun process step (cx, cx') =
         let fun postprocess step (cx, cx') =
@@ -779,7 +825,6 @@
 fun parse_replay typs funs lines ctxt =
   let
     val (u, env) = import_proof_and_post_process typs funs lines ctxt
-    val _ = (SMT_Config.verit_msg ctxt) (fn () => \<^print> u)
   in
     (u, ctxt_of env)
   end