use Cache structure instead of passing tables around explicitly
authorkrauss
Tue, 05 Oct 2010 14:19:38 +0200
changeset 39923 0e1bd289c8ea
parent 39922 5a8aeeb2e63f
child 39924 f4d3e70ed3a8
use Cache structure instead of passing tables around explicitly
src/HOL/Tools/Function/scnp_reconstruct.ML
src/HOL/Tools/Function/termination.ML
--- a/src/HOL/Tools/Function/scnp_reconstruct.ML	Tue Oct 05 11:45:16 2010 +0200
+++ b/src/HOL/Tools/Function/scnp_reconstruct.ML	Tue Oct 05 14:19:38 2010 +0200
@@ -383,7 +383,7 @@
     val strategy = Then (Then first_round second_round) third_round
 
   in
-    TERMINATION ctxt (strategy err_cont err_cont)
+    TERMINATION ctxt autom_tac (strategy err_cont err_cont)
   end
 
 fun gen_sizechange_tac orders autom_tac ctxt err_cont =
--- a/src/HOL/Tools/Function/termination.ML	Tue Oct 05 11:45:16 2010 +0200
+++ b/src/HOL/Tools/Function/termination.ML	Tue Oct 05 14:19:38 2010 +0200
@@ -28,7 +28,7 @@
   (* Termination tactics. Sequential composition via continuations. (2nd argument is the error continuation) *)
   type ttac = (data -> int -> tactic) -> (data -> int -> tactic) -> data -> int -> tactic
 
-  val TERMINATION : Proof.context -> (data -> int -> tactic) -> int -> tactic
+  val TERMINATION : Proof.context -> tactic -> (data -> int -> tactic) -> int -> tactic
 
   val REPEAT : ttac -> ttac
 
@@ -126,15 +126,15 @@
   skel                            (* structure of the sum type encoding "program points" *)
   * (int -> typ)                  (* types of program points *)
   * (term list Inttab.table)      (* measures for program points *)
-  * (thm option Term2tab.table)   (* which calls form chains? *)
-  * (cell Term3tab.table)         (* local descents *)
+  * (term * term -> thm option)   (* which calls form chains? (cached) *)
+  * (term * (term * term) -> cell)(* local descents (cached) *)
 
 
 fun map_chains f (p, T, M, C, D) = (p, T, M, f C, D)
 fun map_descent f (p, T, M, C, D) = (p, T, M, C, f D)
 
-fun note_chain c1 c2 res = map_chains (Term2tab.update ((c1, c2), res))
-fun note_descent c m1 m2 res = map_descent (Term3tab.update ((c,(m1, m2)), res))
+fun note_chain c1 c2 res D = D     (*disabled*)
+fun note_descent c m1 m2 res D = D
 
 (* Build case expression *)
 fun mk_sumcases (sk, _, _, _, _) T fs =
@@ -155,33 +155,21 @@
     mk_skel (fold collect_pats cs [])
   end
 
-fun create ctxt T rel =
+fun prove_chain thy chain_tac (c1, c2) =
   let
-    val sk = mk_sum_skel rel
-    val Ts = node_types sk T
-    val M = Inttab.make (map_index (apsnd (MeasureFunctions.get_measure_functions ctxt)) Ts)
+    val goal =
+      HOLogic.mk_eq (HOLogic.mk_binop @{const_name Relation.rel_comp} (c1, c2),
+        Const (@{const_abbrev Set.empty}, fastype_of c1))
+      |> HOLogic.mk_Trueprop (* "C1 O C2 = {}" *)
   in
-    (sk, nth Ts, M, Term2tab.empty, Term3tab.empty)
+    case Function_Lib.try_proof (cterm_of thy goal) chain_tac of
+      Function_Lib.Solved thm => SOME thm
+    | _ => NONE
   end
 
-fun get_num_points (sk, _, _, _, _) =
+
+fun dest_call' sk (Const (@{const_name Collect}, _) $ Abs (_, _, c)) =
   let
-    fun num (SLeaf i) = i + 1
-      | num (SBranch (s, t)) = num t
-  in num sk end
-
-fun get_types (_, T, _, _, _) = T
-fun get_measures (_, _, M, _, _) = Inttab.lookup_list M
-
-fun get_chain (_, _, _, C, _) c1 c2 =
-  Term2tab.lookup C (c1, c2)
-
-fun get_descent (_, _, _, _, D) c m1 m2 =
-  Term3tab.lookup D (c, (m1, m2))
-
-fun dest_call D (Const (@{const_name Collect}, _) $ Abs (_, _, c)) =
-  let
-    val (sk, _, _, _, _) = D
     val vs = Term.strip_qnt_vars @{const_name Ex} c
 
     (* FIXME: throw error "dest_call" for malformed terms *)
@@ -192,8 +180,9 @@
   in
     (vs, p, l', q, r', Gam)
   end
-  | dest_call D t = error "dest_call"
+  | dest_call' _ _ = error "dest_call"
 
+fun dest_call (sk, _, _, _, _) = dest_call' sk
 
 fun mk_desc thy tac vs Gam l r m1 m2 =
   let
@@ -216,15 +205,43 @@
      | _ => raise Match
 end
 
-fun derive_descent thy tac c m1 m2 D =
-  case get_descent D c m1 m2 of
-    SOME _ => D
-  | NONE => 
-    let
-      val (vs, _, l, _, r, Gam) = dest_call D c
-    in 
-      note_descent c m1 m2 (mk_desc thy tac vs Gam l r m1 m2) D
-    end
+fun prove_descent thy tac sk (c, (m1, m2)) =
+  let
+    val (vs, _, l, _, r, Gam) = dest_call' sk c
+  in 
+    mk_desc thy tac vs Gam l r m1 m2
+  end
+
+fun create ctxt chain_tac descent_tac T rel =
+  let
+    val thy = ProofContext.theory_of ctxt
+    val sk = mk_sum_skel rel
+    val Ts = node_types sk T
+    val M = Inttab.make (map_index (apsnd (MeasureFunctions.get_measure_functions ctxt)) Ts)
+    val chain_cache = Cache.create Term2tab.empty Term2tab.lookup Term2tab.update
+      (prove_chain thy chain_tac)
+    val descent_cache = Cache.create Term3tab.empty Term3tab.lookup Term3tab.update
+      (prove_descent thy descent_tac sk)
+  in
+    (sk, nth Ts, M, chain_cache, descent_cache)
+  end
+
+fun get_num_points (sk, _, _, _, _) =
+  let
+    fun num (SLeaf i) = i + 1
+      | num (SBranch (s, t)) = num t
+  in num sk end
+
+fun get_types (_, T, _, _, _) = T
+fun get_measures (_, _, M, _, _) = Inttab.lookup_list M
+
+fun get_chain (_, _, _, C, _) c1 c2 =
+  SOME (C (c1, c2))
+
+fun get_descent (_, _, _, _, D) c m1 m2 =
+  SOME (D (c, (m1, m2)))
+
+fun derive_descent thy tac c m1 m2 D = D (* disabled *)
 
 fun CALLS tac i st =
   if Thm.no_prems st then all_tac st
@@ -234,12 +251,12 @@
 
 type ttac = (data -> int -> tactic) -> (data -> int -> tactic) -> data -> int -> tactic
 
-fun TERMINATION ctxt tac =
+fun TERMINATION ctxt atac tac =
   SUBGOAL (fn (_ $ (Const (@{const_name wf}, wfT) $ rel), i) =>
   let
     val (T, _) = HOLogic.dest_prodT (HOLogic.dest_setT (domain_type wfT))
   in
-    tac (create ctxt T rel) i
+    tac (create ctxt atac atac T rel) i
   end)
 
 
@@ -315,19 +332,8 @@
 
 (*** DEPENDENCY GRAPHS ***)
 
-fun prove_chain thy chain_tac c1 c2 =
-  let
-    val goal =
-      HOLogic.mk_eq (HOLogic.mk_binop @{const_name Relation.rel_comp} (c1, c2),
-        Const (@{const_abbrev Set.empty}, fastype_of c1))
-      |> HOLogic.mk_Trueprop (* "C1 O C2 = {}" *)
-  in
-    case Function_Lib.try_proof (cterm_of thy goal) chain_tac of
-      Function_Lib.Solved thm => SOME thm
-    | _ => NONE
-  end
-
-fun derive_chains ctxt chain_tac cont D = CALLS (fn (cs, i) =>
+fun derive_chains ctxt chain_tac cont = cont
+(* fn D => CALLS (fn (cs, i) =>
   let
     val thy = ProofContext.theory_of ctxt
 
@@ -337,7 +343,7 @@
   in
     cont (fold_product derive_chain cs cs D) i
   end)
-
+*)
 
 fun mk_dgraph D cs =
   Term_Graph.empty
@@ -392,7 +398,9 @@
 
 (*** Local Descent Proofs ***)
 
-fun gen_descent diag ctxt tac cont D = CALLS (fn (cs, i) =>
+fun gen_descent diag ctxt tac cont = cont
+(*
+  fn D => CALLS (fn (cs, i) =>
   let
     val thy = ProofContext.theory_of ctxt
     val measures_of = get_measures D
@@ -409,6 +417,7 @@
   in
     cont (Function_Common.PROFILE "deriving descents" (fold derive cs) D) i
   end)
+*)
 
 fun derive_diag ctxt = gen_descent true ctxt
 fun derive_all ctxt = gen_descent false ctxt