src/HOL/Tools/Function/induction_schema.ML
author wenzelm
Sun Mar 07 12:19:47 2010 +0100 (2010-03-07)
changeset 35625 9c818cab0dd0
parent 34232 36a2a3029fd3
child 36945 9bec62c10714
permissions -rw-r--r--
modernized structure Object_Logic;
     1 (*  Title:      HOL/Tools/Function/induction_schema.ML
     2     Author:     Alexander Krauss, TU Muenchen
     3 
     4 A method to prove induction schemas.
     5 *)
     6 
     7 signature INDUCTION_SCHEMA =
     8 sig
     9   val mk_ind_tac : (int -> tactic) -> (int -> tactic) -> (int -> tactic)
    10                    -> Proof.context -> thm list -> tactic
    11   val induction_schema_tac : Proof.context -> thm list -> tactic
    12   val setup : theory -> theory
    13 end
    14 
    15 
    16 structure Induction_Schema : INDUCTION_SCHEMA =
    17 struct
    18 
    19 open Function_Lib
    20 
    21 type rec_call_info = int * (string * typ) list * term list * term list
    22 
    23 datatype scheme_case = SchemeCase of
    24  {bidx : int,
    25   qs: (string * typ) list,
    26   oqnames: string list,
    27   gs: term list,
    28   lhs: term list,
    29   rs: rec_call_info list}
    30 
    31 datatype scheme_branch = SchemeBranch of
    32  {P : term,
    33   xs: (string * typ) list,
    34   ws: (string * typ) list,
    35   Cs: term list}
    36 
    37 datatype ind_scheme = IndScheme of
    38  {T: typ, (* sum of products *)
    39   branches: scheme_branch list,
    40   cases: scheme_case list}
    41 
    42 val ind_atomize = MetaSimplifier.rewrite true @{thms induct_atomize}
    43 val ind_rulify = MetaSimplifier.rewrite true @{thms induct_rulify}
    44 
    45 fun meta thm = thm RS eq_reflection
    46 
    47 val sum_prod_conv = MetaSimplifier.rewrite true
    48   (map meta (@{thm split_conv} :: @{thms sum.cases}))
    49 
    50 fun term_conv thy cv t =
    51   cv (cterm_of thy t)
    52   |> prop_of |> Logic.dest_equals |> snd
    53 
    54 fun mk_relT T = HOLogic.mk_setT (HOLogic.mk_prodT (T, T))
    55 
    56 fun dest_hhf ctxt t =
    57   let
    58     val (ctxt', vars, imp) = dest_all_all_ctx ctxt t
    59   in
    60     (ctxt', vars, Logic.strip_imp_prems imp, Logic.strip_imp_concl imp)
    61   end
    62 
    63 fun mk_scheme' ctxt cases concl =
    64   let
    65     fun mk_branch concl =
    66       let
    67         val (_, ws, Cs, _ $ Pxs) = dest_hhf ctxt concl
    68         val (P, xs) = strip_comb Pxs
    69       in
    70         SchemeBranch { P=P, xs=map dest_Free xs, ws=ws, Cs=Cs }
    71       end
    72 
    73     val (branches, cases') = (* correction *)
    74       case Logic.dest_conjunction_list concl of
    75         [conc] =>
    76         let
    77           val _ $ Pxs = Logic.strip_assums_concl conc
    78           val (P, _) = strip_comb Pxs
    79           val (cases', conds) =
    80             take_prefix (Term.exists_subterm (curry op aconv P)) cases
    81           val concl' = fold_rev (curry Logic.mk_implies) conds conc
    82         in
    83           ([mk_branch concl'], cases')
    84         end
    85       | concls => (map mk_branch concls, cases)
    86 
    87     fun mk_case premise =
    88       let
    89         val (ctxt', qs, prems, _ $ Plhs) = dest_hhf ctxt premise
    90         val (P, lhs) = strip_comb Plhs
    91 
    92         fun bidx Q =
    93           find_index (fn SchemeBranch {P=P',...} => Q aconv P') branches
    94 
    95         fun mk_rcinfo pr =
    96           let
    97             val (_, Gvs, Gas, _ $ Phyp) = dest_hhf ctxt' pr
    98             val (P', rcs) = strip_comb Phyp
    99           in
   100             (bidx P', Gvs, Gas, rcs)
   101           end
   102 
   103         fun is_pred v = exists (fn SchemeBranch {P,...} => v aconv P) branches
   104 
   105         val (gs, rcprs) =
   106           take_prefix (not o Term.exists_subterm is_pred) prems
   107       in
   108         SchemeCase {bidx=bidx P, qs=qs, oqnames=map fst qs(*FIXME*),
   109           gs=gs, lhs=lhs, rs=map mk_rcinfo rcprs}
   110       end
   111 
   112     fun PT_of (SchemeBranch { xs, ...}) =
   113       foldr1 HOLogic.mk_prodT (map snd xs)
   114 
   115     val ST = Balanced_Tree.make (uncurry SumTree.mk_sumT) (map PT_of branches)
   116   in
   117     IndScheme {T=ST, cases=map mk_case cases', branches=branches }
   118   end
   119 
   120 fun mk_completeness ctxt (IndScheme {cases, branches, ...}) bidx =
   121   let
   122     val SchemeBranch { xs, ws, Cs, ... } = nth branches bidx
   123     val relevant_cases = filter (fn SchemeCase {bidx=bidx', ...} => bidx' = bidx) cases
   124 
   125     val allqnames = fold (fn SchemeCase {qs, ...} => fold (insert (op =) o Free) qs) relevant_cases []
   126     val (Pbool :: xs') = map Free (Variable.variant_frees ctxt allqnames (("P", HOLogic.boolT) :: xs))
   127     val Cs' = map (Pattern.rewrite_term (ProofContext.theory_of ctxt) (filter_out (op aconv) (map Free xs ~~ xs')) []) Cs
   128 
   129     fun mk_case (SchemeCase {qs, oqnames, gs, lhs, ...}) =
   130       HOLogic.mk_Trueprop Pbool
   131       |> fold_rev (fn x_l => curry Logic.mk_implies (HOLogic.mk_Trueprop(HOLogic.mk_eq x_l)))
   132            (xs' ~~ lhs)
   133       |> fold_rev (curry Logic.mk_implies) gs
   134       |> fold_rev mk_forall_rename (oqnames ~~ map Free qs)
   135   in
   136     HOLogic.mk_Trueprop Pbool
   137     |> fold_rev (curry Logic.mk_implies o mk_case) relevant_cases
   138     |> fold_rev (curry Logic.mk_implies) Cs'
   139     |> fold_rev (Logic.all o Free) ws
   140     |> fold_rev mk_forall_rename (map fst xs ~~ xs')
   141     |> mk_forall_rename ("P", Pbool)
   142   end
   143 
   144 fun mk_wf R (IndScheme {T, ...}) =
   145   HOLogic.Trueprop $ (Const (@{const_name wf}, mk_relT T --> HOLogic.boolT) $ R)
   146 
   147 fun mk_ineqs R (IndScheme {T, cases, branches}) =
   148   let
   149     fun inject i ts =
   150        SumTree.mk_inj T (length branches) (i + 1) (foldr1 HOLogic.mk_prod ts)
   151 
   152     val thesis = Free ("thesis", HOLogic.boolT) (* FIXME *)
   153 
   154     fun mk_pres bdx args =
   155       let
   156         val SchemeBranch { xs, ws, Cs, ... } = nth branches bdx
   157         fun replace (x, v) t = betapply (lambda (Free x) t, v)
   158         val Cs' = map (fold replace (xs ~~ args)) Cs
   159         val cse =
   160           HOLogic.mk_Trueprop thesis
   161           |> fold_rev (curry Logic.mk_implies) Cs'
   162           |> fold_rev (Logic.all o Free) ws
   163       in
   164         Logic.mk_implies (cse, HOLogic.mk_Trueprop thesis)
   165       end
   166 
   167     fun f (SchemeCase {bidx, qs, oqnames, gs, lhs, rs, ...}) =
   168       let
   169         fun g (bidx', Gvs, Gas, rcarg) =
   170           let val export =
   171             fold_rev (curry Logic.mk_implies) Gas
   172             #> fold_rev (curry Logic.mk_implies) gs
   173             #> fold_rev (Logic.all o Free) Gvs
   174             #> fold_rev mk_forall_rename (oqnames ~~ map Free qs)
   175           in
   176             (HOLogic.mk_mem (HOLogic.mk_prod (inject bidx' rcarg, inject bidx lhs), R)
   177              |> HOLogic.mk_Trueprop
   178              |> export,
   179              mk_pres bidx' rcarg
   180              |> export
   181              |> Logic.all thesis)
   182           end
   183       in
   184         map g rs
   185       end
   186   in
   187     map f cases
   188   end
   189 
   190 
   191 fun mk_ind_goal thy branches =
   192   let
   193     fun brnch (SchemeBranch { P, xs, ws, Cs, ... }) =
   194       HOLogic.mk_Trueprop (list_comb (P, map Free xs))
   195       |> fold_rev (curry Logic.mk_implies) Cs
   196       |> fold_rev (Logic.all o Free) ws
   197       |> term_conv thy ind_atomize
   198       |> Object_Logic.drop_judgment thy
   199       |> tupled_lambda (foldr1 HOLogic.mk_prod (map Free xs))
   200   in
   201     SumTree.mk_sumcases HOLogic.boolT (map brnch branches)
   202   end
   203 
   204 fun mk_induct_rule ctxt R x complete_thms wf_thm ineqss
   205   (IndScheme {T, cases=scases, branches}) =
   206   let
   207     val n = length branches
   208     val scases_idx = map_index I scases
   209 
   210     fun inject i ts =
   211       SumTree.mk_inj T n (i + 1) (foldr1 HOLogic.mk_prod ts)
   212     val P_of = nth (map (fn (SchemeBranch { P, ... }) => P) branches)
   213 
   214     val thy = ProofContext.theory_of ctxt
   215     val cert = cterm_of thy
   216 
   217     val P_comp = mk_ind_goal thy branches
   218 
   219     (* Inductive Hypothesis: !!z. (z,x):R ==> P z *)
   220     val ihyp = Term.all T $ Abs ("z", T,
   221       Logic.mk_implies
   222         (HOLogic.mk_Trueprop (
   223           Const ("op :", HOLogic.mk_prodT (T, T) --> mk_relT T --> HOLogic.boolT) 
   224           $ (HOLogic.pair_const T T $ Bound 0 $ x)
   225           $ R),
   226          HOLogic.mk_Trueprop (P_comp $ Bound 0)))
   227       |> cert
   228 
   229     val aihyp = assume ihyp
   230 
   231     (* Rule for case splitting along the sum types *)
   232     val xss = map (fn (SchemeBranch { xs, ... }) => map Free xs) branches
   233     val pats = map_index (uncurry inject) xss
   234     val sum_split_rule =
   235       Pat_Completeness.prove_completeness thy [x] (P_comp $ x) xss (map single pats)
   236 
   237     fun prove_branch (bidx, (SchemeBranch { P, xs, ws, Cs, ... }, (complete_thm, pat))) =
   238       let
   239         val fxs = map Free xs
   240         val branch_hyp = assume (cert (HOLogic.mk_Trueprop (HOLogic.mk_eq (x, pat))))
   241 
   242         val C_hyps = map (cert #> assume) Cs
   243 
   244         val (relevant_cases, ineqss') =
   245           (scases_idx ~~ ineqss)
   246           |> filter (fn ((_, SchemeCase {bidx=bidx', ...}), _) => bidx' = bidx)
   247           |> split_list
   248 
   249         fun prove_case (cidx, SchemeCase {qs, gs, lhs, rs, ...}) ineq_press =
   250           let
   251             val case_hyps = map (assume o cert o HOLogic.mk_Trueprop o HOLogic.mk_eq) (fxs ~~ lhs)
   252 
   253             val cqs = map (cert o Free) qs
   254             val ags = map (assume o cert) gs
   255 
   256             val replace_x_ss = HOL_basic_ss addsimps (branch_hyp :: case_hyps)
   257             val sih = full_simplify replace_x_ss aihyp
   258 
   259             fun mk_Prec (idx, Gvs, Gas, rcargs) (ineq, pres) =
   260               let
   261                 val cGas = map (assume o cert) Gas
   262                 val cGvs = map (cert o Free) Gvs
   263                 val import = fold forall_elim (cqs @ cGvs)
   264                   #> fold Thm.elim_implies (ags @ cGas)
   265                 val ipres = pres
   266                   |> forall_elim (cert (list_comb (P_of idx, rcargs)))
   267                   |> import
   268               in
   269                 sih
   270                 |> forall_elim (cert (inject idx rcargs))
   271                 |> Thm.elim_implies (import ineq) (* Psum rcargs *)
   272                 |> Conv.fconv_rule sum_prod_conv
   273                 |> Conv.fconv_rule ind_rulify
   274                 |> (fn th => th COMP ipres) (* P rs *)
   275                 |> fold_rev (implies_intr o cprop_of) cGas
   276                 |> fold_rev forall_intr cGvs
   277               end
   278 
   279             val P_recs = map2 mk_Prec rs ineq_press   (*  [P rec1, P rec2, ... ]  *)
   280 
   281             val step = HOLogic.mk_Trueprop (list_comb (P, lhs))
   282               |> fold_rev (curry Logic.mk_implies o prop_of) P_recs
   283               |> fold_rev (curry Logic.mk_implies) gs
   284               |> fold_rev (Logic.all o Free) qs
   285               |> cert
   286 
   287             val Plhs_to_Pxs_conv =
   288               foldl1 (uncurry Conv.combination_conv)
   289                 (Conv.all_conv :: map (fn ch => K (Thm.symmetric (ch RS eq_reflection))) case_hyps)
   290 
   291             val res = assume step
   292               |> fold forall_elim cqs
   293               |> fold Thm.elim_implies ags
   294               |> fold Thm.elim_implies P_recs (* P lhs *)
   295               |> Conv.fconv_rule (Conv.arg_conv Plhs_to_Pxs_conv) (* P xs *)
   296               |> fold_rev (implies_intr o cprop_of) (ags @ case_hyps)
   297               |> fold_rev forall_intr cqs (* !!qs. Gas ==> xs = lhss ==> P xs *)
   298           in
   299             (res, (cidx, step))
   300           end
   301 
   302         val (cases, steps) = split_list (map2 prove_case relevant_cases ineqss')
   303 
   304         val bstep = complete_thm
   305           |> forall_elim (cert (list_comb (P, fxs)))
   306           |> fold (forall_elim o cert) (fxs @ map Free ws)
   307           |> fold Thm.elim_implies C_hyps
   308           |> fold Thm.elim_implies cases (* P xs *)
   309           |> fold_rev (implies_intr o cprop_of) C_hyps
   310           |> fold_rev (forall_intr o cert o Free) ws
   311 
   312         val Pxs = cert (HOLogic.mk_Trueprop (P_comp $ x))
   313           |> Goal.init
   314           |> (MetaSimplifier.rewrite_goals_tac (map meta (branch_hyp :: @{thm split_conv} :: @{thms sum.cases}))
   315               THEN CONVERSION ind_rulify 1)
   316           |> Seq.hd
   317           |> Thm.elim_implies (Conv.fconv_rule Drule.beta_eta_conversion bstep)
   318           |> Goal.finish ctxt
   319           |> implies_intr (cprop_of branch_hyp)
   320           |> fold_rev (forall_intr o cert) fxs
   321       in
   322         (Pxs, steps)
   323       end
   324 
   325     val (branches, steps) =
   326       map_index prove_branch (branches ~~ (complete_thms ~~ pats))
   327       |> split_list |> apsnd flat
   328 
   329     val istep = sum_split_rule
   330       |> fold (fn b => fn th => Drule.compose_single (b, 1, th)) branches
   331       |> implies_intr ihyp
   332       |> forall_intr (cert x) (* "!!x. (!!y<x. P y) ==> P x" *)
   333 
   334     val induct_rule =
   335       @{thm "wf_induct_rule"}
   336       |> (curry op COMP) wf_thm
   337       |> (curry op COMP) istep
   338 
   339     val steps_sorted = map snd (sort (int_ord o pairself fst) steps)
   340   in
   341     (steps_sorted, induct_rule)
   342   end
   343 
   344 
   345 fun mk_ind_tac comp_tac pres_tac term_tac ctxt facts =
   346   (ALLGOALS (Method.insert_tac facts)) THEN HEADGOAL (SUBGOAL (fn (t, i) =>
   347   let
   348     val (ctxt', _, cases, concl) = dest_hhf ctxt t
   349     val scheme as IndScheme {T=ST, branches, ...} = mk_scheme' ctxt' cases concl
   350     val ([Rn,xn], ctxt'') = Variable.variant_fixes ["R","x"] ctxt'
   351     val R = Free (Rn, mk_relT ST)
   352     val x = Free (xn, ST)
   353     val cert = cterm_of (ProofContext.theory_of ctxt)
   354 
   355     val ineqss = mk_ineqs R scheme
   356       |> map (map (pairself (assume o cert)))
   357     val complete =
   358       map_range (mk_completeness ctxt scheme #> cert #> assume) (length branches)
   359     val wf_thm = mk_wf R scheme |> cert |> assume
   360 
   361     val (descent, pres) = split_list (flat ineqss)
   362     val newgoals = complete @ pres @ wf_thm :: descent
   363 
   364     val (steps, indthm) =
   365       mk_induct_rule ctxt'' R x complete wf_thm ineqss scheme
   366 
   367     fun project (i, SchemeBranch {xs, ...}) =
   368       let
   369         val inst = (foldr1 HOLogic.mk_prod (map Free xs))
   370           |> SumTree.mk_inj ST (length branches) (i + 1)
   371           |> cert
   372       in
   373         indthm
   374         |> Drule.instantiate' [] [SOME inst]
   375         |> simplify SumTree.sumcase_split_ss
   376         |> Conv.fconv_rule ind_rulify
   377       end
   378 
   379     val res = Conjunction.intr_balanced (map_index project branches)
   380       |> fold_rev implies_intr (map cprop_of newgoals @ steps)
   381       |> Drule.generalize ([], [Rn])
   382 
   383     val nbranches = length branches
   384     val npres = length pres
   385   in
   386     Thm.compose_no_flatten false (res, length newgoals) i
   387     THEN term_tac (i + nbranches + npres)
   388     THEN (EVERY (map (TRY o pres_tac) ((i + nbranches + npres - 1) downto (i + nbranches))))
   389     THEN (EVERY (map (TRY o comp_tac) ((i + nbranches - 1) downto i)))
   390   end))
   391 
   392 
   393 fun induction_schema_tac ctxt =
   394   mk_ind_tac (K all_tac) (assume_tac APPEND' Goal.assume_rule_tac ctxt) (K all_tac) ctxt;
   395 
   396 val setup =
   397   Method.setup @{binding induction_schema} (Scan.succeed (RAW_METHOD o induction_schema_tac))
   398     "proves an induction principle"
   399 
   400 end