generate elim rules for elimination of function equalities;
authorManuel Eberl
Sun, 08 Sep 2013 22:32:47 +0200
changeset 53603 59ef06cda7b9
parent 53474 077a2758ceb4
child 53604 c1db98d7c66f
generate elim rules for elimination of function equalities; added fun_cases command; recover proper cases rules for mutual recursive case (no sum types)
src/HOL/FunDef.thy
src/HOL/Tools/Function/fun_cases.ML
src/HOL/Tools/Function/function.ML
src/HOL/Tools/Function/function_common.ML
src/HOL/Tools/Function/function_core.ML
src/HOL/Tools/Function/function_elims.ML
src/HOL/Tools/Function/mutual.ML
--- a/src/HOL/FunDef.thy	Sun Sep 08 19:25:06 2013 +0200
+++ b/src/HOL/FunDef.thy	Sun Sep 08 22:32:47 2013 +0200
@@ -6,7 +6,7 @@
 
 theory FunDef
 imports Partial_Function SAT Wellfounded
-keywords "function" "termination" :: thy_goal and "fun" :: thy_decl
+keywords "function" "termination" :: thy_goal and "fun" "fun_cases" :: thy_decl
 begin
 
 subsection {* Definitions with default value. *}
@@ -89,6 +89,7 @@
 ML_file "Tools/Function/mutual.ML"
 ML_file "Tools/Function/pattern_split.ML"
 ML_file "Tools/Function/relation.ML"
+ML_file "Tools/Function/function_elims.ML"
 
 method_setup relation = {*
   Args.term >> (fn t => fn ctxt => SIMPLE_METHOD' (Function_Relation.relation_infer_tac ctxt t))
@@ -307,6 +308,7 @@
 ML_file "Tools/Function/termination.ML"
 ML_file "Tools/Function/scnp_solve.ML"
 ML_file "Tools/Function/scnp_reconstruct.ML"
+ML_file "Tools/Function/fun_cases.ML"
 
 setup {* ScnpReconstruct.setup *}
 
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Tools/Function/fun_cases.ML	Sun Sep 08 22:32:47 2013 +0200
@@ -0,0 +1,90 @@
+(*  Title:      HOL/Tools/Function/fun_cases.ML
+    Author:     Manuel Eberl <eberlm@in.tum.de>, TU München
+
+Provides the fun_cases command for generating specialised elimination
+rules for function package functions.
+*)
+
+signature FUN_CASES =
+sig
+  val mk_fun_cases : local_theory -> term -> thm
+end;
+
+
+structure Fun_Cases : FUN_CASES =
+struct
+
+local
+  open Function_Elims;
+
+  val refl_thin = Goal.prove_global @{theory HOL} [] [] @{prop "!!P. a = a ==> P ==> P"}
+    (fn _ => assume_tac 1);
+  val elim_rls = [asm_rl, FalseE, refl_thin, conjE, exE];
+  val elim_tac = REPEAT o Tactic.eresolve_tac elim_rls;
+
+  fun simp_case_tac ctxt i =
+    EVERY' [elim_tac, TRY o asm_full_simp_tac ctxt, elim_tac, REPEAT o bound_hyp_subst_tac ctxt] i;
+in
+fun mk_fun_cases ctxt prop =
+  let val thy = Proof_Context.theory_of ctxt;
+      fun err () =
+        error (Pretty.string_of (Pretty.block
+          [Pretty.str "Proposition is not a function equation:",
+           Pretty.fbrk, Syntax.pretty_term ctxt prop]));
+      val ((f,_),_) = dest_funprop (HOLogic.dest_Trueprop prop)
+              handle TERM _ => err ();
+      val info = Function.get_info ctxt f handle Empty => err ();
+      val {elims, pelims, is_partial, ...} = info;
+      val elims = if is_partial then pelims else the elims
+      val cprop = cterm_of thy prop;
+      val tac = ALLGOALS (simp_case_tac ctxt) THEN prune_params_tac;
+      fun mk_elim rl =
+        Thm.implies_intr cprop (Tactic.rule_by_tactic ctxt tac (Thm.assume cprop RS rl))
+        |> singleton (Variable.export (Variable.auto_fixes prop ctxt) ctxt);
+  in
+    case get_first (try mk_elim) (flat elims) of
+      SOME r => r
+    | NONE => err ()
+  end;
+end;
+
+
+(* Setting up the fun_cases command *)
+local
+  (* Converts the schematic variables and type variables in a term into free
+     variables and takes care of schematic variables originating from dummy
+     patterns by renaming them to something sensible. *)
+  fun pat_to_term ctxt t =
+    let
+       fun prep_var ((x,_),T) =
+            if x = "_dummy_" then ("x",T) else (x,T);
+       val schem_vars = Term.add_vars t [];
+       val prepped_vars = map prep_var schem_vars;
+       val fresh_vars = map Free (Variable.variant_frees ctxt [t] prepped_vars);
+       val subst = ListPair.zip (map fst schem_vars, fresh_vars);
+    in fst (yield_singleton (Variable.import_terms true)
+           (subst_Vars subst t) ctxt)
+    end;
+
+  fun fun_cases args ctxt =
+    let
+      val thy = Proof_Context.theory_of ctxt
+      val thmss = map snd args
+                  |> burrow (grouped 10 Par_List.map
+                      (mk_fun_cases ctxt
+                       o pat_to_term ctxt
+                       o HOLogic.mk_Trueprop
+                       o Proof_Context.read_term_pattern ctxt));
+      val facts = map2 (fn ((a,atts), _) => fn thms =>
+        ((a, map (Attrib.intern_src thy) atts), [(thms, [])])) args thmss;
+    in
+      ctxt |> Local_Theory.notes facts |>> map snd
+    end;
+in
+val _ =
+  Outer_Syntax.local_theory @{command_spec "fun_cases"}
+    "automatic derivation of simplified elimination rules for function equations"
+    (Parse.and_list1 Parse_Spec.specs >> (snd oo fun_cases));
+end;
+end;
+
--- a/src/HOL/Tools/Function/function.ML	Sun Sep 08 19:25:06 2013 +0200
+++ b/src/HOL/Tools/Function/function.ML	Sun Sep 08 22:32:47 2013 +0200
@@ -24,7 +24,7 @@
     (Attrib.binding * string) list -> Function_Common.function_config ->
     bool -> local_theory -> Proof.state
 
-  val prove_termination: term option -> tactic -> local_theory -> 
+  val prove_termination: term option -> tactic -> local_theory ->
     info * local_theory
   val prove_termination_cmd: string option -> tactic -> local_theory ->
     info * local_theory
@@ -94,9 +94,11 @@
 
     fun afterqed [[proof]] lthy =
       let
+        val result = cont (Thm.close_derivation proof)
         val FunctionResult {fs, R, dom, psimps, simple_pinducts,
-          termination, domintros, cases, ...} =
-          cont (Thm.close_derivation proof)
+                termination, domintros, cases, ...} = result
+
+        val pelims = Function_Elims.mk_partial_elim_rules lthy result
 
         val fnames = map (fst o fst) fixes
         fun qualify n = Binding.name n
@@ -105,7 +107,29 @@
 
         val addsmps = add_simps fnames post sort_cont
 
-        val ((((psimps', [pinducts']), [termination']), [cases']), lthy) =
+        (* TODO: case names *)
+        fun addcases lthy =
+          let fun go name thm (thms_acc, lthy) =
+                 case Local_Theory.note ((Binding.name "cases" |> Binding.qualify true name,
+                          [Attrib.internal (K (Rule_Cases.case_names cnames))]), [thm]) lthy
+                 of ((_,thms), lthy') => (thms :: thms_acc, lthy')
+              val (thms, lthy') = fold2 go fnames cases ([], lthy);
+          in
+            (rev thms, lthy')
+          end;
+
+        fun addpelims lthy =
+          let fun go name thm (thms_acc, lthy) =
+                 case Local_Theory.note ((Binding.name "pelims" |> Binding.qualify true name,
+                          [Attrib.internal (K (Rule_Cases.consumes 1)),
+                           Attrib.internal (K (Rule_Cases.constraints 1))]), thm) lthy
+                 of ((_,thms), lthy') => (thms :: thms_acc, lthy')
+              val (thms, lthy') = fold2 go fnames pelims ([], lthy);
+          in
+            (rev thms, lthy')
+          end;
+
+          val (((((psimps', [pinducts']), [termination']), cases'), pelims'), lthy) =
           lthy
           |> addsmps (conceal_partial o Binding.qualify false "partial")
                "psimps" conceal_partial psimp_attribs psimps
@@ -115,14 +139,15 @@
                   Attrib.internal (K (Rule_Cases.consumes (1 - Thm.nprems_of th))),
                   Attrib.internal (K (Induct.induct_pred ""))])))]
           ||>> (apfst snd o Local_Theory.note ((Binding.conceal (qualify "termination"), []), [termination]))
-          ||>> (apfst snd o Local_Theory.note ((qualify "cases",
-                 [Attrib.internal (K (Rule_Cases.case_names cnames))]), [cases]))
-          ||> (case domintros of NONE => I | SOME thms => 
+          ||>> addcases
+          ||>> addpelims
+          ||> (case domintros of NONE => I | SOME thms =>
                    Local_Theory.note ((qualify "domintros", []), thms) #> snd)
 
-        val info = { add_simps=addsmps, case_names=cnames, psimps=psimps',
+        val info = { add_simps=addsmps, fnames=fnames, case_names=cnames, psimps=psimps',
           pinducts=snd pinducts', simps=NONE, inducts=NONE, termination=termination',
-          fs=fs, R=R, dom=dom, defname=defname, is_partial=true, cases=cases'}
+          fs=fs, R=R, dom=dom, defname=defname, is_partial=true, cases=flat cases',
+          pelims=pelims',elims=NONE}
 
         val _ = Proof_Display.print_consts do_print lthy (K false) (map fst fixes)
       in
@@ -180,7 +205,7 @@
           | NONE => error "Not a function"))
 
     val { termination, fs, R, add_simps, case_names, psimps,
-      pinducts, defname, cases, dom, ...} = info
+      pinducts, defname, fnames, cases, dom, pelims, ...} = info
     val domT = domain_type (fastype_of R)
     val goal = HOLogic.mk_Trueprop (HOLogic.mk_all ("x", domT, mk_acc domT R $ Free ("x", domT)))
     fun afterqed [[totality]] lthy =
@@ -191,9 +216,23 @@
             addsimps [totality, @{thm True_implies_equals}])
         val tsimps = map remove_domain_condition psimps
         val tinduct = map remove_domain_condition pinducts
+        val telims = map (map remove_domain_condition) pelims
 
         fun qualify n = Binding.name n
           |> Binding.qualify true defname
+
+        fun addtelims lthy =
+          let fun go name thm (thms_acc, lthy) =
+                 case Local_Theory.note ((Binding.name "elims" |> Binding.qualify true name,
+                          [Attrib.internal (K (Rule_Cases.consumes 1)),
+                           Attrib.internal (K (Rule_Cases.constraints 1)),
+                           Attrib.internal (K (Induct.cases_pred defname))]), thm) lthy
+                 of ((_,thms), lthy') => (thms :: thms_acc, lthy')
+              val (thms, lthy') = fold2 go fnames telims ([], lthy);
+          in
+            (rev thms, lthy')
+          end;
+
       in
         lthy
         |> add_simps I "simps" I simp_attribs tsimps
@@ -201,13 +240,14 @@
            ((qualify "induct",
              [Attrib.internal (K (Rule_Cases.case_names case_names))]),
             tinduct)
-        |-> (fn (simps, (_, inducts)) => fn lthy =>
-          let val info' = { is_partial=false, defname=defname, add_simps=add_simps,
+        ||>> addtelims
+        |-> (fn ((simps,(_,inducts)), elims) => fn lthy =>
+          let val info' = { is_partial=false, defname=defname, fnames=fnames, add_simps=add_simps,
             case_names=case_names, fs=fs, R=R, dom=dom, psimps=psimps, pinducts=pinducts,
-            simps=SOME simps, inducts=SOME inducts, termination=termination, cases=cases }
+            simps=SOME simps, inducts=SOME inducts, termination=termination, cases=cases, pelims=pelims, elims=SOME elims}
           in
             (info',
-             lthy 
+             lthy
              |> Local_Theory.declaration {syntax = false, pervasive = false}
                (add_function_data o transform_function_data info')
              |> Spec_Rules.add Spec_Rules.Equational (fs, tsimps))
--- a/src/HOL/Tools/Function/function_common.ML	Sun Sep 08 19:25:06 2013 +0200
+++ b/src/HOL/Tools/Function/function_common.ML	Sun Sep 08 22:32:47 2013 +0200
@@ -13,6 +13,7 @@
     (* contains no logical entities: invariant under morphisms: *)
   add_simps : (binding -> binding) -> string -> (binding -> binding) ->
     Attrib.src list -> thm list -> local_theory -> thm list * local_theory,
+  fnames : string list,
   case_names : string list,
   fs : term list,
   R : term,
@@ -22,7 +23,9 @@
   simps : thm list option,
   inducts : thm list option,
   termination : thm,
-  cases : thm}
+  cases : thm list,
+  pelims: thm list list,
+  elims: thm list list option}
 
 end
 
@@ -35,6 +38,7 @@
     (* contains no logical entities: invariant under morphisms: *)
   add_simps : (binding -> binding) -> string -> (binding -> binding) ->
     Attrib.src list -> thm list -> local_theory -> thm list * local_theory,
+  fnames : string list,
   case_names : string list,
   fs : term list,
   R : term,
@@ -44,7 +48,9 @@
   simps : thm list option,
   inducts : thm list option,
   termination : thm,
-  cases : thm}
+  cases : thm list,
+  pelims : thm list list,
+  elims : thm list list option}
 
 end
 
@@ -66,7 +72,8 @@
     dom: term,
     psimps : thm list,
     simple_pinducts : thm list,
-    cases : thm,
+    cases : thm list,
+    pelims : thm list list,
     termination : thm,
     domintros : thm list option}
   val transform_function_data : info -> morphism -> info
@@ -146,23 +153,25 @@
   dom: term,
   psimps : thm list,
   simple_pinducts : thm list,
-  cases : thm,
+  cases : thm list,
+  pelims : thm list list,
   termination : thm,
   domintros : thm list option}
 
-fun transform_function_data ({add_simps, case_names, fs, R, dom, psimps, pinducts,
-  simps, inducts, termination, defname, is_partial, cases} : info) phi =
+fun transform_function_data ({add_simps, case_names, fnames, fs, R, dom, psimps, pinducts,
+  simps, inducts, termination, defname, is_partial, cases, pelims, elims} : info) phi =
     let
       val term = Morphism.term phi
       val thm = Morphism.thm phi
       val fact = Morphism.fact phi
       val name = Binding.name_of o Morphism.binding phi o Binding.name
     in
-      { add_simps = add_simps, case_names = case_names,
+      { add_simps = add_simps, case_names = case_names, fnames = fnames,
         fs = map term fs, R = term R, dom = term dom, psimps = fact psimps,
         pinducts = fact pinducts, simps = Option.map fact simps,
         inducts = Option.map fact inducts, termination = thm termination,
-        defname = name defname, is_partial=is_partial, cases = thm cases }
+        defname = name defname, is_partial=is_partial, cases = fact cases,
+        elims = Option.map (map fact) elims, pelims = map fact pelims }
     end
 
 (* FIXME just one data slot (record) per program unit *)
--- a/src/HOL/Tools/Function/function_core.ML	Sun Sep 08 19:25:06 2013 +0200
+++ b/src/HOL/Tools/Function/function_core.ML	Sun Sep 08 22:32:47 2013 +0200
@@ -915,8 +915,9 @@
              (map (mk_domain_intro lthy globals R R_elim)) xclauses)
            else NONE
       in
-        FunctionResult {fs=[f], G=G, R=R, dom=dom, cases=complete_thm,
-          psimps=psimps, simple_pinducts=[simple_pinduct],
+        FunctionResult {fs=[f], G=G, R=R, dom=dom,
+          cases=[complete_thm], psimps=psimps, pelims=[],
+          simple_pinducts=[simple_pinduct],
           termination=total_intro, domintros=dom_intros}
       end
   in
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/Tools/Function/function_elims.ML	Sun Sep 08 22:32:47 2013 +0200
@@ -0,0 +1,150 @@
+(*  Title:      HOL/Tools/Function/function_elims.ML
+    Author:     Manuel Eberl <eberlm@in.tum.de>, TU München
+
+Generates the pelims rules for a function. These are of the shape
+[|f x y z = w; !!…. [|x = …; y = …; z = …; w = …|] ==> P; …|] ==> P
+and are derived from the cases rule. There is at least one pelim rule for
+each function (cf. mutually recursive functions)
+There may be more than one pelim rule for a function in case of functions
+that return a boolean. For such a function, e.g. P x, not only the normal
+elim rule with the premise P x = z is generated, but also two additional
+elim rules with P x resp. ¬P x as premises.
+*)
+
+signature FUNCTION_ELIMS =
+sig
+  val dest_funprop : term -> (term * term list) * term
+  val mk_partial_elim_rules :
+          local_theory -> Function_Common.function_result -> thm list list
+end;
+
+structure Function_Elims : FUNCTION_ELIMS =
+struct
+
+open Function_Lib
+open Function_Common
+
+(* Extracts a function and its arguments from a proposition that is
+   either of the form "f x y z = ..." or, in case of function that
+   returns a boolean, "f x y z" *)
+fun dest_funprop (Const ("HOL.eq", _) $ lhs $ rhs) = (strip_comb lhs, rhs)
+  | dest_funprop (Const ("HOL.Not", _) $ trm) = (strip_comb trm, @{term "False"})
+  | dest_funprop trm = (strip_comb trm, @{term "True"});
+
+local
+  fun propagate_tac i thm =
+    let fun inspect eq = case eq of
+                Const ("HOL.Trueprop",_) $ (Const ("HOL.eq",_) $ Free x $ t) =>
+                    if Logic.occs (Free x, t) then raise Match else true
+              | Const ("HOL.Trueprop",_) $ (Const ("HOL.eq",_) $ t $ Free x) =>
+                    if Logic.occs (Free x, t) then raise Match else false
+              | _ => raise Match;
+        fun mk_eq thm = (if inspect (prop_of thm) then
+                            [thm RS eq_reflection]
+                        else
+                            [Thm.symmetric (thm RS eq_reflection)])
+                        handle Match => [];
+        val ss = Simplifier.global_context (Thm.theory_of_thm thm) empty_ss
+                 |> Simplifier.set_mksimps (K mk_eq)
+    in
+      asm_lr_simp_tac ss i thm
+    end;
+
+  val eqBoolI = @{lemma "!!P. P ==> P = True" "!!P. ~P ==> P = False" by iprover+}
+  val boolE = @{thms HOL.TrueE HOL.FalseE}
+  val boolD = @{lemma "!!P. True = P ==> P" "!!P. False = P ==> ~P" by iprover+}
+  val eqBool = @{thms HOL.eq_True HOL.eq_False HOL.not_False_eq_True HOL.not_True_eq_False}
+
+  fun bool_subst_tac ctxt i =
+      REPEAT (EqSubst.eqsubst_asm_tac ctxt [1] eqBool i)
+      THEN REPEAT (dresolve_tac boolD i)
+      THEN REPEAT (eresolve_tac boolE i)
+
+  fun mk_bool_elims ctxt elim =
+    let val tac = ALLGOALS (bool_subst_tac ctxt)
+        fun mk_bool_elim b =
+          elim
+          |> Thm.forall_elim b
+          |> Tactic.rule_by_tactic ctxt (TRY (resolve_tac eqBoolI 1))
+          |> Tactic.rule_by_tactic ctxt tac
+    in
+        map mk_bool_elim [@{cterm True}, @{cterm False}]
+    end;
+
+in
+
+  fun mk_partial_elim_rules ctxt result=
+    let val FunctionResult {fs, G, R, dom, psimps, simple_pinducts, cases,
+                            termination, domintros, ...} = result;
+        val n_fs = length fs;
+
+        fun mk_partial_elim_rule (idx,f) =
+          let fun mk_funeq 0 T (acc_vars, acc_lhs) =
+                  let val y = Free("y",T) in
+                    (y :: acc_vars, (HOLogic.mk_Trueprop (HOLogic.mk_eq (acc_lhs, y))), T)
+                  end
+                | mk_funeq n (Type("fun",[S,T])) (acc_vars, acc_lhs) =
+                  let val xn = Free ("x" ^ Int.toString n,S) in
+                    mk_funeq (n - 1) T (xn :: acc_vars, acc_lhs $ xn)
+                  end
+                | mk_funeq _ _ _ = raise (TERM ("Not a function.", [f]))
+
+              val f_simps = filter (fn r => (prop_of r |> Logic.strip_assums_concl
+                                             |> HOLogic.dest_Trueprop
+                                             |> dest_funprop |> fst |> fst) = f)
+                                   psimps
+
+              val arity = hd f_simps |> prop_of |> Logic.strip_assums_concl
+                                     |> HOLogic.dest_Trueprop
+                                     |> snd o fst o dest_funprop |> length;
+              val (free_vars,prop,ranT) = mk_funeq arity (fastype_of f) ([],f)
+              val (rhs_var, arg_vars) = case free_vars of x::xs => (x, rev xs)
+              val args = HOLogic.mk_tuple arg_vars;
+              val domT = R |> dest_Free |> snd |> hd o snd o dest_Type
+
+              val sumtree_inj = SumTree.mk_inj domT n_fs (idx+1) args;
+
+              val thy = Proof_Context.theory_of ctxt;
+              val cprop = cterm_of thy prop
+
+              val asms = [cprop, cterm_of thy (HOLogic.mk_Trueprop (dom $ sumtree_inj))];
+              val asms_thms = map Thm.assume asms;
+
+              fun prep_subgoal i =
+                REPEAT (eresolve_tac @{thms Pair_inject} i)
+                THEN Method.insert_tac (case asms_thms of
+                                          thm::thms => (thm RS sym) :: thms) i
+                THEN propagate_tac i
+                THEN TRY
+                    ((EqSubst.eqsubst_asm_tac ctxt [1] psimps i) THEN atac i)
+                THEN bool_subst_tac ctxt i;
+
+            val tac = ALLGOALS prep_subgoal;
+
+            val elim_stripped =
+                  nth cases idx
+                  |> Thm.forall_elim @{cterm "P::bool"}
+                  |> Thm.forall_elim (cterm_of thy args)
+                  |> Tactic.rule_by_tactic ctxt tac
+                  |> fold_rev Thm.implies_intr asms
+                  |> Thm.forall_intr (cterm_of thy rhs_var)
+
+            val bool_elims = (case ranT of
+                                Type ("HOL.bool", []) => mk_bool_elims ctxt elim_stripped
+                                | _ => []);
+
+            fun unstrip rl =
+                  rl  |> (fn thm => List.foldr (uncurry Thm.forall_intr) thm
+                             (map (cterm_of thy) arg_vars))
+                      |> Thm.forall_intr @{cterm "P::bool"}
+
+        in
+          map unstrip (elim_stripped :: bool_elims)
+        end;
+
+    in
+      map_index mk_partial_elim_rule fs
+    end;
+  end;
+end;
+
--- a/src/HOL/Tools/Function/mutual.ML	Sun Sep 08 19:25:06 2013 +0200
+++ b/src/HOL/Tools/Function/mutual.ML	Sun Sep 08 22:32:47 2013 +0200
@@ -252,7 +252,7 @@
   let
     val result = inner_cont proof
     val FunctionResult {G, R, cases, psimps, simple_pinducts=[simple_pinduct],
-      termination, domintros, dom, ...} = result
+      termination, domintros, dom, pelims, ...} = result
 
     val (all_f_defs, fs) =
       map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} =>
@@ -271,13 +271,82 @@
     val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m
     val mtermination = full_simplify rew_simpset termination
     val mdomintros = Option.map (map (full_simplify rew_simpset)) domintros
+
   in
     FunctionResult { fs=fs, G=G, R=R, dom=dom,
       psimps=mpsimps, simple_pinducts=minducts,
-      cases=cases, termination=mtermination,
+      cases=cases, pelims=pelims, termination=mtermination,
       domintros=mdomintros}
   end
 
+
+fun postprocess_cases_rules ctxt cont proof =
+  let val result = cont proof;
+      val FunctionResult {fs, G, R, dom, psimps, simple_pinducts, cases, pelims,
+                        termination, domintros, ...} = result;
+      val n_fs = length fs;
+
+      fun postprocess_cases_rule (idx,f) =
+        let fun dest_funprop (Const ("HOL.eq", _) $ lhs $ rhs) = (strip_comb lhs, rhs)
+              | dest_funprop (Const ("HOL.Not", _) $ trm) = (strip_comb trm, @{term "False"})
+              | dest_funprop trm = (strip_comb trm, @{term "True"});
+
+            fun mk_fun_args 0 _ acc_vars = rev acc_vars
+              | mk_fun_args n (Type("fun",[S,T])) acc_vars =
+                let val xn = Free ("x" ^ Int.toString n,S) in
+                  mk_fun_args (n - 1) T (xn :: acc_vars)
+                end
+              | mk_fun_args _ _ _ = raise (TERM ("Not a function.", [f]))
+
+
+            val f_simps = filter (fn r => (prop_of r |> Logic.strip_assums_concl
+                                           |> HOLogic.dest_Trueprop
+                                           |> dest_funprop |> fst |> fst) = f)
+                                 psimps
+
+            val arity = hd f_simps |> prop_of |> Logic.strip_assums_concl
+                                   |> HOLogic.dest_Trueprop
+                                   |> snd o fst o dest_funprop |> length;
+            val arg_vars = mk_fun_args arity (fastype_of f) []
+            val argsT = fastype_of (HOLogic.mk_tuple arg_vars);
+            val args = Free ("x", argsT);
+
+            val thy = Proof_Context.theory_of ctxt;
+            val domT = R |> dest_Free |> snd |> hd o snd o dest_Type
+
+            val sumtree_inj = SumTree.mk_inj domT n_fs (idx+1) args;
+
+            val sum_elims = @{thms HOL.notE[OF Sum_Type.sum.distinct(1)]
+                                   HOL.notE[OF Sum_Type.sum.distinct(2)]};
+            fun prep_subgoal i =
+              REPEAT (eresolve_tac @{thms Pair_inject Inl_inject[elim_format]
+                                          Inr_inject[elim_format]} i)
+(*              THEN propagate_tac i*)
+(*              THEN bool_subst_tac ctxt i*)
+              THEN REPEAT (Tactic.eresolve_tac sum_elims i);
+
+            val tac = ALLGOALS prep_subgoal;
+
+        in
+            hd cases
+              |> Thm.forall_elim @{cterm "P::bool"}
+              |> Thm.forall_elim (cterm_of thy sumtree_inj)
+              |> Tactic.rule_by_tactic ctxt tac
+              |> Thm.forall_intr (cterm_of thy args)
+              |> Thm.forall_intr @{cterm "P::bool"}
+
+        end;
+
+  val cases' = map_index postprocess_cases_rule fs;
+
+in
+  FunctionResult {fs=fs, G=G, R=R, dom=dom, psimps=psimps,
+                  simple_pinducts=simple_pinducts,
+                  cases=cases', pelims=pelims, termination=termination,
+                  domintros=domintros}
+end;
+
+
 fun prepare_function_mutual config defname fixes eqss lthy =
   let
     val mutual as Mutual {fsum_var=(n, T), qglrs, ...} =
@@ -288,9 +357,10 @@
 
     val (mutual', lthy'') = define_projections fixes mutual fsum lthy'
 
-    val mutual_cont = mk_partial_rules_mutual lthy'' cont mutual'
+    val cont' = mk_partial_rules_mutual lthy'' cont mutual'
+    val cont'' = postprocess_cases_rules lthy'' cont'
   in
-    ((goalstate, mutual_cont), lthy'')
+    ((goalstate, cont''), lthy'')
   end
 
 end