src/HOL/Tools/record.ML
changeset 35138 ad213c602ec1
parent 35137 405bb7e38057
child 35142 495c623f1e3c
--- a/src/HOL/Tools/record.ML	Mon Feb 15 22:40:03 2010 +0100
+++ b/src/HOL/Tools/record.ML	Mon Feb 15 23:58:24 2010 +0100
@@ -9,6 +9,18 @@
 
 signature BASIC_RECORD =
 sig
+  type record_info =
+   {args: (string * sort) list,
+    parent: (typ list * string) option,
+    fields: (string * typ) list,
+    extension: (string * typ list),
+    ext_induct: thm, ext_inject: thm, ext_surjective: thm, ext_split: thm, ext_def: thm,
+    select_convs: thm list, update_convs: thm list, select_defs: thm list, update_defs: thm list,
+    fold_congs: thm list, unfold_congs: thm list, splits: thm list, defs: thm list,
+    surjective: thm, equality: thm, induct_scheme: thm, induct: thm, cases_scheme: thm,
+    cases: thm, simps: thm list, iffs: thm list}
+  val get_record: theory -> string -> record_info option
+  val the_record: theory -> string -> record_info
   val record_simproc: simproc
   val record_eq_simproc: simproc
   val record_upd_simproc: simproc
@@ -337,24 +349,55 @@
   parent: (typ list * string) option,
   fields: (string * typ) list,
   extension: (string * typ list),
+
+  ext_induct: thm,
+  ext_inject: thm,
+  ext_surjective: thm,
+  ext_split: thm,
+  ext_def: thm,
+
+  select_convs: thm list,
+  update_convs: thm list,
+  select_defs: thm list,
+  update_defs: thm list,
+  fold_congs: thm list,
+  unfold_congs: thm list,
+  splits: thm list,
+  defs: thm list,
+
+  surjective: thm,
+  equality: thm,
+  induct_scheme: thm,
   induct: thm,
-  extdef: thm};
-
-fun make_record_info args parent fields extension induct extdef =
+  cases_scheme: thm,
+  cases: thm,
+
+  simps: thm list,
+  iffs: thm list};
+
+fun make_record_info args parent fields extension
+    ext_induct ext_inject ext_surjective ext_split ext_def
+    select_convs update_convs select_defs update_defs fold_congs unfold_congs splits defs
+    surjective equality induct_scheme induct cases_scheme cases
+    simps iffs : record_info =
  {args = args, parent = parent, fields = fields, extension = extension,
-  induct = induct, extdef = extdef}: record_info;
-
+  ext_induct = ext_induct, ext_inject = ext_inject, ext_surjective = ext_surjective,
+  ext_split = ext_split, ext_def = ext_def, select_convs = select_convs,
+  update_convs = update_convs, select_defs = select_defs, update_defs = update_defs,
+  fold_congs = fold_congs, unfold_congs = unfold_congs, splits = splits, defs = defs,
+  surjective = surjective, equality = equality, induct_scheme = induct_scheme,
+  induct = induct, cases_scheme = cases_scheme, cases = cases, simps = simps, iffs = iffs};
 
 type parent_info =
  {name: string,
   fields: (string * typ) list,
   extension: (string * typ list),
-  induct: thm,
-  extdef: thm};
-
-fun make_parent_info name fields extension induct extdef =
+  induct_scheme: thm,
+  ext_def: thm};
+
+fun make_parent_info name fields extension ext_def induct_scheme : parent_info =
  {name = name, fields = fields, extension = extension,
-  induct = induct, extdef = extdef}: parent_info;
+  ext_def = ext_def, induct_scheme = induct_scheme};
 
 
 (* theory data *)
@@ -456,6 +499,11 @@
 
 val get_record = Symtab.lookup o #records o Records_Data.get;
 
+fun the_record thy name =
+  (case get_record thy name of
+    SOME info => info
+  | NONE => error ("Unknown record type " ^ quote name));
+
 fun put_record name info thy =
   let
     val {records, sel_upd, equalities, extinjects, extsplit, splits, extfields, fieldext} =
@@ -625,7 +673,7 @@
       let
         fun err msg = error (msg ^ " parent record " ^ quote name);
 
-        val {args, parent, fields, extension, induct, extdef} =
+        val {args, parent, fields, extension, induct_scheme, ext_def, ...} =
           (case get_record thy name of SOME info => info | NONE => err "Unknown");
         val _ = if length types <> length args then err "Bad number of arguments for" else ();
 
@@ -641,7 +689,7 @@
         val extension' = apsnd (map subst) extension;
       in
         add_parents thy parent'
-          (make_parent_info name fields' extension' induct extdef :: parents)
+          (make_parent_info name fields' extension' ext_def induct_scheme :: parents)
       end;
 
 
@@ -1783,16 +1831,16 @@
       end;
     val induct = timeit_msg "record extension induct proof:" induct_prf;
 
-    val ([inject', induct', surjective', split_meta'], thm_thy) =
+    val ([induct', inject', surjective', split_meta'], thm_thy) =
       defs_thy
       |> PureThy.add_thms (map (Thm.no_attributes o apfst Binding.name)
-           [("ext_inject", inject),
-            ("ext_induct", induct),
+           [("ext_induct", induct),
+            ("ext_inject", inject),
             ("ext_surjective", surject),
             ("ext_split", split_meta)])
       ||> Code.add_default_eqn ext_def;
 
-  in (thm_thy, extT, induct', inject', split_meta', ext_def) end;
+  in ((extT, induct', inject', surjective', split_meta', ext_def), thm_thy) end;
 
 fun chunks [] [] = []
   | chunks [] xs = [xs]
@@ -1895,7 +1943,7 @@
 
     (* 1st stage: extension_thy *)
 
-    val (extension_thy, extT, ext_induct, ext_inject, ext_split, ext_def) =
+    val ((extT, ext_induct, ext_inject, ext_surjective, ext_split, ext_def), extension_thy) =
       thy
       |> Sign.add_path base_name
       |> extension_definition extN fields alphas_ext zeta moreT more vars;
@@ -1979,7 +2027,7 @@
       [(Binding.suffix_name schemeN b, alphas @ [zeta], rec_schemeT0, Syntax.NoSyn),
         (b, alphas, recT0, Syntax.NoSyn)];
 
-    val ext_defs = ext_def :: map #extdef parents;
+    val ext_defs = ext_def :: map #ext_def parents;
 
     (*Theorems from the iso_tuple intros.
       By unfolding ext_defs from r_rec0 we create a tree of constructor
@@ -2182,13 +2230,13 @@
     val (fold_congs, unfold_congs) =
       timeit_msg "record upd fold/unfold congs:" get_upd_acc_congs;
 
-    val parent_induct = if null parents then [] else [#induct (hd (rev parents))];
+    val parent_induct = Option.map #induct_scheme (try List.last parents);
 
     fun induct_scheme_prf () =
       prove_standard [] induct_scheme_prop
         (fn _ =>
           EVERY
-           [if null parent_induct then all_tac else try_param_tac rN (hd parent_induct) 1,
+           [case parent_induct of NONE => all_tac | SOME ind => try_param_tac rN ind 1,
             try_param_tac rN ext_induct 1,
             asm_simp_tac HOL_basic_ss 1]);
     val induct_scheme = timeit_msg "record induct_scheme proof:" induct_scheme_prf;
@@ -2311,7 +2359,7 @@
 
     val ((([sel_convs', upd_convs', sel_defs', upd_defs',
             fold_congs', unfold_congs',
-          [split_meta', split_object', split_ex'], derived_defs'],
+          splits' as [split_meta', split_object', split_ex'], derived_defs'],
           [surjective', equality']),
           [induct_scheme', induct', cases_scheme', cases']), thms_thy) =
       defs_thy
@@ -2337,12 +2385,22 @@
     val sel_upd_defs = sel_defs' @ upd_defs';
     val iffs = [ext_inject]
     val depth = parent_len + 1;
-    val final_thy =
+
+    val ([simps', iffs'], thms_thy') =
       thms_thy
-      |> (snd oo PureThy.add_thmss)
+      |> PureThy.add_thmss
           [((Binding.name "simps", sel_upd_simps), [Simplifier.simp_add]),
-           ((Binding.name "iffs", iffs), [iff_add])]
-      |> put_record name (make_record_info args parent fields extension induct_scheme' ext_def)
+           ((Binding.name "iffs", iffs), [iff_add])];
+
+    val info =
+      make_record_info args parent fields extension
+        ext_induct ext_inject ext_surjective ext_split ext_def
+        sel_convs' upd_convs' sel_defs' upd_defs' fold_congs' unfold_congs' splits' derived_defs'
+        surjective' equality' induct_scheme' induct' cases_scheme' cases' simps' iffs';
+
+    val final_thy =
+      thms_thy'
+      |> put_record name info
       |> put_sel_upd names full_moreN depth sel_upd_simps sel_upd_defs (fold_congs', unfold_congs')
       |> add_record_equalities extension_id equality'
       |> add_extinjects ext_inject