src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML
changeset 39802 7cadad6a18cc
parent 39787 a44f6b11cdc4
child 40048 f3a46d524101
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Thu Sep 30 11:52:22 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Thu Sep 30 15:37:11 2010 +0200
@@ -147,6 +147,7 @@
   (* simple transformations *)
   val split_conjuncts_in_assms : Proof.context -> thm -> thm
   val expand_tuples : theory -> thm -> thm
+  val case_betapply : theory -> term -> term
   val eta_contract_ho_arguments : theory -> thm -> thm
   val remove_equalities : theory -> thm -> thm
   val remove_pointless_clauses : thm -> thm list
@@ -859,6 +860,85 @@
     intro'''''
   end
 
+(** making case distributivity rules **)
+(*** this should be part of the datatype package ***)
+
+fun datatype_names_of_case_name thy case_name =
+  map (#1 o #2) (#descr (the (Datatype_Data.info_of_case thy case_name)))
+
+fun make_case_distribs new_type_names descr sorts thy =
+  let
+    val case_combs = Datatype_Prop.make_case_combs new_type_names descr sorts thy "f";
+    fun make comb =
+      let
+        val Type ("fun", [T, T']) = fastype_of comb;
+        val (Const (case_name, _), fs) = strip_comb comb
+        val used = Term.add_tfree_names comb []
+        val U = TFree (Name.variant used "'t", HOLogic.typeS)
+        val x = Free ("x", T)
+        val f = Free ("f", T' --> U)
+        fun apply_f f' =
+          let
+            val Ts = binder_types (fastype_of f')
+            val bs = map Bound ((length Ts - 1) downto 0)
+          in
+            fold (curry absdummy) (rev Ts) (f $ (list_comb (f', bs)))
+          end
+        val fs' = map apply_f fs
+        val case_c' = Const (case_name, (map fastype_of fs') @ [T] ---> U)
+      in
+        HOLogic.mk_eq (f $ (comb $ x), list_comb (case_c', fs') $ x)
+      end
+  in
+    map make case_combs
+  end
+
+fun case_rewrites thy Tcon =
+  let
+    val info = Datatype.the_info thy Tcon
+    val descr = #descr info
+    val sorts = #sorts info
+    val typ_names = the_default [Tcon] (#alt_names info)
+  in
+    map (Drule.export_without_context o Skip_Proof.make_thm thy o HOLogic.mk_Trueprop)
+      (make_case_distribs typ_names [descr] sorts thy)
+  end
+
+fun instantiated_case_rewrites thy Tcon =
+  let
+    val rew_ths = case_rewrites thy Tcon
+    val ctxt = ProofContext.init_global thy
+    fun instantiate th =
+    let
+      val f = (fst (strip_comb (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop (prop_of th))))))
+      val Type ("fun", [uninst_T, uninst_T']) = fastype_of f
+      val ([tname, tname', uname, yname], ctxt') = Variable.add_fixes ["'t", "'t'", "'u", "y"] ctxt
+      val T = TFree (tname, HOLogic.typeS)
+      val T' = TFree (tname', HOLogic.typeS)
+      val U = TFree (uname, HOLogic.typeS)
+      val y = Free (yname, U)
+      val f' = absdummy (U --> T', Bound 0 $ y)
+      val th' = Thm.certify_instantiate
+        ([(dest_TVar uninst_T, U --> T'), (dest_TVar uninst_T', T')],
+         [((fst (dest_Var f), (U --> T') --> T'), f')]) th
+      val [th'] = Variable.export ctxt' ctxt [th']
+   in
+     th'
+   end
+ in
+   map instantiate rew_ths
+ end
+
+fun case_betapply thy t =
+  let
+    val case_name = fst (dest_Const (fst (strip_comb t)))
+    val Tcons = datatype_names_of_case_name thy case_name
+    val ths = maps (instantiated_case_rewrites thy) Tcons
+  in
+    MetaSimplifier.rewrite_term thy
+      (map (fn th => th RS @{thm eq_reflection}) ths) [] t
+  end
+
 (*** conversions ***)
 
 fun imp_prems_conv cv ct =