--- a/src/HOL/Library/simps_case_conv.ML Fri Sep 06 10:56:40 2013 +0200
+++ b/src/HOL/Library/simps_case_conv.ML Fri Sep 06 10:56:40 2013 +0200
@@ -32,12 +32,59 @@
local
-(*Creates free variables for a list of types*)
-fun mk_Frees Ts ctxt =
+ fun transpose [] = []
+ | transpose ([] :: xss) = transpose xss
+ | transpose xss = map hd xss :: transpose (map tl xss);
+
+ fun same_fun (ts as _ $ _ :: _) =
+ let
+ val (fs, argss) = map strip_comb ts |> split_list
+ val f = hd fs
+ in if forall (fn x => f = x) fs then SOME (f, argss) else NONE end
+ | same_fun _ = NONE
+
+ (* pats must be non-empty *)
+ fun split_pat pats ctxt =
+ case same_fun pats of
+ NONE =>
+ let
+ val (name, ctxt') = yield_singleton Variable.variant_fixes "x" ctxt
+ val var = Free (name, fastype_of (hd pats))
+ in (((var, [var]), map single pats), ctxt') end
+ | SOME (f, argss) =>
+ let
+ val (((def_pats, def_frees), case_patss), ctxt') =
+ split_pats argss ctxt
+ val def_pat = list_comb (f, def_pats)
+ in (((def_pat, flat def_frees), case_patss), ctxt') end
+ and
+ split_pats patss ctxt =
+ let
+ val (splitted, ctxt') = fold_map split_pat (transpose patss) ctxt
+ val r = splitted |> split_list |> apfst split_list |> apsnd (transpose #> map flat)
+ in (r, ctxt') end
+
+(*
+ Takes a list lhss of left hand sides (which are lists of patterns)
+ and a list rhss of right hand sides. Returns
+ - a single equation with a (nested) case-expression on the rhs
+ - a list of all split-thms needed to split the rhs
+ Patterns which have the same outer context in all lhss remain
+ on the lhs of the computed equation.
+*)
+fun build_case_t fun_t lhss rhss ctxt =
let
- val (names,ctxt') = Variable.variant_fixes (replicate (length Ts) "x") ctxt
- val ts = map Free (names ~~ Ts)
- in (ts, ctxt') end
+ val (((def_pats, def_frees), case_patss), ctxt') =
+ split_pats lhss ctxt
+ val pattern = map HOLogic.mk_tuple case_patss
+ val case_arg = HOLogic.mk_tuple (flat def_frees)
+ val cases = Case_Translation.make_case ctxt' Case_Translation.Warning Name.context
+ case_arg (pattern ~~ rhss)
+ val split_thms = get_split_ths (Proof_Context.theory_of ctxt') (fastype_of case_arg)
+ val t = (list_comb (fun_t, def_pats), cases)
+ |> HOLogic.mk_eq
+ |> HOLogic.mk_Trueprop
+ in ((t, split_thms), ctxt') end
fun tac ctxt {splits, intros, defs} =
let val ctxt' = Classical.addSIs (ctxt, intros) in
@@ -67,16 +114,16 @@
f p_mn ... p_mn = tm
of theorems, prove a single theorem
f x1 ... xn = t
- where t is a (nested) case expression. The terms p_11, ..., p_mn must
- be exhaustive, non-overlapping datatype patterns. f must not be a function
- application.
+ where t is a (nested) case expression. f must not be a function
+ application. Moreover, the terms p_11, ..., p_mn must be non-overlapping
+ datatype patterns. The patterns must be exhausting up to common constructor
+ contexts.
*)
fun to_case ctxt ths =
let
val (iths, ctxt') = import ths ctxt
- val (fun_t, arg_ts) = hd iths |> strip_eq |> fst |> strip_comb
+ val fun_t = hd iths |> strip_eq |> fst |> head_of
val eqs = map (strip_eq #> apfst (snd o strip_comb)) iths
- val (arg_Frees, ctxt'') = mk_Frees (map fastype_of arg_ts) ctxt'
fun hide_rhs ((pat, rhs), name) lthy = let
val frees = fold Term.add_frees pat []
@@ -85,23 +132,13 @@
((Binding.name name, Mixfix.NoSyn), abs_rhs) lthy
in ((list_comb (f, map Free (rev frees)), def), lthy') end
- val ((def_ts, def_thms), ctxt3) = let
- val nctxt = Variable.names_of ctxt''
+ val ((def_ts, def_thms), ctxt2) = let
+ val nctxt = Variable.names_of ctxt'
val names = Name.invent nctxt "rhs" (length eqs)
- in fold_map hide_rhs (eqs ~~ names) ctxt'' |> apfst split_list end
+ in fold_map hide_rhs (eqs ~~ names) ctxt' |> apfst split_list end
- val (cases, split_thms) =
- let
- val pattern = map (fst #> HOLogic.mk_tuple) eqs
- val case_arg = HOLogic.mk_tuple arg_Frees
- val cases = Case_Translation.make_case ctxt Case_Translation.Warning Name.context
- case_arg (pattern ~~ def_ts)
- val split_thms = get_split_ths (Proof_Context.theory_of ctxt3) (fastype_of case_arg)
- in (cases, split_thms) end
+ val ((t, split_thms), ctxt3) = build_case_t fun_t (map fst eqs) def_ts ctxt2
- val t = (list_comb (fun_t, arg_Frees), cases)
- |> HOLogic.mk_eq
- |> HOLogic.mk_Trueprop
val th = Goal.prove ctxt3 [] [] t (fn {context=ctxt, ...} =>
tac ctxt {splits=split_thms, intros=ths, defs=def_thms})
in th