src/HOL/Tools/record_package.ML
changeset 5698 2b5d9bdec5af
parent 5290 b755c7240348
child 5707 b0e631634b5a
     1.1 --- a/src/HOL/Tools/record_package.ML	Tue Oct 20 16:37:02 1998 +0200
     1.2 +++ b/src/HOL/Tools/record_package.ML	Tue Oct 20 16:38:37 1998 +0200
     1.3 @@ -3,15 +3,18 @@
     1.4      Author:     Wolfgang Naraschewski and Markus Wenzel, TU Muenchen
     1.5  
     1.6  Extensible records with structural subtyping in HOL.
     1.7 +*)
     1.8  
     1.9 -TODO:
    1.10 -  - field types: datatype;
    1.11 -  - operations and theorems: split, split_all/ex, ...;
    1.12 -  - field constructor: more specific type for snd component (x_more etc. classes);
    1.13 -*)
    1.14 +signature BASIC_RECORD_PACKAGE =
    1.15 +sig
    1.16 +  val record_split_tac: int -> tactic
    1.17 +  val record_split_wrapper: string * wrapper
    1.18 +end;
    1.19  
    1.20  signature RECORD_PACKAGE =
    1.21  sig
    1.22 +  include BASIC_RECORD_PACKAGE
    1.23 +  val quiet_mode: bool ref
    1.24    val moreS: sort
    1.25    val mk_fieldT: (string * typ) * typ -> typ
    1.26    val dest_fieldT: typ -> (string * typ) * typ
    1.27 @@ -37,9 +40,15 @@
    1.28  
    1.29  (*** utilities ***)
    1.30  
    1.31 +(* messages *)
    1.32 +
    1.33 +val quiet_mode = ref false;
    1.34 +fun message s = if ! quiet_mode then () else writeln s;
    1.35 +
    1.36 +
    1.37  (* definitions and equations *)
    1.38  
    1.39 -infix 0 :== === ;
    1.40 +infix 0 :== ===;
    1.41  
    1.42  val (op :==) = Logic.mk_defpair;
    1.43  val (op ===) = HOLogic.mk_Trueprop o HOLogic.mk_eq;
    1.44 @@ -49,7 +58,7 @@
    1.45  
    1.46  (* proof by simplification *)
    1.47  
    1.48 -fun prove_simp thy simps =
    1.49 +fun prove_simp thy tacs simps =
    1.50    let
    1.51      val sign = Theory.sign_of thy;
    1.52      val ss = Simplifier.addsimps (HOL_basic_ss, map Attribute.thm_of simps);
    1.53 @@ -57,7 +66,7 @@
    1.54      fun prove goal =
    1.55        Attribute.tthm_of
    1.56          (Goals.prove_goalw_cterm [] (Thm.cterm_of sign goal)
    1.57 -          (K [ALLGOALS (Simplifier.simp_tac ss)])
    1.58 +          (K (tacs @ [ALLGOALS (Simplifier.simp_tac ss)]))
    1.59          handle ERROR => error ("The error(s) above occurred while trying to prove "
    1.60            ^ quote (Sign.string_of_term sign goal)));
    1.61    in prove end;
    1.62 @@ -71,13 +80,30 @@
    1.63  val moreN = "more";
    1.64  val schemeN = "_scheme";
    1.65  val fieldN = "_field";
    1.66 +val raw_fieldN = "_raw_field";
    1.67  val field_typeN = "_field_type";
    1.68 -val fstN = "_fst";
    1.69 -val sndN = "_snd";
    1.70 +val fstN = "_val";
    1.71 +val sndN = "_more";
    1.72  val updateN = "_update";
    1.73  val makeN = "make";
    1.74  val make_schemeN = "make_scheme";
    1.75  
    1.76 +(*see datatype package*)
    1.77 +val caseN = "_case";
    1.78 +
    1.79 +
    1.80 +
    1.81 +(** generic operations **)
    1.82 +
    1.83 +fun fst_fn T U = Abs ("x", T, Abs ("y", U, Bound 1));
    1.84 +fun snd_fn T U = Abs ("x", T, Abs ("y", U, Bound 0));
    1.85 +
    1.86 +fun mk_prod_case name f p =
    1.87 +  let
    1.88 +    val fT as Type ("fun", [A, Type ("fun", [B, C])]) = fastype_of f;
    1.89 +    val pT = fastype_of p;
    1.90 +  in Const (suffix caseN name, fT --> pT --> C) $ f $ p end;
    1.91 +
    1.92  
    1.93  
    1.94  (** tuple operations **)
    1.95 @@ -102,9 +128,12 @@
    1.96  
    1.97  fun mk_fieldC U (c, T) = (suffix fieldN c, T --> U --> mk_fieldT ((c, T), U));
    1.98  
    1.99 -fun mk_field ((c, t), u) =
   1.100 +fun gen_mk_field sfx ((c, t), u) =
   1.101    let val T = fastype_of t and U = fastype_of u
   1.102 -  in Const (suffix fieldN c, [T, U] ---> mk_fieldT ((c, T), U)) $ t $ u end;
   1.103 +  in Const (suffix sfx c, [T, U] ---> mk_fieldT ((c, T), U)) $ t $ u end;
   1.104 +
   1.105 +val mk_field = gen_mk_field fieldN;
   1.106 +val mk_raw_field = gen_mk_field raw_fieldN;
   1.107  
   1.108  
   1.109  (* destructors *)
   1.110 @@ -289,13 +318,17 @@
   1.111  structure RecordsArgs =
   1.112  struct
   1.113    val name = "HOL/records";
   1.114 -  type T = record_info Symtab.table;
   1.115 +  type T =
   1.116 +    record_info Symtab.table *				(*records*)
   1.117 +      (thm Symtab.table * Simplifier.simpset);		(*field split rules*)
   1.118  
   1.119 -  val empty = Symtab.empty;
   1.120 +  val empty = (Symtab.empty, (Symtab.empty, HOL_basic_ss));
   1.121    val prep_ext = I;
   1.122 -  val merge: T * T -> T = Symtab.merge (K true);
   1.123 +  fun merge ((recs1, (sps1, ss1)), (recs2, (sps2, ss2))) =
   1.124 +    (Symtab.merge (K true) (recs1, recs2),
   1.125 +      (Symtab.merge (K true) (sps1, sps2), Simplifier.merge_ss (ss1, ss2)));
   1.126  
   1.127 -  fun print sg tab =
   1.128 +  fun print sg (recs, _) =
   1.129      let
   1.130        val prt_typ = Sign.pretty_typ sg;
   1.131        val ext_const = Sign.cond_extern sg Sign.constK;
   1.132 @@ -311,7 +344,7 @@
   1.133          (Pretty.block [prt_typ (Type (name, map TFree args)), Pretty.str " = "] ::
   1.134            pretty_parent parent @ map pretty_field fields));
   1.135      in
   1.136 -      seq (Pretty.writeln o pretty_record) (Symtab.dest tab)
   1.137 +      seq (Pretty.writeln o pretty_record) (Symtab.dest recs)
   1.138      end;
   1.139  end;
   1.140  
   1.141 @@ -319,12 +352,19 @@
   1.142  val print_records = RecordsData.print;
   1.143  
   1.144  
   1.145 -(* get and put records *)
   1.146 +(* get and put data *)
   1.147  
   1.148 -fun get_record thy name = Symtab.lookup (RecordsData.get thy, name);
   1.149 +fun get_record thy name = Symtab.lookup (#1 (RecordsData.get thy), name);
   1.150  
   1.151  fun put_record name info thy =
   1.152 -  RecordsData.put (Symtab.update ((name, info), RecordsData.get thy)) thy;
   1.153 +  let val (tab, sp) = RecordsData.get thy
   1.154 +  in RecordsData.put (Symtab.update ((name, info), tab), sp) thy end;
   1.155 +
   1.156 +fun add_record_splits splits thy =
   1.157 +  let
   1.158 +    val (tab, (sps, ss)) = RecordsData.get thy;
   1.159 +    val simps = map #2 splits;
   1.160 +  in RecordsData.put (tab, (Symtab.extend (sps, splits), Simplifier.addsimps (ss, simps))) thy end;
   1.161  
   1.162  
   1.163  (* parent records *)
   1.164 @@ -357,84 +397,139 @@
   1.165  
   1.166  
   1.167  
   1.168 +(** record field splitting **)
   1.169 +
   1.170 +fun record_split_tac i st =
   1.171 +  let
   1.172 +    val (_, (sps, ss)) = RecordsData.get_sg (Thm.sign_of_thm st);
   1.173 +
   1.174 +    fun is_fieldT (_, Type (a, [_, _])) = is_some (Symtab.lookup (sps, a))
   1.175 +      | is_fieldT _ = false;
   1.176 +    val params = Logic.strip_params (Library.nth_elem (i - 1, Thm.prems_of st));
   1.177 +  in
   1.178 +    if exists is_fieldT params then Simplifier.full_simp_tac ss i st
   1.179 +    else Seq.empty
   1.180 +  end handle Library.LIST _ => Seq.empty;
   1.181 +
   1.182 +val record_split_wrapper = ("record_split_tac", fn tac => record_split_tac ORELSE' tac);
   1.183 +
   1.184 +
   1.185 +
   1.186  (** internal theory extenders **)
   1.187  
   1.188 +(* field_type_defs *)
   1.189 +
   1.190 +fun field_type_def ((thy, simps), (name, tname, vs, T, U)) =
   1.191 +  let
   1.192 +    val full = Sign.full_name (sign_of thy);
   1.193 +    val (thy', {simps = simps', ...}) =
   1.194 +      thy
   1.195 +      |> setmp DatatypePackage.quiet_mode true
   1.196 +        (DatatypePackage.add_datatype_i true [tname]
   1.197 +          [(vs, tname, Syntax.NoSyn, [(name, [T, U], Syntax.NoSyn)])]);
   1.198 +    val thy'' =
   1.199 +      thy'
   1.200 +      |> setmp AxClass.quiet_mode true
   1.201 +        (AxClass.add_inst_arity_i (full tname, [HOLogic.termS, moreS], moreS) [] [] None);
   1.202 +  in (thy'', simps' @ simps) end;
   1.203 +
   1.204 +fun field_type_defs args thy = foldl field_type_def ((thy, []), args);
   1.205 +
   1.206 +
   1.207  (* field_definitions *)
   1.208  
   1.209 -(*theorems from Prod.thy*)
   1.210 -val prod_convs = map Attribute.tthm_of [fst_conv, snd_conv];
   1.211 -
   1.212 -
   1.213  fun field_definitions fields names zeta moreT more vars named_vars thy =
   1.214    let
   1.215 +    val sign = Theory.sign_of thy;
   1.216      val base = Sign.base_name;
   1.217 +    val full_path = Sign.full_name_path sign;
   1.218  
   1.219  
   1.220      (* prepare declarations and definitions *)
   1.221  
   1.222      (*field types*)
   1.223      fun mk_fieldT_spec c =
   1.224 -      (suffix field_typeN c, ["'a", zeta],
   1.225 -        HOLogic.mk_prodT (TFree ("'a", HOLogic.termS), moreT), Syntax.NoSyn);
   1.226 +      (suffix raw_fieldN c, suffix field_typeN c,
   1.227 +        ["'a", zeta], TFree ("'a", HOLogic.termS), moreT);
   1.228      val fieldT_specs = map (mk_fieldT_spec o base) names;
   1.229  
   1.230 -    (*field declarations*)
   1.231 +    (*field constructors*)
   1.232      val field_decls = map (mk_fieldC moreT) fields;
   1.233 -    val dest_decls = map (mk_fstC moreT) fields @ map (mk_sndC moreT) fields;
   1.234  
   1.235 -    (*field constructors*)
   1.236      fun mk_field_spec (c, v) =
   1.237 -      mk_field ((c, v), more) :== HOLogic.mk_prod (v, more);
   1.238 +      mk_field ((c, v), more) :== mk_raw_field ((c, v), more);
   1.239      val field_specs = map mk_field_spec named_vars;
   1.240  
   1.241      (*field destructors*)
   1.242 -    fun mk_dest_spec dest dest' (c, T) =
   1.243 -      let
   1.244 -        val p = Free ("p", mk_fieldT ((c, T), moreT));
   1.245 -        val p' = Free ("p", HOLogic.mk_prodT (T, moreT));
   1.246 -          (*note: field types are just abbreviations*)
   1.247 -      in dest p :== dest' p' end;
   1.248 +    val dest_decls = map (mk_fstC moreT) fields @ map (mk_sndC moreT) fields;
   1.249 +
   1.250 +    fun mk_dest_spec dest f (c, T) =
   1.251 +      let val p = Free ("p", mk_fieldT ((c, T), moreT));
   1.252 +      in dest p :== mk_prod_case (suffix field_typeN c) (f T moreT) p end;
   1.253      val dest_specs =
   1.254 -      map (mk_dest_spec mk_fst HOLogic.mk_fst) fields @
   1.255 -      map (mk_dest_spec mk_snd HOLogic.mk_snd) fields;
   1.256 +      map (mk_dest_spec mk_fst fst_fn) fields @
   1.257 +      map (mk_dest_spec mk_snd snd_fn) fields;
   1.258  
   1.259  
   1.260      (* prepare theorems *)
   1.261  
   1.262 +    (*destructor conversions*)
   1.263      fun mk_dest_prop dest dest' (c, v) =
   1.264        dest (mk_field ((c, v), more)) === dest' (v, more);
   1.265      val dest_props =
   1.266        map (mk_dest_prop mk_fst fst) named_vars @
   1.267        map (mk_dest_prop mk_snd snd) named_vars;
   1.268  
   1.269 +    (*surjective pairing*)
   1.270 +    fun mk_surj_prop (c, T) =
   1.271 +      let val p = Free ("p", mk_fieldT ((c, T), moreT));
   1.272 +      in p === mk_field ((c, mk_fst p), mk_snd p) end;
   1.273 +    val surj_props = map mk_surj_prop fields;
   1.274  
   1.275 -    (* 1st stage: defs_thy *)
   1.276 +
   1.277 +    (* 1st stage: types_thy *)
   1.278 +
   1.279 +    val (types_thy, simps) =
   1.280 +      thy
   1.281 +      |> field_type_defs fieldT_specs;
   1.282 +
   1.283 +    val datatype_simps = map Attribute.tthm_of simps;
   1.284 +
   1.285 +
   1.286 +    (* 2nd stage: defs_thy *)
   1.287  
   1.288      val defs_thy =
   1.289 -      thy
   1.290 -      |> Theory.add_tyabbrs_i fieldT_specs
   1.291 -      |> (Theory.add_consts_i o map (Syntax.no_syn o apfst base))
   1.292 -        (field_decls @ dest_decls)
   1.293 -      |> (PureThy.add_defs_i o map (fn x => (x, [Attribute.tag_internal])))
   1.294 -        (field_specs @ dest_specs);
   1.295 +      types_thy
   1.296 +       |> (Theory.add_consts_i o map (Syntax.no_syn o apfst base))
   1.297 +         (field_decls @ dest_decls)
   1.298 +       |> (PureThy.add_defs_i o map (fn x => (x, [Attribute.tag_internal])))
   1.299 +         (field_specs @ dest_specs);
   1.300  
   1.301      val field_defs = get_defs defs_thy field_specs;
   1.302      val dest_defs = get_defs defs_thy dest_specs;
   1.303  
   1.304  
   1.305 -    (* 2nd stage: thms_thy *)
   1.306 +    (* 3rd stage: thms_thy *)
   1.307 +
   1.308 +    val prove = prove_simp defs_thy;
   1.309  
   1.310 -    val dest_convs =
   1.311 -      map (prove_simp defs_thy (field_defs @ dest_defs @ prod_convs)) dest_props;
   1.312 +    val dest_convs = map (prove [] (field_defs @ dest_defs @ datatype_simps)) dest_props;
   1.313 +    val surj_pairs = map (prove [DatatypePackage.induct_tac "p" 1]
   1.314 +      (map (apfst Thm.symmetric) field_defs @ dest_convs)) surj_props;
   1.315 +
   1.316 +    fun mk_split th = SplitPairedAll.rule (standard (th RS eq_reflection));
   1.317 +    val splits = map (fn (th, _) => Attribute.tthm_of (mk_split th)) surj_pairs;
   1.318  
   1.319      val thms_thy =
   1.320        defs_thy
   1.321        |> (PureThy.add_tthmss o map Attribute.none)
   1.322          [("field_defs", field_defs),
   1.323            ("dest_defs", dest_defs),
   1.324 -          ("dest_convs", dest_convs)];
   1.325 +          ("dest_convs", dest_convs),
   1.326 +          ("surj_pairs", surj_pairs),
   1.327 +          ("splits", splits)];
   1.328  
   1.329 -  in (thms_thy, dest_convs) end;
   1.330 +  in (thms_thy, dest_convs, splits) end;
   1.331  
   1.332  
   1.333  (* record_definition *)
   1.334 @@ -493,9 +588,8 @@
   1.335  
   1.336      (* prepare print translation functions *)
   1.337  
   1.338 -    val accesses = distinct (flat (map NameSpace.accesses (full_moreN :: names)));
   1.339 -    val (_, _, tr'_names, _) = Syntax.trfun_names (Theory.syn_of thy);
   1.340 -    val field_tr's = filter_out (fn (c, _) => c mem tr'_names) (print_translation accesses);
   1.341 +    val field_tr's =
   1.342 +      print_translation (distinct (flat (map NameSpace.accesses (full_moreN :: names))));
   1.343  
   1.344  
   1.345      (* prepare declarations *)
   1.346 @@ -561,11 +655,13 @@
   1.347  
   1.348      (* 1st stage: fields_thy *)
   1.349  
   1.350 -    val (fields_thy, field_simps) =
   1.351 +    val (fields_thy, field_simps, splits) =
   1.352        thy
   1.353        |> Theory.add_path bname
   1.354        |> field_definitions fields names zeta moreT more vars named_vars;
   1.355  
   1.356 +    val field_splits = map2 (fn (c, (th, _)) => (suffix field_typeN c, th)) (names, splits);
   1.357 +
   1.358  
   1.359      (* 2nd stage: defs_thy *)
   1.360  
   1.361 @@ -589,7 +685,7 @@
   1.362      (* 3rd stage: thms_thy *)
   1.363  
   1.364      val parent_simps = flat (map #simps parents);
   1.365 -    val prove = prove_simp defs_thy;
   1.366 +    val prove = prove_simp defs_thy [];
   1.367  
   1.368      val sel_convs = map (prove (parent_simps @ sel_defs @ field_simps)) sel_props;
   1.369      val update_convs = map (prove (parent_simps @ update_defs @ sel_convs)) update_props;
   1.370 @@ -612,6 +708,7 @@
   1.371      val final_thy =
   1.372        thms_thy
   1.373        |> put_record name {args = args, parent = parent, fields = fields, simps = simps}
   1.374 +      |> add_record_splits field_splits
   1.375        |> Theory.parent_path;
   1.376  
   1.377    in final_thy end;
   1.378 @@ -649,7 +746,7 @@
   1.379    let
   1.380      val _ = Theory.requires thy "Record" "record definitions";
   1.381      val sign = Theory.sign_of thy;
   1.382 -    val _ = writeln ("Defining record " ^ quote bname ^ " ...");
   1.383 +    val _ = message ("Defining record " ^ quote bname ^ " ...");
   1.384  
   1.385  
   1.386      (* parents *)
   1.387 @@ -732,9 +829,18 @@
   1.388  
   1.389  (** setup theory **)
   1.390  
   1.391 +fun add_wrapper wrapper thy =
   1.392 +  let val r = claset_ref_of thy
   1.393 +  in r := ! r addSWrapper wrapper; thy end;
   1.394 +
   1.395  val setup =
   1.396   [RecordsData.init,
   1.397 -  Theory.add_trfuns ([], parse_translation, [], [])];
   1.398 +  Theory.add_trfuns ([], parse_translation, [], []),
   1.399 +  add_wrapper record_split_wrapper];
   1.400  
   1.401  
   1.402  end;
   1.403 +
   1.404 +
   1.405 +structure BasicRecordPackage: BASIC_RECORD_PACKAGE = RecordPackage;
   1.406 +open BasicRecordPackage;