datatype_record produces simp theorems; contributed in part by Yu Zhang
authorLars Hupel <lars.hupel@mytum.de>
Sat, 28 Jul 2018 07:28:18 +0200
changeset 68686 7f8db1c4ebec
parent 68685 4b367da119ed
child 68688 3a58abb11840
datatype_record produces simp theorems; contributed in part by Yu Zhang
src/HOL/Library/datatype_records.ML
src/HOL/ex/Datatype_Record_Examples.thy
--- a/src/HOL/Library/datatype_records.ML	Wed Jul 25 22:33:04 2018 +0200
+++ b/src/HOL/Library/datatype_records.ML	Sat Jul 28 07:28:18 2018 +0200
@@ -7,10 +7,10 @@
 
   val mk_update_defs: string -> local_theory -> local_theory
 
-  val bnf_record: binding -> ctr_options -> (binding option * (typ * sort)) list ->
+  val record: binding -> ctr_options -> (binding option * (typ * sort)) list ->
     (binding * typ) list -> local_theory -> local_theory
 
-  val bnf_record_cmd: binding -> ctr_options_cmd ->
+  val record_cmd: binding -> ctr_options_cmd ->
     (binding option * (string * string option)) list -> (binding * string) list -> local_theory ->
     local_theory
 
@@ -35,21 +35,31 @@
   val extend = I
 )
 
+fun mk_eq_dummy (lhs, rhs) =
+  Const (@{const_name HOL.eq}, dummyT --> dummyT --> @{typ bool}) $ lhs $ rhs
+
+val dummify = map_types (K dummyT)
+fun repeat_split_tac ctxt thm = REPEAT_ALL_NEW (CHANGED o Splitter.split_tac ctxt [thm])
+
 fun mk_update_defs typ_name lthy =
   let
     val short_name = Long_Name.base_name typ_name
+    val {ctrs, casex, selss, split, sel_thmss, injects, ...} =
+      the (Ctr_Sugar.ctr_sugar_of lthy typ_name)
+    val ctr = case ctrs of [ctr] => ctr | _ => error "Datatype_Records.mk_update_defs: expected only single constructor"
+    val sels = case selss of [sels] => sels | _ => error "Datatype_Records.mk_update_defs: expected selectors"
+    val sels_dummy = map dummify sels
+    val ctr_dummy = dummify ctr
+    val casex_dummy = dummify casex
+    val len = length sels
 
-    val {ctrs, casex, selss, ...} = the (Ctr_Sugar.ctr_sugar_of lthy typ_name)
-    val ctr = case ctrs of [ctr] => ctr | _ => error "BNF_Record.mk_update_defs: expected only single constructor"
-    val sels = case selss of [sels] => sels | _ => error "BNF_Record.mk_update_defs: expected selectors"
-    val ctr_dummy = Const (fst (dest_Const ctr), dummyT)
-    val casex_dummy = Const (fst (dest_Const casex), dummyT)
-
-    val len = length sels
+    val simp_thms = flat sel_thmss @ injects
 
     fun mk_name sel =
       Binding.name ("update_" ^ Long_Name.base_name (fst (dest_Const sel)))
 
+    val thms_binding = (@{binding record_simps}, @{attributes [simp]})
+
     fun mk_t idx =
       let
         val body =
@@ -59,22 +69,143 @@
         Abs ("f", dummyT, casex_dummy $ body)
       end
 
-    fun define name t =
-      Local_Theory.define ((name, NoSyn), ((Binding.empty, @{attributes [datatype_record_update, code]}), t)) #> snd
+    fun simp_only_tac ctxt =
+      REPEAT_ALL_NEW (resolve_tac ctxt @{thms impI allI}) THEN'
+        asm_full_simp_tac (put_simpset HOL_basic_ss ctxt addsimps simp_thms)
+
+    fun prove ctxt defs ts n =
+      let
+        val t = nth ts n
+
+        val sel_dummy = nth sels_dummy n
+        val t_dummy = dummify t
+        fun tac {context = ctxt, ...} =
+          Goal.conjunction_tac 1 THEN
+            Local_Defs.unfold_tac ctxt defs THEN
+            PARALLEL_ALLGOALS (repeat_split_tac ctxt split THEN' simp_only_tac ctxt)
+
+        val sel_upd_same_thm =
+          let
+            val ([f, x], ctxt') = Variable.add_fixes ["f", "x"] ctxt
+            val f = Free (f, dummyT)
+            val x = Free (x, dummyT)
+
+            val lhs = sel_dummy $ (t_dummy $ f $ x)
+            val rhs = f $ (sel_dummy $ x)
+            val prop = Syntax.check_term ctxt' (HOLogic.mk_Trueprop (mk_eq_dummy (lhs, rhs)))
+          in
+            [Goal.prove_future ctxt' [] [] prop tac]
+            |> Variable.export ctxt' ctxt
+          end
+
+        val sel_upd_diff_thms =
+          let
+            val ([f, x], ctxt') = Variable.add_fixes ["f", "x"] ctxt
+            val f = Free (f, dummyT)
+            val x = Free (x, dummyT)
+
+            fun lhs sel = sel $ (t_dummy $ f $ x)
+            fun rhs sel = sel $ x
+            fun eq sel = (lhs sel, rhs sel)
+            fun is_n i = i = n
+            val props =
+              sels_dummy ~~ (0 upto len - 1)
+              |> filter_out (is_n o snd)
+              |> map (HOLogic.mk_Trueprop o mk_eq_dummy o eq o fst)
+              |> Syntax.check_terms ctxt'
+          in
+            if length props > 0 then
+              Goal.prove_common ctxt' (SOME ~1) [] [] props tac
+              |> Variable.export ctxt' ctxt
+            else
+              []
+          end
+
+        val upd_comp_thm =
+          let
+            val ([f, g, x], ctxt') = Variable.add_fixes ["f", "g", "x"] ctxt
+            val f = Free (f, dummyT)
+            val g = Free (g, dummyT)
+            val x = Free (x, dummyT)
 
-    val lthy' =
-      Local_Theory.map_background_naming (Name_Space.qualified_path false (Binding.name short_name)) lthy
+            val lhs = t_dummy $ f $ (t_dummy $ g $ x)
+            val rhs = t_dummy $ Abs ("a", dummyT, f $ (g $ Bound 0)) $ x
+            val prop = Syntax.check_term ctxt' (HOLogic.mk_Trueprop (mk_eq_dummy (lhs, rhs)))
+          in
+            [Goal.prove_future ctxt' [] [] prop tac]
+            |> Variable.export ctxt' ctxt
+          end
+
+        val upd_comm_thms =
+          let
+            fun prop i ctxt =
+              let
+                val ([f, g, x], ctxt') = Variable.variant_fixes ["f", "g", "x"] ctxt
+                val self = t_dummy $ Free (f, dummyT)
+                val other = dummify (nth ts i) $ Free (g, dummyT)
+                val lhs = other $ (self $ Free (x, dummyT))
+                val rhs = self $ (other $ Free (x, dummyT))
+              in
+                (HOLogic.mk_Trueprop (mk_eq_dummy (lhs, rhs)), ctxt')
+              end
+            val (props, ctxt') = fold_map prop (0 upto n - 1) ctxt
+            val props = Syntax.check_terms ctxt' props
+          in
+            if length props > 0 then
+              Goal.prove_common ctxt' (SOME ~1) [] [] props tac
+              |> Variable.export ctxt' ctxt
+            else
+              []
+          end
+
+        val upd_sel_thm =
+          let
+            val ([x], ctxt') = Variable.add_fixes ["x"] ctxt
+
+            val lhs = t_dummy $ Abs("_", dummyT, (sel_dummy $ Free(x, dummyT))) $ Free (x, dummyT)
+            val rhs = Free (x, dummyT)
+            val prop = Syntax.check_term ctxt (HOLogic.mk_Trueprop (mk_eq_dummy (lhs, rhs)))
+          in
+            [Goal.prove_future ctxt [] [] prop tac]
+            |> Variable.export ctxt' ctxt
+          end
+      in
+        sel_upd_same_thm @ sel_upd_diff_thms @ upd_comp_thm @ upd_comm_thms @ upd_sel_thm
+      end
+
+    fun define name t =
+      Local_Theory.define ((name, NoSyn), ((Binding.empty, @{attributes [datatype_record_update, code]}),t))
+      #> apfst (apsnd snd)
+
+    val (updates, (lthy'', lthy')) =
+      lthy
+      |> Local_Theory.open_target
+      |> snd
+      |> Local_Theory.map_background_naming (Name_Space.qualified_path false (Binding.name short_name))
+      |> @{fold_map 2} define (map mk_name sels) (Syntax.check_terms lthy (map mk_t (0 upto len - 1)))
+      ||> `Local_Theory.close_target
+
+    val phi = Proof_Context.export_morphism lthy' lthy''
+
+    val (update_ts, update_defs) =
+      split_list updates
+      |>> map (Morphism.term phi)
+      ||> map (Morphism.thm phi)
+
+    val thms = flat (map (prove lthy'' update_defs update_ts) (0 upto len-1))
 
     fun insert sel =
       Symtab.insert op = (fst (dest_Const sel), Local_Theory.full_name lthy' (mk_name sel))
   in
-    lthy'
-    |> @{fold 2} define (map mk_name sels) (Syntax.check_terms lthy (map mk_t (0 upto len - 1)))
+    lthy''
+    |> Local_Theory.map_background_naming (Name_Space.mandatory_path short_name)
+    |> Local_Theory.note (thms_binding, thms)
+    |> snd
+    |> Local_Theory.restore_background_naming lthy
     |> Local_Theory.background_theory (Data.map (fold insert sels))
-    |> Local_Theory.restore_background_naming lthy
   end
 
-fun bnf_record binding opts tyargs args lthy =
+fun record binding opts tyargs args lthy =
   let
     val constructor =
       (((Binding.empty, Binding.map_name (fn c => "make_" ^ c) binding), args), NoSyn)
@@ -93,8 +224,8 @@
     lthy'
   end
 
-fun bnf_record_cmd binding opts tyargs args lthy =
-  bnf_record binding (opts lthy)
+fun record_cmd binding opts tyargs args lthy =
+  record binding (opts lthy)
     (map (apsnd (apfst (Syntax.parse_typ lthy) o apsnd (Typedecl.read_constraint lthy))) tyargs)
     (map (apsnd (Syntax.parse_typ lthy)) args) lthy
 
@@ -172,7 +303,7 @@
     @{command_keyword datatype_record}
     "Defines a record based on the BNF/datatype machinery"
     (parser >> (fn (((ctr_options, tyargs), binding), args) =>
-      bnf_record_cmd binding ctr_options tyargs args))
+      record_cmd binding ctr_options tyargs args))
 
 val setup =
    (Sign.parse_translation
--- a/src/HOL/ex/Datatype_Record_Examples.thy	Wed Jul 25 22:33:04 2018 +0200
+++ b/src/HOL/ex/Datatype_Record_Examples.thy	Sat Jul 28 07:28:18 2018 +0200
@@ -45,4 +45,23 @@
 lemma "b_set \<lparr> field_1 = True, field_2 = False \<rparr> = {False}"
   by simp
 
+text \<open>More tests\<close>
+
+datatype_record ('a, 'b) test1 =
+  field_t11 :: 'a
+  field_t12 :: 'b
+  field_t13 :: nat
+  field_t14 :: int
+
+thm test1.record_simps
+
+definition ID where "ID x = x"
+lemma ID_cong[cong]: "ID x = ID x" by (rule refl)
+
+lemma "update_field_t11 f (update_field_t12 g (update_field_t11 h x)) = ID (update_field_t12 g (update_field_t11 (\<lambda>x. f (h x)) x))"
+  apply (simp only: test1.record_simps)
+  apply (subst ID_def)
+  apply (rule refl)
+  done
+
 end
\ No newline at end of file