src/HOL/Library/datatype_records.ML
author haftmann
Fri Mar 22 19:18:08 2019 +0000 (4 months ago)
changeset 69946 494934c30f38
parent 69598 81caae4fc4fa
permissions -rw-r--r--
improved code equations taken over from AFP
     1 signature DATATYPE_RECORDS = sig
     2   type ctr_options = string -> bool
     3   type ctr_options_cmd = Proof.context -> string -> bool
     4 
     5   val default_ctr_options: ctr_options
     6   val default_ctr_options_cmd: ctr_options_cmd
     7 
     8   val mk_update_defs: string -> local_theory -> local_theory
     9 
    10   val record: binding -> ctr_options -> (binding option * (typ * sort)) list ->
    11     (binding * typ) list -> local_theory -> local_theory
    12 
    13   val record_cmd: binding -> ctr_options_cmd ->
    14     (binding option * (string * string option)) list -> (binding * string) list -> local_theory ->
    15     local_theory
    16 
    17   val setup: theory -> theory
    18 end
    19 
    20 structure Datatype_Records : DATATYPE_RECORDS = struct
    21 
    22 type ctr_options = string -> bool
    23 type ctr_options_cmd = Proof.context -> string -> bool
    24 
    25 val default_ctr_options = Plugin_Name.default_filter
    26 val default_ctr_options_cmd = K Plugin_Name.default_filter
    27 
    28 type data = string Symtab.table
    29 
    30 structure Data = Theory_Data
    31 (
    32   type T = data
    33   val empty = Symtab.empty
    34   val merge = Symtab.merge op =
    35   val extend = I
    36 )
    37 
    38 fun mk_eq_dummy (lhs, rhs) =
    39   Const (\<^const_name>\<open>HOL.eq\<close>, dummyT --> dummyT --> \<^typ>\<open>bool\<close>) $ lhs $ rhs
    40 
    41 val dummify = map_types (K dummyT)
    42 fun repeat_split_tac ctxt thm = REPEAT_ALL_NEW (CHANGED o Splitter.split_tac ctxt [thm])
    43 
    44 fun mk_update_defs typ_name lthy =
    45   let
    46     val short_name = Long_Name.base_name typ_name
    47     val {ctrs, casex, selss, split, sel_thmss, injects, ...} =
    48       the (Ctr_Sugar.ctr_sugar_of lthy typ_name)
    49     val ctr = case ctrs of [ctr] => ctr | _ => error "Datatype_Records.mk_update_defs: expected only single constructor"
    50     val sels = case selss of [sels] => sels | _ => error "Datatype_Records.mk_update_defs: expected selectors"
    51     val sels_dummy = map dummify sels
    52     val ctr_dummy = dummify ctr
    53     val casex_dummy = dummify casex
    54     val len = length sels
    55 
    56     val simp_thms = flat sel_thmss @ injects
    57 
    58     fun mk_name sel =
    59       Binding.name ("update_" ^ Long_Name.base_name (fst (dest_Const sel)))
    60 
    61     val thms_binding = (Binding.name "record_simps", @{attributes [simp]})
    62 
    63     fun mk_t idx =
    64       let
    65         val body =
    66           fold_rev (fn pos => fn t => t $ (if len - pos = idx + 1 then Bound len $ Bound pos else Bound pos)) (0 upto len - 1) ctr_dummy
    67           |> fold_rev (fn idx => fn t => Abs ("x" ^ Value.print_int idx, dummyT, t)) (1 upto len)
    68       in
    69         Abs ("f", dummyT, casex_dummy $ body)
    70       end
    71 
    72     fun simp_only_tac ctxt =
    73       REPEAT_ALL_NEW (resolve_tac ctxt @{thms impI allI}) THEN'
    74         asm_full_simp_tac (put_simpset HOL_basic_ss ctxt addsimps simp_thms)
    75 
    76     fun prove ctxt defs ts n =
    77       let
    78         val t = nth ts n
    79 
    80         val sel_dummy = nth sels_dummy n
    81         val t_dummy = dummify t
    82         fun tac {context = ctxt, ...} =
    83           Goal.conjunction_tac 1 THEN
    84             Local_Defs.unfold_tac ctxt defs THEN
    85             PARALLEL_ALLGOALS (repeat_split_tac ctxt split THEN' simp_only_tac ctxt)
    86 
    87         val sel_upd_same_thm =
    88           let
    89             val ([f, x], ctxt') = Variable.add_fixes ["f", "x"] ctxt
    90             val f = Free (f, dummyT)
    91             val x = Free (x, dummyT)
    92 
    93             val lhs = sel_dummy $ (t_dummy $ f $ x)
    94             val rhs = f $ (sel_dummy $ x)
    95             val prop = Syntax.check_term ctxt' (HOLogic.mk_Trueprop (mk_eq_dummy (lhs, rhs)))
    96           in
    97             [Goal.prove_future ctxt' [] [] prop tac]
    98             |> Variable.export ctxt' ctxt
    99           end
   100 
   101         val sel_upd_diff_thms =
   102           let
   103             val ([f, x], ctxt') = Variable.add_fixes ["f", "x"] ctxt
   104             val f = Free (f, dummyT)
   105             val x = Free (x, dummyT)
   106 
   107             fun lhs sel = sel $ (t_dummy $ f $ x)
   108             fun rhs sel = sel $ x
   109             fun eq sel = (lhs sel, rhs sel)
   110             fun is_n i = i = n
   111             val props =
   112               sels_dummy ~~ (0 upto len - 1)
   113               |> filter_out (is_n o snd)
   114               |> map (HOLogic.mk_Trueprop o mk_eq_dummy o eq o fst)
   115               |> Syntax.check_terms ctxt'
   116           in
   117             if length props > 0 then
   118               Goal.prove_common ctxt' (SOME ~1) [] [] props tac
   119               |> Variable.export ctxt' ctxt
   120             else
   121               []
   122           end
   123 
   124         val upd_comp_thm =
   125           let
   126             val ([f, g, x], ctxt') = Variable.add_fixes ["f", "g", "x"] ctxt
   127             val f = Free (f, dummyT)
   128             val g = Free (g, dummyT)
   129             val x = Free (x, dummyT)
   130 
   131             val lhs = t_dummy $ f $ (t_dummy $ g $ x)
   132             val rhs = t_dummy $ Abs ("a", dummyT, f $ (g $ Bound 0)) $ x
   133             val prop = Syntax.check_term ctxt' (HOLogic.mk_Trueprop (mk_eq_dummy (lhs, rhs)))
   134           in
   135             [Goal.prove_future ctxt' [] [] prop tac]
   136             |> Variable.export ctxt' ctxt
   137           end
   138 
   139         val upd_comm_thms =
   140           let
   141             fun prop i ctxt =
   142               let
   143                 val ([f, g, x], ctxt') = Variable.variant_fixes ["f", "g", "x"] ctxt
   144                 val self = t_dummy $ Free (f, dummyT)
   145                 val other = dummify (nth ts i) $ Free (g, dummyT)
   146                 val lhs = other $ (self $ Free (x, dummyT))
   147                 val rhs = self $ (other $ Free (x, dummyT))
   148               in
   149                 (HOLogic.mk_Trueprop (mk_eq_dummy (lhs, rhs)), ctxt')
   150               end
   151             val (props, ctxt') = fold_map prop (0 upto n - 1) ctxt
   152             val props = Syntax.check_terms ctxt' props
   153           in
   154             if length props > 0 then
   155               Goal.prove_common ctxt' (SOME ~1) [] [] props tac
   156               |> Variable.export ctxt' ctxt
   157             else
   158               []
   159           end
   160 
   161         val upd_sel_thm =
   162           let
   163             val ([x], ctxt') = Variable.add_fixes ["x"] ctxt
   164 
   165             val lhs = t_dummy $ Abs("_", dummyT, (sel_dummy $ Free(x, dummyT))) $ Free (x, dummyT)
   166             val rhs = Free (x, dummyT)
   167             val prop = Syntax.check_term ctxt (HOLogic.mk_Trueprop (mk_eq_dummy (lhs, rhs)))
   168           in
   169             [Goal.prove_future ctxt [] [] prop tac]
   170             |> Variable.export ctxt' ctxt
   171           end
   172       in
   173         sel_upd_same_thm @ sel_upd_diff_thms @ upd_comp_thm @ upd_comm_thms @ upd_sel_thm
   174       end
   175 
   176     fun define name t =
   177       Local_Theory.define ((name, NoSyn), ((Binding.empty, @{attributes [datatype_record_update, code]}),t))
   178       #> apfst (apsnd snd)
   179 
   180     val (updates, (lthy'', lthy')) =
   181       lthy
   182       |> Local_Theory.open_target
   183       |> snd
   184       |> Local_Theory.map_background_naming (Name_Space.qualified_path false (Binding.name short_name))
   185       |> @{fold_map 2} define (map mk_name sels) (Syntax.check_terms lthy (map mk_t (0 upto len - 1)))
   186       ||> `Local_Theory.close_target
   187 
   188     val phi = Proof_Context.export_morphism lthy' lthy''
   189 
   190     val (update_ts, update_defs) =
   191       split_list updates
   192       |>> map (Morphism.term phi)
   193       ||> map (Morphism.thm phi)
   194 
   195     val thms = flat (map (prove lthy'' update_defs update_ts) (0 upto len-1))
   196 
   197     fun insert sel =
   198       Symtab.insert op = (fst (dest_Const sel), Local_Theory.full_name lthy' (mk_name sel))
   199   in
   200     lthy''
   201     |> Local_Theory.map_background_naming (Name_Space.mandatory_path short_name)
   202     |> Local_Theory.note (thms_binding, thms)
   203     |> snd
   204     |> Local_Theory.restore_background_naming lthy
   205     |> Local_Theory.background_theory (Data.map (fold insert sels))
   206   end
   207 
   208 fun record binding opts tyargs args lthy =
   209   let
   210     val constructor =
   211       (((Binding.empty, Binding.map_name (fn c => "make_" ^ c) binding), args), NoSyn)
   212 
   213     val datatyp =
   214       ((tyargs, binding), NoSyn)
   215 
   216     val dtspec =
   217       ((opts, false),
   218        [(((datatyp, [constructor]), (Binding.empty, Binding.empty, Binding.empty)), [])])
   219 
   220     val lthy' =
   221       BNF_FP_Def_Sugar.co_datatypes BNF_Util.Least_FP BNF_LFP.construct_lfp dtspec lthy
   222       |> mk_update_defs (Local_Theory.full_name lthy binding)
   223   in
   224     lthy'
   225   end
   226 
   227 fun record_cmd binding opts tyargs args lthy =
   228   record binding (opts lthy)
   229     (map (apsnd (apfst (Syntax.parse_typ lthy) o apsnd (Typedecl.read_constraint lthy))) tyargs)
   230     (map (apsnd (Syntax.parse_typ lthy)) args) lthy
   231 
   232 (* syntax *)
   233 (* copied and adapted from record.ML *)
   234 
   235 val read_const =
   236   dest_Const oo Proof_Context.read_const {proper = true, strict = true}
   237 
   238 fun field_tr ((Const (\<^syntax_const>\<open>_datatype_field\<close>, _) $ Const (name, _) $ arg)) = (name, arg)
   239   | field_tr t = raise TERM ("field_tr", [t]);
   240 
   241 fun fields_tr (Const (\<^syntax_const>\<open>_datatype_fields\<close>, _) $ t $ u) = field_tr t :: fields_tr u
   242   | fields_tr t = [field_tr t];
   243 
   244 fun record_fields_tr ctxt t =
   245   let
   246     val assns = map (apfst (read_const ctxt)) (fields_tr t)
   247 
   248     val typ_name =
   249       snd (fst (hd assns))
   250       |> domain_type
   251       |> dest_Type
   252       |> fst
   253 
   254     val assns' = map (apfst fst) assns
   255 
   256     val {ctrs, selss, ...} = the (Ctr_Sugar.ctr_sugar_of ctxt typ_name)
   257     val ctr = case ctrs of [ctr] => ctr | _ => error "BNF_Record.record_fields_tr: expected only single constructor"
   258     val sels = case selss of [sels] => sels | _ => error "BNF_Record.record_fields_tr: expected selectors"
   259     val ctr_dummy = Const (fst (dest_Const ctr), dummyT)
   260 
   261     fun mk_arg name =
   262       case AList.lookup op = assns' name of
   263         NONE => error ("BNF_Record.record_fields_tr: missing field " ^ name)
   264       | SOME t => t
   265   in
   266     if length assns = length sels then
   267       list_comb (ctr_dummy, map (mk_arg o fst o dest_Const) sels)
   268     else
   269       error ("BNF_Record.record_fields_tr: expected " ^ Value.print_int (length sels) ^ " field(s)")
   270   end
   271 
   272 fun field_update_tr ctxt (Const (\<^syntax_const>\<open>_datatype_field_update\<close>, _) $ Const (name, _) $ arg) =
   273       let
   274         val thy = Proof_Context.theory_of ctxt
   275         val (name, _) = read_const ctxt name
   276       in
   277         case Symtab.lookup (Data.get thy) name of
   278           NONE => raise Fail ("not a valid record field: " ^ name)
   279         | SOME s => Const (s, dummyT) $ Abs (Name.uu_, dummyT, arg)
   280       end
   281   | field_update_tr _ t = raise TERM ("field_update_tr", [t]);
   282 
   283 fun field_updates_tr ctxt (Const (\<^syntax_const>\<open>_datatype_field_updates\<close>, _) $ t $ u) =
   284       field_update_tr ctxt t :: field_updates_tr ctxt u
   285   | field_updates_tr ctxt t = [field_update_tr ctxt t];
   286 
   287 fun record_tr ctxt [t] = record_fields_tr ctxt t
   288   | record_tr _ ts = raise TERM ("record_tr", ts);
   289 
   290 fun record_update_tr ctxt [t, u] = fold (curry op $) (field_updates_tr ctxt u) t
   291   | record_update_tr _ ts = raise TERM ("record_update_tr", ts);
   292 
   293 val parse_ctr_options =
   294   Scan.optional (\<^keyword>\<open>(\<close> |-- Parse.list1 (Plugin_Name.parse_filter >> K) --| \<^keyword>\<open>)\<close> >>
   295     (fn fs => fold I fs default_ctr_options_cmd)) default_ctr_options_cmd
   296 
   297 val parser =
   298   (parse_ctr_options -- BNF_Util.parse_type_args_named_constrained -- Parse.binding) --
   299     (\<^keyword>\<open>=\<close> |-- Scan.repeat1 (Parse.binding -- (Parse.$$$ "::" |-- Parse.!!! Parse.typ)))
   300 
   301 val _ =
   302   Outer_Syntax.local_theory
   303     \<^command_keyword>\<open>datatype_record\<close>
   304     "Defines a record based on the BNF/datatype machinery"
   305     (parser >> (fn (((ctr_options, tyargs), binding), args) =>
   306       record_cmd binding ctr_options tyargs args))
   307 
   308 val setup =
   309    (Sign.parse_translation
   310      [(\<^syntax_const>\<open>_datatype_record_update\<close>, record_update_tr),
   311       (\<^syntax_const>\<open>_datatype_record\<close>, record_tr)]);
   312 
   313 end