recdef hints (attributes and modifiers);
authorwenzelm
Tue, 05 Sep 2000 18:49:26 +0200
changeset 9859 2cd338998b53
parent 9858 c3ac6128b649
child 9860 5c5efed691b9
recdef hints (attributes and modifiers);
src/HOL/Tools/recdef_package.ML
--- a/src/HOL/Tools/recdef_package.ML	Tue Sep 05 18:49:02 2000 +0200
+++ b/src/HOL/Tools/recdef_package.ML	Tue Sep 05 18:49:26 2000 +0200
@@ -11,12 +11,26 @@
   val print_recdefs: theory -> unit
   val get_recdef: theory -> string
     -> {simps: thm list, rules: thm list list, induct: thm, tcs: term list} option
-  val add_recdef: xstring -> string -> ((bstring * string) * Args.src list) list
-    -> simpset option -> (xstring * Args.src list) list -> theory
-    -> theory * {simps: thm list, rules: thm list list, induct: thm, tcs: term list}
-  val add_recdef_i: xstring -> term -> ((bstring * term) * theory attribute list) list
-    -> simpset option -> (thm * theory attribute list) list
-    -> theory -> theory * {simps: thm list, rules: thm list list, induct: thm, tcs: term list}
+  val simp_add_global: theory attribute
+  val simp_del_global: theory attribute
+  val cong_add_global: theory attribute
+  val cong_del_global: theory attribute
+  val wf_add_global: theory attribute
+  val wf_del_global: theory attribute
+  val simp_add_local: Proof.context attribute
+  val simp_del_local: Proof.context attribute
+  val cong_add_local: Proof.context attribute
+  val cong_del_local: Proof.context attribute
+  val wf_add_local: Proof.context attribute
+  val wf_del_local: Proof.context attribute
+  val add_recdef: xstring -> string -> ((bstring * string) * Args.src list) list ->
+    Args.src option -> theory -> theory
+      * {simps: thm list, rules: thm list list, induct: thm, tcs: term list}
+  val add_recdef_i: xstring -> term -> ((bstring * term) * theory attribute list) list ->
+    theory -> theory * {simps: thm list, rules: thm list list, induct: thm, tcs: term list}
+  val add_recdef_old: xstring -> string -> ((bstring * string) * Args.src list) list ->
+    simpset * thm list -> theory ->
+    theory * {simps: thm list, rules: thm list list, induct: thm, tcs: term list}
   val defer_recdef: xstring -> string list -> (xstring * Args.src list) list
     -> theory -> theory * {induct_rules: thm}
   val defer_recdef_i: xstring -> term list -> (thm * theory attribute list) list
@@ -27,45 +41,188 @@
 structure RecdefPackage: RECDEF_PACKAGE =
 struct
 
+
 val quiet_mode = Tfl.quiet_mode;
 val message = Tfl.message;
 
 
+(** recdef hints **)
 
-(** theory data **)
+(* type hints *)
+
+type hints = {simps: thm list, congs: (string * thm) list, wfs: thm list};
+
+fun mk_hints (simps, congs, wfs) = {simps = simps, congs = congs, wfs = wfs}: hints;
+fun map_hints f ({simps, congs, wfs}: hints) = mk_hints (f (simps, congs, wfs));
+
+fun map_simps f = map_hints (fn (simps, congs, wfs) => (f simps, congs, wfs));
+fun map_congs f = map_hints (fn (simps, congs, wfs) => (simps, f congs, wfs));
+fun map_wfs f = map_hints (fn (simps, congs, wfs) => (simps, congs, f wfs));
+
+fun pretty_hints ({simps, congs, wfs}: hints) =
+ [Pretty.big_list "recdef simp hints:" (map Display.pretty_thm simps),
+  Pretty.big_list "recdef cong hints:" (map Display.pretty_thm (map #2 congs)),
+  Pretty.big_list "recdef wf hints:" (map Display.pretty_thm wfs)];
+
+
+(* congruence rules *)
+
+local
+
+val cong_head =
+  fst o Term.dest_Const o Term.head_of o fst o Logic.dest_equals o Thm.concl_of;
 
-(* data kind 'HOL/recdef' *)
+fun prep_cong raw_thm =
+  let val thm = safe_mk_meta_eq raw_thm in (cong_head thm, thm) end;
+
+in
+
+fun add_cong raw_thm congs =
+  let val (c, thm) = prep_cong raw_thm
+  in overwrite_warn (congs, (c, thm)) ("Overwriting recdef congruence rule for " ^ quote c) end;
+
+fun del_cong raw_thm congs =
+  let
+    val (c, thm) = prep_cong raw_thm;
+    val (del, rest) = Library.partition (Library.equal c o fst) congs;
+  in if null del then (warning ("No recdef congruence rule for " ^ quote c); congs) else rest end;
+
+val add_congs = curry (foldr (uncurry add_cong));
+
+end;
+
+
+
+(** global and local recdef data **)
+
+(* theory data kind 'HOL/recdef' *)
 
 type recdef_info = {simps: thm list, rules: thm list list, induct: thm, tcs: term list};
 
-structure RecdefArgs =
+structure GlobalRecdefArgs =
 struct
   val name = "HOL/recdef";
-  type T = recdef_info Symtab.table;
+  type T = recdef_info Symtab.table * hints;
 
-  val empty = Symtab.empty;
+  val empty = (Symtab.empty, mk_hints ([], [], []));
   val copy = I;
   val prep_ext = I;
-  val merge: T * T -> T = Symtab.merge (K true);
+  fun merge
+   ((tab1, {simps = simps1, congs = congs1, wfs = wfs1}),
+    (tab2, {simps = simps2, congs = congs2, wfs = wfs2})) =
+      (Symtab.merge (K true) (tab1, tab2),
+        mk_hints (Drule.merge_rules (simps1, simps2),
+          Library.merge_alists congs1 congs2,
+          Drule.merge_rules (wfs1, wfs2)));
 
-  fun print sg tab =
-    Pretty.writeln (Pretty.strs ("recdefs:" ::
-      map #1 (Sign.cond_extern_table sg Sign.constK tab)));
+  fun print sg (tab, hints) =
+   (Pretty.strs ("recdefs:" :: map #1 (Sign.cond_extern_table sg Sign.constK tab)) ::
+     pretty_hints hints) |> Pretty.chunks |> Pretty.writeln;
 end;
 
-structure RecdefData = TheoryDataFun(RecdefArgs);
-val print_recdefs = RecdefData.print;
+structure GlobalRecdefData = TheoryDataFun(GlobalRecdefArgs);
+val print_recdefs = GlobalRecdefData.print;
 
 
-(* get and put data *)
-
-fun get_recdef thy name = Symtab.lookup (RecdefData.get thy, name);
+fun get_recdef thy name = Symtab.lookup (#1 (GlobalRecdefData.get thy), name);
 
 fun put_recdef name info thy =
   let
-    val tab = Symtab.update_new ((name, info), RecdefData.get thy)
+    val (tab, hints) = GlobalRecdefData.get thy;
+    val tab' = Symtab.update_new ((name, info), tab)
       handle Symtab.DUP _ => error ("Duplicate recursive function definition " ^ quote name);
-  in RecdefData.put tab thy end;
+  in GlobalRecdefData.put (tab', hints) thy end;
+
+val get_global_hints = #2 o GlobalRecdefData.get;
+val map_global_hints = GlobalRecdefData.map o apsnd;
+
+
+(* proof data kind 'HOL/recdef' *)
+
+structure LocalRecdefArgs =
+struct
+  val name = "HOL/recdef";
+  type T = hints;
+  val init = get_global_hints;
+  fun print _ hints = pretty_hints hints |> Pretty.chunks |> Pretty.writeln;
+end;
+
+structure LocalRecdefData = ProofDataFun(LocalRecdefArgs);
+val get_local_hints = LocalRecdefData.get;
+val map_local_hints = LocalRecdefData.map;
+
+
+(* attributes *)
+
+local
+
+fun global_local f g =
+ (fn (thy, thm) => (map_global_hints (f (g thm)) thy, thm),
+  fn (ctxt, thm) => (map_local_hints (f (g thm)) ctxt, thm));
+
+fun mk_attr (add1, add2) (del1, del2) =
+ (Attrib.add_del_args add1 del1, Attrib.add_del_args add2 del2);
+
+in
+
+val (simp_add_global, simp_add_local) = global_local map_simps (Drule.add_rules o single);
+val (simp_del_global, simp_del_local) = global_local map_simps (Drule.del_rules o single);
+val (cong_add_global, cong_add_local) = global_local map_congs add_cong;
+val (cong_del_global, cong_del_local) = global_local map_congs del_cong;
+val (wf_add_global, wf_add_local) = global_local map_wfs (Drule.add_rules o single);
+val (wf_del_global, wf_del_local) = global_local map_wfs (Drule.del_rules o single);
+
+val simp_attr = mk_attr (simp_add_global, simp_add_local) (simp_del_global, simp_del_local);
+val cong_attr = mk_attr (cong_add_global, cong_add_local) (cong_del_global, cong_del_local);
+val wf_attr = mk_attr (wf_add_global, wf_add_local) (wf_del_global, wf_del_local);
+
+end;
+
+
+
+(** prepare_hints(_i) **)
+
+local
+
+val simpN = "simp";
+val congN = "cong";
+val wfN = "wf";
+val addN = "add";
+val delN = "del";
+
+val recdef_modifiers =
+ [Args.$$$ simpN -- Args.colon >> K ((I, simp_add_local): Method.modifier),
+  Args.$$$ simpN -- Args.$$$ addN -- Args.colon >> K (I, simp_add_local),
+  Args.$$$ simpN -- Args.$$$ delN -- Args.colon >> K (I, simp_del_local),
+  Args.$$$ congN -- Args.colon >> K (I, cong_add_local),
+  Args.$$$ congN -- Args.$$$ addN -- Args.colon >> K (I, cong_add_local),
+  Args.$$$ congN -- Args.$$$ delN -- Args.colon >> K (I, cong_del_local),
+  Args.$$$ wfN -- Args.colon >> K (I, wf_add_local),
+  Args.$$$ wfN -- Args.$$$ addN -- Args.colon >> K (I, wf_add_local),
+  Args.$$$ wfN -- Args.$$$ delN -- Args.colon >> K (I, wf_del_local)];
+
+val modifiers =  (* FIXME include 'simp cong' (!?) *)
+  recdef_modifiers @ Splitter.split_modifiers @ Classical.cla_modifiers @ Clasimp.iff_modifiers;
+
+in
+
+fun prepare_hints thy opt_src =
+  let
+    val ctxt0 = ProofContext.init thy;
+    val ctxt =
+      (case opt_src of
+        None => ctxt0
+      | Some src => Method.only_sectioned_args modifiers I src ctxt0);
+    val {simps, congs, wfs} = get_local_hints ctxt;
+    val cs = Classical.get_local_claset ctxt;
+    val ss = Simplifier.get_local_simpset ctxt addsimps simps;
+  in (cs, ss, map #2 congs, wfs) end;
+
+fun prepare_hints_i thy () =
+  let val {simps, congs, wfs} = get_global_hints thy
+  in (Classical.claset_of thy, Simplifier.simpset_of thy addsimps simps, map #2 congs, wfs) end;
+
+end;
 
 
 
@@ -73,20 +230,19 @@
 
 fun requires_recdef thy = Theory.requires thy "Recdef" "recursive functions";
 
-fun gen_add_recdef tfl_fn prep_att app_thms raw_name R eq_srcs opt_ss raw_congs thy =
+fun gen_add_recdef tfl_fn prep_att prep_hints raw_name R eq_srcs hints thy =
   let
+    val _ = requires_recdef thy;
+
     val name = Sign.intern_const (Theory.sign_of thy) raw_name;
     val bname = Sign.base_name name;
-
-    val _ = requires_recdef thy;
     val _ = message ("Defining recursive function " ^ quote name ^ " ...");
 
     val ((eq_names, eqs), raw_eq_atts) = apfst split_list (split_list eq_srcs);
     val eq_atts = map (map (prep_att thy)) raw_eq_atts;
-    val ss = if_none opt_ss (Simplifier.simpset_of thy);
-    val (thy, congs) = thy |> app_thms raw_congs;
 
-    val (thy, {rules = rules_idx, induct, tcs}) = tfl_fn thy name R (ss, congs) eqs;
+    val (cs, ss, congs, wfs) = prep_hints thy hints;
+    val (thy, {rules = rules_idx, induct, tcs}) = tfl_fn thy cs ss congs wfs name R eqs;
     val rules = map (map #1) (Library.partition_eq Library.eq_snd rules_idx);
     val simp_att = if null tcs then [Simplifier.simp_add_global] else [];
 
@@ -103,8 +259,17 @@
       |> Theory.parent_path;
   in (thy, result) end;
 
-val add_recdef = gen_add_recdef Tfl.define Attrib.global_attribute IsarThy.apply_theorems;
-val add_recdef_i = gen_add_recdef Tfl.define_i (K I) IsarThy.apply_theorems_i;
+val add_recdef = gen_add_recdef Tfl.define Attrib.global_attribute prepare_hints;
+fun add_recdef_i x y z = gen_add_recdef Tfl.define_i (K I) prepare_hints_i x y z ();
+
+
+(* add_recdef_old -- legacy interface *)
+
+fun prepare_hints_old thy (ss, thms) =
+  let val {simps, congs, wfs} = get_global_hints thy
+  in (Classical.claset_of thy, ss addsimps simps, map #2 (add_congs thms congs), wfs) end;
+
+val add_recdef_old = gen_add_recdef Tfl.define Attrib.global_attribute prepare_hints_old;
 
 
 
@@ -119,7 +284,7 @@
     val _ = message ("Deferred recursive function " ^ quote name ^ " ...");
 
     val (thy1, congs) = thy |> app_thms raw_congs;
-    val (thy2, induct_rules) = tfl_fn thy1 name congs eqs;
+    val (thy2, induct_rules) = tfl_fn thy1 congs name eqs;
     val (thy3, [induct_rules']) =
       thy2
       |> Theory.add_path bname
@@ -136,17 +301,25 @@
 
 (* setup theory *)
 
-val setup = [Prim.init,RecdefData.init];
+val setup =
+ [GlobalRecdefData.init, LocalRecdefData.init,
+  Attrib.add_attributes
+   [("recdef_simp", simp_attr, "declare recdef simp rule"),
+    ("recdef_cong", cong_attr, "declare recdef cong rule"),
+    ("recdef_wf", wf_attr, "declare recdef wf rule")]];
 
 
 (* outer syntax *)
 
 local structure P = OuterParse and K = OuterSyntax.Keyword in
 
+val hints =
+  P.$$$ "(" |-- P.!!! (P.position (P.$$$ "hints" -- P.arguments) --| P.$$$ ")") >> Args.src;
+
 val recdef_decl =
-  P.name -- P.term -- Scan.repeat1 (P.opt_thm_name ":" -- P.prop --| P.marg_comment) --
-  Scan.optional (P.$$$ "congs" |-- P.!!! P.xthms1) []
-  >> (fn (((f, R), eqs), congs) => #1 o add_recdef f R (map P.triple_swap eqs) None congs);
+  P.name -- P.term -- Scan.repeat1 (P.opt_thm_name ":" -- P.prop --| P.marg_comment)
+    -- Scan.option hints
+  >> (fn (((f, R), eqs), src) => #1 o add_recdef f R (map P.triple_swap eqs) src);
 
 val recdefP =
   OuterSyntax.command "recdef" "define general recursive functions (TFL)" K.thy_decl
@@ -162,7 +335,7 @@
   OuterSyntax.command "defer_recdef" "defer general recursive functions (TFL)" K.thy_decl
     (defer_recdef_decl >> Toplevel.theory);
 
-val _ = OuterSyntax.add_keywords ["congs"];
+val _ = OuterSyntax.add_keywords ["hints"];
 val _ = OuterSyntax.add_parsers [recdefP, defer_recdefP];
 
 end;