src/HOL/Tools/Function/induction_schema.ML
author krauss
Fri Nov 06 14:42:42 2009 +0100 (2009-11-06)
changeset 33471 5aef13872723
child 33697 7d6793ce0a26
permissions -rw-r--r--
renamed method induct_scheme to induction_schema
     1 (*  Title:      HOL/Tools/Function/induction_schema.ML
     2     Author:     Alexander Krauss, TU Muenchen
     3 
     4 A method to prove induction schemas.
     5 *)
     6 
     7 signature INDUCTION_SCHEMA =
     8 sig
     9   val mk_ind_tac : (int -> tactic) -> (int -> tactic) -> (int -> tactic)
    10                    -> Proof.context -> thm list -> tactic
    11   val induction_schema_tac : Proof.context -> thm list -> tactic
    12   val setup : theory -> theory
    13 end
    14 
    15 
    16 structure Induction_Schema : INDUCTION_SCHEMA =
    17 struct
    18 
    19 open Function_Lib
    20 
    21 
    22 type rec_call_info = int * (string * typ) list * term list * term list
    23 
    24 datatype scheme_case =
    25   SchemeCase of
    26   {
    27    bidx : int,
    28    qs: (string * typ) list,
    29    oqnames: string list,
    30    gs: term list,
    31    lhs: term list,
    32    rs: rec_call_info list
    33   }
    34 
    35 datatype scheme_branch = 
    36   SchemeBranch of
    37   {
    38    P : term,
    39    xs: (string * typ) list,
    40    ws: (string * typ) list,
    41    Cs: term list
    42   }
    43 
    44 datatype ind_scheme =
    45   IndScheme of
    46   {
    47    T: typ, (* sum of products *)
    48    branches: scheme_branch list,
    49    cases: scheme_case list
    50   }
    51 
    52 val ind_atomize = MetaSimplifier.rewrite true @{thms induct_atomize}
    53 val ind_rulify = MetaSimplifier.rewrite true @{thms induct_rulify}
    54 
    55 fun meta thm = thm RS eq_reflection
    56 
    57 val sum_prod_conv = MetaSimplifier.rewrite true 
    58                     (map meta (@{thm split_conv} :: @{thms sum.cases}))
    59 
    60 fun term_conv thy cv t = 
    61     cv (cterm_of thy t)
    62     |> prop_of |> Logic.dest_equals |> snd
    63 
    64 fun mk_relT T = HOLogic.mk_setT (HOLogic.mk_prodT (T, T))
    65 
    66 fun dest_hhf ctxt t = 
    67     let 
    68       val (ctxt', vars, imp) = dest_all_all_ctx ctxt t
    69     in
    70       (ctxt', vars, Logic.strip_imp_prems imp, Logic.strip_imp_concl imp)
    71     end
    72 
    73 
    74 fun mk_scheme' ctxt cases concl =
    75     let
    76       fun mk_branch concl =
    77           let
    78             val (ctxt', ws, Cs, _ $ Pxs) = dest_hhf ctxt concl
    79             val (P, xs) = strip_comb Pxs
    80           in
    81             SchemeBranch { P=P, xs=map dest_Free xs, ws=ws, Cs=Cs }
    82           end
    83 
    84       val (branches, cases') = (* correction *)
    85           case Logic.dest_conjunction_list concl of
    86             [conc] => 
    87             let 
    88               val _ $ Pxs = Logic.strip_assums_concl conc
    89               val (P, _) = strip_comb Pxs
    90               val (cases', conds) = take_prefix (Term.exists_subterm (curry op aconv P)) cases
    91               val concl' = fold_rev (curry Logic.mk_implies) conds conc
    92             in
    93               ([mk_branch concl'], cases')
    94             end
    95           | concls => (map mk_branch concls, cases)
    96 
    97       fun mk_case premise =
    98           let
    99             val (ctxt', qs, prems, _ $ Plhs) = dest_hhf ctxt premise
   100             val (P, lhs) = strip_comb Plhs
   101                                 
   102             fun bidx Q = find_index (fn SchemeBranch {P=P',...} => Q aconv P') branches
   103 
   104             fun mk_rcinfo pr =
   105                 let
   106                   val (ctxt'', Gvs, Gas, _ $ Phyp) = dest_hhf ctxt' pr
   107                   val (P', rcs) = strip_comb Phyp
   108                 in
   109                   (bidx P', Gvs, Gas, rcs)
   110                 end
   111                 
   112             fun is_pred v = exists (fn SchemeBranch {P,...} => v aconv P) branches
   113 
   114             val (gs, rcprs) = 
   115                 take_prefix (not o Term.exists_subterm is_pred) prems
   116           in
   117             SchemeCase {bidx=bidx P, qs=qs, oqnames=map fst qs(*FIXME*), gs=gs, lhs=lhs, rs=map mk_rcinfo rcprs}
   118           end
   119 
   120       fun PT_of (SchemeBranch { xs, ...}) =
   121             foldr1 HOLogic.mk_prodT (map snd xs)
   122 
   123       val ST = Balanced_Tree.make (uncurry SumTree.mk_sumT) (map PT_of branches)
   124     in
   125       IndScheme {T=ST, cases=map mk_case cases', branches=branches }
   126     end
   127 
   128 
   129 
   130 fun mk_completeness ctxt (IndScheme {cases, branches, ...}) bidx =
   131     let
   132       val SchemeBranch { xs, ws, Cs, ... } = nth branches bidx
   133       val relevant_cases = filter (fn SchemeCase {bidx=bidx', ...} => bidx' = bidx) cases
   134 
   135       val allqnames = fold (fn SchemeCase {qs, ...} => fold (insert (op =) o Free) qs) relevant_cases []
   136       val (Pbool :: xs') = map Free (Variable.variant_frees ctxt allqnames (("P", HOLogic.boolT) :: xs))
   137       val Cs' = map (Pattern.rewrite_term (ProofContext.theory_of ctxt) (filter_out (op aconv) (map Free xs ~~ xs')) []) Cs
   138                        
   139       fun mk_case (SchemeCase {qs, oqnames, gs, lhs, ...}) =
   140           HOLogic.mk_Trueprop Pbool
   141                      |> fold_rev (fn x_l => curry Logic.mk_implies (HOLogic.mk_Trueprop(HOLogic.mk_eq x_l)))
   142                                  (xs' ~~ lhs)
   143                      |> fold_rev (curry Logic.mk_implies) gs
   144                      |> fold_rev mk_forall_rename (oqnames ~~ map Free qs)
   145     in
   146       HOLogic.mk_Trueprop Pbool
   147        |> fold_rev (curry Logic.mk_implies o mk_case) relevant_cases
   148        |> fold_rev (curry Logic.mk_implies) Cs'
   149        |> fold_rev (Logic.all o Free) ws
   150        |> fold_rev mk_forall_rename (map fst xs ~~ xs')
   151        |> mk_forall_rename ("P", Pbool)
   152     end
   153 
   154 fun mk_wf ctxt R (IndScheme {T, ...}) =
   155     HOLogic.Trueprop $ (Const (@{const_name wf}, mk_relT T --> HOLogic.boolT) $ R)
   156 
   157 fun mk_ineqs R (IndScheme {T, cases, branches}) =
   158     let
   159       fun inject i ts =
   160           SumTree.mk_inj T (length branches) (i + 1) (foldr1 HOLogic.mk_prod ts)
   161 
   162       val thesis = Free ("thesis", HOLogic.boolT) (* FIXME *)
   163 
   164       fun mk_pres bdx args = 
   165           let
   166             val SchemeBranch { xs, ws, Cs, ... } = nth branches bdx
   167             fun replace (x, v) t = betapply (lambda (Free x) t, v)
   168             val Cs' = map (fold replace (xs ~~ args)) Cs
   169             val cse = 
   170                 HOLogic.mk_Trueprop thesis
   171                 |> fold_rev (curry Logic.mk_implies) Cs'
   172                 |> fold_rev (Logic.all o Free) ws
   173           in
   174             Logic.mk_implies (cse, HOLogic.mk_Trueprop thesis)
   175           end
   176 
   177       fun f (SchemeCase {bidx, qs, oqnames, gs, lhs, rs, ...}) = 
   178           let
   179             fun g (bidx', Gvs, Gas, rcarg) =
   180                 let val export = 
   181                          fold_rev (curry Logic.mk_implies) Gas
   182                          #> fold_rev (curry Logic.mk_implies) gs
   183                          #> fold_rev (Logic.all o Free) Gvs
   184                          #> fold_rev mk_forall_rename (oqnames ~~ map Free qs)
   185                 in
   186                 (HOLogic.mk_mem (HOLogic.mk_prod (inject bidx' rcarg, inject bidx lhs), R)
   187                  |> HOLogic.mk_Trueprop
   188                  |> export,
   189                  mk_pres bidx' rcarg
   190                  |> export
   191                  |> Logic.all thesis)
   192                 end
   193           in
   194             map g rs
   195           end
   196     in
   197       map f cases
   198     end
   199 
   200 
   201 fun mk_hol_imp a b = HOLogic.imp $ a $ b
   202 
   203 fun mk_ind_goal thy branches =
   204     let
   205       fun brnch (SchemeBranch { P, xs, ws, Cs, ... }) =
   206           HOLogic.mk_Trueprop (list_comb (P, map Free xs))
   207           |> fold_rev (curry Logic.mk_implies) Cs
   208           |> fold_rev (Logic.all o Free) ws
   209           |> term_conv thy ind_atomize
   210           |> ObjectLogic.drop_judgment thy
   211           |> tupled_lambda (foldr1 HOLogic.mk_prod (map Free xs))
   212     in
   213       SumTree.mk_sumcases HOLogic.boolT (map brnch branches)
   214     end
   215 
   216 
   217 fun mk_induct_rule ctxt R x complete_thms wf_thm ineqss (IndScheme {T, cases=scases, branches}) =
   218     let
   219       val n = length branches
   220 
   221       val scases_idx = map_index I scases
   222 
   223       fun inject i ts =
   224           SumTree.mk_inj T n (i + 1) (foldr1 HOLogic.mk_prod ts)
   225       val P_of = nth (map (fn (SchemeBranch { P, ... }) => P) branches)
   226 
   227       val thy = ProofContext.theory_of ctxt
   228       val cert = cterm_of thy 
   229 
   230       val P_comp = mk_ind_goal thy branches
   231 
   232       (* Inductive Hypothesis: !!z. (z,x):R ==> P z *)
   233       val ihyp = Term.all T $ Abs ("z", T, 
   234                Logic.mk_implies
   235                  (HOLogic.mk_Trueprop (
   236                   Const ("op :", HOLogic.mk_prodT (T, T) --> mk_relT T --> HOLogic.boolT) 
   237                     $ (HOLogic.pair_const T T $ Bound 0 $ x) 
   238                     $ R),
   239                    HOLogic.mk_Trueprop (P_comp $ Bound 0)))
   240            |> cert
   241 
   242       val aihyp = assume ihyp
   243 
   244      (* Rule for case splitting along the sum types *)
   245       val xss = map (fn (SchemeBranch { xs, ... }) => map Free xs) branches
   246       val pats = map_index (uncurry inject) xss
   247       val sum_split_rule = Pat_Completeness.prove_completeness thy [x] (P_comp $ x) xss (map single pats)
   248 
   249       fun prove_branch (bidx, (SchemeBranch { P, xs, ws, Cs, ... }, (complete_thm, pat))) =
   250           let
   251             val fxs = map Free xs
   252             val branch_hyp = assume (cert (HOLogic.mk_Trueprop (HOLogic.mk_eq (x, pat))))
   253                              
   254             val C_hyps = map (cert #> assume) Cs
   255 
   256             val (relevant_cases, ineqss') = filter (fn ((_, SchemeCase {bidx=bidx', ...}), _) => bidx' = bidx) (scases_idx ~~ ineqss)
   257                                             |> split_list
   258                            
   259             fun prove_case (cidx, SchemeCase {qs, oqnames, gs, lhs, rs, ...}) ineq_press =
   260                 let
   261                   val case_hyps = map (assume o cert o HOLogic.mk_Trueprop o HOLogic.mk_eq) (fxs ~~ lhs)
   262                            
   263                   val cqs = map (cert o Free) qs
   264                   val ags = map (assume o cert) gs
   265                             
   266                   val replace_x_ss = HOL_basic_ss addsimps (branch_hyp :: case_hyps)
   267                   val sih = full_simplify replace_x_ss aihyp
   268                             
   269                   fun mk_Prec (idx, Gvs, Gas, rcargs) (ineq, pres) =
   270                       let
   271                         val cGas = map (assume o cert) Gas
   272                         val cGvs = map (cert o Free) Gvs
   273                         val import = fold forall_elim (cqs @ cGvs)
   274                                      #> fold Thm.elim_implies (ags @ cGas)
   275                         val ipres = pres
   276                                      |> forall_elim (cert (list_comb (P_of idx, rcargs)))
   277                                      |> import
   278                       in
   279                         sih |> forall_elim (cert (inject idx rcargs))
   280                             |> Thm.elim_implies (import ineq) (* Psum rcargs *)
   281                             |> Conv.fconv_rule sum_prod_conv
   282                             |> Conv.fconv_rule ind_rulify
   283                             |> (fn th => th COMP ipres) (* P rs *)
   284                             |> fold_rev (implies_intr o cprop_of) cGas
   285                             |> fold_rev forall_intr cGvs
   286                       end
   287                       
   288                   val P_recs = map2 mk_Prec rs ineq_press   (*  [P rec1, P rec2, ... ]  *)
   289                                
   290                   val step = HOLogic.mk_Trueprop (list_comb (P, lhs))
   291                              |> fold_rev (curry Logic.mk_implies o prop_of) P_recs
   292                              |> fold_rev (curry Logic.mk_implies) gs
   293                              |> fold_rev (Logic.all o Free) qs
   294                              |> cert
   295                              
   296                   val Plhs_to_Pxs_conv = 
   297                       foldl1 (uncurry Conv.combination_conv) 
   298                       (Conv.all_conv :: map (fn ch => K (Thm.symmetric (ch RS eq_reflection))) case_hyps)
   299 
   300                   val res = assume step
   301                                    |> fold forall_elim cqs
   302                                    |> fold Thm.elim_implies ags
   303                                    |> fold Thm.elim_implies P_recs (* P lhs *) 
   304                                    |> Conv.fconv_rule (Conv.arg_conv Plhs_to_Pxs_conv) (* P xs *)
   305                                    |> fold_rev (implies_intr o cprop_of) (ags @ case_hyps)
   306                                    |> fold_rev forall_intr cqs (* !!qs. Gas ==> xs = lhss ==> P xs *)
   307                 in
   308                   (res, (cidx, step))
   309                 end
   310 
   311             val (cases, steps) = split_list (map2 prove_case relevant_cases ineqss')
   312 
   313             val bstep = complete_thm
   314                 |> forall_elim (cert (list_comb (P, fxs)))
   315                 |> fold (forall_elim o cert) (fxs @ map Free ws)
   316                 |> fold Thm.elim_implies C_hyps             (* FIXME: optimization using rotate_prems *)
   317                 |> fold Thm.elim_implies cases (* P xs *)
   318                 |> fold_rev (implies_intr o cprop_of) C_hyps
   319                 |> fold_rev (forall_intr o cert o Free) ws
   320 
   321             val Pxs = cert (HOLogic.mk_Trueprop (P_comp $ x))
   322                      |> Goal.init
   323                      |> (MetaSimplifier.rewrite_goals_tac (map meta (branch_hyp :: @{thm split_conv} :: @{thms sum.cases}))
   324                          THEN CONVERSION ind_rulify 1)
   325                      |> Seq.hd
   326                      |> Thm.elim_implies (Conv.fconv_rule Drule.beta_eta_conversion bstep)
   327                      |> Goal.finish ctxt
   328                      |> implies_intr (cprop_of branch_hyp)
   329                      |> fold_rev (forall_intr o cert) fxs
   330           in
   331             (Pxs, steps)
   332           end
   333 
   334       val (branches, steps) = split_list (map_index prove_branch (branches ~~ (complete_thms ~~ pats)))
   335                               |> apsnd flat
   336                            
   337       val istep = sum_split_rule
   338                 |> fold (fn b => fn th => Drule.compose_single (b, 1, th)) branches
   339                 |> implies_intr ihyp
   340                 |> forall_intr (cert x) (* "!!x. (!!y<x. P y) ==> P x" *)
   341          
   342       val induct_rule =
   343           @{thm "wf_induct_rule"}
   344             |> (curry op COMP) wf_thm 
   345             |> (curry op COMP) istep
   346 
   347       val steps_sorted = map snd (sort (int_ord o pairself fst) steps)
   348     in
   349       (steps_sorted, induct_rule)
   350     end
   351 
   352 
   353 fun mk_ind_tac comp_tac pres_tac term_tac ctxt facts = (ALLGOALS (Method.insert_tac facts)) THEN HEADGOAL 
   354 (SUBGOAL (fn (t, i) =>
   355   let
   356     val (ctxt', _, cases, concl) = dest_hhf ctxt t
   357     val scheme as IndScheme {T=ST, branches, ...} = mk_scheme' ctxt' cases concl
   358 (*     val _ = tracing (makestring scheme)*)
   359     val ([Rn,xn], ctxt'') = Variable.variant_fixes ["R","x"] ctxt'
   360     val R = Free (Rn, mk_relT ST)
   361     val x = Free (xn, ST)
   362     val cert = cterm_of (ProofContext.theory_of ctxt)
   363 
   364     val ineqss = mk_ineqs R scheme
   365                    |> map (map (pairself (assume o cert)))
   366     val complete = map_range (mk_completeness ctxt scheme #> cert #> assume) (length branches)
   367     val wf_thm = mk_wf ctxt R scheme |> cert |> assume
   368 
   369     val (descent, pres) = split_list (flat ineqss)
   370     val newgoals = complete @ pres @ wf_thm :: descent 
   371 
   372     val (steps, indthm) = mk_induct_rule ctxt'' R x complete wf_thm ineqss scheme
   373 
   374     fun project (i, SchemeBranch {xs, ...}) =
   375         let
   376           val inst = cert (SumTree.mk_inj ST (length branches) (i + 1) (foldr1 HOLogic.mk_prod (map Free xs)))
   377         in
   378           indthm |> Drule.instantiate' [] [SOME inst]
   379                  |> simplify SumTree.sumcase_split_ss
   380                  |> Conv.fconv_rule ind_rulify
   381 (*                 |> (fn thm => (tracing (makestring thm); thm))*)
   382         end                  
   383 
   384     val res = Conjunction.intr_balanced (map_index project branches)
   385                  |> fold_rev implies_intr (map cprop_of newgoals @ steps)
   386                  |> (fn thm => Thm.generalize ([], [Rn]) (Thm.maxidx_of thm + 1) thm)
   387 
   388     val nbranches = length branches
   389     val npres = length pres
   390   in
   391     Thm.compose_no_flatten false (res, length newgoals) i
   392     THEN term_tac (i + nbranches + npres)
   393     THEN (EVERY (map (TRY o pres_tac) ((i + nbranches + npres - 1) downto (i + nbranches))))
   394     THEN (EVERY (map (TRY o comp_tac) ((i + nbranches - 1) downto i)))
   395   end))
   396 
   397 
   398 fun induction_schema_tac ctxt =
   399   mk_ind_tac (K all_tac) (assume_tac APPEND' Goal.assume_rule_tac ctxt) (K all_tac) ctxt;
   400 
   401 val setup =
   402   Method.setup @{binding induction_schema} (Scan.succeed (RAW_METHOD o induction_schema_tac))
   403     "proves an induction principle"
   404 
   405 end