src/HOL/Tools/SMT/verit_proof.ML
changeset 74403 dbd69d287ec6
parent 74382 8d0294d877bd
child 74561 8e6c973003c8
--- a/src/HOL/Tools/SMT/verit_proof.ML	Fri Oct 01 12:45:47 2021 +0200
+++ b/src/HOL/Tools/SMT/verit_proof.ML	Fri Oct 01 22:35:32 2021 +0200
@@ -296,61 +296,49 @@
      Const(\<^const_name>\<open>Trueprop\<close>, T) $ (synctatic_rew_in_lhs_subst old_name new_name t1)
   | synctatic_rew_in_lhs_subst _ _ t = t
 
-fun add_bound_variables_to_ctxt concl =
+fun add_bound_variables_to_ctxt cx =
   fold (update_binding o
-    (fn s => (s, Term (Free (s, the_default dummyT (find_type_in_formula concl s))))))
+    (fn (s, SOME typ) => (s, Term (Free (s, type_of cx typ)))))
 
 local
 
-  fun remove_Sym (SMTLIB.Sym y) = y
-    | remove_Sym y = (@{print} y; raise (Fail "failed to match"))
-
   fun extract_symbols bds =
     bds
-    |> map (fn SMTLIB.S [SMTLIB.Sym "=", SMTLIB.Sym x, SMTLIB.Sym y] => [[x, y]]
-             | SMTLIB.S [SMTLIB.Key "=", SMTLIB.Sym x, SMTLIB.Sym y] => [[x, y]]
-             | SMTLIB.S syms => map (single o remove_Sym) syms
-             | SMTLIB.Sym x => [[x]]
+    |> map (fn (SMTLIB.S [SMTLIB.Sym "=", SMTLIB.Sym x, SMTLIB.Sym y], typ) => [([x, y], typ)]
              | t => raise (Fail ("match error " ^ @{make_string} t)))
     |> flat
 
   (* onepoint can bind a variable to another variable or to a constant *)
   fun extract_qnt_symbols cx bds =
     bds
-    |> map (fn SMTLIB.S [SMTLIB.Sym "=", SMTLIB.Sym x, SMTLIB.Sym y] =>
-                (case node_of (SMTLIB.Sym y) cx  of
-                  ((_, []), _) => [[x]]
-                | _ => [[x, y]])
-             | SMTLIB.S [SMTLIB.Key "=", SMTLIB.Sym x, SMTLIB.Sym y] =>
-                (case node_of (SMTLIB.Sym y) cx  of
-                  ((_, []), _) => [[x]]
-                | _ => [[x, y]])
-             | SMTLIB.S (SMTLIB.Sym "=" :: SMTLIB.Sym x :: _) => [[x]]
-             | SMTLIB.S (SMTLIB.Key "=" :: SMTLIB.Sym x :: _) => [[x]]
-             | SMTLIB.S syms => map (single o remove_Sym) syms
-             | SMTLIB.Sym x => [[x]]
+    |> map (fn (SMTLIB.S [SMTLIB.Sym "=", SMTLIB.Sym x, SMTLIB.Sym y], typ) =>
+                (case node_of (SMTLIB.Sym y) cx of
+                  ((_, []), _) => [([x], typ)]
+                | _ => [([x, y], typ)])
+             | (SMTLIB.S (SMTLIB.Sym "=" :: SMTLIB.Sym x :: _), typ) => [([x], typ)]
              | t => raise (Fail ("match error " ^ @{make_string} t)))
     |> flat
 
   fun extract_symbols_map bds =
     bds
-    |> map (fn SMTLIB.S [SMTLIB.Sym "=", SMTLIB.Sym x, _] => [[x]]
-             | SMTLIB.S syms =>  map (single o remove_Sym) syms)
+    |> map (fn (SMTLIB.S [SMTLIB.Sym "=", SMTLIB.Sym x, _], typ) => [([x], typ)])
     |> flat
 in
 
-fun declared_csts _ "__skolem_definition" (SMTLIB.S [SMTLIB.Sym x, typ, _]) = [(x, typ)]
+fun declared_csts _ "__skolem_definition" [(SMTLIB.S [SMTLIB.Sym x, typ, _], _)] = [(x, typ)]
+  | declared_csts _ "__skolem_definition" t = raise (Fail ("unrecognized skolem_definition " ^ @{make_string} t))
   | declared_csts _ _ _ = []
 
 fun skolems_introduced_by_rule (SMTLIB.S bds) =
-   fold (fn (SMTLIB.S [SMTLIB.Sym "=", SMTLIB.Sym _, SMTLIB.Sym y]) => curry (op ::) y) bds []
+   fold (fn (SMTLIB.S [SMTLIB.Sym "=", _, SMTLIB.Sym y]) => curry (op ::) y) bds []
 
 (*FIXME there is probably a way to use the information given by onepoint*)
-fun bound_vars_by_rule _ "bind" (SMTLIB.S bds) = extract_symbols bds
-  | bound_vars_by_rule cx "onepoint" (SMTLIB.S bds) = extract_qnt_symbols cx bds
-  | bound_vars_by_rule _ "sko_forall" (SMTLIB.S bds) = extract_symbols_map bds
-  | bound_vars_by_rule _ "sko_ex" (SMTLIB.S bds) = extract_symbols_map bds
-  | bound_vars_by_rule _ "__skolem_definition" (SMTLIB.S [SMTLIB.Sym x, _, _]) = [[x]]
+fun bound_vars_by_rule _ "bind" (bds) = extract_symbols bds
+  | bound_vars_by_rule cx "onepoint" bds = extract_qnt_symbols cx bds
+  | bound_vars_by_rule _ "sko_forall" bds = extract_symbols_map bds
+  | bound_vars_by_rule _ "sko_ex" bds = extract_symbols_map bds
+  | bound_vars_by_rule _ "__skolem_definition" [(SMTLIB.S [SMTLIB.Sym x, typ, _], _)] = [([x], SOME typ)]
+  | bound_vars_by_rule _ "__skolem_definition" [(SMTLIB.S [_, SMTLIB.Sym x, _], _)] = [([x], NONE)]
   | bound_vars_by_rule _ _ _ = []
 
 (* VeriT adds "?" before some variables. *)
@@ -515,23 +503,19 @@
     fun expand_assms cs =
       map (fn t => case AList.lookup (op =) cs t of NONE => t | SOME a => a)
     fun expand_lonely_arguments (args as SMTLIB.S [SMTLIB.Sym "=", _, _]) = [args]
-      | expand_lonely_arguments (SMTLIB.S S) = map (fn x => SMTLIB.S [SMTLIB.Sym "=", x, x]) S
-      | expand_lonely_arguments (x as SMTLIB.Sym _) = [SMTLIB.S [SMTLIB.Sym "=", x, x]]
-      | expand_lonely_arguments t = [t]
+      | expand_lonely_arguments (x as SMTLIB.S [SMTLIB.Sym var, _]) = [SMTLIB.S [SMTLIB.Sym "=", x, SMTLIB.Sym var]]
 
     fun preprocess (Raw_VeriT_Node {id, rule, args, prems, concl, subproof, ...}) (cx, remap_assms)  =
       let
-        val stripped_args = args
+        val (skolem_names, stripped_args) = args
           |> (fn SMTLIB.S args => args)
           |> map
               (fn SMTLIB.S [SMTLIB.Key "=", x, y] => SMTLIB.S [SMTLIB.Sym "=", x, y]
-              | x => x)
+                | x => x)
           |> (rule = "bind" orelse rule = "onepoint") ? flat o (map expand_lonely_arguments)
           |> `(if rule = veriT_def then single o extract_skolem else K [])
           ||> SMTLIB.S
-
         val (subproof, (cx, _)) = fold_map preprocess subproof (cx, remap_assms) |> apfst flat
-        val (skolem_names, stripped_args) = stripped_args
         val remap_assms = (if rule = "or" then (id, hd prems) :: remap_assms else remap_assms)
         (* declare variables in the context *)
         val declarations =
@@ -555,6 +539,18 @@
   (if is_skolemization rule then map (fn id => id ^ veriT_def) (skolems_introduced_by_rule args) else []) @
   flat (map collect_skolem_defs subproof)
 
+fun extract_types_of_args (SMTLIB.S [var, typ, t as SMTLIB.S [SMTLIB.Sym "choice", _, _]]) =
+    (SMTLIB.S [var, typ, t], SOME typ)
+    |> single
+ | extract_types_of_args (SMTLIB.S t) =
+  let
+    fun extract_types_of_arg (SMTLIB.S [eq, SMTLIB.S [var, typ], t]) = (SMTLIB.S [eq, var, t], SOME typ)
+      | extract_types_of_arg t = (t, NONE)
+  in
+    t
+    |> map extract_types_of_arg
+  end
+
 (*The postprocessing does:
   1. translate the terms to Isabelle syntax, taking care of free variables
   2. remove the ambiguity in the proof terms:
@@ -571,29 +567,32 @@
     let
       val _ = (SMT_Config.verit_msg ctxt) (fn () => @{print} ("id =", id, "concl =", concl))
 
+      val (args) = extract_types_of_args args
       val globally_bound_vars = declared_csts cx rule args
       val cx = fold (update_binding o (fn (s, typ) => (s, Term (Free (s, type_of cx typ)))))
            globally_bound_vars cx
 
       (*find rebound variables specific to the LHS of the equivalence symbol*)
       val bound_vars = bound_vars_by_rule cx rule args
-
-      val rhs_vars = fold (fn [t', t] => t <> t' ? (curry (op ::) t) | _ => fn x => x) bound_vars []
+      val bound_vars_no_typ = map fst bound_vars
+      val rhs_vars =
+        fold (fn [t', t] => t <> t' ? (curry (op ::) t) | _ => fn x => x) bound_vars_no_typ []
       fun not_already_bound cx t = SMTLIB_Proof.lookup_binding cx t = None andalso
           not (member (op =) rhs_vars t)
       val (shadowing_vars, rebound_lhs_vars) = bound_vars
-        |> filter_split (fn [t, _] => not_already_bound cx t | _ => true)
-        |>> map (single o hd)
-        |>> (fn vars => vars @ map (fn [_, t] => [t] | _ => []) bound_vars)
-        |>> flat
+        |> filter_split (fn ([t, _], typ) => not_already_bound cx t | _ => true)
+        |>> map (apfst (hd))
+        |>> (fn vars => vars @ flat (map (fn ([_, t], typ) => [(t, typ)] | _ => []) bound_vars))
       val subproof_rew = fold (fn [t, t'] => curry (op ::) (t, t ^ t'))
-        rebound_lhs_vars rew
+        (map fst rebound_lhs_vars) rew
       val subproof_rewriter = fold (fn (t, t') => synctatic_rew_in_lhs_subst t t')
          subproof_rew
 
       val ((concl, bounds), cx') = node_of concl cx
 
-      val extra_lhs_vars = map (fn [a,b] => (a, a^b)) rebound_lhs_vars
+      val extra_lhs_vars = map (fn ([a,b], typ) => (a, a^b, typ)) rebound_lhs_vars
+      val old_lhs_vars = map (fn (a, _, typ) => (a, typ)) extra_lhs_vars
+      val new_lhs_vars = map (fn (_, newvar, typ) => (newvar, typ)) extra_lhs_vars
 
       (* postprocess conclusion *)
       val concl = SMTLIB_Isar.unskolemize_names ctxt (subproof_rewriter concl)
@@ -603,10 +602,10 @@
         "bound_vars =", bound_vars))
 
       val bound_tvars =
-        map (fn s => (s, the (find_type_in_formula concl s)))
-         (shadowing_vars @ map snd extra_lhs_vars)
+        map (fn (s, SOME typ) => (s, type_of cx typ))
+         (shadowing_vars @ new_lhs_vars)
       val subproof_cx =
-         add_bound_variables_to_ctxt concl (shadowing_vars @ map snd extra_lhs_vars) cx
+         add_bound_variables_to_ctxt cx (shadowing_vars @ new_lhs_vars) cx
 
       fun could_unify (Bound i, Bound j) = i = j
         | could_unify (Var v, Var v') = v = v'
@@ -632,12 +631,10 @@
       val unsk_and_rewrite = SMTLIB_Isar.unskolemize_names ctxt o subproof_rewriter
 
       (* postprocess assms *)
-      val stripped_args = args |> (fn SMTLIB.S S => S)
+      val stripped_args = map fst args
       val sanitized_args = proof_ctxt_of_rule rule stripped_args
 
-      val arg_cx =
-        subproof_cx
-        |> add_bound_variables_to_ctxt concl (shadowing_vars @ map fst extra_lhs_vars)
+      val arg_cx = add_bound_variables_to_ctxt cx (shadowing_vars @ old_lhs_vars) subproof_cx
       val (termified_args, _) = fold_map node_of sanitized_args arg_cx |> apfst (map fst)
       val normalized_args = map unsk_and_rewrite termified_args
 
@@ -661,11 +658,9 @@
         |> apfst (map (apsnd unsk_and_rewrite))
 
       (* fix step *)
-      val bound_t = bounds
-        |> map (fn s => (s, the (find_type_in_formula concl s)))
-
+      val _ = if bounds <> [] then raise (Fail "found dangling variable in concl") else ()
       val skolem_defs = (if is_skolemization rule
-         then map (fn id => id ^ veriT_def) (skolems_introduced_by_rule args) else [])
+         then map (fn id => id ^ veriT_def) (skolems_introduced_by_rule (SMTLIB.S (map fst args))) else [])
       val skolems_of_subproof = (if is_skolemization rule
          then flat (map collect_skolem_defs subproof) else [])
       val fixed_prems =
@@ -680,7 +675,7 @@
         (if rule = subproof_rule then extract_assumptions_from_subproof fixed_subproof else [])
 
       val step = mk_replay_node id normalized_rule rule_args fixed_prems subproof_assms concl
-        bound_t insts declarations (bound_tvars, subproof_assms, extra_assms2, fixed_subproof)
+        [] insts declarations (bound_tvars, subproof_assms, extra_assms2, fixed_subproof)
 
     in
        (step, (cx', rew))