field types: datatype;
authorwenzelm
Tue, 20 Oct 1998 16:38:37 +0200
changeset 5698 2b5d9bdec5af
parent 5697 e816c4f1a396
child 5699 5b9a359e083c
field types: datatype; record_split_tac; quiet_mode; renamed fst/snd to val/more; structure BasicRecordPackage;
src/HOL/Tools/record_package.ML
--- a/src/HOL/Tools/record_package.ML	Tue Oct 20 16:37:02 1998 +0200
+++ b/src/HOL/Tools/record_package.ML	Tue Oct 20 16:38:37 1998 +0200
@@ -3,15 +3,18 @@
     Author:     Wolfgang Naraschewski and Markus Wenzel, TU Muenchen
 
 Extensible records with structural subtyping in HOL.
+*)
 
-TODO:
-  - field types: datatype;
-  - operations and theorems: split, split_all/ex, ...;
-  - field constructor: more specific type for snd component (x_more etc. classes);
-*)
+signature BASIC_RECORD_PACKAGE =
+sig
+  val record_split_tac: int -> tactic
+  val record_split_wrapper: string * wrapper
+end;
 
 signature RECORD_PACKAGE =
 sig
+  include BASIC_RECORD_PACKAGE
+  val quiet_mode: bool ref
   val moreS: sort
   val mk_fieldT: (string * typ) * typ -> typ
   val dest_fieldT: typ -> (string * typ) * typ
@@ -37,9 +40,15 @@
 
 (*** utilities ***)
 
+(* messages *)
+
+val quiet_mode = ref false;
+fun message s = if ! quiet_mode then () else writeln s;
+
+
 (* definitions and equations *)
 
-infix 0 :== === ;
+infix 0 :== ===;
 
 val (op :==) = Logic.mk_defpair;
 val (op ===) = HOLogic.mk_Trueprop o HOLogic.mk_eq;
@@ -49,7 +58,7 @@
 
 (* proof by simplification *)
 
-fun prove_simp thy simps =
+fun prove_simp thy tacs simps =
   let
     val sign = Theory.sign_of thy;
     val ss = Simplifier.addsimps (HOL_basic_ss, map Attribute.thm_of simps);
@@ -57,7 +66,7 @@
     fun prove goal =
       Attribute.tthm_of
         (Goals.prove_goalw_cterm [] (Thm.cterm_of sign goal)
-          (K [ALLGOALS (Simplifier.simp_tac ss)])
+          (K (tacs @ [ALLGOALS (Simplifier.simp_tac ss)]))
         handle ERROR => error ("The error(s) above occurred while trying to prove "
           ^ quote (Sign.string_of_term sign goal)));
   in prove end;
@@ -71,13 +80,30 @@
 val moreN = "more";
 val schemeN = "_scheme";
 val fieldN = "_field";
+val raw_fieldN = "_raw_field";
 val field_typeN = "_field_type";
-val fstN = "_fst";
-val sndN = "_snd";
+val fstN = "_val";
+val sndN = "_more";
 val updateN = "_update";
 val makeN = "make";
 val make_schemeN = "make_scheme";
 
+(*see datatype package*)
+val caseN = "_case";
+
+
+
+(** generic operations **)
+
+fun fst_fn T U = Abs ("x", T, Abs ("y", U, Bound 1));
+fun snd_fn T U = Abs ("x", T, Abs ("y", U, Bound 0));
+
+fun mk_prod_case name f p =
+  let
+    val fT as Type ("fun", [A, Type ("fun", [B, C])]) = fastype_of f;
+    val pT = fastype_of p;
+  in Const (suffix caseN name, fT --> pT --> C) $ f $ p end;
+
 
 
 (** tuple operations **)
@@ -102,9 +128,12 @@
 
 fun mk_fieldC U (c, T) = (suffix fieldN c, T --> U --> mk_fieldT ((c, T), U));
 
-fun mk_field ((c, t), u) =
+fun gen_mk_field sfx ((c, t), u) =
   let val T = fastype_of t and U = fastype_of u
-  in Const (suffix fieldN c, [T, U] ---> mk_fieldT ((c, T), U)) $ t $ u end;
+  in Const (suffix sfx c, [T, U] ---> mk_fieldT ((c, T), U)) $ t $ u end;
+
+val mk_field = gen_mk_field fieldN;
+val mk_raw_field = gen_mk_field raw_fieldN;
 
 
 (* destructors *)
@@ -289,13 +318,17 @@
 structure RecordsArgs =
 struct
   val name = "HOL/records";
-  type T = record_info Symtab.table;
+  type T =
+    record_info Symtab.table *				(*records*)
+      (thm Symtab.table * Simplifier.simpset);		(*field split rules*)
 
-  val empty = Symtab.empty;
+  val empty = (Symtab.empty, (Symtab.empty, HOL_basic_ss));
   val prep_ext = I;
-  val merge: T * T -> T = Symtab.merge (K true);
+  fun merge ((recs1, (sps1, ss1)), (recs2, (sps2, ss2))) =
+    (Symtab.merge (K true) (recs1, recs2),
+      (Symtab.merge (K true) (sps1, sps2), Simplifier.merge_ss (ss1, ss2)));
 
-  fun print sg tab =
+  fun print sg (recs, _) =
     let
       val prt_typ = Sign.pretty_typ sg;
       val ext_const = Sign.cond_extern sg Sign.constK;
@@ -311,7 +344,7 @@
         (Pretty.block [prt_typ (Type (name, map TFree args)), Pretty.str " = "] ::
           pretty_parent parent @ map pretty_field fields));
     in
-      seq (Pretty.writeln o pretty_record) (Symtab.dest tab)
+      seq (Pretty.writeln o pretty_record) (Symtab.dest recs)
     end;
 end;
 
@@ -319,12 +352,19 @@
 val print_records = RecordsData.print;
 
 
-(* get and put records *)
+(* get and put data *)
 
-fun get_record thy name = Symtab.lookup (RecordsData.get thy, name);
+fun get_record thy name = Symtab.lookup (#1 (RecordsData.get thy), name);
 
 fun put_record name info thy =
-  RecordsData.put (Symtab.update ((name, info), RecordsData.get thy)) thy;
+  let val (tab, sp) = RecordsData.get thy
+  in RecordsData.put (Symtab.update ((name, info), tab), sp) thy end;
+
+fun add_record_splits splits thy =
+  let
+    val (tab, (sps, ss)) = RecordsData.get thy;
+    val simps = map #2 splits;
+  in RecordsData.put (tab, (Symtab.extend (sps, splits), Simplifier.addsimps (ss, simps))) thy end;
 
 
 (* parent records *)
@@ -357,84 +397,139 @@
 
 
 
+(** record field splitting **)
+
+fun record_split_tac i st =
+  let
+    val (_, (sps, ss)) = RecordsData.get_sg (Thm.sign_of_thm st);
+
+    fun is_fieldT (_, Type (a, [_, _])) = is_some (Symtab.lookup (sps, a))
+      | is_fieldT _ = false;
+    val params = Logic.strip_params (Library.nth_elem (i - 1, Thm.prems_of st));
+  in
+    if exists is_fieldT params then Simplifier.full_simp_tac ss i st
+    else Seq.empty
+  end handle Library.LIST _ => Seq.empty;
+
+val record_split_wrapper = ("record_split_tac", fn tac => record_split_tac ORELSE' tac);
+
+
+
 (** internal theory extenders **)
 
+(* field_type_defs *)
+
+fun field_type_def ((thy, simps), (name, tname, vs, T, U)) =
+  let
+    val full = Sign.full_name (sign_of thy);
+    val (thy', {simps = simps', ...}) =
+      thy
+      |> setmp DatatypePackage.quiet_mode true
+        (DatatypePackage.add_datatype_i true [tname]
+          [(vs, tname, Syntax.NoSyn, [(name, [T, U], Syntax.NoSyn)])]);
+    val thy'' =
+      thy'
+      |> setmp AxClass.quiet_mode true
+        (AxClass.add_inst_arity_i (full tname, [HOLogic.termS, moreS], moreS) [] [] None);
+  in (thy'', simps' @ simps) end;
+
+fun field_type_defs args thy = foldl field_type_def ((thy, []), args);
+
+
 (* field_definitions *)
 
-(*theorems from Prod.thy*)
-val prod_convs = map Attribute.tthm_of [fst_conv, snd_conv];
-
-
 fun field_definitions fields names zeta moreT more vars named_vars thy =
   let
+    val sign = Theory.sign_of thy;
     val base = Sign.base_name;
+    val full_path = Sign.full_name_path sign;
 
 
     (* prepare declarations and definitions *)
 
     (*field types*)
     fun mk_fieldT_spec c =
-      (suffix field_typeN c, ["'a", zeta],
-        HOLogic.mk_prodT (TFree ("'a", HOLogic.termS), moreT), Syntax.NoSyn);
+      (suffix raw_fieldN c, suffix field_typeN c,
+        ["'a", zeta], TFree ("'a", HOLogic.termS), moreT);
     val fieldT_specs = map (mk_fieldT_spec o base) names;
 
-    (*field declarations*)
+    (*field constructors*)
     val field_decls = map (mk_fieldC moreT) fields;
-    val dest_decls = map (mk_fstC moreT) fields @ map (mk_sndC moreT) fields;
 
-    (*field constructors*)
     fun mk_field_spec (c, v) =
-      mk_field ((c, v), more) :== HOLogic.mk_prod (v, more);
+      mk_field ((c, v), more) :== mk_raw_field ((c, v), more);
     val field_specs = map mk_field_spec named_vars;
 
     (*field destructors*)
-    fun mk_dest_spec dest dest' (c, T) =
-      let
-        val p = Free ("p", mk_fieldT ((c, T), moreT));
-        val p' = Free ("p", HOLogic.mk_prodT (T, moreT));
-          (*note: field types are just abbreviations*)
-      in dest p :== dest' p' end;
+    val dest_decls = map (mk_fstC moreT) fields @ map (mk_sndC moreT) fields;
+
+    fun mk_dest_spec dest f (c, T) =
+      let val p = Free ("p", mk_fieldT ((c, T), moreT));
+      in dest p :== mk_prod_case (suffix field_typeN c) (f T moreT) p end;
     val dest_specs =
-      map (mk_dest_spec mk_fst HOLogic.mk_fst) fields @
-      map (mk_dest_spec mk_snd HOLogic.mk_snd) fields;
+      map (mk_dest_spec mk_fst fst_fn) fields @
+      map (mk_dest_spec mk_snd snd_fn) fields;
 
 
     (* prepare theorems *)
 
+    (*destructor conversions*)
     fun mk_dest_prop dest dest' (c, v) =
       dest (mk_field ((c, v), more)) === dest' (v, more);
     val dest_props =
       map (mk_dest_prop mk_fst fst) named_vars @
       map (mk_dest_prop mk_snd snd) named_vars;
 
+    (*surjective pairing*)
+    fun mk_surj_prop (c, T) =
+      let val p = Free ("p", mk_fieldT ((c, T), moreT));
+      in p === mk_field ((c, mk_fst p), mk_snd p) end;
+    val surj_props = map mk_surj_prop fields;
 
-    (* 1st stage: defs_thy *)
+
+    (* 1st stage: types_thy *)
+
+    val (types_thy, simps) =
+      thy
+      |> field_type_defs fieldT_specs;
+
+    val datatype_simps = map Attribute.tthm_of simps;
+
+
+    (* 2nd stage: defs_thy *)
 
     val defs_thy =
-      thy
-      |> Theory.add_tyabbrs_i fieldT_specs
-      |> (Theory.add_consts_i o map (Syntax.no_syn o apfst base))
-        (field_decls @ dest_decls)
-      |> (PureThy.add_defs_i o map (fn x => (x, [Attribute.tag_internal])))
-        (field_specs @ dest_specs);
+      types_thy
+       |> (Theory.add_consts_i o map (Syntax.no_syn o apfst base))
+         (field_decls @ dest_decls)
+       |> (PureThy.add_defs_i o map (fn x => (x, [Attribute.tag_internal])))
+         (field_specs @ dest_specs);
 
     val field_defs = get_defs defs_thy field_specs;
     val dest_defs = get_defs defs_thy dest_specs;
 
 
-    (* 2nd stage: thms_thy *)
+    (* 3rd stage: thms_thy *)
+
+    val prove = prove_simp defs_thy;
 
-    val dest_convs =
-      map (prove_simp defs_thy (field_defs @ dest_defs @ prod_convs)) dest_props;
+    val dest_convs = map (prove [] (field_defs @ dest_defs @ datatype_simps)) dest_props;
+    val surj_pairs = map (prove [DatatypePackage.induct_tac "p" 1]
+      (map (apfst Thm.symmetric) field_defs @ dest_convs)) surj_props;
+
+    fun mk_split th = SplitPairedAll.rule (standard (th RS eq_reflection));
+    val splits = map (fn (th, _) => Attribute.tthm_of (mk_split th)) surj_pairs;
 
     val thms_thy =
       defs_thy
       |> (PureThy.add_tthmss o map Attribute.none)
         [("field_defs", field_defs),
           ("dest_defs", dest_defs),
-          ("dest_convs", dest_convs)];
+          ("dest_convs", dest_convs),
+          ("surj_pairs", surj_pairs),
+          ("splits", splits)];
 
-  in (thms_thy, dest_convs) end;
+  in (thms_thy, dest_convs, splits) end;
 
 
 (* record_definition *)
@@ -493,9 +588,8 @@
 
     (* prepare print translation functions *)
 
-    val accesses = distinct (flat (map NameSpace.accesses (full_moreN :: names)));
-    val (_, _, tr'_names, _) = Syntax.trfun_names (Theory.syn_of thy);
-    val field_tr's = filter_out (fn (c, _) => c mem tr'_names) (print_translation accesses);
+    val field_tr's =
+      print_translation (distinct (flat (map NameSpace.accesses (full_moreN :: names))));
 
 
     (* prepare declarations *)
@@ -561,11 +655,13 @@
 
     (* 1st stage: fields_thy *)
 
-    val (fields_thy, field_simps) =
+    val (fields_thy, field_simps, splits) =
       thy
       |> Theory.add_path bname
       |> field_definitions fields names zeta moreT more vars named_vars;
 
+    val field_splits = map2 (fn (c, (th, _)) => (suffix field_typeN c, th)) (names, splits);
+
 
     (* 2nd stage: defs_thy *)
 
@@ -589,7 +685,7 @@
     (* 3rd stage: thms_thy *)
 
     val parent_simps = flat (map #simps parents);
-    val prove = prove_simp defs_thy;
+    val prove = prove_simp defs_thy [];
 
     val sel_convs = map (prove (parent_simps @ sel_defs @ field_simps)) sel_props;
     val update_convs = map (prove (parent_simps @ update_defs @ sel_convs)) update_props;
@@ -612,6 +708,7 @@
     val final_thy =
       thms_thy
       |> put_record name {args = args, parent = parent, fields = fields, simps = simps}
+      |> add_record_splits field_splits
       |> Theory.parent_path;
 
   in final_thy end;
@@ -649,7 +746,7 @@
   let
     val _ = Theory.requires thy "Record" "record definitions";
     val sign = Theory.sign_of thy;
-    val _ = writeln ("Defining record " ^ quote bname ^ " ...");
+    val _ = message ("Defining record " ^ quote bname ^ " ...");
 
 
     (* parents *)
@@ -732,9 +829,18 @@
 
 (** setup theory **)
 
+fun add_wrapper wrapper thy =
+  let val r = claset_ref_of thy
+  in r := ! r addSWrapper wrapper; thy end;
+
 val setup =
  [RecordsData.init,
-  Theory.add_trfuns ([], parse_translation, [], [])];
+  Theory.add_trfuns ([], parse_translation, [], []),
+  add_wrapper record_split_wrapper];
 
 
 end;
+
+
+structure BasicRecordPackage: BASIC_RECORD_PACKAGE = RecordPackage;
+open BasicRecordPackage;