src/Tools/induct.ML
changeset 24861 cc669ca5f382
parent 24832 64cd13299d39
child 24865 62c48c4bee48
--- a/src/Tools/induct.ML	Fri Oct 05 22:00:13 2007 +0200
+++ b/src/Tools/induct.ML	Fri Oct 05 22:00:15 2007 +0200
@@ -18,32 +18,33 @@
   (*rule declarations*)
   val vars_of: term -> term list
   val dest_rules: Proof.context ->
-    {type_cases: (string * thm) list, set_cases: (string * thm) list,
-      type_induct: (string * thm) list, set_induct: (string * thm) list,
-      type_coinduct: (string * thm) list, set_coinduct: (string * thm) list}
+    {type_cases: (string * thm) list, pred_cases: (string * thm) list,
+      type_induct: (string * thm) list, pred_induct: (string * thm) list,
+      type_coinduct: (string * thm) list, pred_coinduct: (string * thm) list}
   val print_rules: Proof.context -> unit
   val lookup_casesT: Proof.context -> string -> thm option
-  val lookup_casesS: Proof.context -> string -> thm option
+  val lookup_casesP: Proof.context -> string -> thm option
   val lookup_inductT: Proof.context -> string -> thm option
-  val lookup_inductS: Proof.context -> string -> thm option
+  val lookup_inductP: Proof.context -> string -> thm option
   val lookup_coinductT: Proof.context -> string -> thm option
-  val lookup_coinductS: Proof.context -> string -> thm option
+  val lookup_coinductP: Proof.context -> string -> thm option
   val find_casesT: Proof.context -> typ -> thm list
-  val find_casesS: Proof.context -> term -> thm list
+  val find_casesP: Proof.context -> term -> thm list
   val find_inductT: Proof.context -> typ -> thm list
-  val find_inductS: Proof.context -> term -> thm list
+  val find_inductP: Proof.context -> term -> thm list
   val find_coinductT: Proof.context -> typ -> thm list
-  val find_coinductS: Proof.context -> term -> thm list
+  val find_coinductP: Proof.context -> term -> thm list
   val cases_type: string -> attribute
-  val cases_set: string -> attribute
+  val cases_pred: string -> attribute
   val induct_type: string -> attribute
-  val induct_set: string -> attribute
+  val induct_pred: string -> attribute
   val coinduct_type: string -> attribute
-  val coinduct_set: string -> attribute
+  val coinduct_pred: string -> attribute
   val casesN: string
   val inductN: string
   val coinductN: string
   val typeN: string
+  val predN: string
   val setN: string
   (*proof methods*)
   val fix_tac: Proof.context -> int -> (string * typ) list -> int -> tactic
@@ -130,33 +131,33 @@
      (init_rules (right_var_concl o #2), init_rules (Thm.major_prem_of o #2)),
      (init_rules (left_var_concl o #2), init_rules (Thm.concl_of o #2)));
   val extend = I;
-  fun merge _ (((casesT1, casesS1), (inductT1, inductS1), (coinductT1, coinductS1)),
-      ((casesT2, casesS2), (inductT2, inductS2), (coinductT2, coinductS2))) =
-    ((NetRules.merge (casesT1, casesT2), NetRules.merge (casesS1, casesS2)),
-      (NetRules.merge (inductT1, inductT2), NetRules.merge (inductS1, inductS2)),
-      (NetRules.merge (coinductT1, coinductT2), NetRules.merge (coinductS1, coinductS2)));
+  fun merge _ (((casesT1, casesP1), (inductT1, inductP1), (coinductT1, coinductP1)),
+      ((casesT2, casesP2), (inductT2, inductP2), (coinductT2, coinductP2))) =
+    ((NetRules.merge (casesT1, casesT2), NetRules.merge (casesP1, casesP2)),
+      (NetRules.merge (inductT1, inductT2), NetRules.merge (inductP1, inductP2)),
+      (NetRules.merge (coinductT1, coinductT2), NetRules.merge (coinductP1, coinductP2)));
 );
 
 val get_local = Induct.get o Context.Proof;
 
 fun dest_rules ctxt =
-  let val ((casesT, casesS), (inductT, inductS), (coinductT, coinductS)) = get_local ctxt in
+  let val ((casesT, casesP), (inductT, inductP), (coinductT, coinductP)) = get_local ctxt in
     {type_cases = NetRules.rules casesT,
-     set_cases = NetRules.rules casesS,
+     pred_cases = NetRules.rules casesP,
      type_induct = NetRules.rules inductT,
-     set_induct = NetRules.rules inductS,
+     pred_induct = NetRules.rules inductP,
      type_coinduct = NetRules.rules coinductT,
-     set_coinduct = NetRules.rules coinductS}
+     pred_coinduct = NetRules.rules coinductP}
   end;
 
 fun print_rules ctxt =
-  let val ((casesT, casesS), (inductT, inductS), (coinductT, coinductS)) = get_local ctxt in
+  let val ((casesT, casesP), (inductT, inductP), (coinductT, coinductP)) = get_local ctxt in
    [pretty_rules ctxt "coinduct type:" coinductT,
-    pretty_rules ctxt "coinduct set:" coinductS,
+    pretty_rules ctxt "coinduct pred:" coinductP,
     pretty_rules ctxt "induct type:" inductT,
-    pretty_rules ctxt "induct set:" inductS,
+    pretty_rules ctxt "induct pred:" inductP,
     pretty_rules ctxt "cases type:" casesT,
-    pretty_rules ctxt "cases set:" casesS]
+    pretty_rules ctxt "cases pred:" casesP]
     |> Pretty.chunks |> Pretty.writeln
   end;
 
@@ -169,22 +170,22 @@
 (* access rules *)
 
 val lookup_casesT = lookup_rule o #1 o #1 o get_local;
-val lookup_casesS = lookup_rule o #2 o #1 o get_local;
+val lookup_casesP = lookup_rule o #2 o #1 o get_local;
 val lookup_inductT = lookup_rule o #1 o #2 o get_local;
-val lookup_inductS = lookup_rule o #2 o #2 o get_local;
+val lookup_inductP = lookup_rule o #2 o #2 o get_local;
 val lookup_coinductT = lookup_rule o #1 o #3 o get_local;
-val lookup_coinductS = lookup_rule o #2 o #3 o get_local;
+val lookup_coinductP = lookup_rule o #2 o #3 o get_local;
 
 
 fun find_rules which how ctxt x =
   map snd (NetRules.retrieve (which (get_local ctxt)) (how x));
 
 val find_casesT = find_rules (#1 o #1) encode_type;
-val find_casesS = find_rules (#2 o #1) I;
+val find_casesP = find_rules (#2 o #1) I;
 val find_inductT = find_rules (#1 o #2) encode_type;
-val find_inductS = find_rules (#2 o #2) I;
+val find_inductP = find_rules (#2 o #2) I;
 val find_coinductT = find_rules (#1 o #3) encode_type;
-val find_coinductS = find_rules (#2 o #3) I;
+val find_coinductP = find_rules (#2 o #3) I;
 
 
 
@@ -200,11 +201,11 @@
 fun map3 f (x, y, z) = (x, y, f z);
 
 fun add_casesT rule x = map1 (apfst (NetRules.insert rule)) x;
-fun add_casesS rule x = map1 (apsnd (NetRules.insert rule)) x;
+fun add_casesP rule x = map1 (apsnd (NetRules.insert rule)) x;
 fun add_inductT rule x = map2 (apfst (NetRules.insert rule)) x;
-fun add_inductS rule x = map2 (apsnd (NetRules.insert rule)) x;
+fun add_inductP rule x = map2 (apsnd (NetRules.insert rule)) x;
 fun add_coinductT rule x = map3 (apfst (NetRules.insert rule)) x;
-fun add_coinductS rule x = map3 (apsnd (NetRules.insert rule)) x;
+fun add_coinductP rule x = map3 (apsnd (NetRules.insert rule)) x;
 
 fun consumes0 x = RuleCases.consumes_default 0 x;
 fun consumes1 x = RuleCases.consumes_default 1 x;
@@ -212,11 +213,11 @@
 in
 
 val cases_type = mk_att add_casesT consumes0;
-val cases_set = mk_att add_casesS consumes1;
+val cases_pred = mk_att add_casesP consumes1;
 val induct_type = mk_att add_inductT consumes0;
-val induct_set = mk_att add_inductS consumes1;
+val induct_pred = mk_att add_inductP consumes1;
 val coinduct_type = mk_att add_coinductT consumes0;
-val coinduct_set = mk_att add_coinductS consumes1;
+val coinduct_pred = mk_att add_coinductP consumes1;
 
 end;
 
@@ -229,6 +230,7 @@
 val coinductN = "coinduct";
 
 val typeN = "type";
+val predN = "pred";
 val setN = "set";
 
 local
@@ -237,19 +239,21 @@
   Scan.lift (Args.$$$ k -- Args.colon) |-- arg ||
   Scan.lift (Args.$$$ k) >> K "";
 
-fun attrib add_type add_set =
-  Attrib.syntax (spec typeN Args.tyname >> add_type || spec setN Args.const >> add_set);
+fun attrib add_type add_pred = Attrib.syntax
+ (spec typeN Args.tyname >> add_type ||
+  spec predN Args.const >> add_pred ||
+  spec setN Args.const >> add_pred);
 
-val cases_att = attrib cases_type cases_set;
-val induct_att = attrib induct_type induct_set;
-val coinduct_att = attrib coinduct_type coinduct_set;
+val cases_att = attrib cases_type cases_pred;
+val induct_att = attrib induct_type induct_pred;
+val coinduct_att = attrib coinduct_type coinduct_pred;
 
 in
 
 val attrib_setup = Attrib.add_attributes
- [(casesN, cases_att, "declaration of cases rule for type or set"),
-  (inductN, induct_att, "declaration of induction rule for type or set"),
-  (coinductN, coinduct_att, "declaration of coinduction rule for type or set")];
+ [(casesN, cases_att, "declaration of cases rule for type or predicate/set"),
+  (inductN, induct_att, "declaration of induction rule for type or predicate/set"),
+  (coinductN, coinduct_att, "declaration of coinduction rule for type or predicate/set")];
 
 end;
 
@@ -314,7 +318,7 @@
 (*
   rule selection scheme:
           cases         - default case split
-    `x:A` cases ...     - set cases
+    `A t` cases ...     - predicate/set cases
           cases t       - type cases
     ...   cases ... r   - explicit rule
 *)
@@ -324,8 +328,8 @@
 fun get_casesT ctxt ((SOME t :: _) :: _) = find_casesT ctxt (Term.fastype_of t)
   | get_casesT _ _ = [];
 
-fun get_casesS ctxt (fact :: _) = find_casesS ctxt (Thm.concl_of fact)
-  | get_casesS _ _ = [];
+fun get_casesP ctxt (fact :: _) = find_casesP ctxt (Thm.concl_of fact)
+  | get_casesP _ _ = [];
 
 in
 
@@ -345,7 +349,7 @@
       (case opt_rule of
         SOME r => Seq.single (inst_rule r)
       | NONE =>
-          (get_casesS ctxt facts @ get_casesT ctxt insts @ [Data.cases_default])
+          (get_casesP ctxt facts @ get_casesT ctxt insts @ [Data.cases_default])
           |> tap (trace_rules ctxt casesN)
           |> Seq.of_list |> Seq.maps (Seq.try inst_rule));
   in
@@ -551,7 +555,7 @@
 
 (*
   rule selection scheme:
-    `x:A` induct ...     - set induction
+    `A x` induct ...     - predicate/set induction
           induct x       - type induction
     ...   induct ... r   - explicit rule
 *)
@@ -563,8 +567,8 @@
     |> map (find_inductT ctxt o Term.fastype_of)) [[]]
   |> filter_out (forall PureThy.is_internal);
 
-fun get_inductS ctxt (fact :: _) = map single (find_inductS ctxt (Thm.concl_of fact))
-  | get_inductS _ _ = [];
+fun get_inductP ctxt (fact :: _) = map single (find_inductP ctxt (Thm.concl_of fact))
+  | get_inductP _ _ = [];
 
 in
 
@@ -589,7 +593,7 @@
       (case opt_rule of
         SOME rs => Seq.single (inst_rule (RuleCases.strict_mutual_rule ctxt rs))
       | NONE =>
-          (get_inductS ctxt facts @
+          (get_inductP ctxt facts @
             map (special_rename_params defs_ctxt insts) (get_inductT ctxt insts))
           |> map_filter (RuleCases.mutual_rule ctxt)
           |> tap (trace_rules ctxt inductN o map #2)
@@ -627,7 +631,7 @@
 
 (*
   rule selection scheme:
-    goal "x:A" coinduct ...   - set coinduction
+    goal "A x" coinduct ...   - predicate/set coinduction
                coinduct x     - type coinduction
                coinduct ... r - explicit rule
 *)
@@ -637,7 +641,10 @@
 fun get_coinductT ctxt (SOME t :: _) = find_coinductT ctxt (Term.fastype_of t)
   | get_coinductT _ _ = [];
 
-fun get_coinductS ctxt goal = find_coinductS ctxt (Logic.strip_assums_concl goal);
+fun get_coinductP ctxt goal = find_coinductP ctxt (Logic.strip_assums_concl goal);
+
+fun main_prop_of th =
+  if RuleCases.get_consumes th > 0 then Thm.major_prem_of th else Thm.concl_of th;
 
 in
 
@@ -649,14 +656,14 @@
 
     fun inst_rule r =
       if null inst then `RuleCases.get r
-      else Drule.cterm_instantiate (prep_inst thy align_left I (Thm.concl_of r, inst)) r
+      else Drule.cterm_instantiate (prep_inst thy align_right I (main_prop_of r, inst)) r
         |> pair (RuleCases.get r);
 
     fun ruleq goal =
       (case opt_rule of
         SOME r => Seq.single (inst_rule r)
       | NONE =>
-          (get_coinductS ctxt goal @ get_coinductT ctxt inst)
+          (get_coinductP ctxt goal @ get_coinductT ctxt inst)
           |> tap (trace_rules ctxt coinductN)
           |> Seq.of_list |> Seq.maps (Seq.try inst_rule));
   in
@@ -693,14 +700,15 @@
       (case get (Context.proof_of context) name of SOME x => x
       | NONE => error ("No rule for " ^ k ^ " " ^ quote name))))));
 
-fun rule get_type get_set =
+fun rule get_type get_pred =
   named_rule typeN Args.tyname get_type ||
-  named_rule setN Args.const get_set ||
+  named_rule predN Args.const get_pred ||
+  named_rule setN Args.const get_pred ||
   Scan.lift (Args.$$$ ruleN -- Args.colon) |-- Attrib.thms;
 
-val cases_rule = rule lookup_casesT lookup_casesS >> single_rule;
-val induct_rule = rule lookup_inductT lookup_inductS;
-val coinduct_rule = rule lookup_coinductT lookup_coinductS >> single_rule;
+val cases_rule = rule lookup_casesT lookup_casesP >> single_rule;
+val induct_rule = rule lookup_inductT lookup_inductP;
+val coinduct_rule = rule lookup_coinductT lookup_coinductP >> single_rule;
 
 val inst = Scan.lift (Args.$$$ "_") >> K NONE || Args.term >> SOME;
 
@@ -714,7 +722,7 @@
 
 fun unless_more_args scan = Scan.unless (Scan.lift
   ((Args.$$$ arbitraryN || Args.$$$ takingN || Args.$$$ typeN ||
-    Args.$$$ setN || Args.$$$ ruleN) -- Args.colon)) scan;
+    Args.$$$ predN || Args.$$$ setN || Args.$$$ ruleN) -- Args.colon)) scan;
 
 val arbitrary = Scan.optional (Scan.lift (Args.$$$ arbitraryN -- Args.colon) |--
   Args.and_list1 (Scan.repeat (unless_more_args free))) [];
@@ -755,8 +763,8 @@
 val setup =
   attrib_setup #>
   Method.add_methods
-    [(casesN, cases_meth, "case analysis on types or sets"),
-     (inductN, induct_meth, "induction on types or sets"),
-     (coinductN, coinduct_meth, "coinduction on types or sets")];
+    [(casesN, cases_meth, "case analysis on types or predicates/sets"),
+     (inductN, induct_meth, "induction on types or predicates/sets"),
+     (coinductN, coinduct_meth, "coinduction on types or predicates/sets")];
 
 end;