--- a/src/HOL/Tools/record_package.ML Tue Oct 20 16:37:02 1998 +0200
+++ b/src/HOL/Tools/record_package.ML Tue Oct 20 16:38:37 1998 +0200
@@ -3,15 +3,18 @@
Author: Wolfgang Naraschewski and Markus Wenzel, TU Muenchen
Extensible records with structural subtyping in HOL.
+*)
-TODO:
- - field types: datatype;
- - operations and theorems: split, split_all/ex, ...;
- - field constructor: more specific type for snd component (x_more etc. classes);
-*)
+signature BASIC_RECORD_PACKAGE =
+sig
+ val record_split_tac: int -> tactic
+ val record_split_wrapper: string * wrapper
+end;
signature RECORD_PACKAGE =
sig
+ include BASIC_RECORD_PACKAGE
+ val quiet_mode: bool ref
val moreS: sort
val mk_fieldT: (string * typ) * typ -> typ
val dest_fieldT: typ -> (string * typ) * typ
@@ -37,9 +40,15 @@
(*** utilities ***)
+(* messages *)
+
+val quiet_mode = ref false;
+fun message s = if ! quiet_mode then () else writeln s;
+
+
(* definitions and equations *)
-infix 0 :== === ;
+infix 0 :== ===;
val (op :==) = Logic.mk_defpair;
val (op ===) = HOLogic.mk_Trueprop o HOLogic.mk_eq;
@@ -49,7 +58,7 @@
(* proof by simplification *)
-fun prove_simp thy simps =
+fun prove_simp thy tacs simps =
let
val sign = Theory.sign_of thy;
val ss = Simplifier.addsimps (HOL_basic_ss, map Attribute.thm_of simps);
@@ -57,7 +66,7 @@
fun prove goal =
Attribute.tthm_of
(Goals.prove_goalw_cterm [] (Thm.cterm_of sign goal)
- (K [ALLGOALS (Simplifier.simp_tac ss)])
+ (K (tacs @ [ALLGOALS (Simplifier.simp_tac ss)]))
handle ERROR => error ("The error(s) above occurred while trying to prove "
^ quote (Sign.string_of_term sign goal)));
in prove end;
@@ -71,13 +80,30 @@
val moreN = "more";
val schemeN = "_scheme";
val fieldN = "_field";
+val raw_fieldN = "_raw_field";
val field_typeN = "_field_type";
-val fstN = "_fst";
-val sndN = "_snd";
+val fstN = "_val";
+val sndN = "_more";
val updateN = "_update";
val makeN = "make";
val make_schemeN = "make_scheme";
+(*see datatype package*)
+val caseN = "_case";
+
+
+
+(** generic operations **)
+
+fun fst_fn T U = Abs ("x", T, Abs ("y", U, Bound 1));
+fun snd_fn T U = Abs ("x", T, Abs ("y", U, Bound 0));
+
+fun mk_prod_case name f p =
+ let
+ val fT as Type ("fun", [A, Type ("fun", [B, C])]) = fastype_of f;
+ val pT = fastype_of p;
+ in Const (suffix caseN name, fT --> pT --> C) $ f $ p end;
+
(** tuple operations **)
@@ -102,9 +128,12 @@
fun mk_fieldC U (c, T) = (suffix fieldN c, T --> U --> mk_fieldT ((c, T), U));
-fun mk_field ((c, t), u) =
+fun gen_mk_field sfx ((c, t), u) =
let val T = fastype_of t and U = fastype_of u
- in Const (suffix fieldN c, [T, U] ---> mk_fieldT ((c, T), U)) $ t $ u end;
+ in Const (suffix sfx c, [T, U] ---> mk_fieldT ((c, T), U)) $ t $ u end;
+
+val mk_field = gen_mk_field fieldN;
+val mk_raw_field = gen_mk_field raw_fieldN;
(* destructors *)
@@ -289,13 +318,17 @@
structure RecordsArgs =
struct
val name = "HOL/records";
- type T = record_info Symtab.table;
+ type T =
+ record_info Symtab.table * (*records*)
+ (thm Symtab.table * Simplifier.simpset); (*field split rules*)
- val empty = Symtab.empty;
+ val empty = (Symtab.empty, (Symtab.empty, HOL_basic_ss));
val prep_ext = I;
- val merge: T * T -> T = Symtab.merge (K true);
+ fun merge ((recs1, (sps1, ss1)), (recs2, (sps2, ss2))) =
+ (Symtab.merge (K true) (recs1, recs2),
+ (Symtab.merge (K true) (sps1, sps2), Simplifier.merge_ss (ss1, ss2)));
- fun print sg tab =
+ fun print sg (recs, _) =
let
val prt_typ = Sign.pretty_typ sg;
val ext_const = Sign.cond_extern sg Sign.constK;
@@ -311,7 +344,7 @@
(Pretty.block [prt_typ (Type (name, map TFree args)), Pretty.str " = "] ::
pretty_parent parent @ map pretty_field fields));
in
- seq (Pretty.writeln o pretty_record) (Symtab.dest tab)
+ seq (Pretty.writeln o pretty_record) (Symtab.dest recs)
end;
end;
@@ -319,12 +352,19 @@
val print_records = RecordsData.print;
-(* get and put records *)
+(* get and put data *)
-fun get_record thy name = Symtab.lookup (RecordsData.get thy, name);
+fun get_record thy name = Symtab.lookup (#1 (RecordsData.get thy), name);
fun put_record name info thy =
- RecordsData.put (Symtab.update ((name, info), RecordsData.get thy)) thy;
+ let val (tab, sp) = RecordsData.get thy
+ in RecordsData.put (Symtab.update ((name, info), tab), sp) thy end;
+
+fun add_record_splits splits thy =
+ let
+ val (tab, (sps, ss)) = RecordsData.get thy;
+ val simps = map #2 splits;
+ in RecordsData.put (tab, (Symtab.extend (sps, splits), Simplifier.addsimps (ss, simps))) thy end;
(* parent records *)
@@ -357,84 +397,139 @@
+(** record field splitting **)
+
+fun record_split_tac i st =
+ let
+ val (_, (sps, ss)) = RecordsData.get_sg (Thm.sign_of_thm st);
+
+ fun is_fieldT (_, Type (a, [_, _])) = is_some (Symtab.lookup (sps, a))
+ | is_fieldT _ = false;
+ val params = Logic.strip_params (Library.nth_elem (i - 1, Thm.prems_of st));
+ in
+ if exists is_fieldT params then Simplifier.full_simp_tac ss i st
+ else Seq.empty
+ end handle Library.LIST _ => Seq.empty;
+
+val record_split_wrapper = ("record_split_tac", fn tac => record_split_tac ORELSE' tac);
+
+
+
(** internal theory extenders **)
+(* field_type_defs *)
+
+fun field_type_def ((thy, simps), (name, tname, vs, T, U)) =
+ let
+ val full = Sign.full_name (sign_of thy);
+ val (thy', {simps = simps', ...}) =
+ thy
+ |> setmp DatatypePackage.quiet_mode true
+ (DatatypePackage.add_datatype_i true [tname]
+ [(vs, tname, Syntax.NoSyn, [(name, [T, U], Syntax.NoSyn)])]);
+ val thy'' =
+ thy'
+ |> setmp AxClass.quiet_mode true
+ (AxClass.add_inst_arity_i (full tname, [HOLogic.termS, moreS], moreS) [] [] None);
+ in (thy'', simps' @ simps) end;
+
+fun field_type_defs args thy = foldl field_type_def ((thy, []), args);
+
+
(* field_definitions *)
-(*theorems from Prod.thy*)
-val prod_convs = map Attribute.tthm_of [fst_conv, snd_conv];
-
-
fun field_definitions fields names zeta moreT more vars named_vars thy =
let
+ val sign = Theory.sign_of thy;
val base = Sign.base_name;
+ val full_path = Sign.full_name_path sign;
(* prepare declarations and definitions *)
(*field types*)
fun mk_fieldT_spec c =
- (suffix field_typeN c, ["'a", zeta],
- HOLogic.mk_prodT (TFree ("'a", HOLogic.termS), moreT), Syntax.NoSyn);
+ (suffix raw_fieldN c, suffix field_typeN c,
+ ["'a", zeta], TFree ("'a", HOLogic.termS), moreT);
val fieldT_specs = map (mk_fieldT_spec o base) names;
- (*field declarations*)
+ (*field constructors*)
val field_decls = map (mk_fieldC moreT) fields;
- val dest_decls = map (mk_fstC moreT) fields @ map (mk_sndC moreT) fields;
- (*field constructors*)
fun mk_field_spec (c, v) =
- mk_field ((c, v), more) :== HOLogic.mk_prod (v, more);
+ mk_field ((c, v), more) :== mk_raw_field ((c, v), more);
val field_specs = map mk_field_spec named_vars;
(*field destructors*)
- fun mk_dest_spec dest dest' (c, T) =
- let
- val p = Free ("p", mk_fieldT ((c, T), moreT));
- val p' = Free ("p", HOLogic.mk_prodT (T, moreT));
- (*note: field types are just abbreviations*)
- in dest p :== dest' p' end;
+ val dest_decls = map (mk_fstC moreT) fields @ map (mk_sndC moreT) fields;
+
+ fun mk_dest_spec dest f (c, T) =
+ let val p = Free ("p", mk_fieldT ((c, T), moreT));
+ in dest p :== mk_prod_case (suffix field_typeN c) (f T moreT) p end;
val dest_specs =
- map (mk_dest_spec mk_fst HOLogic.mk_fst) fields @
- map (mk_dest_spec mk_snd HOLogic.mk_snd) fields;
+ map (mk_dest_spec mk_fst fst_fn) fields @
+ map (mk_dest_spec mk_snd snd_fn) fields;
(* prepare theorems *)
+ (*destructor conversions*)
fun mk_dest_prop dest dest' (c, v) =
dest (mk_field ((c, v), more)) === dest' (v, more);
val dest_props =
map (mk_dest_prop mk_fst fst) named_vars @
map (mk_dest_prop mk_snd snd) named_vars;
+ (*surjective pairing*)
+ fun mk_surj_prop (c, T) =
+ let val p = Free ("p", mk_fieldT ((c, T), moreT));
+ in p === mk_field ((c, mk_fst p), mk_snd p) end;
+ val surj_props = map mk_surj_prop fields;
- (* 1st stage: defs_thy *)
+
+ (* 1st stage: types_thy *)
+
+ val (types_thy, simps) =
+ thy
+ |> field_type_defs fieldT_specs;
+
+ val datatype_simps = map Attribute.tthm_of simps;
+
+
+ (* 2nd stage: defs_thy *)
val defs_thy =
- thy
- |> Theory.add_tyabbrs_i fieldT_specs
- |> (Theory.add_consts_i o map (Syntax.no_syn o apfst base))
- (field_decls @ dest_decls)
- |> (PureThy.add_defs_i o map (fn x => (x, [Attribute.tag_internal])))
- (field_specs @ dest_specs);
+ types_thy
+ |> (Theory.add_consts_i o map (Syntax.no_syn o apfst base))
+ (field_decls @ dest_decls)
+ |> (PureThy.add_defs_i o map (fn x => (x, [Attribute.tag_internal])))
+ (field_specs @ dest_specs);
val field_defs = get_defs defs_thy field_specs;
val dest_defs = get_defs defs_thy dest_specs;
- (* 2nd stage: thms_thy *)
+ (* 3rd stage: thms_thy *)
+
+ val prove = prove_simp defs_thy;
- val dest_convs =
- map (prove_simp defs_thy (field_defs @ dest_defs @ prod_convs)) dest_props;
+ val dest_convs = map (prove [] (field_defs @ dest_defs @ datatype_simps)) dest_props;
+ val surj_pairs = map (prove [DatatypePackage.induct_tac "p" 1]
+ (map (apfst Thm.symmetric) field_defs @ dest_convs)) surj_props;
+
+ fun mk_split th = SplitPairedAll.rule (standard (th RS eq_reflection));
+ val splits = map (fn (th, _) => Attribute.tthm_of (mk_split th)) surj_pairs;
val thms_thy =
defs_thy
|> (PureThy.add_tthmss o map Attribute.none)
[("field_defs", field_defs),
("dest_defs", dest_defs),
- ("dest_convs", dest_convs)];
+ ("dest_convs", dest_convs),
+ ("surj_pairs", surj_pairs),
+ ("splits", splits)];
- in (thms_thy, dest_convs) end;
+ in (thms_thy, dest_convs, splits) end;
(* record_definition *)
@@ -493,9 +588,8 @@
(* prepare print translation functions *)
- val accesses = distinct (flat (map NameSpace.accesses (full_moreN :: names)));
- val (_, _, tr'_names, _) = Syntax.trfun_names (Theory.syn_of thy);
- val field_tr's = filter_out (fn (c, _) => c mem tr'_names) (print_translation accesses);
+ val field_tr's =
+ print_translation (distinct (flat (map NameSpace.accesses (full_moreN :: names))));
(* prepare declarations *)
@@ -561,11 +655,13 @@
(* 1st stage: fields_thy *)
- val (fields_thy, field_simps) =
+ val (fields_thy, field_simps, splits) =
thy
|> Theory.add_path bname
|> field_definitions fields names zeta moreT more vars named_vars;
+ val field_splits = map2 (fn (c, (th, _)) => (suffix field_typeN c, th)) (names, splits);
+
(* 2nd stage: defs_thy *)
@@ -589,7 +685,7 @@
(* 3rd stage: thms_thy *)
val parent_simps = flat (map #simps parents);
- val prove = prove_simp defs_thy;
+ val prove = prove_simp defs_thy [];
val sel_convs = map (prove (parent_simps @ sel_defs @ field_simps)) sel_props;
val update_convs = map (prove (parent_simps @ update_defs @ sel_convs)) update_props;
@@ -612,6 +708,7 @@
val final_thy =
thms_thy
|> put_record name {args = args, parent = parent, fields = fields, simps = simps}
+ |> add_record_splits field_splits
|> Theory.parent_path;
in final_thy end;
@@ -649,7 +746,7 @@
let
val _ = Theory.requires thy "Record" "record definitions";
val sign = Theory.sign_of thy;
- val _ = writeln ("Defining record " ^ quote bname ^ " ...");
+ val _ = message ("Defining record " ^ quote bname ^ " ...");
(* parents *)
@@ -732,9 +829,18 @@
(** setup theory **)
+fun add_wrapper wrapper thy =
+ let val r = claset_ref_of thy
+ in r := ! r addSWrapper wrapper; thy end;
+
val setup =
[RecordsData.init,
- Theory.add_trfuns ([], parse_translation, [], [])];
+ Theory.add_trfuns ([], parse_translation, [], []),
+ add_wrapper record_split_wrapper];
end;
+
+
+structure BasicRecordPackage: BASIC_RECORD_PACKAGE = RecordPackage;
+open BasicRecordPackage;