src/HOL/Tools/Function/mutual.ML
changeset 53603 59ef06cda7b9
parent 52384 80c00a851de5
child 53605 462151f900ea
--- a/src/HOL/Tools/Function/mutual.ML	Sun Sep 08 19:25:06 2013 +0200
+++ b/src/HOL/Tools/Function/mutual.ML	Sun Sep 08 22:32:47 2013 +0200
@@ -252,7 +252,7 @@
   let
     val result = inner_cont proof
     val FunctionResult {G, R, cases, psimps, simple_pinducts=[simple_pinduct],
-      termination, domintros, dom, ...} = result
+      termination, domintros, dom, pelims, ...} = result
 
     val (all_f_defs, fs) =
       map (fn MutualPart {f_defthm = SOME f_def, f = SOME f, cargTs, ...} =>
@@ -271,13 +271,82 @@
     val minducts = mutual_induct_rules lthy simple_pinduct all_f_defs m
     val mtermination = full_simplify rew_simpset termination
     val mdomintros = Option.map (map (full_simplify rew_simpset)) domintros
+
   in
     FunctionResult { fs=fs, G=G, R=R, dom=dom,
       psimps=mpsimps, simple_pinducts=minducts,
-      cases=cases, termination=mtermination,
+      cases=cases, pelims=pelims, termination=mtermination,
       domintros=mdomintros}
   end
 
+
+fun postprocess_cases_rules ctxt cont proof =
+  let val result = cont proof;
+      val FunctionResult {fs, G, R, dom, psimps, simple_pinducts, cases, pelims,
+                        termination, domintros, ...} = result;
+      val n_fs = length fs;
+
+      fun postprocess_cases_rule (idx,f) =
+        let fun dest_funprop (Const ("HOL.eq", _) $ lhs $ rhs) = (strip_comb lhs, rhs)
+              | dest_funprop (Const ("HOL.Not", _) $ trm) = (strip_comb trm, @{term "False"})
+              | dest_funprop trm = (strip_comb trm, @{term "True"});
+
+            fun mk_fun_args 0 _ acc_vars = rev acc_vars
+              | mk_fun_args n (Type("fun",[S,T])) acc_vars =
+                let val xn = Free ("x" ^ Int.toString n,S) in
+                  mk_fun_args (n - 1) T (xn :: acc_vars)
+                end
+              | mk_fun_args _ _ _ = raise (TERM ("Not a function.", [f]))
+
+
+            val f_simps = filter (fn r => (prop_of r |> Logic.strip_assums_concl
+                                           |> HOLogic.dest_Trueprop
+                                           |> dest_funprop |> fst |> fst) = f)
+                                 psimps
+
+            val arity = hd f_simps |> prop_of |> Logic.strip_assums_concl
+                                   |> HOLogic.dest_Trueprop
+                                   |> snd o fst o dest_funprop |> length;
+            val arg_vars = mk_fun_args arity (fastype_of f) []
+            val argsT = fastype_of (HOLogic.mk_tuple arg_vars);
+            val args = Free ("x", argsT);
+
+            val thy = Proof_Context.theory_of ctxt;
+            val domT = R |> dest_Free |> snd |> hd o snd o dest_Type
+
+            val sumtree_inj = SumTree.mk_inj domT n_fs (idx+1) args;
+
+            val sum_elims = @{thms HOL.notE[OF Sum_Type.sum.distinct(1)]
+                                   HOL.notE[OF Sum_Type.sum.distinct(2)]};
+            fun prep_subgoal i =
+              REPEAT (eresolve_tac @{thms Pair_inject Inl_inject[elim_format]
+                                          Inr_inject[elim_format]} i)
+(*              THEN propagate_tac i*)
+(*              THEN bool_subst_tac ctxt i*)
+              THEN REPEAT (Tactic.eresolve_tac sum_elims i);
+
+            val tac = ALLGOALS prep_subgoal;
+
+        in
+            hd cases
+              |> Thm.forall_elim @{cterm "P::bool"}
+              |> Thm.forall_elim (cterm_of thy sumtree_inj)
+              |> Tactic.rule_by_tactic ctxt tac
+              |> Thm.forall_intr (cterm_of thy args)
+              |> Thm.forall_intr @{cterm "P::bool"}
+
+        end;
+
+  val cases' = map_index postprocess_cases_rule fs;
+
+in
+  FunctionResult {fs=fs, G=G, R=R, dom=dom, psimps=psimps,
+                  simple_pinducts=simple_pinducts,
+                  cases=cases', pelims=pelims, termination=termination,
+                  domintros=domintros}
+end;
+
+
 fun prepare_function_mutual config defname fixes eqss lthy =
   let
     val mutual as Mutual {fsum_var=(n, T), qglrs, ...} =
@@ -288,9 +357,10 @@
 
     val (mutual', lthy'') = define_projections fixes mutual fsum lthy'
 
-    val mutual_cont = mk_partial_rules_mutual lthy'' cont mutual'
+    val cont' = mk_partial_rules_mutual lthy'' cont mutual'
+    val cont'' = postprocess_cases_rules lthy'' cont'
   in
-    ((goalstate, mutual_cont), lthy'')
+    ((goalstate, cont''), lthy'')
   end
 
 end