diff -r 077a2758ceb4 -r 59ef06cda7b9 src/HOL/Tools/Function/mutual.ML --- a/src/HOL/Tools/Function/mutual.ML Sun Sep 08 19:25:06 2013 +0200 +++ b/src/HOL/Tools/Function/mutual.ML Sun Sep 08 22:32:47 2013 +0200 @@ -252,7 +252,7 @@ let val result = inner_cont proof val FunctionResult {G, R, cases, psimps, simple_pinducts=[simple_pinduct], - termination, domintros, dom, ...} = result + termination, domintros, dom, pelims, ...} = result val (all_f_defs, fs) = map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} => @@ -271,13 +271,82 @@ val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m val mtermination = full_simplify rew_simpset termination val mdomintros = Option.map (map (full_simplify rew_simpset)) domintros + in FunctionResult { fs=fs, G=G, R=R, dom=dom, psimps=mpsimps, simple_pinducts=minducts, - cases=cases, termination=mtermination, + cases=cases, pelims=pelims, termination=mtermination, domintros=mdomintros} end + +fun postprocess_cases_rules ctxt cont proof = + let val result = cont proof; + val FunctionResult {fs, G, R, dom, psimps, simple_pinducts, cases, pelims, + termination, domintros, ...} = result; + val n_fs = length fs; + + fun postprocess_cases_rule (idx,f) = + let fun dest_funprop (Const ("HOL.eq", _) $ lhs $ rhs) = (strip_comb lhs, rhs) + | dest_funprop (Const ("HOL.Not", _) $ trm) = (strip_comb trm, @{term "False"}) + | dest_funprop trm = (strip_comb trm, @{term "True"}); + + fun mk_fun_args 0 _ acc_vars = rev acc_vars + | mk_fun_args n (Type("fun",[S,T])) acc_vars = + let val xn = Free ("x" ^ Int.toString n,S) in + mk_fun_args (n - 1) T (xn :: acc_vars) + end + | mk_fun_args _ _ _ = raise (TERM ("Not a function.", [f])) + + + val f_simps = filter (fn r => (prop_of r |> Logic.strip_assums_concl + |> HOLogic.dest_Trueprop + |> dest_funprop |> fst |> fst) = f) + psimps + + val arity = hd f_simps |> prop_of |> Logic.strip_assums_concl + |> HOLogic.dest_Trueprop + |> snd o fst o dest_funprop |> length; + val arg_vars = mk_fun_args arity (fastype_of f) [] + val argsT = fastype_of (HOLogic.mk_tuple arg_vars); + val args = Free ("x", argsT); + + val thy = Proof_Context.theory_of ctxt; + val domT = R |> dest_Free |> snd |> hd o snd o dest_Type + + val sumtree_inj = SumTree.mk_inj domT n_fs (idx+1) args; + + val sum_elims = @{thms HOL.notE[OF Sum_Type.sum.distinct(1)] + HOL.notE[OF Sum_Type.sum.distinct(2)]}; + fun prep_subgoal i = + REPEAT (eresolve_tac @{thms Pair_inject Inl_inject[elim_format] + Inr_inject[elim_format]} i) +(* THEN propagate_tac i*) +(* THEN bool_subst_tac ctxt i*) + THEN REPEAT (Tactic.eresolve_tac sum_elims i); + + val tac = ALLGOALS prep_subgoal; + + in + hd cases + |> Thm.forall_elim @{cterm "P::bool"} + |> Thm.forall_elim (cterm_of thy sumtree_inj) + |> Tactic.rule_by_tactic ctxt tac + |> Thm.forall_intr (cterm_of thy args) + |> Thm.forall_intr @{cterm "P::bool"} + + end; + + val cases' = map_index postprocess_cases_rule fs; + +in + FunctionResult {fs=fs, G=G, R=R, dom=dom, psimps=psimps, + simple_pinducts=simple_pinducts, + cases=cases', pelims=pelims, termination=termination, + domintros=domintros} +end; + + fun prepare_function_mutual config defname fixes eqss lthy = let val mutual as Mutual {fsum_var=(n, T), qglrs, ...} = @@ -288,9 +357,10 @@ val (mutual', lthy'') = define_projections fixes mutual fsum lthy' - val mutual_cont = mk_partial_rules_mutual lthy'' cont mutual' + val cont' = mk_partial_rules_mutual lthy'' cont mutual' + val cont'' = postprocess_cases_rules lthy'' cont' in - ((goalstate, mutual_cont), lthy'') + ((goalstate, cont''), lthy'') end end