generalized induct_scheme method to prove conditional induction schemes.
authorkrauss
Thu, 19 Jun 2008 11:46:14 +0200
changeset 27271 ba2a00d35df1
parent 27270 6a353260735e
child 27272 75b251e9cdb7
generalized induct_scheme method to prove conditional induction schemes.
src/HOL/FunDef.thy
src/HOL/Tools/function_package/fundef_datatype.ML
src/HOL/Tools/function_package/induction_scheme.ML
src/HOL/ex/Induction_Scheme.thy
--- a/src/HOL/FunDef.thy	Thu Jun 19 00:02:08 2008 +0200
+++ b/src/HOL/FunDef.thy	Thu Jun 19 11:46:14 2008 +0200
@@ -18,10 +18,10 @@
   ("Tools/function_package/pattern_split.ML")
   ("Tools/function_package/fundef_package.ML")
   ("Tools/function_package/auto_term.ML")
-  ("Tools/function_package/induction_scheme.ML")
   ("Tools/function_package/measure_functions.ML")
   ("Tools/function_package/lexicographic_order.ML")
   ("Tools/function_package/fundef_datatype.ML")
+  ("Tools/function_package/induction_scheme.ML")
 begin
 
 text {* Definitions with default value. *}
@@ -110,10 +110,10 @@
 use "Tools/function_package/pattern_split.ML"
 use "Tools/function_package/auto_term.ML"
 use "Tools/function_package/fundef_package.ML"
-use "Tools/function_package/induction_scheme.ML"
 use "Tools/function_package/measure_functions.ML"
 use "Tools/function_package/lexicographic_order.ML"
 use "Tools/function_package/fundef_datatype.ML"
+use "Tools/function_package/induction_scheme.ML"
 
 setup {* 
   FundefPackage.setup 
--- a/src/HOL/Tools/function_package/fundef_datatype.ML	Thu Jun 19 00:02:08 2008 +0200
+++ b/src/HOL/Tools/function_package/fundef_datatype.ML	Thu Jun 19 11:46:14 2008 +0200
@@ -8,13 +8,14 @@
 
 signature FUNDEF_DATATYPE =
 sig
-    val pat_complete_tac: int -> tactic
+    val pat_complete_tac: Proof.context -> int -> tactic
+    val prove_completeness : theory -> term list -> term -> term list list -> term list list -> thm
 
-    val pat_completeness : method
+    val pat_completeness : Proof.context -> method
     val setup : theory -> theory
 end
 
-structure FundefDatatype: FUNDEF_DATATYPE =
+structure FundefDatatype : FUNDEF_DATATYPE =
 struct
 
 open FundefLib
@@ -146,60 +147,60 @@
   | o_alg _ _ _ _ _ = raise Match
 
 
-fun prove_completeness thy x P qss pats =
+fun prove_completeness thy xs P qss patss =
     let
-        fun mk_assum qs pat = Logic.mk_implies (HOLogic.mk_Trueprop (HOLogic.mk_eq (x,pat)),
-                                                HOLogic.mk_Trueprop P)
-                                               |> fold_rev mk_forall qs
-                                               |> cterm_of thy
+        fun mk_assum qs pats = 
+            HOLogic.mk_Trueprop P
+            |> fold_rev (curry Logic.mk_implies o HOLogic.mk_Trueprop o HOLogic.mk_eq) (xs ~~ pats)
+            |> fold_rev mk_forall qs
+            |> cterm_of thy
 
-        val hyps = map2 mk_assum qss pats
+        val hyps = map2 mk_assum qss patss
 
         fun inst_hyps hyp qs = fold (forall_elim o cterm_of thy) qs (assume hyp)
 
         val assums = map2 inst_hyps hyps qss
     in
-        o_alg thy P 2 [x] (map2 (pair o single) pats assums)
+        o_alg thy P 2 xs (patss ~~ assums)
               |> fold_rev implies_intr hyps
     end
 
 
 
-fun pat_complete_tac i thm =
+fun pat_complete_tac ctxt = SUBGOAL (fn (subgoal, i) =>
     let
-      val thy = theory_of_thm thm
-
-        val subgoal = nth (prems_of thm) (i - 1)   (* FIXME SUBGOAL tactical *)
+      val thy = ProofContext.theory_of ctxt
+      val (vs, subgf) = dest_all_all subgoal
+      val (cases, _ $ thesis) = Logic.strip_horn subgf
+          handle Bind => raise COMPLETENESS
 
-        val ([P, x], subgf) = dest_all_all subgoal
-
-        val assums = Logic.strip_imp_prems subgf
-
-        fun pat_of assum =
+      fun pat_of assum =
             let
                 val (qs, imp) = dest_all_all assum
+                val prems = Logic.strip_imp_prems imp
             in
-                case Logic.dest_implies imp of
-                    (_ $ (_ $ _ $ pat), _) => (qs, pat)
-                  | _ => raise COMPLETENESS
+              (qs, map (HOLogic.dest_eq o HOLogic.dest_Trueprop) prems)
             end
 
-        val (qss, pats) = split_list (map pat_of assums)
+        val (qss, x_pats) = split_list (map pat_of cases)
+        val xs = map fst (hd x_pats)
+                 handle Empty => raise COMPLETENESS
+                 
+        val patss = map (map snd) x_pats 
 
-        val complete_thm = prove_completeness thy x P qss pats
-                                              |> forall_intr (cterm_of thy x)
-                                              |> forall_intr (cterm_of thy P)
+        val complete_thm = prove_completeness thy xs thesis qss patss
+             |> fold_rev (forall_intr o cterm_of thy) vs
     in
-        Seq.single (Drule.compose_single(complete_thm, i, thm))
+      PRIMITIVE (fn st => Drule.compose_single(complete_thm, i, st))
     end
-    handle COMPLETENESS => Seq.empty
+    handle COMPLETENESS => no_tac)
 
 
-val pat_completeness = Method.SIMPLE_METHOD' pat_complete_tac
+fun pat_completeness ctxt = Method.SIMPLE_METHOD' (pat_complete_tac ctxt)
 
 val by_pat_completeness_simp =
     Proof.global_terminal_proof
-      (Method.Basic (K pat_completeness, Position.none),
+      (Method.Basic (pat_completeness, Position.none),
        SOME (Method.Source_i (Args.src (("HOL.auto", []), Position.none))))
 
 val termination_by_lexicographic_order =
@@ -292,7 +293,7 @@
     end
 
 val setup =
-    Method.add_methods [("pat_completeness", Method.no_args pat_completeness, 
+    Method.add_methods [("pat_completeness", Method.ctxt_args pat_completeness, 
                          "Completeness prover for datatype patterns")]
     #> Context.theory_map (FundefCommon.set_preproc sequential_preproc)
 
--- a/src/HOL/Tools/function_package/induction_scheme.ML	Thu Jun 19 00:02:08 2008 +0200
+++ b/src/HOL/Tools/function_package/induction_scheme.ML	Thu Jun 19 11:46:14 2008 +0200
@@ -7,7 +7,8 @@
 
 signature INDUCTION_SCHEME =
 sig
-  val mk_ind_tac : Proof.context -> thm list -> tactic  
+  val mk_ind_tac : (int -> tactic) -> (int -> tactic) -> (int -> tactic)
+                   -> Proof.context -> thm list -> tactic  
   val setup : theory -> theory
 end
 
@@ -17,26 +18,48 @@
 
 open FundefLib
 
-type rec_call_info = (string * typ) list * term list * term
+
+type rec_call_info = int * (string * typ) list * term list * term list
 
 datatype scheme_case =
   SchemeCase of
   {
+   bidx : int,
    qs: (string * typ) list,
+   oqnames: string list,
    gs: term list,
-   lhs: term,
+   lhs: term list,
    rs: rec_call_info list
   }
 
+datatype scheme_branch = 
+  SchemeBranch of
+  {
+   P : term,
+   xs: (string * typ) list,
+   ws: (string * typ) list,
+   Cs: term list
+  }
+
 datatype ind_scheme =
   IndScheme of
   {
-   (*cvars : (string, typ) list,       
-   cassms : term list,    *)     (* additional context for partial rules *)
-   T: typ,
+   T: typ, (* sum of products *)
+   branches: scheme_branch list,
    cases: scheme_case list
   }
 
+val ind_atomize = MetaSimplifier.rewrite true @{thms induct_atomize}
+val ind_rulify = MetaSimplifier.rewrite true @{thms induct_rulify}
+
+fun meta thm = thm RS eq_reflection
+
+val sum_prod_conv = MetaSimplifier.rewrite true 
+                    (map meta (@{thm split_conv} :: @{thms sum_cases}))
+
+fun term_conv thy cv t = 
+    cv (cterm_of thy t)
+    |> prop_of |> Logic.dest_equals |> snd
 
 fun mk_relT T = HOLogic.mk_setT (HOLogic.mk_prodT (T, T))
 
@@ -47,57 +70,126 @@
       (ctxt', vars, Logic.strip_imp_prems imp, Logic.strip_imp_concl imp)
     end
 
-fun mk_case P ctxt premise =
+
+fun mk_scheme' ctxt cases concl =
     let
-      val (ctxt', qs, prems, concl) = dest_hhf ctxt premise
-      val _ $ (_ $ lhs) = concl 
-
-      fun mk_rcinfo pr =
+      fun mk_branch concl =
           let
-            val (ctxt'', Gvs, Gas, _ $ (_ $ rcarg)) = dest_hhf ctxt' pr
+            val (ctxt', ws, Cs, _ $ Pxs) = dest_hhf ctxt concl
+            val (P, xs) = strip_comb Pxs
           in
-            (Gvs, Gas, rcarg)
+            SchemeBranch { P=P, xs=map dest_Free xs, ws=ws, Cs=Cs }
           end
 
-      val (gs, rcprs) = take_prefix (not o exists_aterm (fn Free v => v = P | _ => false)) prems
+      val (branches, cases') = (* correction *)
+          case Logic.dest_conjunction_list concl of
+            [conc] => 
+            let 
+              val _ $ Pxs = Logic.strip_assums_concl conc
+              val (P, _) = strip_comb Pxs
+              val (cases', conds) = take_prefix (Term.exists_subterm (curry op aconv P)) cases
+              val concl' = fold_rev (curry Logic.mk_implies) conds conc
+            in
+              ([mk_branch concl'], cases')
+            end
+          | concls => (map mk_branch concls, cases)
+
+      fun mk_case premise =
+          let
+            val (ctxt', qs, prems, _ $ Plhs) = dest_hhf ctxt premise
+            val (P, lhs) = strip_comb Plhs
+                                
+            fun bidx Q = find_index (fn SchemeBranch {P=P',...} => Q aconv P') branches
+
+            fun mk_rcinfo pr =
+                let
+                  val (ctxt'', Gvs, Gas, _ $ Phyp) = dest_hhf ctxt' pr
+                  val (P', rcs) = strip_comb Phyp
+                in
+                  (bidx P', Gvs, Gas, rcs)
+                end
+                
+            fun is_pred v = exists (fn SchemeBranch {P,...} => v aconv P) branches
+
+            val (gs, rcprs) = 
+                take_prefix (not o Term.exists_subterm is_pred) prems
+          in
+            SchemeCase {bidx=bidx P, qs=qs, oqnames=map fst qs(*FIXME*), gs=gs, lhs=lhs, rs=map mk_rcinfo rcprs}
+          end
+
+      fun PT_of (SchemeBranch { xs, ...}) =
+            foldr1 HOLogic.mk_prodT (map snd xs)
+
+      val ST = BalancedTree.make (uncurry SumTree.mk_sumT) (map PT_of branches)
     in
-      SchemeCase {qs=qs, gs=gs, lhs=lhs, rs=map mk_rcinfo rcprs}
+      IndScheme {T=ST, cases=map mk_case cases', branches=branches }
     end
 
-fun mk_scheme' ctxt cases (Pn, PT) =
-    IndScheme {T=domain_type PT, cases=map (mk_case (Pn,PT) ctxt) cases }
+
 
-fun mk_completeness ctxt (IndScheme {T, cases}) =
+fun mk_completeness ctxt (IndScheme {cases, branches, ...}) bidx =
     let
-      val allqnames = fold (fn SchemeCase {qs, ...} => fold (insert (op =) o Free) qs) cases []
-      val [Pbool, x] = map Free (Variable.variant_frees ctxt allqnames [("P", HOLogic.boolT), ("x", T)])
+      val SchemeBranch { xs, ws, Cs, ... } = nth branches bidx
+      val relevant_cases = filter (fn SchemeCase {bidx=bidx', ...} => bidx' = bidx) cases
+
+      val allqnames = fold (fn SchemeCase {qs, ...} => fold (insert (op =) o Free) qs) relevant_cases []
+      val (Pbool :: xs') = map Free (Variable.variant_frees ctxt allqnames (("P", HOLogic.boolT) :: xs))
+      val Cs' = map (Pattern.rewrite_term (ProofContext.theory_of ctxt) (filter_out (op aconv) (map Free xs ~~ xs')) []) Cs
                        
-      fun mk_case (SchemeCase {qs, gs, lhs, ...}) =
+      fun mk_case (SchemeCase {qs, oqnames, gs, lhs, ...}) =
           HOLogic.mk_Trueprop Pbool
-                     |> curry Logic.mk_implies (HOLogic.mk_Trueprop (HOLogic.mk_eq (x, lhs)))
+                     |> fold_rev (fn x_l => curry Logic.mk_implies (HOLogic.mk_Trueprop(HOLogic.mk_eq x_l)))
+                                 (xs' ~~ lhs)
                      |> fold_rev (curry Logic.mk_implies) gs
-                     |> fold_rev (mk_forall o Free) qs
+                     |> fold_rev mk_forall_rename (oqnames ~~ map Free qs)
     in
       HOLogic.mk_Trueprop Pbool
-       |> fold_rev (curry Logic.mk_implies o mk_case) cases
-       |> mk_forall_rename ("x", x)
+       |> fold_rev (curry Logic.mk_implies o mk_case) relevant_cases
+       |> fold_rev (curry Logic.mk_implies) Cs'
+       |> fold_rev (mk_forall o Free) ws
+       |> fold_rev mk_forall_rename (map fst xs ~~ xs')
        |> mk_forall_rename ("P", Pbool)
     end
 
 fun mk_wf ctxt R (IndScheme {T, ...}) =
     HOLogic.Trueprop $ (Const (@{const_name "wf"}, mk_relT T --> HOLogic.boolT) $ R)
 
-fun mk_ineqs R (IndScheme {T, cases}) =
+fun mk_ineqs R (IndScheme {T, cases, branches}) =
     let
-      fun f (SchemeCase {qs, gs, lhs, rs, ...}) = 
+      fun inject i ts =
+          SumTree.mk_inj T (length branches) (i + 1) (foldr1 HOLogic.mk_prod ts)
+
+      val thesis = Free ("thesis", HOLogic.boolT) (* FIXME *)
+
+      fun mk_pres bdx args = 
           let
-            fun g (Gvs, Gas, rcarg) =
-                HOLogic.mk_mem (HOLogic.mk_prod (rcarg, lhs), R)
-                  |> HOLogic.mk_Trueprop
-                  |> fold_rev (curry Logic.mk_implies) Gas
-                  |> fold_rev (curry Logic.mk_implies) gs
-                  |> fold_rev (mk_forall o Free) Gvs
-                  |> fold_rev (mk_forall o Free) qs
+            val SchemeBranch { xs, ws, Cs, ... } = nth branches bdx
+            fun replace (x, v) t = betapply (lambda (Free x) t, v)
+            val Cs' = map (fold replace (xs ~~ args)) Cs
+            val cse = 
+                HOLogic.mk_Trueprop thesis
+                |> fold_rev (curry Logic.mk_implies) Cs'
+                |> fold_rev (mk_forall o Free) ws
+          in
+            Logic.mk_implies (cse, HOLogic.mk_Trueprop thesis)
+          end
+
+      fun f (SchemeCase {bidx, qs, oqnames, gs, lhs, rs, ...}) = 
+          let
+            fun g (bidx', Gvs, Gas, rcarg) =
+                let val export = 
+                         fold_rev (curry Logic.mk_implies) Gas
+                         #> fold_rev (curry Logic.mk_implies) gs
+                         #> fold_rev (mk_forall o Free) Gvs
+                         #> fold_rev mk_forall_rename (oqnames ~~ map Free qs)
+                in
+                (HOLogic.mk_mem (HOLogic.mk_prod (inject bidx' rcarg, inject bidx lhs), R)
+                 |> HOLogic.mk_Trueprop
+                 |> export,
+                 mk_pres bidx' rcarg
+                 |> export
+                 |> mk_forall thesis)
+                end
           in
             map g rs
           end
@@ -106,11 +198,37 @@
     end
 
 
-fun mk_induct_rule ctxt R P x complete_thm wf_thm ineqss (IndScheme {T, cases=scases}) =
+fun mk_hol_imp a b = HOLogic.imp $ a $ b
+
+fun mk_ind_goal thy branches =
     let
+      fun brnch (SchemeBranch { P, xs, ws, Cs, ... }) =
+          HOLogic.mk_Trueprop (list_comb (P, map Free xs))
+          |> fold_rev (curry Logic.mk_implies) Cs
+          |> fold_rev (mk_forall o Free) ws
+          |> term_conv thy ind_atomize
+          |> ObjectLogic.drop_judgment thy
+          |> tupled_lambda (foldr1 HOLogic.mk_prod (map Free xs))
+    in
+      SumTree.mk_sumcases HOLogic.boolT (map brnch branches)
+    end
+
+
+fun mk_induct_rule ctxt R x complete_thms wf_thm ineqss (IndScheme {T, cases=scases, branches}) =
+    let
+      val n = length branches
+
+      val scases_idx = map_index I scases
+
+      fun inject i ts =
+          SumTree.mk_inj T n (i + 1) (foldr1 HOLogic.mk_prod ts)
+      val P_of = nth (map (fn (SchemeBranch { P, ... }) => P) branches)
+
       val thy = ProofContext.theory_of ctxt
       val cert = cterm_of thy 
 
+      val P_comp = mk_ind_goal thy branches
+
       (* Inductive Hypothesis: !!z. (z,x):R ==> P z *)
       val ihyp = all T $ Abs ("z", T, 
                implies $ 
@@ -118,63 +236,106 @@
                   Const ("op :", HOLogic.mk_prodT (T, T) --> mk_relT T --> HOLogic.boolT) 
                     $ (HOLogic.pair_const T T $ Bound 0 $ x) 
                     $ R)
-             $ HOLogic.mk_Trueprop (P $ Bound 0))
+             $ HOLogic.mk_Trueprop (P_comp $ Bound 0))
            |> cert
 
       val aihyp = assume ihyp
 
-      fun prove_case (SchemeCase {qs, gs, lhs, rs, ...}) ineqs =
+     (* Rule for case splitting along the sum types *)
+      val xss = map (fn (SchemeBranch { xs, ... }) => map Free xs) branches
+      val pats = map_index (uncurry inject) xss
+      val sum_split_rule = FundefDatatype.prove_completeness thy [x] (P_comp $ x) xss (map single pats)
+
+      fun prove_branch (bidx, (SchemeBranch { P, xs, ws, Cs, ... }, (complete_thm, pat))) =
           let
-            val case_hyp = assume (cert (HOLogic.Trueprop $ (HOLogic.mk_eq (x, lhs))))
+            val fxs = map Free xs
+            val branch_hyp = assume (cert (HOLogic.mk_Trueprop (HOLogic.mk_eq (x, pat))))
+                             
+            val C_hyps = map (cert #> assume) Cs
+
+            val (relevant_cases, ineqss') = filter (fn ((_, SchemeCase {bidx=bidx', ...}), _) => bidx' = bidx) (scases_idx ~~ ineqss)
+                                            |> split_list
+                           
+            fun prove_case (cidx, SchemeCase {qs, oqnames, gs, lhs, rs, ...}) ineq_press =
+                let
+                  val case_hyps = map (assume o cert o HOLogic.mk_Trueprop o HOLogic.mk_eq) (fxs ~~ lhs)
                            
-            val cqs = map (cert o Free) qs
-            val ags = map (assume o cert) gs
-                      
-            val replace_x_ss = HOL_basic_ss addsimps [case_hyp]
-            val sih = full_simplify replace_x_ss aihyp
+                  val cqs = map (cert o Free) qs
+                  val ags = map (assume o cert) gs
+                            
+                  val replace_x_ss = HOL_basic_ss addsimps (branch_hyp :: case_hyps)
+                  val sih = full_simplify replace_x_ss aihyp
+                            
+                  fun mk_Prec (idx, Gvs, Gas, rcargs) (ineq, pres) =
+                      let
+                        val cGas = map (assume o cert) Gas
+                        val cGvs = map (cert o Free) Gvs
+                        val import = fold forall_elim (cqs @ cGvs)
+                                     #> fold Thm.elim_implies (ags @ cGas)
+                        val ipres = pres
+                                     |> forall_elim (cert (list_comb (P_of idx, rcargs)))
+                                     |> import
+                      in
+                        sih |> forall_elim (cert (inject idx rcargs))
+                            |> Thm.elim_implies (import ineq) (* Psum rcargs *)
+                            |> Conv.fconv_rule sum_prod_conv
+                            |> Conv.fconv_rule ind_rulify
+                            |> (fn th => th COMP ipres) (* P rs *)
+                            |> fold_rev (implies_intr o cprop_of) cGas
+                            |> fold_rev forall_intr cGvs
+                      end
                       
-            fun mk_Prec (Gvs, Gas, rcarg) ineq =
-                let
-                  val cGas = map (assume o cert) Gas
-                  val cGvs = map (cert o Free) Gvs
-                  val loc_ineq = ineq 
-                                   |> fold forall_elim (cqs @ cGvs)
-                                   |> fold Thm.elim_implies (ags @ cGas)
+                  val P_recs = map2 mk_Prec rs ineq_press   (*  [P rec1, P rec2, ... ]  *)
+                               
+                  val step = HOLogic.mk_Trueprop (list_comb (P, lhs))
+                             |> fold_rev (curry Logic.mk_implies o prop_of) P_recs
+                             |> fold_rev (curry Logic.mk_implies) gs
+                             |> fold_rev (mk_forall o Free) qs
+                             |> cert
+                             
+                  val Plhs_to_Pxs_conv = 
+                      foldl1 (uncurry Conv.combination_conv) 
+                      (Conv.all_conv :: map (fn ch => K (Thm.symmetric (ch RS eq_reflection))) case_hyps)
+
+                  val res = assume step
+                                   |> fold forall_elim cqs
+                                   |> fold Thm.elim_implies ags
+                                   |> fold Thm.elim_implies P_recs (* P lhs *) 
+                                   |> Conv.fconv_rule (Conv.arg_conv Plhs_to_Pxs_conv) (* P xs *)
+                                   |> fold_rev (implies_intr o cprop_of) (ags @ case_hyps)
+                                   |> fold_rev forall_intr cqs (* !!qs. Gas ==> xs = lhss ==> P xs *)
                 in
-                  sih |> forall_elim (cert rcarg)
-                      |> Thm.elim_implies loc_ineq
-                      |> fold_rev (implies_intr o cprop_of) cGas
-                      |> fold_rev forall_intr cGvs
+                  (res, (cidx, step))
                 end
-                
-            val P_recs = map2 mk_Prec rs ineqs   (*  [P rec1, P rec2, ... ]  *)
-                         
-            val step = HOLogic.mk_Trueprop (P $ lhs)
-                                           |> fold_rev (curry Logic.mk_implies o prop_of) P_recs
-                                           |> fold_rev (curry Logic.mk_implies) gs
-                                           |> fold_rev (mk_forall o Free) qs
-                                           |> cert
-                       
-            val res = assume step
-                       |> fold forall_elim cqs
-                       |> fold Thm.elim_implies ags
-                       |> fold Thm.elim_implies P_recs
-                       |> Conv.fconv_rule 
-                       (Conv.arg_conv (Conv.arg_conv (K (Thm.symmetric (case_hyp RS eq_reflection))))) 
-                       (* "P x" *)
-                       |> implies_intr (cprop_of case_hyp)
-                       |> fold_rev (implies_intr o cprop_of) ags
-                       |> fold_rev forall_intr cqs
+
+            val (cases, steps) = split_list (map2 prove_case relevant_cases ineqss')
+
+            val bstep = complete_thm
+                |> forall_elim (cert (list_comb (P, fxs)))
+                |> fold (forall_elim o cert) (fxs @ map Free ws)
+                |> fold Thm.elim_implies C_hyps             (* FIXME: optimization using rotate_prems *)
+                |> fold Thm.elim_implies cases (* P xs *)
+                |> fold_rev (implies_intr o cprop_of) C_hyps
+                |> fold_rev (forall_intr o cert o Free) ws
+
+            val Pxs = cert (HOLogic.mk_Trueprop (P_comp $ x))
+                     |> Goal.init
+                     |> (MetaSimplifier.rewrite_goals_tac (map meta (branch_hyp :: @{thm split_conv} :: @{thms sum_cases}))
+                         THEN CONVERSION ind_rulify 1)
+                     |> Seq.hd
+                     |> Thm.elim_implies bstep
+                     |> Goal.finish
+                     |> implies_intr (cprop_of branch_hyp)
+                     |> fold_rev (forall_intr o cert) fxs
           in
-            (res, step)
+            (Pxs, steps)
           end
-          
-      val (cases, steps) = split_list (map2 prove_case scases ineqss)
+
+      val (branches, steps) = split_list (map_index prove_branch (branches ~~ (complete_thms ~~ pats)))
+                              |> apsnd flat
                            
-      val istep = complete_thm 
-                |> forall_elim (cert (P $ x))
-                |> forall_elim (cert x)
-                |> fold (Thm.elim_implies) cases
+      val istep = sum_split_rule
+                |> fold (fn b => fn th => Drule.compose_single (b, 1, th)) branches
                 |> implies_intr ihyp
                 |> forall_intr (cert x) (* "!!x. (!!y<x. P y) ==> P x" *)
          
@@ -182,100 +343,60 @@
           @{thm "wf_induct_rule"}
             |> (curry op COMP) wf_thm 
             |> (curry op COMP) istep
-            |> fold_rev implies_intr steps
-            |> forall_intr (cert P)
+
+      val steps_sorted = map snd (sort (int_ord o pairself fst) steps)
     in
-      induct_rule
+      (steps_sorted, induct_rule)
     end
 
-fun mk_ind_tac ctxt facts = (ALLGOALS (Method.insert_tac facts)) THEN HEADGOAL 
+
+fun mk_ind_tac comp_tac pres_tac term_tac ctxt facts = (ALLGOALS (Method.insert_tac facts)) THEN HEADGOAL 
 (SUBGOAL (fn (t, i) =>
   let
     val (ctxt', _, cases, concl) = dest_hhf ctxt t
-                                   
-    fun get_types t = 
-        let
-          val (P, vs) = strip_comb (HOLogic.dest_Trueprop t)
-          val Ts = map fastype_of vs
-          val tupT = foldr1 HOLogic.mk_prodT Ts
-        in 
-          ((P, Ts), tupT)
-        end
-        
-    val concls = Logic.dest_conjunction_list (Logic.strip_imp_concl concl)
-    val (PTss, tupTs) = split_list (map get_types concls)
-                        
-    val n = length tupTs
-    val ST = BalancedTree.make (uncurry SumTree.mk_sumT) tupTs
-
-    val ([Psn, Rn, xn], ctxt'') = Variable.variant_fixes ["Psum", "R", "x"] ctxt'
-    val Psum = (Psn, ST --> HOLogic.boolT)
+    val scheme as IndScheme {T=ST, branches, ...} = mk_scheme' ctxt' cases concl
+(*     val _ = Output.tracing (makestring scheme)*)
+    val ([Rn,xn], ctxt'') = Variable.variant_fixes ["R","x"] ctxt'
     val R = Free (Rn, mk_relT ST)
     val x = Free (xn, ST)
-               
-    fun mk_rews (i, (P, Ts)) = 
-        let
-          val vs = map_index (fn (j,T) => Free ("x" ^ string_of_int j, T)) Ts 
-          val t = Free Psum $ SumTree.mk_inj ST n (i + 1) (foldr1 HOLogic.mk_prod vs)
-                       |> fold_rev lambda vs
-        in
-          (P, t)
-        end
-        
-    val rews = map_index mk_rews PTss
-    val thy = ProofContext.theory_of ctxt''
-    val cases' = map (Pattern.rewrite_term thy rews []) cases
-                 
-    val scheme = mk_scheme' ctxt'' cases' Psum
-
-    val cert = cterm_of thy
+    val cert = cterm_of (ProofContext.theory_of ctxt)
 
     val ineqss = mk_ineqs R scheme
-                   |> map (map (assume o cert))
-    val complete = mk_completeness ctxt scheme |> cert |> assume
+                   |> map (map (pairself (assume o cert)))
+    val complete = map (mk_completeness ctxt scheme #> cert #> assume) (0 upto (length branches - 1))
     val wf_thm = mk_wf ctxt R scheme |> cert |> assume
 
-    val indthm = mk_induct_rule ctxt'' R (Free Psum) x complete wf_thm ineqss scheme
+    val (descent, pres) = split_list (flat ineqss)
+    val newgoals = complete @ pres @ wf_thm :: descent 
 
-    fun mk_P (P, Ts) = 
+    val (steps, indthm) = mk_induct_rule ctxt'' R x complete wf_thm ineqss scheme
+
+    fun project (i, SchemeBranch {xs, ...}) =
         let
-          val avars = map_index (fn (i,T) => Var (("a", i), T)) Ts
-          val atup = foldr1 HOLogic.mk_prod avars
+          val inst = cert (SumTree.mk_inj ST (length branches) (i + 1) (foldr1 HOLogic.mk_prod (map Free xs)))
         in
-          tupled_lambda atup (list_comb (P, avars))
-        end
-          
-    val case_exp = cert (SumTree.mk_sumcases HOLogic.boolT (map mk_P PTss))
-    val acases = map (assume o cert) cases
-    val indthm' = indthm |> forall_elim case_exp
-                         |> full_simplify SumTree.sumcase_split_ss
-                         |> fold Thm.elim_implies acases
-
-    fun project (i,t) = 
-        let
-          val (P, vs) = strip_comb (HOLogic.dest_Trueprop t)
-          val inst = cert (SumTree.mk_inj ST n (i + 1) (foldr1 HOLogic.mk_prod vs))
-        in
-          indthm' |> Drule.instantiate' [] [SOME inst]
-                  |> simplify SumTree.sumcase_split_ss
+          indthm |> Drule.instantiate' [] [SOME inst]
+                 |> simplify SumTree.sumcase_split_ss
+                 |> Conv.fconv_rule ind_rulify
+(*                 |> (fn thm => (Output.tracing (makestring thm); thm))*)
         end                  
 
-    val res = Conjunction.intr_balanced (map_index project concls)
-                |> fold_rev (implies_intr o cprop_of) acases
-                |> Thm.forall_elim_vars 0
-        in
-          (fn st =>
-        Drule.compose_single (res, i, st)
-          |> fold_rev (implies_intr o cprop_of) (complete :: wf_thm :: flat ineqss)
-          |> forall_intr (cert R)
-          |> Thm.forall_elim_vars 0
-          |> Seq.single
-          )
+    val res = Conjunction.intr_balanced (map_index project branches)
+                 |> fold_rev implies_intr (map cprop_of newgoals @ steps)
+                 |> (fn thm => Thm.generalize ([], [Rn]) (Thm.maxidx_of thm + 1) thm)
+
+    val nbranches = length branches
+    val npres = length pres
+  in
+    Thm.compose_no_flatten false (res, length newgoals) i
+    THEN term_tac (i + nbranches + npres)
+    THEN (EVERY (map (TRY o pres_tac) ((i + nbranches + npres - 1) downto (i + nbranches))))
+    THEN (EVERY (map (TRY o comp_tac) ((i + nbranches - 1) downto i)))
   end))
 
 
 val setup = Method.add_methods
-  [("induct_scheme", Method.ctxt_args (Method.RAW_METHOD o mk_ind_tac),
+  [("induct_scheme", Method.ctxt_args (Method.RAW_METHOD o (fn ctxt => mk_ind_tac (K all_tac) (assume_tac APPEND' Goal.assume_rule_tac ctxt) (K all_tac) ctxt)),
     "proves an induction principle")]
 
 end
--- a/src/HOL/ex/Induction_Scheme.thy	Thu Jun 19 00:02:08 2008 +0200
+++ b/src/HOL/ex/Induction_Scheme.thy	Thu Jun 19 11:46:14 2008 +0200
@@ -44,6 +44,6 @@
   assumes "\<And>n. R n \<Longrightarrow> Q (Suc n)"
   shows "R n" "Q n"
   using assms
-by induct_scheme (pat_completeness, lexicographic_order)
+by induct_scheme (pat_completeness+, lexicographic_order)
 
 end
\ No newline at end of file