src/Pure/Isar/induct_attrib.ML
changeset 11730 418533653668
parent 11665 7324f018ea15
child 11784 b66b198ee29a
--- a/src/Pure/Isar/induct_attrib.ML	Fri Oct 12 12:08:04 2001 +0200
+++ b/src/Pure/Isar/induct_attrib.ML	Fri Oct 12 12:08:57 2001 +0200
@@ -8,6 +8,7 @@
 
 signature INDUCT_ATTRIB =
 sig
+  val vars_of: term -> term list
   val dest_global_rules: theory ->
     {type_cases: (string * thm) list, set_cases: (string * thm) list,
       type_induct: (string * thm) list, set_induct: (string * thm) list}
@@ -16,12 +17,14 @@
     {type_cases: (string * thm) list, set_cases: (string * thm) list,
       type_induct: (string * thm) list, set_induct: (string * thm) list}
   val print_local_rules: Proof.context -> unit
-  val get_cases : Proof.context -> (string * thm) NetRules.T * (string * thm) NetRules.T
-  val get_induct : Proof.context -> (string * thm) NetRules.T * (string * thm) NetRules.T
+  val lookup_casesT : Proof.context -> string -> thm option
   val lookup_casesS : Proof.context -> string -> thm option
-  val lookup_casesT : Proof.context -> string -> thm option
+  val lookup_inductT : Proof.context -> string -> thm option
   val lookup_inductS : Proof.context -> string -> thm option
-  val lookup_inductT : Proof.context -> string -> thm option
+  val find_casesT: Proof.context -> typ -> thm list
+  val find_casesS: Proof.context -> thm -> thm list
+  val find_inductT: Proof.context -> typ -> thm list
+  val find_inductS: Proof.context -> thm -> thm list
   val cases_type_global: string -> theory attribute
   val cases_set_global: string -> theory attribute
   val cases_type_local: string -> Proof.context attribute
@@ -41,18 +44,49 @@
 struct
 
 
+(** misc utils **)
+
+(* encode_type -- for indexing purposes *)
+
+fun encode_type (Type (c, Ts)) = Term.list_comb (Const (c, dummyT), map encode_type Ts)
+  | encode_type (TFree (a, _)) = Free (a, dummyT)
+  | encode_type (TVar (a, _)) = Var (a, dummyT);
+
+
+(* variables -- ordered left-to-right, preferring right *)
+
+local
+
+fun rev_vars_of tm =
+  Term.foldl_aterms (fn (ts, t as Var _) => t :: ts | (ts, _) => ts) ([], tm)
+  |> Library.distinct;
+
+val mk_var = encode_type o #2 o Term.dest_Var;
+
+in
+
+val vars_of = rev o rev_vars_of;
+
+fun first_var thm = mk_var (hd (vars_of (hd (Thm.prems_of thm)))) handle LIST _ =>
+  raise THM ("No variables in first premise of rule", 0, [thm]);
+
+fun last_var thm = mk_var (hd (rev_vars_of (Thm.concl_of thm))) handle LIST _ =>
+  raise THM ("No variables in conclusion of rule", 0, [thm]);
+
+end;
+
+
+
 (** global and local induct data **)
 
 (* rules *)
 
 type rules = (string * thm) NetRules.T;
 
-fun eq_rule ((s1:string, th1), (s2, th2)) = s1 = s2 andalso Thm.eq_thm (th1, th2);
+val init_rules = NetRules.init (fn ((s1: string, th1), (s2, th2)) => s1 = s2
+  andalso Thm.eq_thm (th1, th2));
 
-val type_rules = NetRules.init eq_rule (Thm.concl_of o #2);
-val set_rules = NetRules.init eq_rule (Thm.major_prem_of o #2);
-
-fun lookup_rule (rs:rules) name = Library.assoc (NetRules.rules rs, name);
+fun lookup_rule (rs: rules) name = Library.assoc (NetRules.rules rs, name);
 
 fun print_rules kind sg rs =
   let val thms = map snd (NetRules.rules rs)
@@ -66,7 +100,9 @@
   val name = "Isar/induction";
   type T = (rules * rules) * (rules * rules);
 
-  val empty = ((type_rules, set_rules), (type_rules, set_rules));
+  val empty =
+    ((init_rules (first_var o #2), init_rules (Thm.major_prem_of o #2)),
+     (init_rules (last_var o #2), init_rules (Thm.major_prem_of o #2)));
   val copy = I;
   val prep_ext = I;
   fun merge (((casesT1, casesS1), (inductT1, inductS1)),
@@ -110,13 +146,19 @@
 
 (* access rules *)
 
-val get_cases = #1 o LocalInduct.get;
-val get_induct = #2 o LocalInduct.get;
+val lookup_casesT = lookup_rule o #1 o #1 o LocalInduct.get;
+val lookup_casesS = lookup_rule o #2 o #1 o LocalInduct.get;
+val lookup_inductT = lookup_rule o #1 o #2 o LocalInduct.get;
+val lookup_inductS = lookup_rule o #2 o #2 o LocalInduct.get;
+
 
-val lookup_casesT = lookup_rule o #1 o get_cases;
-val lookup_casesS = lookup_rule o #2 o get_cases;
-val lookup_inductT = lookup_rule o #1 o get_induct;
-val lookup_inductS = lookup_rule o #2 o get_induct;
+fun find_rules which how ctxt x =
+  map snd (NetRules.may_unify (which (LocalInduct.get ctxt)) (how x));
+
+val find_casesT = find_rules (#1 o #1) encode_type;
+val find_casesS = find_rules (#2 o #1) Thm.concl_of;
+val find_inductT = find_rules (#1 o #2) encode_type;
+val find_inductS = find_rules (#2 o #2) Thm.concl_of;
 
 
 
@@ -164,7 +206,7 @@
 
 fun attrib sign_of add_type add_set = Scan.depend (fn x =>
   let val sg = sign_of x in
-    spec typeN >> (add_type o Sign.certify_tycon sg o Sign.intern_tycon sg) ||
+    spec typeN >> (add_type o Sign.certify_tyname sg o Sign.intern_tycon sg) ||
     spec setN  >> (add_set o Sign.certify_const sg o Sign.intern_const sg)
   end >> pair x);