derive cases/induct rules for ``more'' parts;
authorwenzelm
Tue, 20 Nov 2001 20:55:33 +0100
changeset 12247 9b029789aff6
parent 12246 fdb65a05fca8
child 12248 f059876ef1d3
derive cases/induct rules for ``more'' parts;
src/HOL/Tools/record_package.ML
--- a/src/HOL/Tools/record_package.ML	Tue Nov 20 20:54:12 2001 +0100
+++ b/src/HOL/Tools/record_package.ML	Tue Nov 20 20:55:33 2001 +0100
@@ -65,6 +65,8 @@
 
 (* fundamental syntax *)
 
+fun prune n xs = Library.drop (n, xs);
+
 fun prefix_base s = NameSpace.map_base (fn bname => s ^ bname);
 
 val Trueprop = HOLogic.mk_Trueprop;
@@ -321,21 +323,24 @@
  {args: (string * sort) list,
   parent: (typ list * string) option,
   fields: (string * typ) list,
-  simps: thm list, induct: thm, cases: thm};
+  field_inducts: thm list,
+  field_cases: thm list,
+  simps: thm list};
 
-fun make_record_info args parent fields simps induct cases =
- {args = args, parent = parent, fields = fields, simps = simps,
-  induct = induct, cases = cases}: record_info;
+fun make_record_info args parent fields field_inducts field_cases simps =
+ {args = args, parent = parent, fields = fields, field_inducts = field_inducts,
+  field_cases = field_cases, simps = simps}: record_info;
 
 type parent_info =
  {name: string,
   fields: (string * typ) list,
-  simps: thm list, induct: thm, cases: thm};
+  field_inducts: thm list,
+  field_cases: thm list,
+  simps: thm list};
 
-fun make_parent_info name fields simps induct cases =
- {name = name, fields = fields, simps = simps,
-  induct = induct, cases = cases}: parent_info;
-
+fun make_parent_info name fields field_inducts field_cases simps =
+ {name = name, fields = fields, field_inducts = field_inducts,
+  field_cases = field_cases, simps = simps}: parent_info;
 
 
 (* data kind 'HOL/records' *)
@@ -393,7 +398,7 @@
         [Pretty.str (Sign.cond_extern sg Sign.constK c), Pretty.str " ::",
           Pretty.brk 1, Pretty.quote (prt_typ T)];
 
-      fun pretty_record (name, {args, parent, fields, simps = _, induct = _, cases = _}) =
+      fun pretty_record (name, {args, parent, fields, ...}: record_info) =
         Pretty.block (Pretty.fbreaks (Pretty.block
           [prt_typ (Type (name, map TFree args)), Pretty.str " = "] ::
           pretty_parent parent @ map pretty_field fields));
@@ -452,31 +457,30 @@
 
 (* parent records *)
 
-fun inst_record thy (types, name) =
-  let
-    val sign = Theory.sign_of thy;
-    fun err msg = error (msg ^ " parent record " ^ quote name);
-
-    val {args, parent, fields, simps, induct, cases} =
-      (case get_record thy name of Some info => info | None => err "Unknown");
-    val _ = if length types <> length args then err "Bad number of arguments for" else ();
-
-    fun bad_inst ((x, S), T) =
-      if Sign.of_sort sign (T, S) then None else Some x
-    val bads = mapfilter bad_inst (args ~~ types);
-
-    val inst = map fst args ~~ types;
-    val subst = Term.map_type_tfree (fn (x, _) => the (assoc (inst, x)));
-  in
-    if not (null bads) then
-      err ("Ill-sorted instantiation of " ^ commas bads ^ " in")
-    else (apsome (apfst (map subst)) parent, map (apsnd subst) fields, simps, induct, cases)
-  end;
-
-fun add_parents thy (None, parents) = parents
-  | add_parents thy (Some (types, name), parents) =
-      let val (parent, fields, simps, induct, cases) = inst_record thy (types, name)
-      in add_parents thy (parent, make_parent_info name fields simps induct cases :: parents) end;
+fun add_parents thy None parents = parents
+  | add_parents thy (Some (types, name)) parents =
+      let
+        val sign = Theory.sign_of thy;
+        fun err msg = error (msg ^ " parent record " ^ quote name);
+    
+        val {args, parent, fields, field_inducts, field_cases, simps} =
+          (case get_record thy name of Some info => info | None => err "Unknown");
+        val _ = if length types <> length args then err "Bad number of arguments for" else ();
+    
+        fun bad_inst ((x, S), T) =
+          if Sign.of_sort sign (T, S) then None else Some x
+        val bads = mapfilter bad_inst (args ~~ types);
+    
+        val inst = map fst args ~~ types;
+        val subst = Term.map_type_tfree (fn (x, _) => the (assoc (inst, x)));
+        val parent' = apsome (apfst (map subst)) parent;
+        val fields' = map (apsnd subst) fields;
+      in
+        if not (null bads) then
+          err ("Ill-sorted instantiation of " ^ commas bads ^ " in")
+        else add_parents thy parent'
+          (make_parent_info name fields' field_inducts field_cases simps :: parents)
+      end;
 
 
 
@@ -637,16 +641,16 @@
 fun record_definition (args, bname) parent (parents: parent_info list) bfields thy =
   let
     val sign = Theory.sign_of thy;
+
+    val alphas = map fst args;
+    val name = Sign.full_name sign bname;
     val full = Sign.full_name_path sign bname;
     val base = Sign.base_name;
 
 
     (* basic components *)
 
-    val alphas = map fst args;
-    val name = Sign.full_name sign bname;       (*not made part of record name space!*)
-
-    val previous = if null parents then None else Some (last_elem parents);
+    val ancestry = map (length o flat o map #fields) (Library.prefixes1 parents);
 
     val parent_fields = flat (map #fields parents);
     val parent_names = map fst parent_fields;
@@ -687,12 +691,12 @@
     val parentT = if null parent_fields then [] else [mk_recordT (parent_fields, HOLogic.unitT)];
     val r_parent = if null parent_fields then [] else [Free (rN, hd parentT)];
 
-    val rec_schemeT = mk_recordT (all_fields, moreT);
-    val rec_scheme = mk_record (all_named_vars, more);
-    val recT = mk_recordT (all_fields, HOLogic.unitT);
-    val rec_ = mk_record (all_named_vars, HOLogic.unit);
-    val r_scheme = Free (rN, rec_schemeT);
-    val r = Free (rN, recT);
+    fun rec_schemeT n = mk_recordT (prune n all_fields, moreT);
+    fun rec_scheme n = mk_record (prune n all_named_vars, more);
+    fun recT n = mk_recordT (prune n all_fields, HOLogic.unitT);
+    fun rec_ n = mk_record (prune n all_named_vars, HOLogic.unit);    
+    fun r_scheme n = Free (rN, rec_schemeT n);
+    fun r n = Free (rN, recT n);
 
 
     (* prepare print translation functions *)
@@ -703,44 +707,44 @@
 
     (* prepare declarations *)
 
-    val sel_decls = map (mk_selC rec_schemeT) bfields @
-      [mk_moreC rec_schemeT (moreN, moreT)];
-    val update_decls = map (mk_updateC rec_schemeT) bfields @
-      [mk_more_updateC rec_schemeT (moreN, moreT)];
-    val make_decl = (makeN, parentT ---> types ---> recT);
-    val extend_decl = (extendN, recT --> moreT --> rec_schemeT);
-    val truncate_decl = (truncateN, rec_schemeT --> recT);
+    val sel_decls = map (mk_selC (rec_schemeT 0)) bfields @
+      [mk_moreC (rec_schemeT 0) (moreN, moreT)];
+    val update_decls = map (mk_updateC (rec_schemeT 0)) bfields @
+      [mk_more_updateC (rec_schemeT 0) (moreN, moreT)];
+    val make_decl = (makeN, parentT ---> types ---> recT 0);
+    val extend_decl = (extendN, recT 0 --> moreT --> rec_schemeT 0);
+    val truncate_decl = (truncateN, rec_schemeT 0 --> recT 0);
 
 
     (* prepare definitions *)
 
     (*record (scheme) type abbreviation*)
     val recordT_specs =
-      [(suffix schemeN bname, alphas @ [zeta], rec_schemeT, Syntax.NoSyn),
-        (bname, alphas, recT, Syntax.NoSyn)];
+      [(suffix schemeN bname, alphas @ [zeta], rec_schemeT 0, Syntax.NoSyn),
+        (bname, alphas, recT 0, Syntax.NoSyn)];
 
     (*selectors*)
     fun mk_sel_spec (i, c) =
-      mk_sel r_scheme c :== mk_fst (funpow i mk_snd (parent_more r_scheme));
+      mk_sel (r_scheme 0) c :== mk_fst (funpow i mk_snd (parent_more (r_scheme 0)));
     val sel_specs =
       ListPair.map mk_sel_spec (idxs, names) @
-        [more_part r_scheme :== funpow len mk_snd (parent_more r_scheme)];
+        [more_part (r_scheme 0) :== funpow len mk_snd (parent_more (r_scheme 0))];
 
     (*updates*)
-    val all_sels = mk_named_sels all_names r_scheme;
+    val all_sels = mk_named_sels all_names (r_scheme 0);
     fun mk_upd_spec (i, (c, x)) =
-      mk_update r_scheme (c, x) :==
-        mk_record (nth_update (c, x) (parent_len + i, all_sels), more_part r_scheme)
+      mk_update (r_scheme 0) (c, x) :==
+        mk_record (nth_update (c, x) (parent_len + i, all_sels), more_part (r_scheme 0))
     val update_specs =
       ListPair.map mk_upd_spec (idxs, named_vars) @
-        [more_part_update r_scheme more :== mk_record (all_sels, more)];
+        [more_part_update (r_scheme 0) more :== mk_record (all_sels, more)];
 
     (*derived operations*)
-    val make_spec = Const (full makeN, parentT ---> types ---> recT) $$ r_parent $$ vars :==
+    val make_spec = Const (full makeN, parentT ---> types ---> recT 0) $$ r_parent $$ vars :==
       mk_record (flat (map (mk_named_sels parent_names) r_parent) @ named_vars, HOLogic.unit);
-    val extend_spec = Const (full extendN, recT --> moreT --> rec_schemeT) $ r $ more :==
-      mk_record (mk_named_sels all_names r, more);
-    val truncate_spec = Const (full truncateN, rec_schemeT --> recT) $ r_scheme :==
+    val extend_spec = Const (full extendN, recT 0 --> moreT --> rec_schemeT 0) $ r 0 $ more :==
+      mk_record (mk_named_sels all_names (r 0), more);
+    val truncate_spec = Const (full truncateN, rec_schemeT 0 --> recT 0) $ r_scheme 0 :==
       mk_record (all_sels, HOLogic.unit);
 
 
@@ -748,45 +752,50 @@
 
     (*selectors*)
     val sel_props =
-      map (fn (c, x) => mk_sel rec_scheme c === x) named_vars @
-        [more_part rec_scheme === more];
+      map (fn (c, x) => mk_sel (rec_scheme 0) c === x) named_vars @
+        [more_part (rec_scheme 0) === more];
 
     (*updates*)
     fun mk_upd_prop (i, (c, T)) =
       let val x' = Free (variant all_xs (base c ^ "'"), T) in
-        mk_update rec_scheme (c, x') ===
+        mk_update (rec_scheme 0) (c, x') ===
           mk_record (nth_update (c, x') (parent_len + i, all_named_vars), more)
       end;
     val update_props =
       ListPair.map mk_upd_prop (idxs, fields) @
         let val more' = Free (variant all_xs (moreN ^ "'"), moreT)
-        in [more_part_update rec_scheme more' === mk_record (all_named_vars, more')] end;
+        in [more_part_update (rec_scheme 0) more' === mk_record (all_named_vars, more')] end;
 
     (*equality*)
     fun mk_sel_eq (t, T) =
-      let val t' = Term.abstract_over (r_scheme, t)
+      let val t' = Term.abstract_over (r_scheme 0, t)
       in Trueprop (HOLogic.eq_const T $ Term.incr_boundvars 1 t' $ t') end;
-    val sel_eqs =
-      map2 mk_sel_eq (map (mk_sel r_scheme) all_names @ [more_part r_scheme], all_types @ [moreT]);
+    val sel_eqs = map2 mk_sel_eq
+      (map (mk_sel (r_scheme 0)) all_names @ [more_part (r_scheme 0)], all_types @ [moreT]);
     val equality_prop =
-      Term.all rec_schemeT $ (Abs ("r", rec_schemeT,
-        Term.all rec_schemeT $ (Abs ("r'", rec_schemeT,
+      Term.all (rec_schemeT 0) $ (Abs ("r", rec_schemeT 0,
+        Term.all (rec_schemeT 0) $ (Abs ("r'", rec_schemeT 0,
           Logic.list_implies (sel_eqs,
-            Trueprop (HOLogic.eq_const rec_schemeT $ Bound 1 $ Bound 0))))));
+            Trueprop (HOLogic.eq_const (rec_schemeT 0) $ Bound 1 $ Bound 0))))));
 
     (*induct*)
-    val P = Free ("P", rec_schemeT --> HOLogic.boolT);
-    val P' = Free ("P", recT --> HOLogic.boolT);
-    val induct_scheme_assm = All (all_xs_more ~~ all_types_more) (Trueprop (P $ rec_scheme));
-    val induct_scheme_concl = Trueprop (P $ r_scheme);
-    val induct_assm = All (all_xs ~~ all_types) (Trueprop (P' $ rec_));
-    val induct_concl = Trueprop (P' $ r);
+    fun induct_scheme_prop n =
+      let val P = Free ("P", rec_schemeT n --> HOLogic.boolT) in
+        (All (prune n all_xs_more ~~ prune n all_types_more)
+          (Trueprop (P $ rec_scheme n)), Trueprop (P $ r_scheme n))
+      end;
+    fun induct_prop n =
+      let val P = Free ("P", recT n --> HOLogic.boolT) in
+        (All (prune n all_xs ~~ prune n all_types) (Trueprop (P $ rec_ n)), Trueprop (P $ r n))
+      end;
 
     (*cases*)
     val C = Trueprop (Free (variant all_xs_more "C", HOLogic.boolT));
-    val cases_scheme_prop =
-      All (all_xs_more ~~ all_types_more) ((r_scheme === rec_scheme) ==> C) ==> C;
-    val cases_prop = All (all_xs ~~ all_types) ((r === rec_) ==> C) ==> C;
+    fun cases_scheme_prop n =
+      All (prune n all_xs_more ~~ prune n all_types_more)
+        ((r_scheme n === rec_scheme n) ==> C) ==> C;
+    fun cases_prop n =
+      All (prune n all_xs ~~ prune n all_types) ((r n === rec_ n) ==> C) ==> C;
 
 
     (* 1st stage: fields_thy *)
@@ -796,6 +805,9 @@
       |> Theory.add_path bname
       |> field_definitions fields names xs alphas zeta moreT more vars named_vars;
 
+    val all_field_inducts = flat (map #field_inducts parents) @ field_inducts;
+    val all_field_cases = flat (map #field_cases parents) @ field_cases;
+
     val named_splits = map2 (fn (c, th) => (suffix field_typeN c, th)) (names, field_splits);
 
 
@@ -815,12 +827,10 @@
       |>>> (PureThy.add_defs_i false o map Thm.no_attributes)
         [make_spec, extend_spec, truncate_spec];
 
-    val defs_sg = Theory.sign_of defs_thy;
-
 
     (* 3rd stage: thms_thy *)
 
-    val prove_standard = Tactic.prove_standard defs_sg;
+    val prove_standard = Tactic.prove_standard (Theory.sign_of defs_thy);
     fun prove_simp simps =
       let val tac = simp_all_tac HOL_basic_ss simps
       in fn prop => prove_standard [] [] prop (K tac) end;
@@ -829,29 +839,27 @@
     val sel_convs = map (prove_simp (parent_simps @ sel_defs @ field_simps)) sel_props;
     val update_convs = map (prove_simp (parent_simps @ update_defs @ sel_convs)) update_props;
 
-    val induct_scheme = prove_standard [] [induct_scheme_assm] induct_scheme_concl (fn prems =>
-        (case previous of Some {induct, ...} => res_inst_tac [(rN, rN)] induct 1
-        | None => all_tac)
-        THEN EVERY (map (fn rule => try_param_tac "p" rN rule 1) field_inducts)
-        THEN resolve_tac prems 1);
+    fun induct_scheme n =
+      let val (assm, concl) = induct_scheme_prop n in
+        prove_standard [] [assm] concl (fn prems =>
+          EVERY (map (fn rule => try_param_tac "p" rN rule 1) (prune n all_field_inducts))
+          THEN resolve_tac prems 1)
+      end;
 
-    val induct = prove_standard [] [induct_assm] induct_concl (fn prems =>
-        res_inst_tac [(rN, rN)] induct_scheme 1
-        THEN try_param_tac "x" "more" unit_induct 1
-        THEN resolve_tac prems 1);
-
-    val cases_scheme = prove_standard [] [] cases_scheme_prop (fn _ =>
-        (case previous of Some {cases, ...} => try_param_tac rN rN cases 1
-        | None => all_tac)
-        THEN EVERY (map (fn rule => try_param_tac "p" rN rule 1) field_cases)
+    fun cases_scheme n =
+      prove_standard [] [] (cases_scheme_prop n) (fn _ =>
+        EVERY (map (fn rule => try_param_tac "p" rN rule 1) (prune n all_field_cases))
         THEN simp_all_tac HOL_basic_ss []);
 
-    val cases = prove_standard [] [] cases_prop (fn _ =>
-        res_inst_tac [(rN, rN)] cases_scheme 1
-        THEN simp_all_tac HOL_basic_ss [unit_all_eq1]);
+    val induct_scheme0 = induct_scheme 0;
+    val cases_scheme0 = cases_scheme 0;
+    val more_induct_scheme = map induct_scheme ancestry;
+    val more_cases_scheme = map cases_scheme ancestry;
 
-    val (thms_thy, ([sel_convs', update_convs', sel_defs', update_defs', _],
-        [induct_scheme', induct', cases_scheme', cases'])) =
+    val case_names = RuleCases.case_names [fieldsN];
+
+    val (thms_thy, (([sel_convs', update_convs', sel_defs', update_defs', _],
+        [induct_scheme', cases_scheme']), [more_induct_scheme', more_cases_scheme'])) =
       defs_thy
       |> (PureThy.add_thmss o map Thm.no_attributes)
        [("select_convs", sel_convs),
@@ -860,40 +868,68 @@
         ("update_defs", update_defs),
         ("derived_defs", derived_defs)]
       |>>> PureThy.add_thms
-       [(("induct_scheme", induct_scheme), [RuleCases.case_names [fieldsN],
-          InductAttrib.induct_type_global (suffix schemeN name)]),
-        (("induct", induct), [RuleCases.case_names [fieldsN],
-          InductAttrib.induct_type_global name]),
-        (("cases_scheme", cases_scheme), [RuleCases.case_names [fieldsN],
-          InductAttrib.cases_type_global (suffix schemeN name)]),
-        (("cases", cases), [RuleCases.case_names [fieldsN],
-          InductAttrib.cases_type_global name])];
+       [(("induct_scheme", induct_scheme0),
+         [case_names, InductAttrib.induct_type_global (suffix schemeN name)]),
+        (("cases_scheme", cases_scheme0),
+         [case_names, InductAttrib.cases_type_global (suffix schemeN name)])]
+      |>>> (PureThy.add_thmss o map Thm.no_attributes)
+        [("more_induct_scheme", more_induct_scheme),
+         ("more_cases_scheme", more_cases_scheme)];
+
+
+    (* 4th stage: more_thms_thy *)
+
+    val prove_standard = Tactic.prove_standard (Theory.sign_of thms_thy);
 
-    val equality = Tactic.prove_standard (Theory.sign_of thms_thy) [] [] equality_prop (fn _ =>
+    fun induct (n, scheme) =
+      let val (assm, concl) = induct_prop n in
+        prove_standard [] [assm] concl (fn prems =>
+          res_inst_tac [(rN, rN)] scheme 1
+          THEN try_param_tac "x" "more" unit_induct 1
+          THEN resolve_tac prems 1)
+      end;
+
+    fun cases (n, scheme) =
+      prove_standard [] [] (cases_prop n) (fn _ =>
+        res_inst_tac [(rN, rN)] scheme 1
+        THEN simp_all_tac HOL_basic_ss [unit_all_eq1]);
+
+    val induct0 = induct (0, induct_scheme');
+    val cases0 = cases (0, cases_scheme');
+    val more_induct = map induct (ancestry ~~ more_induct_scheme');
+    val more_cases = map cases (ancestry ~~ more_cases_scheme');
+
+    val equality = prove_standard [] [] equality_prop (fn _ =>
       fn st => let val [r, r'] = map #1 (rev (Tactic.innermost_params 1 st)) in
         st |> (res_inst_tac [(rN, r)] cases_scheme' 1
         THEN res_inst_tac [(rN, r')] cases_scheme' 1
         THEN simp_all_tac HOL_basic_ss (parent_simps @ sel_convs))
       end);
 
-    val (thms_thy', [equality']) =
-      thms_thy |> PureThy.add_thms [(("equality", equality), [Classical.xtra_intro_global])];
+    val (more_thms_thy, [_, _, equality']) =
+      thms_thy |> PureThy.add_thms
+       [(("induct", induct0), [case_names, InductAttrib.induct_type_global name]),
+        (("cases", cases0), [case_names, InductAttrib.cases_type_global name]),
+        (("equality", equality), [Classical.xtra_intro_global])]
+      |>> (#1 oo (PureThy.add_thmss o map Thm.no_attributes))
+        [("more_induct", more_induct),
+         ("more_cases", more_cases)];
 
     val simps = sel_convs' @ update_convs' @ [equality'];
     val iffs = field_injects;
 
-    val thms_thy'' =
-      thms_thy' |> (#1 oo PureThy.add_thmss)
+    val more_thms_thy' =
+      more_thms_thy |> (#1 oo PureThy.add_thmss)
         [(("simps", simps), [Simplifier.simp_add_global]),
          (("iffs", iffs), [iff_add_global])];
 
 
-    (* 4th stage: final_thy *)
+    (* 5th stage: final_thy *)
 
     val final_thy =
-      thms_thy''
-      |> put_record name (make_record_info args parent fields (field_simps @ simps)
-          induct_scheme' cases_scheme')
+      more_thms_thy'
+      |> put_record name (make_record_info args parent fields field_inducts field_cases
+        (field_simps @ simps))
       |> put_sel_upd (names @ [full_moreN]) (field_simps @ sel_defs' @ update_defs')
       |> Theory.parent_path;
 
@@ -941,7 +977,7 @@
 
     val parent = apsome (apfst (map prep_inst) o prep_raw_parent sign) raw_parent
       handle ERROR => error ("The error(s) above in parent record specification");
-    val parents = add_parents thy (parent, []);
+    val parents = add_parents thy parent [];
 
     val init_env =
       (case parent of