src/HOL/Tools/record_package.ML
changeset 4867 9be2bf0ce909
child 4890 f0a24bad990a
equal deleted inserted replaced
4866:72a46bd00c8d 4867:9be2bf0ce909
       
     1 (*  Title:      HOL/Tools/record_package.ML
       
     2     ID:         $Id$
       
     3     Author:     Wolfgang Naraschewski and Markus Wenzel, TU Muenchen
       
     4 
       
     5 Extensible records with structural subtyping in HOL.
       
     6 
       
     7 TODO:
       
     8   - record_info: tr' funs;
       
     9   - trfuns for record types;
       
    10   - field types: typedef;
       
    11   - make selector types as general as possible (no!?);
       
    12 *)
       
    13 
       
    14 signature RECORD_PACKAGE =
       
    15 sig
       
    16   val print_records: theory -> unit
       
    17   val add_record: (string list * bstring) -> string option
       
    18     -> (bstring * string) list -> theory -> theory
       
    19   val add_record_i: (string list * bstring) -> (typ list * string) option
       
    20     -> (bstring * typ) list -> theory -> theory
       
    21   val setup: (theory -> theory) list
       
    22 end;
       
    23 
       
    24 structure RecordPackage: RECORD_PACKAGE =
       
    25 struct
       
    26 
       
    27 
       
    28 (*** syntax operations ***)
       
    29 
       
    30 (** names **)
       
    31 
       
    32 (* name components *)
       
    33 
       
    34 val moreN = "more";
       
    35 val schemeN = "_scheme";
       
    36 val fieldN = "_field";
       
    37 val field_typeN = "_field_type";
       
    38 val fstN = "_val";
       
    39 val sndN = "_more";
       
    40 val updateN = "_update";
       
    41 val makeN = "make";
       
    42 val make_schemeN = "make_scheme";
       
    43 
       
    44 
       
    45 (* suffixes *)
       
    46 
       
    47 fun suffix sfx s = s ^ sfx;
       
    48 
       
    49 fun unsuffix sfx s =
       
    50   let
       
    51     val cs = explode s;
       
    52     val prfx_len = size s - size sfx;
       
    53   in
       
    54     if prfx_len >= 0 andalso implode (drop (prfx_len, cs)) = sfx then
       
    55       implode (take (prfx_len, cs))
       
    56     else raise LIST "unsuffix"
       
    57   end;
       
    58 
       
    59 
       
    60 
       
    61 (** tuple operations **)
       
    62 
       
    63 (* more type class *)
       
    64 
       
    65 val moreS = ["more"];
       
    66 
       
    67 
       
    68 (* types *)
       
    69 
       
    70 fun mk_fieldT ((c, T), U) = Type (suffix field_typeN c, [T, U]);
       
    71 
       
    72 fun dest_fieldT (typ as Type (c_field_type, [T, U])) =
       
    73       (case try (unsuffix field_typeN) c_field_type of
       
    74         None => raise TYPE ("dest_fieldT", [typ], [])
       
    75       | Some c => ((c, T), U))
       
    76   | dest_fieldT typ = raise TYPE ("dest_fieldT", [typ], []);
       
    77 
       
    78 
       
    79 (* constructors *)
       
    80 
       
    81 fun mk_fieldC U (c, T) = (suffix fieldN c, T --> U --> mk_fieldT ((c, T), U));
       
    82 
       
    83 fun mk_field ((c, t), u) =
       
    84   let val T = fastype_of t and U = fastype_of u
       
    85   in Const (suffix fieldN c, [T, U] ---> mk_fieldT ((c, T), U)) $ t $ u end;
       
    86 
       
    87 
       
    88 (* destructors *)
       
    89 
       
    90 fun mk_fstC U (c, T) = (suffix fstN c, mk_fieldT ((c, T), U) --> T);
       
    91 fun mk_sndC U (c, T) = (suffix sndN c, mk_fieldT ((c, T), U) --> U);
       
    92 
       
    93 fun dest_field fst_or_snd p =
       
    94   let
       
    95     val pT = fastype_of p;
       
    96     val ((c, T), U) = dest_fieldT pT;
       
    97     val (destN, destT) = if fst_or_snd then (fstN, T) else (sndN, U);
       
    98   in Const (suffix destN c, pT --> destT) $ p end;
       
    99 
       
   100 val mk_fst = dest_field true;
       
   101 val mk_snd = dest_field false;
       
   102 
       
   103 
       
   104 
       
   105 (** record operations **)
       
   106 
       
   107 (* types *)
       
   108 
       
   109 val mk_recordT = foldr mk_fieldT;
       
   110 
       
   111 fun dest_recordT T =
       
   112   (case try dest_fieldT T of
       
   113     None => ([], T)
       
   114   | Some (c_T, U) => apfst (cons c_T) (dest_recordT U));
       
   115 
       
   116 fun find_fieldT c rT =
       
   117   (case assoc (fst (dest_recordT rT), c) of
       
   118     None => raise TYPE ("find_field: " ^ c, [rT], [])
       
   119   | Some T => T);
       
   120 
       
   121 
       
   122 (* constructors *)
       
   123 
       
   124 val mk_record = foldr mk_field;
       
   125 
       
   126 
       
   127 (* selectors *)
       
   128 
       
   129 fun mk_selC rT (c, T) = (c, rT --> T);
       
   130 
       
   131 fun mk_sel c r =
       
   132   let val rT = fastype_of r
       
   133   in Const (mk_selC rT (c, find_fieldT c rT)) $ r end;
       
   134 
       
   135 
       
   136 (* updates *)
       
   137 
       
   138 fun mk_updateC rT (c, T) = (suffix updateN c, T --> rT --> rT);
       
   139 
       
   140 fun mk_update c x r =
       
   141   let val rT = fastype_of r
       
   142   in Const (mk_updateC rT (c, find_fieldT c rT)) $ x $ r end;
       
   143 
       
   144 
       
   145 
       
   146 (** concrete syntax for records **)
       
   147 
       
   148 (* parse translations *)
       
   149 
       
   150 fun field_tr (Const ("_field", _) $ Free (name, _) $ arg) =
       
   151       Syntax.const (suffix fieldN name) $ arg
       
   152   | field_tr t = raise TERM ("field_tr", [t]);
       
   153 
       
   154 fun fields_tr (Const ("_fields", _) $ field $ fields) =
       
   155       field_tr field :: fields_tr fields
       
   156   | fields_tr field = [field_tr field];
       
   157 
       
   158 fun record_tr (*"_record"*) [fields] =
       
   159       foldr (op $) (fields_tr fields, HOLogic.unit)
       
   160   | record_tr (*"_record"*) ts = raise TERM ("record_tr", ts);
       
   161 
       
   162 fun record_scheme_tr (*"_record_scheme"*) [fields, more] =
       
   163       foldr (op $) (fields_tr fields, more)
       
   164   | record_scheme_tr (*"_record_scheme"*) ts = raise TERM ("record_scheme_tr", ts);
       
   165 
       
   166 
       
   167 (* print translations *)		(* FIXME tune, activate *)
       
   168 
       
   169 (* FIXME ... :: tms *)
       
   170 fun fields_tr' (tm as Const (name_field, _) $ arg $ more) =
       
   171       (case try (unsuffix fieldN) name_field of
       
   172         Some name =>
       
   173           apfst (cons (Syntax.const "_field" $ Syntax.free name $ arg)) (fields_tr' more)
       
   174       | None => ([], tm))
       
   175   | fields_tr' tm = ([], tm);
       
   176 
       
   177 fun record_tr' tm =
       
   178   let
       
   179     val mk_fields = foldr (fn (field, fields) => Syntax.const "_fields" $ field $ fields);
       
   180     val (fields, more) = fields_tr' tm;
       
   181   in
       
   182     if HOLogic.is_unit more then
       
   183       Syntax.const "_record" $ mk_fields (split_last fields)
       
   184     else Syntax.const "_record_scheme" $ mk_fields (fields, more)
       
   185   end;
       
   186 
       
   187 fun field_tr' name [arg, more] = record_tr' (Syntax.const name $ arg $ more)
       
   188   | field_tr' _ _ = raise Match;
       
   189 
       
   190 
       
   191 
       
   192 (*** extend theory by record definition ***)
       
   193 
       
   194 (** record info **)
       
   195 
       
   196 (* type record_info and parent_info *)
       
   197 
       
   198 type record_info =
       
   199  {args: (string * sort) list,
       
   200   parent: (typ list * string) option,
       
   201   fields: (string * typ) list,
       
   202   simps: tthm list};
       
   203 
       
   204 type parent_info =
       
   205  {name: string,
       
   206   fields: (string * typ) list,
       
   207   simps: tthm list};
       
   208 
       
   209 
       
   210 (* theory data *)
       
   211 
       
   212 val recordsK = "HOL/records";
       
   213 exception Records of record_info Symtab.table;
       
   214 
       
   215 fun print_records thy = Display.print_data thy recordsK;
       
   216 
       
   217 local
       
   218   val empty = Records Symtab.empty;
       
   219 
       
   220   fun prep_ext (x as Records _) = x;
       
   221 
       
   222   fun merge (Records tab1, Records tab2) =
       
   223     Records (Symtab.merge (K true) (tab1, tab2));
       
   224 
       
   225   fun print sg (Records tab) =
       
   226     let
       
   227       val prt_typ = Sign.pretty_typ sg;
       
   228       val ext_const = Sign.cond_extern sg Sign.constK;
       
   229 
       
   230       fun pretty_parent None = []
       
   231         | pretty_parent (Some (Ts, name)) =
       
   232             [Pretty.block [prt_typ (Type (name, Ts)), Pretty.str " +"]];
       
   233 
       
   234       fun pretty_field (c, T) = Pretty.block
       
   235         [Pretty.str (ext_const c), Pretty.str " ::", Pretty.brk 1, Pretty.quote (prt_typ T)];
       
   236 
       
   237       fun pretty_record (name, {args, parent, fields, simps = _}) = Pretty.block (Pretty.fbreaks
       
   238         (Pretty.block [prt_typ (Type (name, map TFree args)), Pretty.str " = "] ::
       
   239           pretty_parent parent @ map pretty_field fields));
       
   240     in
       
   241       seq (Pretty.writeln o pretty_record) (Symtab.dest tab)
       
   242     end;
       
   243 in
       
   244   val record_thy_data = (recordsK, (empty, prep_ext, merge, print));
       
   245 end;
       
   246 
       
   247 
       
   248 (* get and put records *)
       
   249 
       
   250 fun get_records thy =
       
   251   (case Theory.get_data thy recordsK of
       
   252     Records tab => tab
       
   253   | _ => type_error recordsK);
       
   254 
       
   255 fun get_record thy name = Symtab.lookup (get_records thy, name);
       
   256 
       
   257 
       
   258 fun put_records tab thy =
       
   259   Theory.put_data (recordsK, Records tab) thy;
       
   260 
       
   261 fun put_new_record name info thy =
       
   262   thy |> put_records
       
   263     (Symtab.update_new ((name, info), get_records thy)
       
   264       handle Symtab.DUP _ => error ("Duplicate definition of record " ^ quote name));
       
   265 
       
   266 
       
   267 (* parent records *)
       
   268 
       
   269 fun inst_record thy (types, name) =
       
   270   let
       
   271     val sign = Theory.sign_of thy;
       
   272     fun err msg = error (msg ^ " parent record " ^ quote name);
       
   273 
       
   274     val {args, parent, fields, simps} =
       
   275       (case get_record thy name of Some info => info | None => err "Unknown");
       
   276 
       
   277     fun bad_inst ((x, S), T) =
       
   278       if Sign.of_sort sign (T, S) then None else Some x
       
   279     val bads = mapfilter bad_inst (args ~~ types);
       
   280 
       
   281     val inst = map fst args ~~ types;
       
   282     val subst = Term.map_type_tfree (fn (x, _) => the (assoc (inst, x)));
       
   283   in
       
   284     if length types <> length args then
       
   285       err "Bad number of arguments for"
       
   286     else if not (null bads) then
       
   287       err ("Ill-sorted instantiation of " ^ commas bads ^ " in")
       
   288     else (apsome (apfst (map subst)) parent, map (apsnd subst) fields, simps)
       
   289   end;
       
   290 
       
   291 fun add_parents thy (None, parents) = parents
       
   292   | add_parents thy (Some (types, name), parents) =
       
   293       let val (pparent, pfields, psimps) = inst_record thy (types, name)
       
   294       in add_parents thy (pparent, {name = name, fields = pfields, simps = psimps} :: parents) end;
       
   295 
       
   296 
       
   297 
       
   298 (** record theorems **)
       
   299 
       
   300 (* proof by simplification *)
       
   301 
       
   302 fun prove_simp thy opt_ss simps =
       
   303   let val ss = if_none opt_ss HOL_basic_ss addsimps simps in
       
   304     fn goal => Goals.prove_goalw_cterm [] (Thm.cterm_of (sign_of thy) goal)
       
   305       (K [ALLGOALS (Simplifier.simp_tac ss)])
       
   306   end;
       
   307 
       
   308 
       
   309 
       
   310 (** internal theory extender **)
       
   311 
       
   312 (*do the actual record definition, assuming that all arguments are
       
   313   well-formed*)
       
   314 
       
   315 fun record_definition (args, bname) parent (parents: parent_info list) bfields thy =
       
   316   let
       
   317     val sign = Theory.sign_of thy;
       
   318     val full = Sign.full_name_path sign bname;
       
   319 
       
   320 
       
   321     (* input *)
       
   322 
       
   323     val alphas = map fst args;
       
   324     val name = Sign.full_name sign bname;		(* FIXME !? *)
       
   325     val parent_fields = flat (map #fields parents);
       
   326     val fields = map (apfst full) bfields;
       
   327 
       
   328     val all_fields = parent_fields @ fields;
       
   329     val all_types = map snd all_fields;
       
   330 
       
   331 
       
   332     (* term / type components *)
       
   333 
       
   334     val zeta = variant alphas "'z";
       
   335     val moreT = TFree (zeta, moreS);
       
   336 
       
   337     val xs = variantlist (map fst bfields, []);
       
   338     val vars = map2 Free (xs, map snd fields);
       
   339     val more = Free (variant xs moreN, moreT);
       
   340 
       
   341     val rec_schemeT = mk_recordT (all_fields, moreT);
       
   342     val recT = mk_recordT (all_fields, HOLogic.unitT);
       
   343 
       
   344     (* FIXME tune *)
       
   345     val make_schemeT = all_types ---> moreT --> rec_schemeT;
       
   346     val make_scheme = Const (full make_schemeN, make_schemeT);
       
   347     val makeT = all_types ---> recT;
       
   348     val make = Const (full makeN, makeT);
       
   349 
       
   350     val parent_more = funpow (length parent_fields) mk_snd;
       
   351 
       
   352 
       
   353     (* prepare type definitions *)
       
   354 
       
   355     (*field types*)
       
   356     fun mk_fieldT_spec ((c, T), a) =
       
   357       (suffix field_typeN c, [a, zeta],
       
   358         HOLogic.mk_prodT (TFree (a, HOLogic.termS), moreT), Syntax.NoSyn);
       
   359     val fieldT_specs = map2 mk_fieldT_spec (bfields, alphas);
       
   360 
       
   361     (*record types*)
       
   362     val recordT_specs =
       
   363       [(suffix schemeN bname, alphas @ [zeta], rec_schemeT, Syntax.NoSyn),
       
   364         (bname, alphas, recT, Syntax.NoSyn)];
       
   365 
       
   366 
       
   367     (* prepare declarations *)
       
   368 
       
   369     val field_decls = map (mk_fieldC moreT) fields;
       
   370     val dest_decls = map (mk_fstC moreT) fields @ map (mk_sndC moreT) fields;
       
   371     val sel_decls = map (mk_selC rec_schemeT) fields;
       
   372     val update_decls = map (mk_updateC rec_schemeT) fields;
       
   373     val make_decls = [(make_schemeN, make_schemeT), (makeN, makeT)];
       
   374 
       
   375 
       
   376     (* prepare definitions *)
       
   377 
       
   378     (*field constructors*)
       
   379     fun mk_field_spec ((c, _), v) =
       
   380       Logic.mk_defpair (mk_field ((c, v), more), HOLogic.mk_prod (v, more));
       
   381     val field_specs = map2 mk_field_spec (fields, vars);
       
   382 
       
   383     (*field destructors*)
       
   384     fun mk_dest_spec dest dest' (c, T) =
       
   385       let
       
   386         val p = Free ("p",  mk_fieldT ((c, T), moreT));
       
   387         val p' = Free ("p",  HOLogic.mk_prodT (T, moreT));  (*Note: field types are abbreviations*)
       
   388       in Logic.mk_defpair (dest p, dest' p') end;
       
   389     val dest_specs =
       
   390       map (mk_dest_spec mk_fst HOLogic.mk_fst) fields @
       
   391       map (mk_dest_spec mk_snd HOLogic.mk_snd) fields;
       
   392 
       
   393     (*field selectors*)		(* FIXME tune *)
       
   394     fun mk_sel_specs _ [] specs = rev specs
       
   395       | mk_sel_specs prfx ((c, T) :: fs) specs =
       
   396           let
       
   397             val prfx' = prfx @ [(c, T)];
       
   398             val r = Free ("r", mk_recordT (prfx' @ fs, moreT));
       
   399             val spec = Logic.mk_defpair (mk_sel c r, mk_fst (funpow (length prfx) mk_snd r));
       
   400           in mk_sel_specs prfx' fs (spec :: specs) end;
       
   401     val sel_specs = mk_sel_specs parent_fields fields [];
       
   402 
       
   403     (*updates*)
       
   404     val update_specs = [];	(* FIXME *)
       
   405 
       
   406     (*makes*)
       
   407     val make_specs =
       
   408       map Logic.mk_defpair
       
   409         [(list_comb (make_scheme, vars) $ more, mk_record (map fst fields ~~ vars, more)),
       
   410           (list_comb (make, vars), mk_record (map fst fields ~~ vars, HOLogic.unit))];
       
   411 
       
   412 
       
   413     (* 1st stage: defs_thy *)
       
   414 
       
   415     val defs_thy =
       
   416       thy
       
   417       |> Theory.add_path bname
       
   418       |> Theory.add_tyabbrs_i (fieldT_specs @ recordT_specs)
       
   419       |> (Theory.add_consts_i o map (Syntax.no_syn o apfst Sign.base_name))
       
   420         (field_decls @ dest_decls @ sel_decls @ update_decls @ make_decls)
       
   421       |> (PureThy.add_defs_i o map Attribute.none)
       
   422         (field_specs @ dest_specs @ sel_specs @ update_specs @ make_specs);
       
   423 
       
   424     local fun get_defs specs = map (PureThy.get_tthm defs_thy o fst) specs in
       
   425       val make_defs = get_defs make_specs;
       
   426       val field_defs = get_defs field_specs;
       
   427       val sel_defs = get_defs sel_specs;
       
   428       val update_defs = get_defs update_specs;
       
   429     end;
       
   430 
       
   431 
       
   432     (* 2nd stage: thms_thy *)
       
   433 
       
   434     val thms_thy =
       
   435       defs_thy
       
   436       |> (PureThy.add_tthmss o map Attribute.none)
       
   437         [("make_defs", make_defs),
       
   438           ("field_defs", field_defs),
       
   439           ("sel_defs", sel_defs),
       
   440           ("update_defs", update_defs)]
       
   441 (*    |> record_theorems FIXME *)
       
   442 
       
   443 
       
   444     (* 3rd stage: final_thy *)
       
   445 
       
   446     val final_thy =
       
   447       thms_thy
       
   448       |> put_new_record name
       
   449         {args = args, parent = parent, fields = fields, simps = [] (* FIXME *)}
       
   450       |> Theory.parent_path;
       
   451 
       
   452   in final_thy end;
       
   453 
       
   454 
       
   455 
       
   456 (** theory extender interface **)
       
   457 
       
   458 (*do all preparations and error checks here, deferring the real work
       
   459   to record_definition above*)
       
   460 
       
   461 
       
   462 (* prepare arguments *)
       
   463 
       
   464 (*Note: read_raw_typ avoids expanding type abbreviations*)
       
   465 fun read_raw_parent sign s =
       
   466   (case Sign.read_raw_typ (sign, K None) s handle TYPE (msg, _, _) => error msg of
       
   467     Type (name, Ts) => (Ts, name)
       
   468   | _ => error ("Bad parent record specification: " ^ quote s));
       
   469 
       
   470 fun read_typ sign (env, s) =
       
   471   let
       
   472     fun def_type (x, ~1) = assoc (env, x)
       
   473       | def_type _ = None;
       
   474     val T = Type.no_tvars (Sign.read_typ (sign, def_type) s) handle TYPE (msg, _, _) => error msg;
       
   475   in (Term.add_typ_tfrees (T, env), T) end;
       
   476 
       
   477 fun cert_typ sign (env, raw_T) =
       
   478   let val T = Type.no_tvars (Sign.certify_typ sign raw_T) handle TYPE (msg, _, _) => error msg
       
   479   in (Term.add_typ_tfrees (T, env), T) end;
       
   480 
       
   481 
       
   482 (* add_record *)
       
   483 
       
   484 fun gen_add_record prep_typ prep_raw_parent (params, bname) raw_parent raw_fields thy =
       
   485   let
       
   486     val _ = Theory.require thy "Record" "record definitions";
       
   487     val sign = Theory.sign_of thy;
       
   488 
       
   489 
       
   490     (* parents *)
       
   491 
       
   492     fun prep_inst T = snd (cert_typ sign ([], T));
       
   493 
       
   494     val parent = apsome (apfst (map prep_inst) o prep_raw_parent sign) raw_parent
       
   495       handle ERROR => error ("The error(s) above in parent record specification");
       
   496     val parents = add_parents thy (parent, []);
       
   497 
       
   498     val init_env =
       
   499       (case parent of
       
   500         None => []
       
   501       | Some (types, _) => foldr Term.add_typ_tfrees (types, []));
       
   502 
       
   503 
       
   504     (* fields *)
       
   505 
       
   506     fun prep_fields (env, []) = (env, [])
       
   507       | prep_fields (env, (c, raw_T) :: fs) =
       
   508           let
       
   509             val (env', T) = prep_typ sign (env, raw_T) handle ERROR =>
       
   510               error ("The error(s) above occured in field " ^ quote c);
       
   511             val (env'', fs') = prep_fields (env', fs);
       
   512       in (env'', (c, T) :: fs') end;
       
   513 
       
   514     val (envir, bfields) = prep_fields (init_env, raw_fields);
       
   515     val envir_names = map fst envir;
       
   516 
       
   517 
       
   518     (* args *)
       
   519 
       
   520     val defaultS = Sign.defaultS sign;
       
   521     val args = map (fn x => (x, if_none (assoc (envir, x)) defaultS)) params;
       
   522 
       
   523 
       
   524     (* errors *)
       
   525 
       
   526     val err_dup_parms =
       
   527       (case duplicates params of
       
   528         [] => []
       
   529       | dups => ["Duplicate parameters " ^ commas params]);
       
   530 
       
   531     val err_extra_frees =
       
   532       (case gen_rems (op =) (envir_names, params) of
       
   533         [] => []
       
   534       | extras => ["Extraneous free type variables " ^ commas extras]);
       
   535 
       
   536     val err_no_fields = if null bfields then ["No fields"] else [];
       
   537 
       
   538     val err_dup_fields =
       
   539       (case duplicates (map fst bfields) of
       
   540         [] => []
       
   541       | dups => ["Duplicate fields " ^ commas_quote dups]);
       
   542 
       
   543     val err_dup_sorts =
       
   544       (case duplicates envir_names of
       
   545         [] => []
       
   546       | dups => ["Inconsistent sort constraints for " ^ commas dups]);
       
   547 
       
   548     val errs =
       
   549       err_dup_parms @ err_extra_frees @ err_no_fields @ err_dup_fields @ err_dup_sorts;
       
   550   in
       
   551     if null errs then ()
       
   552     else error (cat_lines errs);
       
   553 
       
   554     writeln ("Defining record " ^ quote bname ^ " ...");
       
   555     thy |> record_definition (args, bname) parent parents bfields
       
   556   end
       
   557   handle ERROR => error ("Failed to define record " ^ quote bname);
       
   558 
       
   559 val add_record = gen_add_record read_typ read_raw_parent;
       
   560 val add_record_i = gen_add_record cert_typ (K I);
       
   561 
       
   562 
       
   563 
       
   564 (** setup theory **)
       
   565 
       
   566 val setup =
       
   567  [Theory.init_data [record_thy_data],
       
   568   Theory.add_trfuns
       
   569     ([], [("_record", record_tr), ("_record_scheme", record_scheme_tr)], [], [])];
       
   570 
       
   571 
       
   572 end;