allowed less exhaustive patterns
authornoschinl
Fri, 06 Sep 2013 10:56:40 +0200
changeset 53429 9d9945941eab
parent 53428 3083c611ec40
child 53430 d92578436d47
allowed less exhaustive patterns
src/HOL/Library/simps_case_conv.ML
--- a/src/HOL/Library/simps_case_conv.ML	Fri Sep 06 10:56:40 2013 +0200
+++ b/src/HOL/Library/simps_case_conv.ML	Fri Sep 06 10:56:40 2013 +0200
@@ -32,12 +32,59 @@
 
 local
 
-(*Creates free variables for a list of types*)
-fun mk_Frees Ts ctxt =
+  fun transpose [] = []
+    | transpose ([] :: xss) = transpose xss
+    | transpose xss = map hd xss :: transpose (map tl xss);
+
+  fun same_fun (ts as _ $ _ :: _) =
+      let
+        val (fs, argss) = map strip_comb ts |> split_list
+        val f = hd fs
+      in if forall (fn x => f = x) fs then SOME (f, argss) else NONE end
+    | same_fun _ = NONE
+
+  (* pats must be non-empty *)
+  fun split_pat pats ctxt =
+      case same_fun pats of
+        NONE =>
+          let
+            val (name, ctxt') = yield_singleton Variable.variant_fixes "x" ctxt
+            val var = Free (name, fastype_of (hd pats))
+          in (((var, [var]), map single pats), ctxt') end
+      | SOME (f, argss) =>
+          let
+            val (((def_pats, def_frees), case_patss), ctxt') =
+              split_pats argss ctxt
+            val def_pat = list_comb (f, def_pats)
+          in (((def_pat, flat def_frees), case_patss), ctxt') end
+  and
+      split_pats patss ctxt =
+        let
+          val (splitted, ctxt') = fold_map split_pat (transpose patss) ctxt
+          val r = splitted |> split_list |> apfst split_list |> apsnd (transpose #> map flat)
+        in (r, ctxt') end
+
+(*
+  Takes a list lhss of left hand sides (which are lists of patterns)
+  and a list rhss of right hand sides. Returns
+    - a single equation with a (nested) case-expression on the rhs
+    - a list of all split-thms needed to split the rhs
+  Patterns which have the same outer context in all lhss remain
+  on the lhs of the computed equation.
+*)
+fun build_case_t fun_t lhss rhss ctxt =
   let
-    val (names,ctxt') = Variable.variant_fixes (replicate (length Ts) "x") ctxt
-    val ts = map Free (names ~~ Ts)
-  in (ts, ctxt') end
+    val (((def_pats, def_frees), case_patss), ctxt') =
+      split_pats lhss ctxt
+    val pattern = map HOLogic.mk_tuple case_patss
+    val case_arg = HOLogic.mk_tuple (flat def_frees)
+    val cases = Case_Translation.make_case ctxt' Case_Translation.Warning Name.context
+      case_arg (pattern ~~ rhss)
+    val split_thms = get_split_ths (Proof_Context.theory_of ctxt') (fastype_of case_arg)
+    val t = (list_comb (fun_t, def_pats), cases)
+      |> HOLogic.mk_eq
+      |> HOLogic.mk_Trueprop
+  in ((t, split_thms), ctxt') end
 
 fun tac ctxt {splits, intros, defs} =
   let val ctxt' = Classical.addSIs (ctxt, intros) in
@@ -67,16 +114,16 @@
     f p_mn ... p_mn = tm
   of theorems, prove a single theorem
     f x1 ... xn = t
-  where t is a (nested) case expression. The terms p_11, ..., p_mn must
-  be exhaustive, non-overlapping datatype patterns. f must not be a function
-  application.
+  where t is a (nested) case expression. f must not be a function
+  application. Moreover, the terms p_11, ..., p_mn must be non-overlapping
+  datatype patterns. The patterns must be exhausting up to common constructor
+  contexts.
 *)
 fun to_case ctxt ths =
   let
     val (iths, ctxt') = import ths ctxt
-    val (fun_t, arg_ts) = hd iths |> strip_eq |> fst |> strip_comb
+    val fun_t = hd iths |> strip_eq |> fst |> head_of
     val eqs = map (strip_eq #> apfst (snd o strip_comb)) iths
-    val (arg_Frees, ctxt'') = mk_Frees (map fastype_of arg_ts) ctxt'
 
     fun hide_rhs ((pat, rhs), name) lthy = let
         val frees = fold Term.add_frees pat []
@@ -85,23 +132,13 @@
           ((Binding.name name, Mixfix.NoSyn), abs_rhs) lthy
       in ((list_comb (f, map Free (rev frees)), def), lthy') end
 
-    val ((def_ts, def_thms), ctxt3) = let
-        val nctxt = Variable.names_of ctxt''
+    val ((def_ts, def_thms), ctxt2) = let
+        val nctxt = Variable.names_of ctxt'
         val names = Name.invent nctxt "rhs" (length eqs)
-      in fold_map hide_rhs (eqs ~~ names) ctxt'' |> apfst split_list end
+      in fold_map hide_rhs (eqs ~~ names) ctxt' |> apfst split_list end
 
-    val (cases, split_thms) =
-      let
-        val pattern = map (fst #> HOLogic.mk_tuple) eqs
-        val case_arg = HOLogic.mk_tuple arg_Frees
-        val cases = Case_Translation.make_case ctxt Case_Translation.Warning Name.context
-          case_arg (pattern ~~ def_ts)
-        val split_thms = get_split_ths (Proof_Context.theory_of ctxt3) (fastype_of case_arg)
-      in (cases, split_thms) end
+    val ((t, split_thms), ctxt3) = build_case_t fun_t (map fst eqs) def_ts ctxt2
 
-    val t = (list_comb (fun_t, arg_Frees), cases)
-      |> HOLogic.mk_eq
-      |> HOLogic.mk_Trueprop
     val th = Goal.prove ctxt3 [] [] t (fn {context=ctxt, ...} =>
           tac ctxt {splits=split_thms, intros=ths, defs=def_thms})
   in th