src/HOL/Library/datatype_records.ML
changeset 68686 7f8db1c4ebec
parent 67611 7929240e44d4
child 68910 a21202dfe3eb
equal deleted inserted replaced
68685:4b367da119ed 68686:7f8db1c4ebec
     5   val default_ctr_options: ctr_options
     5   val default_ctr_options: ctr_options
     6   val default_ctr_options_cmd: ctr_options_cmd
     6   val default_ctr_options_cmd: ctr_options_cmd
     7 
     7 
     8   val mk_update_defs: string -> local_theory -> local_theory
     8   val mk_update_defs: string -> local_theory -> local_theory
     9 
     9 
    10   val bnf_record: binding -> ctr_options -> (binding option * (typ * sort)) list ->
    10   val record: binding -> ctr_options -> (binding option * (typ * sort)) list ->
    11     (binding * typ) list -> local_theory -> local_theory
    11     (binding * typ) list -> local_theory -> local_theory
    12 
    12 
    13   val bnf_record_cmd: binding -> ctr_options_cmd ->
    13   val record_cmd: binding -> ctr_options_cmd ->
    14     (binding option * (string * string option)) list -> (binding * string) list -> local_theory ->
    14     (binding option * (string * string option)) list -> (binding * string) list -> local_theory ->
    15     local_theory
    15     local_theory
    16 
    16 
    17   val setup: theory -> theory
    17   val setup: theory -> theory
    18 end
    18 end
    33   val empty = Symtab.empty
    33   val empty = Symtab.empty
    34   val merge = Symtab.merge op =
    34   val merge = Symtab.merge op =
    35   val extend = I
    35   val extend = I
    36 )
    36 )
    37 
    37 
       
    38 fun mk_eq_dummy (lhs, rhs) =
       
    39   Const (@{const_name HOL.eq}, dummyT --> dummyT --> @{typ bool}) $ lhs $ rhs
       
    40 
       
    41 val dummify = map_types (K dummyT)
       
    42 fun repeat_split_tac ctxt thm = REPEAT_ALL_NEW (CHANGED o Splitter.split_tac ctxt [thm])
       
    43 
    38 fun mk_update_defs typ_name lthy =
    44 fun mk_update_defs typ_name lthy =
    39   let
    45   let
    40     val short_name = Long_Name.base_name typ_name
    46     val short_name = Long_Name.base_name typ_name
    41 
    47     val {ctrs, casex, selss, split, sel_thmss, injects, ...} =
    42     val {ctrs, casex, selss, ...} = the (Ctr_Sugar.ctr_sugar_of lthy typ_name)
    48       the (Ctr_Sugar.ctr_sugar_of lthy typ_name)
    43     val ctr = case ctrs of [ctr] => ctr | _ => error "BNF_Record.mk_update_defs: expected only single constructor"
    49     val ctr = case ctrs of [ctr] => ctr | _ => error "Datatype_Records.mk_update_defs: expected only single constructor"
    44     val sels = case selss of [sels] => sels | _ => error "BNF_Record.mk_update_defs: expected selectors"
    50     val sels = case selss of [sels] => sels | _ => error "Datatype_Records.mk_update_defs: expected selectors"
    45     val ctr_dummy = Const (fst (dest_Const ctr), dummyT)
    51     val sels_dummy = map dummify sels
    46     val casex_dummy = Const (fst (dest_Const casex), dummyT)
    52     val ctr_dummy = dummify ctr
    47 
    53     val casex_dummy = dummify casex
    48     val len = length sels
    54     val len = length sels
       
    55 
       
    56     val simp_thms = flat sel_thmss @ injects
    49 
    57 
    50     fun mk_name sel =
    58     fun mk_name sel =
    51       Binding.name ("update_" ^ Long_Name.base_name (fst (dest_Const sel)))
    59       Binding.name ("update_" ^ Long_Name.base_name (fst (dest_Const sel)))
       
    60 
       
    61     val thms_binding = (@{binding record_simps}, @{attributes [simp]})
    52 
    62 
    53     fun mk_t idx =
    63     fun mk_t idx =
    54       let
    64       let
    55         val body =
    65         val body =
    56           fold_rev (fn pos => fn t => t $ (if len - pos = idx + 1 then Bound len $ Bound pos else Bound pos)) (0 upto len - 1) ctr_dummy
    66           fold_rev (fn pos => fn t => t $ (if len - pos = idx + 1 then Bound len $ Bound pos else Bound pos)) (0 upto len - 1) ctr_dummy
    57           |> fold_rev (fn idx => fn t => Abs ("x" ^ Value.print_int idx, dummyT, t)) (1 upto len)
    67           |> fold_rev (fn idx => fn t => Abs ("x" ^ Value.print_int idx, dummyT, t)) (1 upto len)
    58       in
    68       in
    59         Abs ("f", dummyT, casex_dummy $ body)
    69         Abs ("f", dummyT, casex_dummy $ body)
    60       end
    70       end
    61 
    71 
       
    72     fun simp_only_tac ctxt =
       
    73       REPEAT_ALL_NEW (resolve_tac ctxt @{thms impI allI}) THEN'
       
    74         asm_full_simp_tac (put_simpset HOL_basic_ss ctxt addsimps simp_thms)
       
    75 
       
    76     fun prove ctxt defs ts n =
       
    77       let
       
    78         val t = nth ts n
       
    79 
       
    80         val sel_dummy = nth sels_dummy n
       
    81         val t_dummy = dummify t
       
    82         fun tac {context = ctxt, ...} =
       
    83           Goal.conjunction_tac 1 THEN
       
    84             Local_Defs.unfold_tac ctxt defs THEN
       
    85             PARALLEL_ALLGOALS (repeat_split_tac ctxt split THEN' simp_only_tac ctxt)
       
    86 
       
    87         val sel_upd_same_thm =
       
    88           let
       
    89             val ([f, x], ctxt') = Variable.add_fixes ["f", "x"] ctxt
       
    90             val f = Free (f, dummyT)
       
    91             val x = Free (x, dummyT)
       
    92 
       
    93             val lhs = sel_dummy $ (t_dummy $ f $ x)
       
    94             val rhs = f $ (sel_dummy $ x)
       
    95             val prop = Syntax.check_term ctxt' (HOLogic.mk_Trueprop (mk_eq_dummy (lhs, rhs)))
       
    96           in
       
    97             [Goal.prove_future ctxt' [] [] prop tac]
       
    98             |> Variable.export ctxt' ctxt
       
    99           end
       
   100 
       
   101         val sel_upd_diff_thms =
       
   102           let
       
   103             val ([f, x], ctxt') = Variable.add_fixes ["f", "x"] ctxt
       
   104             val f = Free (f, dummyT)
       
   105             val x = Free (x, dummyT)
       
   106 
       
   107             fun lhs sel = sel $ (t_dummy $ f $ x)
       
   108             fun rhs sel = sel $ x
       
   109             fun eq sel = (lhs sel, rhs sel)
       
   110             fun is_n i = i = n
       
   111             val props =
       
   112               sels_dummy ~~ (0 upto len - 1)
       
   113               |> filter_out (is_n o snd)
       
   114               |> map (HOLogic.mk_Trueprop o mk_eq_dummy o eq o fst)
       
   115               |> Syntax.check_terms ctxt'
       
   116           in
       
   117             if length props > 0 then
       
   118               Goal.prove_common ctxt' (SOME ~1) [] [] props tac
       
   119               |> Variable.export ctxt' ctxt
       
   120             else
       
   121               []
       
   122           end
       
   123 
       
   124         val upd_comp_thm =
       
   125           let
       
   126             val ([f, g, x], ctxt') = Variable.add_fixes ["f", "g", "x"] ctxt
       
   127             val f = Free (f, dummyT)
       
   128             val g = Free (g, dummyT)
       
   129             val x = Free (x, dummyT)
       
   130 
       
   131             val lhs = t_dummy $ f $ (t_dummy $ g $ x)
       
   132             val rhs = t_dummy $ Abs ("a", dummyT, f $ (g $ Bound 0)) $ x
       
   133             val prop = Syntax.check_term ctxt' (HOLogic.mk_Trueprop (mk_eq_dummy (lhs, rhs)))
       
   134           in
       
   135             [Goal.prove_future ctxt' [] [] prop tac]
       
   136             |> Variable.export ctxt' ctxt
       
   137           end
       
   138 
       
   139         val upd_comm_thms =
       
   140           let
       
   141             fun prop i ctxt =
       
   142               let
       
   143                 val ([f, g, x], ctxt') = Variable.variant_fixes ["f", "g", "x"] ctxt
       
   144                 val self = t_dummy $ Free (f, dummyT)
       
   145                 val other = dummify (nth ts i) $ Free (g, dummyT)
       
   146                 val lhs = other $ (self $ Free (x, dummyT))
       
   147                 val rhs = self $ (other $ Free (x, dummyT))
       
   148               in
       
   149                 (HOLogic.mk_Trueprop (mk_eq_dummy (lhs, rhs)), ctxt')
       
   150               end
       
   151             val (props, ctxt') = fold_map prop (0 upto n - 1) ctxt
       
   152             val props = Syntax.check_terms ctxt' props
       
   153           in
       
   154             if length props > 0 then
       
   155               Goal.prove_common ctxt' (SOME ~1) [] [] props tac
       
   156               |> Variable.export ctxt' ctxt
       
   157             else
       
   158               []
       
   159           end
       
   160 
       
   161         val upd_sel_thm =
       
   162           let
       
   163             val ([x], ctxt') = Variable.add_fixes ["x"] ctxt
       
   164 
       
   165             val lhs = t_dummy $ Abs("_", dummyT, (sel_dummy $ Free(x, dummyT))) $ Free (x, dummyT)
       
   166             val rhs = Free (x, dummyT)
       
   167             val prop = Syntax.check_term ctxt (HOLogic.mk_Trueprop (mk_eq_dummy (lhs, rhs)))
       
   168           in
       
   169             [Goal.prove_future ctxt [] [] prop tac]
       
   170             |> Variable.export ctxt' ctxt
       
   171           end
       
   172       in
       
   173         sel_upd_same_thm @ sel_upd_diff_thms @ upd_comp_thm @ upd_comm_thms @ upd_sel_thm
       
   174       end
       
   175 
    62     fun define name t =
   176     fun define name t =
    63       Local_Theory.define ((name, NoSyn), ((Binding.empty, @{attributes [datatype_record_update, code]}), t)) #> snd
   177       Local_Theory.define ((name, NoSyn), ((Binding.empty, @{attributes [datatype_record_update, code]}),t))
    64 
   178       #> apfst (apsnd snd)
    65     val lthy' =
   179 
    66       Local_Theory.map_background_naming (Name_Space.qualified_path false (Binding.name short_name)) lthy
   180     val (updates, (lthy'', lthy')) =
       
   181       lthy
       
   182       |> Local_Theory.open_target
       
   183       |> snd
       
   184       |> Local_Theory.map_background_naming (Name_Space.qualified_path false (Binding.name short_name))
       
   185       |> @{fold_map 2} define (map mk_name sels) (Syntax.check_terms lthy (map mk_t (0 upto len - 1)))
       
   186       ||> `Local_Theory.close_target
       
   187 
       
   188     val phi = Proof_Context.export_morphism lthy' lthy''
       
   189 
       
   190     val (update_ts, update_defs) =
       
   191       split_list updates
       
   192       |>> map (Morphism.term phi)
       
   193       ||> map (Morphism.thm phi)
       
   194 
       
   195     val thms = flat (map (prove lthy'' update_defs update_ts) (0 upto len-1))
    67 
   196 
    68     fun insert sel =
   197     fun insert sel =
    69       Symtab.insert op = (fst (dest_Const sel), Local_Theory.full_name lthy' (mk_name sel))
   198       Symtab.insert op = (fst (dest_Const sel), Local_Theory.full_name lthy' (mk_name sel))
    70   in
   199   in
    71     lthy'
   200     lthy''
    72     |> @{fold 2} define (map mk_name sels) (Syntax.check_terms lthy (map mk_t (0 upto len - 1)))
   201     |> Local_Theory.map_background_naming (Name_Space.mandatory_path short_name)
       
   202     |> Local_Theory.note (thms_binding, thms)
       
   203     |> snd
       
   204     |> Local_Theory.restore_background_naming lthy
    73     |> Local_Theory.background_theory (Data.map (fold insert sels))
   205     |> Local_Theory.background_theory (Data.map (fold insert sels))
    74     |> Local_Theory.restore_background_naming lthy
       
    75   end
   206   end
    76 
   207 
    77 fun bnf_record binding opts tyargs args lthy =
   208 fun record binding opts tyargs args lthy =
    78   let
   209   let
    79     val constructor =
   210     val constructor =
    80       (((Binding.empty, Binding.map_name (fn c => "make_" ^ c) binding), args), NoSyn)
   211       (((Binding.empty, Binding.map_name (fn c => "make_" ^ c) binding), args), NoSyn)
    81 
   212 
    82     val datatyp =
   213     val datatyp =
    91       |> mk_update_defs (Local_Theory.full_name lthy binding)
   222       |> mk_update_defs (Local_Theory.full_name lthy binding)
    92   in
   223   in
    93     lthy'
   224     lthy'
    94   end
   225   end
    95 
   226 
    96 fun bnf_record_cmd binding opts tyargs args lthy =
   227 fun record_cmd binding opts tyargs args lthy =
    97   bnf_record binding (opts lthy)
   228   record binding (opts lthy)
    98     (map (apsnd (apfst (Syntax.parse_typ lthy) o apsnd (Typedecl.read_constraint lthy))) tyargs)
   229     (map (apsnd (apfst (Syntax.parse_typ lthy) o apsnd (Typedecl.read_constraint lthy))) tyargs)
    99     (map (apsnd (Syntax.parse_typ lthy)) args) lthy
   230     (map (apsnd (Syntax.parse_typ lthy)) args) lthy
   100 
   231 
   101 (* syntax *)
   232 (* syntax *)
   102 (* copied and adapted from record.ML *)
   233 (* copied and adapted from record.ML *)
   170 val _ =
   301 val _ =
   171   Outer_Syntax.local_theory
   302   Outer_Syntax.local_theory
   172     @{command_keyword datatype_record}
   303     @{command_keyword datatype_record}
   173     "Defines a record based on the BNF/datatype machinery"
   304     "Defines a record based on the BNF/datatype machinery"
   174     (parser >> (fn (((ctr_options, tyargs), binding), args) =>
   305     (parser >> (fn (((ctr_options, tyargs), binding), args) =>
   175       bnf_record_cmd binding ctr_options tyargs args))
   306       record_cmd binding ctr_options tyargs args))
   176 
   307 
   177 val setup =
   308 val setup =
   178    (Sign.parse_translation
   309    (Sign.parse_translation
   179      [(\<^syntax_const>\<open>_datatype_record_update\<close>, record_update_tr),
   310      [(\<^syntax_const>\<open>_datatype_record_update\<close>, record_update_tr),
   180       (\<^syntax_const>\<open>_datatype_record\<close>, record_tr)]);
   311       (\<^syntax_const>\<open>_datatype_record\<close>, record_tr)]);