consolidate nested thms with persistent result, for improved performance;
authorwenzelm
Fri, 16 Dec 2016 19:07:16 +0100
changeset 64574 1134e4d5e5b7
parent 64573 e6aee01da22d
child 64575 d44f0b714e13
consolidate nested thms with persistent result, for improved performance; always consolidate parts of fulfill_norm_proof: important to exhibit cyclic thms (via non-termination as officially published), but this was lost in f33d5a00c25d;
src/HOL/Library/Old_SMT/old_z3_proof_reconstruction.ML
src/HOL/SPARK/Tools/spark_vcs.ML
src/Pure/Thy/thy_info.ML
src/Pure/more_thm.ML
src/Pure/proofterm.ML
src/Pure/thm.ML
--- a/src/HOL/Library/Old_SMT/old_z3_proof_reconstruction.ML	Fri Dec 16 14:06:31 2016 +0100
+++ b/src/HOL/Library/Old_SMT/old_z3_proof_reconstruction.ML	Fri Dec 16 19:07:16 2016 +0100
@@ -167,7 +167,7 @@
       assms
       |> map (apsnd (rewrite ctxt eqs'))
       |> map (apsnd (Conv.fconv_rule Thm.eta_conversion))
-      |> Old_Z3_Proof_Tools.thm_net_of snd 
+      |> Old_Z3_Proof_Tools.thm_net_of snd
 
     fun revert_conv ctxt = rewrite_conv ctxt eqs' then_conv Thm.eta_conversion
 
@@ -183,14 +183,14 @@
           if exact then (Thm.implies_elim thm1 th, ctxt)
           else assume thm1 ctxt
         val thms' = if exact then thms else th :: thms
-      in 
+      in
         ((insert (op =) i is, thms'),
           (ctxt', Inttab.update (idx, Thm thm) ptab))
       end
 
     fun add (idx, ct) (cx as ((is, thms), (ctxt, ptab))) =
       let
-        val thm1 = 
+        val thm1 =
           Thm.trivial ct
           |> Conv.fconv_rule (Conv.arg1_conv (revert_conv outer_ctxt))
         val thm2 = singleton (Variable.export ctxt outer_ctxt) thm1
@@ -218,7 +218,7 @@
   val mp_c = precomp (Thm.dest_binop o Thm.dest_arg) @{thm mp}
 in
 fun mp (MetaEq thm) p = Thm (Thm.implies_elim (comp meta_iffD1_c thm) p)
-  | mp p_q p = 
+  | mp p_q p =
       let
         val pq = thm_of p_q
         val thm = comp iffD1_c pq handle THM _ => comp mp_c pq
@@ -509,7 +509,7 @@
     (case lookup (Logic.mk_equals (apply2 Thm.term_of cp)) of
       SOME eq => eq
     | NONE => if exn then raise MONO else prove_refl cp)
-  
+
   val prove_exn = prove_eq true
   and prove_safe = prove_eq false
 
@@ -752,7 +752,9 @@
 fun check_after idx r ps ct (p, (ctxt, _)) =
   if not (Config.get ctxt Old_SMT_Config.trace) then ()
   else
-    let val thm = thm_of p |> tap (Thm.join_proofs o single)
+    let
+      val thm = thm_of p
+      val _ = Thm.consolidate thm
     in
       if (Thm.cprop_of thm) aconvc ct then ()
       else
@@ -852,7 +854,7 @@
 
   fun discharge_assms_tac ctxt rules =
     REPEAT (HEADGOAL (resolve_tac ctxt rules ORELSE' SOLVED' (discharge_sk_tac ctxt)))
-    
+
   fun discharge_assms ctxt rules thm =
     if Thm.nprems_of thm = 0 then Goal.norm_result ctxt thm
     else
@@ -881,7 +883,7 @@
     if Config.get ctxt2 Old_SMT_Config.filter_only_facts then (is, @{thm TrueI})
     else
       (Thm @{thm TrueI}, cxp)
-      |> fold (prove simpset vars) steps 
+      |> fold (prove simpset vars) steps
       |> discharge rules outer_ctxt
       |> pair []
   end
--- a/src/HOL/SPARK/Tools/spark_vcs.ML	Fri Dec 16 14:06:31 2016 +0100
+++ b/src/HOL/SPARK/Tools/spark_vcs.ML	Fri Dec 16 19:07:16 2016 +0100
@@ -293,7 +293,7 @@
   | SOME _ => error ("Cannot associate a type with " ^ s ^
       "\nsince it is no record or enumeration type");
 
-fun check_enum [] [] = NONE 
+fun check_enum [] [] = NONE
   | check_enum els [] = SOME ("has no element(s) " ^ commas els)
   | check_enum [] cs = SOME ("has extra element(s) " ^
       commas (map (Long_Name.base_name o fst) cs))
@@ -305,7 +305,7 @@
 fun invert_map [] = I
   | invert_map cmap =
       map (apfst (the o AList.lookup (op =) (map swap cmap)));
- 
+
 fun add_type_def prfx (s, Basic_Type ty) (ids, thy) =
       (check_no_assoc thy prfx s;
        (ids,
@@ -677,7 +677,7 @@
    "+", "-", "*", "/", "div", "mod", "**"]);
 
 fun complex_expr (Number _) = false
-  | complex_expr (Ident _) = false 
+  | complex_expr (Ident _) = false
   | complex_expr (Funct (s, es)) =
       not (Symtab.defined builtin s) orelse exists complex_expr es
   | complex_expr (Quantifier (_, _, _, e)) = complex_expr e
@@ -959,7 +959,7 @@
     | SOME {vcs, path, ...} =>
         let
           val (proved, unproved) = partition_vcs vcs;
-          val _ = Thm.join_proofs (maps (#2 o snd) proved);
+          val _ = List.app Thm.consolidate (maps (#2 o snd) proved);
           val (proved', proved'') =
             List.partition (fn (_, (_, thms, _, _)) =>
               exists (#oracle o Thm.peek_status) thms) proved;
@@ -1117,7 +1117,7 @@
           [(term_of_rule thy' prfx types pfuns ids rl, [])]))
            other_rules),
        Element.Notes ("", [((Binding.name "defns", []), map (rpair [] o single o snd) defs')])]
-          
+
   in
     set_env ctxt defs' types funs ids vcs' path prfx thy'
   end;
--- a/src/Pure/Thy/thy_info.ML	Fri Dec 16 14:06:31 2016 +0100
+++ b/src/Pure/Thy/thy_info.ML	Fri Dec 16 19:07:16 2016 +0100
@@ -159,7 +159,7 @@
     (*toplevel proofs and diags*)
     val _ = Future.join_tasks (maps Future.group_snapshot (Execution.peek exec_id));
     (*fully nested proofs*)
-    val res = Exn.capture Thm.join_theory_proofs theory;
+    val res = Exn.capture Thm.consolidate_theory theory;
   in res :: map Exn.Exn (maps Task_Queue.group_status (Execution.peek exec_id)) end;
 
 datatype task =
--- a/src/Pure/more_thm.ML	Fri Dec 16 14:06:31 2016 +0100
+++ b/src/Pure/more_thm.ML	Fri Dec 16 19:07:16 2016 +0100
@@ -111,7 +111,7 @@
   val untag: string -> attribute
   val kind: string -> attribute
   val register_proofs: thm list -> theory -> theory
-  val join_theory_proofs: theory -> unit
+  val consolidate_theory: theory -> unit
   val show_consts_raw: Config.raw
   val show_consts: bool Config.T
   val show_hyps_raw: Config.raw
@@ -644,8 +644,8 @@
 fun register_proofs more_thms =
   Proofs.map (fold (cons o Thm.trim_context) more_thms);
 
-fun join_theory_proofs thy =
-  Thm.join_proofs (map (Thm.transfer thy) (rev (Proofs.get thy)));
+fun consolidate_theory thy =
+  List.app (Thm.consolidate o Thm.transfer thy) (rev (Proofs.get thy));
 
 
 
--- a/src/Pure/proofterm.ML	Fri Dec 16 14:06:31 2016 +0100
+++ b/src/Pure/proofterm.ML	Fri Dec 16 19:07:16 2016 +0100
@@ -46,7 +46,7 @@
   val fold_proof_atoms: bool -> (proof -> 'a -> 'a) -> proof list -> 'a -> 'a
   val fold_body_thms: ({serial: serial, name: string, prop: term, body: proof_body} -> 'a -> 'a) ->
     proof_body list -> 'a -> 'a
-  val join_bodies: proof_body list -> unit
+  val consolidate: proof_body -> unit
   val peek_status: proof_body list -> {failed: bool, oracle: bool, unfinished: bool}
 
   val oracle_ord: oracle * oracle -> order
@@ -182,7 +182,8 @@
   {oracles: (string * term) Ord_List.T,
    thms: (serial * thm_node) Ord_List.T,
    proof: proof}
-and thm_node = Thm_Node of string * term * proof_body future;
+and thm_node =
+  Thm_Node of {name: string, prop: term, body: proof_body future, consolidate: unit lazy};
 
 type oracle = string * term;
 type pthm = serial * thm_node;
@@ -190,12 +191,24 @@
 fun proof_of (PBody {proof, ...}) = proof;
 val join_proof = Future.join #> proof_of;
 
-fun thm_node_name (Thm_Node (name, _, _)) = name;
-fun thm_node_prop (Thm_Node (_, prop, _)) = prop;
-fun thm_node_body (Thm_Node (_, _, body)) = body;
+fun rep_thm_node (Thm_Node args) = args;
+val thm_node_name = #name o rep_thm_node;
+val thm_node_prop = #prop o rep_thm_node;
+val thm_node_body = #body o rep_thm_node;
+val thm_node_consolidate = #consolidate o rep_thm_node;
 
 fun join_thms (thms: pthm list) =
-  ignore (Future.joins (map (fn (_, Thm_Node (_, _, body)) => body) thms));
+  Future.joins (map (thm_node_body o #2) thms);
+
+fun consolidate (PBody {thms, ...}) =
+  List.app (Lazy.force o thm_node_consolidate o #2) thms;
+
+fun make_thm_node name prop body =
+  Thm_Node {name = name, prop = prop, body = body,
+    consolidate =
+      Lazy.lazy (fn () =>
+        let val PBody {thms, ...} = Future.join body
+        in List.app consolidate (join_thms thms) end)};
 
 
 (***** proof atoms *****)
@@ -218,27 +231,27 @@
 fun fold_body_thms f =
   let
     fun app (PBody {thms, ...}) =
-      tap join_thms thms |> fold (fn (i, Thm_Node (name, prop, body)) => fn (x, seen) =>
+      tap join_thms thms |> fold (fn (i, thm_node) => fn (x, seen) =>
         if Inttab.defined seen i then (x, seen)
         else
           let
-            val body' = Future.join body;
-            val (x', seen') = app body' (x, Inttab.update (i, ()) seen);
-          in (f {serial = i, name = name, prop = prop, body = body'} x', seen') end);
+            val name = thm_node_name thm_node;
+            val prop = thm_node_prop thm_node;
+            val body = Future.join (thm_node_body thm_node);
+            val (x', seen') = app body (x, Inttab.update (i, ()) seen);
+          in (f {serial = i, name = name, prop = prop, body = body} x', seen') end);
   in fn bodies => fn x => #1 (fold app bodies (x, Inttab.empty)) end;
 
-fun join_bodies bodies = fold_body_thms (fn _ => fn () => ()) bodies ();
-
 fun peek_status bodies =
   let
     fun status (PBody {oracles, thms, ...}) x =
       let
         val ((oracle, unfinished, failed), seen) =
-          (thms, x) |-> fold (fn (i, Thm_Node (_, _, body)) => fn (st, seen) =>
+          (thms, x) |-> fold (fn (i, thm_node) => fn (st, seen) =>
             if Inttab.defined seen i then (st, seen)
             else
               let val seen' = Inttab.update (i, ()) seen in
-                (case Future.peek body of
+                (case Future.peek (thm_node_body thm_node) of
                   SOME (Exn.Res body') => status body' (st, seen')
                 | SOME (Exn.Exn _) =>
                     let val (oracle, unfinished, _) = st
@@ -264,12 +277,12 @@
 val all_oracles_of =
   let
     fun collect (PBody {oracles, thms, ...}) =
-      tap join_thms thms |> fold (fn (i, Thm_Node (_, _, body)) => fn (x, seen) =>
+      tap join_thms thms |> fold (fn (i, thm_node) => fn (x, seen) =>
         if Inttab.defined seen i then (x, seen)
         else
           let
-            val body' = Future.join body;
-            val (x', seen') = collect body' (x, Inttab.update (i, ()) seen);
+            val body = Future.join (thm_node_body thm_node);
+            val (x', seen') = collect body (x, Inttab.update (i, ()) seen);
           in (if null oracles then x' else oracles :: x', seen') end);
   in fn body => unions_oracles (#1 (collect body ([], Inttab.empty))) end;
 
@@ -277,7 +290,7 @@
   let
     val (oracles, thms) = fold_proof_atoms false
       (fn Oracle (s, prop, _) => apfst (cons (s, prop))
-        | PThm (i, ((name, prop, _), body)) => apsnd (cons (i, Thm_Node (name, prop, body)))
+        | PThm (i, ((name, prop, _), body)) => apsnd (cons (i, make_thm_node name prop body))
         | _ => I) [prf] ([], []);
   in
     PBody
@@ -321,8 +334,9 @@
     ([int_atom a, b], triple term (option (list typ)) proof_body (c, d, Future.join body))]
 and proof_body (PBody {oracles, thms, proof = prf}) =
   triple (list (pair string term)) (list pthm) proof (oracles, thms, prf)
-and pthm (a, Thm_Node (b, c, body)) =
-  pair int (triple string term proof_body) (a, (b, c, Future.join body));
+and pthm (a, thm_node) =
+  pair int (triple string term proof_body)
+    (a, (thm_node_name thm_node, thm_node_prop thm_node, Future.join (thm_node_body thm_node)));
 
 in
 
@@ -358,7 +372,7 @@
   in PBody {oracles = a, thms = b, proof = c} end
 and pthm x =
   let val (a, (b, c, d)) = pair int (triple string term proof_body) x
-  in (a, Thm_Node (b, c, Future.value d)) end;
+  in (a, make_thm_node b c (Future.value d)) end;
 
 in
 
@@ -1519,6 +1533,8 @@
 
 fun fulfill_norm_proof thy ps body0 =
   let
+    val _ = List.app (consolidate o #2) ps;
+    val _ = consolidate body0;
     val PBody {oracles = oracles0, thms = thms0, proof = proof0} = body0;
     val oracles =
       unions_oracles
@@ -1616,7 +1632,7 @@
           else new_prf ()
       | _ => new_prf ());
     val head = PThm (i, ((name, prop1, NONE), body'));
-  in ((i, Thm_Node (name, prop1, body')), head, args, argsP, args1) end;
+  in ((i, make_thm_node name prop1 body'), head, args, argsP, args1) end;
 
 fun thm_proof thy name shyps hyps concl promises body =
   let val (pthm, head, args, argsP, _) = prepare_thm_proof thy name shyps hyps concl promises body
--- a/src/Pure/thm.ML	Fri Dec 16 14:06:31 2016 +0100
+++ b/src/Pure/thm.ML	Fri Dec 16 19:07:16 2016 +0100
@@ -86,7 +86,7 @@
   val proof_bodies_of: thm list -> proof_body list
   val proof_body_of: thm -> proof_body
   val proof_of: thm -> proof
-  val join_proofs: thm list -> unit
+  val consolidate: thm -> unit
   val peek_status: thm -> {oracle: bool, unfinished: bool, failed: bool}
   val future: thm future -> cterm -> thm
   val derivation_closed: thm -> bool
@@ -588,17 +588,11 @@
   let val fulfilled_promises = map #1 promises ~~ map fulfill_body (Future.joins (map #2 promises))
   in Proofterm.fulfill_norm_proof (theory_of_thm th) fulfilled_promises body end;
 
-fun proof_bodies_of thms =
-  let
-    val _ = join_promises_of thms;
-    val bodies = map fulfill_body thms;
-    val _ = Proofterm.join_bodies bodies;
-  in bodies end;
-
+fun proof_bodies_of thms = (join_promises_of thms; map fulfill_body thms);
 val proof_body_of = singleton proof_bodies_of;
 val proof_of = Proofterm.proof_of o proof_body_of;
 
-val join_proofs = Proofterm.join_bodies o proof_bodies_of;
+val consolidate = ignore o proof_bodies_of o single;
 
 
 (* derivation status *)