context dependent components;
authorwenzelm
Sun, 11 Jul 2004 20:35:50 +0200
changeset 15036 cab1c1fc1851
parent 15035 8c57751cd43f
child 15037 19b3b0382303
context dependent components;
src/Provers/classical.ML
src/Provers/simplifier.ML
--- a/src/Provers/classical.ML	Sun Jul 11 20:35:23 2004 +0200
+++ b/src/Provers/classical.ML	Sun Jul 11 20:35:50 2004 +0200
@@ -80,6 +80,7 @@
   val CLASET': (claset -> 'a -> tactic) -> 'a -> tactic
   val claset: unit -> claset
   val claset_ref: unit -> claset ref
+  val local_claset_of   : Proof.context -> claset
 
   val fast_tac          : claset -> int -> tactic
   val slow_tac          : claset -> int -> tactic
@@ -136,6 +137,10 @@
 signature CLASSICAL =
 sig
   include BASIC_CLASSICAL
+  val add_context_safe_wrapper: string * (Proof.context -> wrapper) -> theory -> theory
+  val del_context_safe_wrapper: string -> theory -> theory
+  val add_context_unsafe_wrapper: string * (Proof.context -> wrapper) -> theory -> theory
+  val del_context_unsafe_wrapper: string -> theory -> theory
   val print_local_claset: Proof.context -> unit
   val get_local_claset: Proof.context -> claset
   val put_local_claset: claset -> Proof.context -> Proof.context
@@ -256,12 +261,14 @@
      dup_netpair   = empty_netpair,
      xtra_netpair  = empty_netpair};
 
-fun print_cs (CS {safeIs, safeEs, hazIs, hazEs, ...}) =
+fun print_cs (CS {safeIs, safeEs, hazIs, hazEs, swrappers, uwrappers, ...}) =
   let val pretty_thms = map Display.pretty_thm in
     [Pretty.big_list "safe introduction rules (intro!):" (pretty_thms safeIs),
       Pretty.big_list "introduction rules (intro):" (pretty_thms hazIs),
       Pretty.big_list "safe elimination rules (elim!):" (pretty_thms safeEs),
-      Pretty.big_list "elimination rules (elim):" (pretty_thms hazEs)]
+      Pretty.big_list "elimination rules (elim):" (pretty_thms hazEs),
+      Pretty.strs ("safe wrappers:" :: map #1 swrappers),
+      Pretty.strs ("unsafe wrappers:" :: map #1 uwrappers)]
     |> Pretty.chunks |> Pretty.writeln
   end;
 
@@ -565,15 +572,17 @@
 
 (*Remove a safe wrapper*)
 fun cs delSWrapper name = update_swrappers cs (fn swrappers =>
-    let val (del,rest) = partition (fn (n,_) => n=name) swrappers
-    in if null del then (warning ("No such safe wrapper in claset: "^ name);
-                         swrappers) else rest end);
+  let val swrappers' = filter_out (equal name o #1) swrappers in
+    if length swrappers <> length swrappers' then swrappers'
+    else (warning ("No such safe wrapper in claset: "^ name); swrappers)
+  end);
 
 (*Remove an unsafe wrapper*)
 fun cs delWrapper name = update_uwrappers cs (fn uwrappers =>
-    let val (del,rest) = partition (fn (n,_) => n=name) uwrappers
-    in if null del then (warning ("No such unsafe wrapper in claset: " ^ name);
-                         uwrappers) else rest end);
+  let val uwrappers' = filter_out (equal name o #1) uwrappers in
+    if length uwrappers <> length uwrappers' then uwrappers'
+    else (warning ("No such unsafe wrapper in claset: " ^ name); uwrappers)
+  end);
 
 (* compose a safe tactic alternatively before/after safe_step_tac *)
 fun cs addSbefore  (name,    tac1) =
@@ -772,6 +781,36 @@
 
 
 
+(** context dependent claset components **)
+
+datatype context_cs = ContextCS of
+ {swrappers: (string * (Proof.context -> wrapper)) list,
+  uwrappers: (string * (Proof.context -> wrapper)) list};
+
+fun context_cs ctxt cs (ContextCS {swrappers, uwrappers}) =
+  let
+    fun add_wrapper add (name, f) claset = add (claset, (name, f ctxt));
+  in
+    cs |> fold_rev (add_wrapper (op addSWrapper)) swrappers
+    |> fold_rev (add_wrapper (op addWrapper)) uwrappers
+  end;
+
+fun make_context_cs (swrappers, uwrappers) =
+  ContextCS {swrappers = swrappers, uwrappers = uwrappers};
+
+val empty_context_cs = make_context_cs ([], []);
+
+fun merge_context_cs (ctxt_cs1, ctxt_cs2) =
+  let
+    val ContextCS {swrappers = swrappers1, uwrappers = uwrappers1} = ctxt_cs1;
+    val ContextCS {swrappers = swrappers2, uwrappers = uwrappers2} = ctxt_cs2;
+
+    val swrappers' = merge_alists swrappers1 swrappers2;
+    val uwrappers' = merge_alists uwrappers1 uwrappers2;
+  in make_context_cs (swrappers', uwrappers') end;
+
+
+
 (** claset theory data **)
 
 (* theory data kind 'Provers/claset' *)
@@ -779,19 +818,24 @@
 structure GlobalClasetArgs =
 struct
   val name = "Provers/claset";
-  type T = claset ref;
+  type T = claset ref * context_cs;
 
-  val empty = ref empty_cs;
-  fun copy (ref cs) = (ref cs): T;            (*create new reference!*)
+  val empty = (ref empty_cs, empty_context_cs);
+  fun copy (ref cs, ctxt_cs) = (ref cs, ctxt_cs): T;            (*create new reference!*)
   val prep_ext = copy;
-  fun merge (ref cs1, ref cs2) = ref (merge_cs (cs1, cs2));
-  fun print _ (ref cs) = print_cs cs;
+  fun merge ((ref cs1, ctxt_cs1), (ref cs2, ctxt_cs2)) =
+    (ref (merge_cs (cs1, cs2)), merge_context_cs (ctxt_cs1, ctxt_cs2));
+  fun print _ (ref cs, _) = print_cs cs;
 end;
 
 structure GlobalClaset = TheoryDataFun(GlobalClasetArgs);
 val print_claset = GlobalClaset.print;
-val claset_ref_of_sg = GlobalClaset.get_sg;
-val claset_ref_of = GlobalClaset.get;
+val claset_ref_of_sg = #1 o GlobalClaset.get_sg;
+val claset_ref_of = #1 o GlobalClaset.get;
+val get_context_cs = #2 o GlobalClaset.get o ProofContext.theory_of;
+
+fun map_context_cs f = GlobalClaset.map (apsnd
+  (fn ContextCS {swrappers, uwrappers} => make_context_cs (f (swrappers, uwrappers))));
 
 
 (* access claset *)
@@ -819,6 +863,15 @@
 val Delrules = change_claset (op delrules);
 
 
+(* context dependent components *)
+
+fun add_context_safe_wrapper wrapper = map_context_cs (apfst (merge_alists [wrapper]));
+fun del_context_safe_wrapper name = map_context_cs (apfst (filter_out (equal name o #1)));
+
+fun add_context_unsafe_wrapper wrapper = map_context_cs (apsnd (merge_alists [wrapper]));
+fun del_context_unsafe_wrapper name = map_context_cs (apsnd (filter_out (equal name o #1)));
+
+
 (* proof data kind 'Provers/claset' *)
 
 structure LocalClasetArgs =
@@ -826,7 +879,7 @@
   val name = "Provers/claset";
   type T = claset;
   val init = claset_of;
-  fun print _ cs = print_cs cs;
+  fun print ctxt cs = print_cs (context_cs ctxt cs (get_context_cs ctxt));
 end;
 
 structure LocalClaset = ProofDataFun(LocalClasetArgs);
@@ -834,6 +887,9 @@
 val get_local_claset = LocalClaset.get;
 val put_local_claset = LocalClaset.put;
 
+fun local_claset_of ctxt =
+  context_cs ctxt (get_local_claset ctxt) (get_context_cs ctxt);
+
 
 (* attributes *)
 
@@ -925,10 +981,10 @@
 (** proof methods **)
 
 fun METHOD_CLASET tac ctxt =
-  Method.METHOD (tac ctxt (get_local_claset ctxt));
+  Method.METHOD (tac ctxt (local_claset_of ctxt));
 
 fun METHOD_CLASET' tac ctxt =
-  Method.METHOD (HEADGOAL o tac ctxt (get_local_claset ctxt));
+  Method.METHOD (HEADGOAL o tac ctxt (local_claset_of ctxt));
 
 
 local
@@ -974,10 +1030,10 @@
   Args.del -- Args.colon >> K (I, rule_del_local)];
 
 fun cla_meth tac prems ctxt = Method.METHOD (fn facts =>
-  ALLGOALS (Method.insert_tac (prems @ facts)) THEN tac (get_local_claset ctxt));
+  ALLGOALS (Method.insert_tac (prems @ facts)) THEN tac (local_claset_of ctxt));
 
 fun cla_meth' tac prems ctxt = Method.METHOD (fn facts =>
-  HEADGOAL (Method.insert_tac (prems @ facts) THEN' tac (get_local_claset ctxt)));
+  HEADGOAL (Method.insert_tac (prems @ facts) THEN' tac (local_claset_of ctxt)));
 
 val cla_method = Method.bang_sectioned_args cla_modifiers o cla_meth;
 val cla_method' = Method.bang_sectioned_args cla_modifiers o cla_meth';
--- a/src/Provers/simplifier.ML	Sun Jul 11 20:35:23 2004 +0200
+++ b/src/Provers/simplifier.ML	Sun Jul 11 20:35:50 2004 +0200
@@ -9,10 +9,12 @@
 signature BASIC_SIMPLIFIER =
 sig
   include BASIC_META_SIMPLIFIER
-  val simproc_i: Sign.sg -> string -> term list
-    -> (Sign.sg -> simpset -> term -> thm option) -> simproc
-  val simproc: Sign.sg -> string -> string list
-    -> (Sign.sg -> simpset -> term -> thm option) -> simproc
+  type context_solver
+  val mk_context_solver: string -> (Proof.context -> thm list -> int -> tactic)
+    -> context_solver
+  type context_simproc
+  val mk_context_simproc: string -> cterm list ->
+    (Proof.context -> simpset -> term -> thm option) -> context_simproc
   val print_simpset: theory -> unit
   val simpset_ref_of_sg: Sign.sg -> simpset ref
   val simpset_ref_of: theory -> simpset ref
@@ -28,6 +30,7 @@
   val Delsimprocs: simproc list -> unit
   val Addcongs: thm list -> unit
   val Delcongs: thm list -> unit
+  val local_simpset_of: Proof.context -> simpset
   val safe_asm_full_simp_tac: simpset -> int -> tactic
   val               simp_tac: simpset -> int -> tactic
   val           asm_simp_tac: simpset -> int -> tactic
@@ -49,11 +52,28 @@
 signature SIMPLIFIER =
 sig
   include BASIC_SIMPLIFIER
+  val simproc_i: Sign.sg -> string -> term list
+    -> (Sign.sg -> simpset -> term -> thm option) -> simproc
+  val simproc: Sign.sg -> string -> string list
+    -> (Sign.sg -> simpset -> term -> thm option) -> simproc
+  val context_simproc_i: Sign.sg -> string -> term list
+    -> (Proof.context -> simpset -> term -> thm option) -> context_simproc
+  val context_simproc: Sign.sg -> string -> string list
+    -> (Proof.context -> simpset -> term -> thm option) -> context_simproc
   val          rewrite: simpset -> cterm -> thm
   val      asm_rewrite: simpset -> cterm -> thm
   val     full_rewrite: simpset -> cterm -> thm
   val   asm_lr_rewrite: simpset -> cterm -> thm
   val asm_full_rewrite: simpset -> cterm -> thm
+  val add_context_simprocs: context_simproc list -> theory -> theory
+  val del_context_simprocs: context_simproc list -> theory -> theory
+  val set_context_subgoaler: (Proof.context -> simpset -> int -> tactic) -> theory -> theory
+  val reset_context_subgoaler: theory -> theory
+  val add_context_looper: string * (Proof.context -> int -> Tactical.tactic) ->
+    theory -> theory
+  val del_context_looper: string -> theory -> theory
+  val add_context_unsafe_solver: context_solver -> theory -> theory
+  val add_context_safe_solver: context_solver -> theory -> theory
   val print_local_simpset: Proof.context -> unit
   val get_local_simpset: Proof.context -> simpset
   val put_local_simpset: simpset -> Proof.context -> Proof.context
@@ -81,6 +101,81 @@
 open MetaSimplifier;
 
 
+(** context dependent simpset components **)
+
+(* datatype context_solver *)
+
+datatype context_solver =
+  ContextSolver of (string * (Proof.context -> thm list -> int -> tactic)) * stamp;
+
+fun mk_context_solver name f = ContextSolver ((name, f), stamp ());
+fun eq_context_solver (ContextSolver (_, id1), ContextSolver (_, id2)) = (id1 = id2);
+val merge_context_solvers = gen_merge_lists eq_context_solver;
+
+
+(* datatype context_simproc *)
+
+datatype context_simproc = ContextSimproc of
+  (string * cterm list * (Proof.context -> simpset -> term -> thm option)) * stamp;
+
+fun mk_context_simproc name lhss f = ContextSimproc ((name, lhss, f), stamp ());
+fun eq_context_simproc (ContextSimproc (_, id1), ContextSimproc (_, id2)) = (id1 = id2);
+val merge_context_simprocs = gen_merge_lists eq_context_simproc;
+
+fun context_simproc_i sg name =
+  mk_context_simproc name o map (Thm.cterm_of sg o Logic.varify);
+
+fun context_simproc sg name =
+  context_simproc_i sg name o map (Sign.simple_read_term sg TypeInfer.logicT);
+
+
+(* datatype context_ss *)
+
+datatype context_ss = ContextSS of
+ {simprocs: context_simproc list,
+  subgoal_tac: (Proof.context -> simpset -> int -> tactic) option,
+  loop_tacs: (string * (Proof.context -> int -> tactic)) list,
+  unsafe_solvers: context_solver list,
+  solvers: context_solver list};
+
+fun context_ss ctxt ss ctxt_ss =
+  let
+    val ContextSS {simprocs, subgoal_tac, loop_tacs, unsafe_solvers, solvers} = ctxt_ss;
+    fun prep_simproc (ContextSimproc ((name, lhss, f), _)) =
+      mk_simproc name lhss (K (f ctxt));
+    fun add_loop (name, f) simpset = simpset addloop (name, f ctxt);
+    fun add_solver add (ContextSolver ((name, f), _)) simpset =
+      add (simpset, mk_solver name (f ctxt));
+  in
+    ((case subgoal_tac of None => ss | Some tac => ss setsubgoaler tac ctxt)
+      addsimprocs map prep_simproc simprocs)
+    |> fold_rev add_loop loop_tacs
+    |> fold_rev (add_solver (op addSolver)) unsafe_solvers
+    |> fold_rev (add_solver (op addSSolver)) solvers
+  end;
+
+fun make_context_ss (simprocs, subgoal_tac, loop_tacs, unsafe_solvers, solvers) =
+  ContextSS {simprocs = simprocs, subgoal_tac = subgoal_tac, loop_tacs = loop_tacs,
+    unsafe_solvers = unsafe_solvers, solvers = solvers};
+
+val empty_context_ss = make_context_ss ([], None, [], [], []);
+
+fun merge_context_ss (ctxt_ss1, ctxt_ss2) =
+  let
+    val ContextSS {simprocs = simprocs1, subgoal_tac = subgoal_tac1, loop_tacs = loop_tacs1,
+      unsafe_solvers = unsafe_solvers1, solvers = solvers1} = ctxt_ss1;
+    val ContextSS {simprocs = simprocs2, subgoal_tac = subgoal_tac2, loop_tacs = loop_tacs2,
+      unsafe_solvers = unsafe_solvers2, solvers = solvers2} = ctxt_ss2;
+
+    val simprocs' = merge_context_simprocs simprocs1 simprocs2;
+    val subgoal_tac' = (case subgoal_tac1 of None => subgoal_tac2 | some => some);
+    val loop_tacs' = merge_alists loop_tacs1 loop_tacs2;
+    val unsafe_solvers' = merge_context_solvers unsafe_solvers1 unsafe_solvers2;
+    val solvers' = merge_context_solvers solvers1 solvers2;
+  in make_context_ss (simprocs', subgoal_tac', loop_tacs', unsafe_solvers', solvers') end;
+
+
+
 (** global and local simpset data **)
 
 (* theory data kind 'Provers/simpset' *)
@@ -88,19 +183,25 @@
 structure GlobalSimpsetArgs =
 struct
   val name = "Provers/simpset";
-  type T = simpset ref;
+  type T = simpset ref * context_ss;
 
-  val empty = ref empty_ss;
-  fun copy (ref ss) = (ref ss): T;            (*create new reference!*)
+  val empty = (ref empty_ss, empty_context_ss);
+  fun copy (ref ss, ctxt_ss) = (ref ss, ctxt_ss): T;            (*create new reference!*)
   val prep_ext = copy;
-  fun merge (ref ss1, ref ss2) = ref (merge_ss (ss1, ss2));
-  fun print _ (ref ss) = print_ss ss;
+  fun merge ((ref ss1, ctxt_ss1), (ref ss2, ctxt_ss2)) =
+    (ref (merge_ss (ss1, ss2)), merge_context_ss (ctxt_ss1, ctxt_ss2));
+  fun print _ (ref ss, _) = print_ss ss;
 end;
 
 structure GlobalSimpset = TheoryDataFun(GlobalSimpsetArgs);
 val print_simpset = GlobalSimpset.print;
-val simpset_ref_of_sg = GlobalSimpset.get_sg;
-val simpset_ref_of = GlobalSimpset.get;
+val simpset_ref_of_sg = #1 o GlobalSimpset.get_sg;
+val simpset_ref_of = #1 o GlobalSimpset.get;
+val get_context_ss = #2 o GlobalSimpset.get o ProofContext.theory_of;
+
+fun map_context_ss f = GlobalSimpset.map (apsnd
+  (fn ContextSS {simprocs, subgoal_tac, loop_tacs, unsafe_solvers, solvers} =>
+    make_context_ss (f (simprocs, subgoal_tac, loop_tacs, unsafe_solvers, solvers))));
 
 
 (* access global simpset *)
@@ -131,6 +232,47 @@
 val Delcongs = change_simpset (op delcongs);
 
 
+(* change context dependent components *)
+
+fun add_context_simprocs procs =
+  map_context_ss (fn (simprocs, subgoal_tac, loop_tacs, unsafe_solvers, solvers) =>
+    (merge_context_simprocs procs simprocs, subgoal_tac, loop_tacs,
+      unsafe_solvers, solvers));
+
+fun del_context_simprocs procs =
+  map_context_ss (fn (simprocs, subgoal_tac, loop_tacs, unsafe_solvers, solvers) =>
+    (gen_rems eq_context_simproc (simprocs, procs), subgoal_tac, loop_tacs,
+      unsafe_solvers, solvers));
+
+fun set_context_subgoaler tac =
+  map_context_ss (fn (simprocs, _, loop_tacs, unsafe_solvers, solvers) =>
+    (simprocs, Some tac, loop_tacs, unsafe_solvers, solvers));
+
+val reset_context_subgoaler =
+  map_context_ss (fn (simprocs, _, loop_tacs, unsafe_solvers, solvers) =>
+    (simprocs, None, loop_tacs, unsafe_solvers, solvers));
+
+fun add_context_looper (name, tac) =
+  map_context_ss (fn (simprocs, subgoal_tac, loop_tacs, unsafe_solvers, solvers) =>
+    (simprocs, subgoal_tac, merge_alists [(name, tac)] loop_tacs,
+      unsafe_solvers, solvers));
+
+fun del_context_looper name =
+  map_context_ss (fn (simprocs, subgoal_tac, loop_tacs, unsafe_solvers, solvers) =>
+    (simprocs, subgoal_tac, filter_out (equal name o #1) loop_tacs,
+      unsafe_solvers, solvers));
+
+fun add_context_unsafe_solver solver =
+  map_context_ss (fn (simprocs, subgoal_tac, loop_tacs, unsafe_solvers, solvers) =>
+    (simprocs, subgoal_tac, loop_tacs,
+      merge_context_solvers [solver] unsafe_solvers, solvers));
+
+fun add_context_safe_solver solver =
+  map_context_ss (fn (simprocs, subgoal_tac, loop_tacs, unsafe_solvers, solvers) =>
+    (simprocs, subgoal_tac, loop_tacs, unsafe_solvers,
+      merge_context_solvers [solver] solvers));
+
+
 (* proof data kind 'Provers/simpset' *)
 
 structure LocalSimpsetArgs =
@@ -138,7 +280,7 @@
   val name = "Provers/simpset";
   type T = simpset;
   val init = simpset_of;
-  fun print _ ss = print_ss ss;
+  fun print ctxt ss = print_ss (context_ss ctxt ss (get_context_ss ctxt));
 end;
 
 structure LocalSimpset = ProofDataFun(LocalSimpsetArgs);
@@ -147,6 +289,9 @@
 val put_local_simpset = LocalSimpset.put;
 fun map_local_simpset f ctxt = put_local_simpset (f (get_local_simpset ctxt)) ctxt;
 
+fun local_simpset_of ctxt =
+  context_ss ctxt (get_local_simpset ctxt) (get_context_ss ctxt);
+
 
 (* attributes *)
 
@@ -240,7 +385,7 @@
 
 val simplified_attr =
  (simplified_att simpset_of Attrib.global_thmss,
-  simplified_att get_local_simpset Attrib.local_thmss);
+  simplified_att local_simpset_of Attrib.local_thmss);
 
 end;
 
@@ -289,11 +434,11 @@
 
 fun simp_method (prems, tac) ctxt = Method.METHOD (fn facts =>
   ALLGOALS (Method.insert_tac (prems @ facts)) THEN
-    (CHANGED_PROP o ALLGOALS o tac) (get_local_simpset ctxt));
+    (CHANGED_PROP o ALLGOALS o tac) (local_simpset_of ctxt));
 
 fun simp_method' (prems, tac) ctxt = Method.METHOD (fn facts =>
   HEADGOAL (Method.insert_tac (prems @ facts) THEN'
-      (CHANGED_PROP oo tac) (get_local_simpset ctxt)));
+      (CHANGED_PROP oo tac) (local_simpset_of ctxt)));
 
 
 (* setup_methods *)