src/HOL/Matrix_LP/Compute_Oracle/compute.ML
changeset 77863 760515c45864
parent 77808 b43ee37926a9
child 77869 1156aa9db7f5
--- a/src/HOL/Matrix_LP/Compute_Oracle/compute.ML	Sat Apr 15 23:11:08 2023 +0200
+++ b/src/HOL/Matrix_LP/Compute_Oracle/compute.ML	Mon Apr 17 23:32:46 2023 +0200
@@ -16,7 +16,7 @@
     val make : machine -> theory -> thm list -> computer
     val make_with_cache : machine -> theory -> term list -> thm list -> computer
     val theory_of : computer -> theory
-    val hyps_of : computer -> Termset.T
+    val hyps_of : computer -> term list
     val shyps_of : computer -> Sortset.T
     (* ! *) val update : computer -> thm list -> unit
     (* ! *) val update_with_cache : computer -> term list -> thm list -> unit
@@ -169,7 +169,7 @@
 fun default_naming i = "v_" ^ string_of_int i
 
 datatype computer = Computer of
-  (theory * Encode.encoding * Termset.T * Sortset.T * prog * unit Unsynchronized.ref * naming)
+  (theory * Encode.encoding * term list * Sortset.T * prog * unit Unsynchronized.ref * naming)
     option Unsynchronized.ref
 
 fun theory_of (Computer (Unsynchronized.ref (SOME (thy,_,_,_,_,_,_)))) = thy
@@ -187,7 +187,7 @@
 fun ref_of (Computer r) = r
 
 
-datatype cthm = ComputeThm of Termset.T * Sortset.T * term
+datatype cthm = ComputeThm of term list * Sortset.T * term
 
 fun thm2cthm th = 
     (if not (null (Thm.tpairs_of th)) then raise Make "theorems may not contain tpairs" else ();
@@ -219,10 +219,10 @@
                     (n, vars, AbstractMachine.PConst (c, args@[pb]))
             end
 
-        fun thm2rule (encoding, hypset, shypset) th =
+        fun thm2rule (encoding, hyptable, shypset) th =
             let
                 val (ComputeThm (hyps, shyps, prop)) = th
-                val hypset = Termset.merge (hyps, hypset)
+                val hyptable = fold (fn h => Termtab.update (h, ())) hyps hyptable
                 val shypset = Sortset.merge (shyps, shypset)
                 val (prems, prop) = (Logic.strip_imp_prems prop, Logic.strip_imp_concl prop)
                 val (a, b) = Logic.dest_equals prop
@@ -269,15 +269,15 @@
                 fun rename_guard (AbstractMachine.Guard (a,b)) = 
                     AbstractMachine.Guard (rename 0 vars a, rename 0 vars b)
             in
-                ((encoding, hypset, shypset), (map rename_guard prems, pattern, rename 0 vars right))
+                ((encoding, hyptable, shypset), (map rename_guard prems, pattern, rename 0 vars right))
             end
 
-        val ((encoding, hypset, shypset), rules) =
-          fold_rev (fn th => fn (encoding_hypset, rules) =>
+        val ((encoding, hyptable, shypset), rules) =
+          fold_rev (fn th => fn (encoding_hyptable, rules) =>
             let
-              val (encoding_hypset, rule) = thm2rule encoding_hypset th
-            in (encoding_hypset, rule::rules) end)
-          ths ((encoding, Termset.empty, Sortset.empty), [])
+              val (encoding_hyptable, rule) = thm2rule encoding_hyptable th
+            in (encoding_hyptable, rule::rules) end)
+          ths ((encoding, Termtab.empty, Sortset.empty), [])
 
         fun make_cache_pattern t (encoding, cache_patterns) =
             let
@@ -287,7 +287,7 @@
                 (encoding, p::cache_patterns)
             end
         
-        val (encoding, _) = Termset.fold_rev make_cache_pattern cache_pattern_terms (encoding, [])
+        val (encoding, _) = fold_rev make_cache_pattern cache_pattern_terms (encoding, [])
 
         val prog = 
             case machine of 
@@ -300,18 +300,17 @@
 
         val shypset = Sortset.fold (fn s => has_witness s ? Sortset.remove s) shypset shypset
 
-    in (thy, encoding, hypset, shypset, prog, stamp, default_naming) end
+    in (thy, encoding, Termtab.keys hyptable, shypset, prog, stamp, default_naming) end
 
 fun make_with_cache machine thy cache_patterns raw_thms =
-  Computer (Unsynchronized.ref
-    (SOME (make_internal machine thy (Unsynchronized.ref ()) Encode.empty (Termset.make cache_patterns) raw_thms)))
+  Computer (Unsynchronized.ref (SOME (make_internal machine thy (Unsynchronized.ref ()) Encode.empty cache_patterns raw_thms)))
 
 fun make machine thy raw_thms = make_with_cache machine thy [] raw_thms
 
 fun update_with_cache computer cache_patterns raw_thms =
     let 
         val c = make_internal (machine_of_prog (prog_of computer)) (theory_of computer) (stamp_of computer) 
-                              (encoding_of computer) (Termset.make cache_patterns) raw_thms
+                              (encoding_of computer) cache_patterns raw_thms
         val _ = (ref_of computer) := SOME c     
     in
         ()
@@ -328,6 +327,13 @@
 (* An oracle for exporting theorems; must only be accessible from inside this structure! *)
 (* ------------------------------------------------------------------------------------- *)
 
+fun merge_hyps hyps1 hyps2 = 
+let
+    fun add hyps tab = fold (fn h => fn tab => Termtab.update (h, ()) tab) hyps tab
+in
+    Termtab.keys (add hyps2 (add hyps1 Termtab.empty))
+end
+
 val (_, export_oracle) = Context.>>> (Context.map_theory_result
   (Thm.add_oracle (\<^binding>\<open>compute\<close>, fn (thy, hyps, shyps, prop) =>
     let
@@ -368,7 +374,7 @@
         val t = infer_types naming encoding ty t
         val eq = Logic.mk_equals (t', t)
     in
-        export_thm thy (Termset.dest (hyps_of computer)) (shyps_of computer) eq
+        export_thm thy (hyps_of computer) (shyps_of computer) eq
     end
 
 (* --------- Simplify ------------ *)
@@ -376,7 +382,7 @@
 datatype prem = EqPrem of AbstractMachine.term * AbstractMachine.term * Term.typ * int 
               | Prem of AbstractMachine.term
 datatype theorem = Theorem of theory * unit Unsynchronized.ref * (int * typ) Symtab.table * (AbstractMachine.term option) Inttab.table  
-               * prem list * AbstractMachine.term * Termset.T * Sortset.T
+               * prem list * AbstractMachine.term * term list * Sortset.T
 
 
 exception ParamSimplify of computer * theorem
@@ -607,7 +613,7 @@
         let
             val th = update_varsubst varsubst th
             val th = update_prems (splicein prem_no (prems_of_theorem th') prems) th
-            val th = update_hyps (Termset.merge (hyps_of_theorem th, hyps_of_theorem th')) th
+            val th = update_hyps (merge_hyps (hyps_of_theorem th) (hyps_of_theorem th')) th
             val th = update_shyps (Sortset.merge (shyps_of_theorem th, shyps_of_theorem th')) th
         in
             update_theory thy th
@@ -624,10 +630,10 @@
     fun run t = infer (runprog (prog_of computer) (apply_subst true varsubst t))
     fun runprem p = run (prem2term p)
     val prop = Logic.list_implies (map runprem (prems_of_theorem th), run (concl_of_theorem th))
-    val hyps = Termset.merge (hyps_of computer, hyps_of_theorem th)
+    val hyps = merge_hyps (hyps_of computer) (hyps_of_theorem th)
     val shyps = Sortset.merge (shyps_of_theorem th, shyps_of computer)
 in
-    export_thm (theory_of_theorem th) (Termset.dest hyps) shyps prop
+    export_thm (theory_of_theorem th) hyps shyps prop
 end
 
 end