src/HOL/Tools/Function/induction_schema.ML
author blanchet
Fri Feb 21 00:09:56 2014 +0100 (2014-02-21)
changeset 55642 63beb38e9258
parent 54742 7a86358a3c0b
child 55968 94242fa87638
permissions -rw-r--r--
adapted to renaming of datatype 'cases' and 'recs' to 'case' and 'rec'
     1 (*  Title:      HOL/Tools/Function/induction_schema.ML
     2     Author:     Alexander Krauss, TU Muenchen
     3 
     4 A method to prove induction schemas.
     5 *)
     6 
     7 signature INDUCTION_SCHEMA =
     8 sig
     9   val mk_ind_tac : (int -> tactic) -> (int -> tactic) -> (int -> tactic)
    10                    -> Proof.context -> thm list -> tactic
    11   val induction_schema_tac : Proof.context -> thm list -> tactic
    12 end
    13 
    14 
    15 structure Induction_Schema : INDUCTION_SCHEMA =
    16 struct
    17 
    18 open Function_Lib
    19 
    20 type rec_call_info = int * (string * typ) list * term list * term list
    21 
    22 datatype scheme_case = SchemeCase of
    23  {bidx : int,
    24   qs: (string * typ) list,
    25   oqnames: string list,
    26   gs: term list,
    27   lhs: term list,
    28   rs: rec_call_info list}
    29 
    30 datatype scheme_branch = SchemeBranch of
    31  {P : term,
    32   xs: (string * typ) list,
    33   ws: (string * typ) list,
    34   Cs: term list}
    35 
    36 datatype ind_scheme = IndScheme of
    37  {T: typ, (* sum of products *)
    38   branches: scheme_branch list,
    39   cases: scheme_case list}
    40 
    41 fun ind_atomize ctxt = Raw_Simplifier.rewrite ctxt true @{thms induct_atomize}
    42 fun ind_rulify ctxt = Raw_Simplifier.rewrite ctxt true @{thms induct_rulify}
    43 
    44 fun meta thm = thm RS eq_reflection
    45 
    46 fun sum_prod_conv ctxt = Raw_Simplifier.rewrite ctxt true
    47   (map meta (@{thm split_conv} :: @{thms sum.case}))
    48 
    49 fun term_conv thy cv t =
    50   cv (cterm_of thy t)
    51   |> prop_of |> Logic.dest_equals |> snd
    52 
    53 fun mk_relT T = HOLogic.mk_setT (HOLogic.mk_prodT (T, T))
    54 
    55 fun dest_hhf ctxt t =
    56   let
    57     val ((params, imp), ctxt') = Variable.focus t ctxt
    58   in
    59     (ctxt', map #2 params, Logic.strip_imp_prems imp, Logic.strip_imp_concl imp)
    60   end
    61 
    62 fun mk_scheme' ctxt cases concl =
    63   let
    64     fun mk_branch concl =
    65       let
    66         val (_, ws, Cs, _ $ Pxs) = dest_hhf ctxt concl
    67         val (P, xs) = strip_comb Pxs
    68       in
    69         SchemeBranch { P=P, xs=map dest_Free xs, ws=ws, Cs=Cs }
    70       end
    71 
    72     val (branches, cases') = (* correction *)
    73       case Logic.dest_conjunctions concl of
    74         [conc] =>
    75         let
    76           val _ $ Pxs = Logic.strip_assums_concl conc
    77           val (P, _) = strip_comb Pxs
    78           val (cases', conds) =
    79             take_prefix (Term.exists_subterm (curry op aconv P)) cases
    80           val concl' = fold_rev (curry Logic.mk_implies) conds conc
    81         in
    82           ([mk_branch concl'], cases')
    83         end
    84       | concls => (map mk_branch concls, cases)
    85 
    86     fun mk_case premise =
    87       let
    88         val (ctxt', qs, prems, _ $ Plhs) = dest_hhf ctxt premise
    89         val (P, lhs) = strip_comb Plhs
    90 
    91         fun bidx Q =
    92           find_index (fn SchemeBranch {P=P',...} => Q aconv P') branches
    93 
    94         fun mk_rcinfo pr =
    95           let
    96             val (_, Gvs, Gas, _ $ Phyp) = dest_hhf ctxt' pr
    97             val (P', rcs) = strip_comb Phyp
    98           in
    99             (bidx P', Gvs, Gas, rcs)
   100           end
   101 
   102         fun is_pred v = exists (fn SchemeBranch {P,...} => v aconv P) branches
   103 
   104         val (gs, rcprs) =
   105           take_prefix (not o Term.exists_subterm is_pred) prems
   106       in
   107         SchemeCase {bidx=bidx P, qs=qs, oqnames=map fst qs(*FIXME*),
   108           gs=gs, lhs=lhs, rs=map mk_rcinfo rcprs}
   109       end
   110 
   111     fun PT_of (SchemeBranch { xs, ...}) =
   112       foldr1 HOLogic.mk_prodT (map snd xs)
   113 
   114     val ST = Balanced_Tree.make (uncurry SumTree.mk_sumT) (map PT_of branches)
   115   in
   116     IndScheme {T=ST, cases=map mk_case cases', branches=branches }
   117   end
   118 
   119 fun mk_completeness ctxt (IndScheme {cases, branches, ...}) bidx =
   120   let
   121     val SchemeBranch { xs, ws, Cs, ... } = nth branches bidx
   122     val relevant_cases = filter (fn SchemeCase {bidx=bidx', ...} => bidx' = bidx) cases
   123 
   124     val allqnames = fold (fn SchemeCase {qs, ...} => fold (insert (op =) o Free) qs) relevant_cases []
   125     val (Pbool :: xs') = map Free (Variable.variant_frees ctxt allqnames (("P", HOLogic.boolT) :: xs))
   126     val Cs' = map (Pattern.rewrite_term (Proof_Context.theory_of ctxt) (filter_out (op aconv) (map Free xs ~~ xs')) []) Cs
   127 
   128     fun mk_case (SchemeCase {qs, oqnames, gs, lhs, ...}) =
   129       HOLogic.mk_Trueprop Pbool
   130       |> fold_rev (fn x_l => curry Logic.mk_implies (HOLogic.mk_Trueprop(HOLogic.mk_eq x_l)))
   131            (xs' ~~ lhs)
   132       |> fold_rev (curry Logic.mk_implies) gs
   133       |> fold_rev mk_forall_rename (oqnames ~~ map Free qs)
   134   in
   135     HOLogic.mk_Trueprop Pbool
   136     |> fold_rev (curry Logic.mk_implies o mk_case) relevant_cases
   137     |> fold_rev (curry Logic.mk_implies) Cs'
   138     |> fold_rev (Logic.all o Free) ws
   139     |> fold_rev mk_forall_rename (map fst xs ~~ xs')
   140     |> mk_forall_rename ("P", Pbool)
   141   end
   142 
   143 fun mk_wf R (IndScheme {T, ...}) =
   144   HOLogic.Trueprop $ (Const (@{const_name wf}, mk_relT T --> HOLogic.boolT) $ R)
   145 
   146 fun mk_ineqs R (IndScheme {T, cases, branches}) =
   147   let
   148     fun inject i ts =
   149        SumTree.mk_inj T (length branches) (i + 1) (foldr1 HOLogic.mk_prod ts)
   150 
   151     val thesis = Free ("thesis", HOLogic.boolT) (* FIXME *)
   152 
   153     fun mk_pres bdx args =
   154       let
   155         val SchemeBranch { xs, ws, Cs, ... } = nth branches bdx
   156         fun replace (x, v) t = betapply (lambda (Free x) t, v)
   157         val Cs' = map (fold replace (xs ~~ args)) Cs
   158         val cse =
   159           HOLogic.mk_Trueprop thesis
   160           |> fold_rev (curry Logic.mk_implies) Cs'
   161           |> fold_rev (Logic.all o Free) ws
   162       in
   163         Logic.mk_implies (cse, HOLogic.mk_Trueprop thesis)
   164       end
   165 
   166     fun f (SchemeCase {bidx, qs, oqnames, gs, lhs, rs, ...}) =
   167       let
   168         fun g (bidx', Gvs, Gas, rcarg) =
   169           let val export =
   170             fold_rev (curry Logic.mk_implies) Gas
   171             #> fold_rev (curry Logic.mk_implies) gs
   172             #> fold_rev (Logic.all o Free) Gvs
   173             #> fold_rev mk_forall_rename (oqnames ~~ map Free qs)
   174           in
   175             (HOLogic.mk_mem (HOLogic.mk_prod (inject bidx' rcarg, inject bidx lhs), R)
   176              |> HOLogic.mk_Trueprop
   177              |> export,
   178              mk_pres bidx' rcarg
   179              |> export
   180              |> Logic.all thesis)
   181           end
   182       in
   183         map g rs
   184       end
   185   in
   186     map f cases
   187   end
   188 
   189 
   190 fun mk_ind_goal ctxt branches =
   191   let
   192     val thy = Proof_Context.theory_of ctxt
   193 
   194     fun brnch (SchemeBranch { P, xs, ws, Cs, ... }) =
   195       HOLogic.mk_Trueprop (list_comb (P, map Free xs))
   196       |> fold_rev (curry Logic.mk_implies) Cs
   197       |> fold_rev (Logic.all o Free) ws
   198       |> term_conv thy (ind_atomize ctxt)
   199       |> Object_Logic.drop_judgment thy
   200       |> HOLogic.tupled_lambda (foldr1 HOLogic.mk_prod (map Free xs))
   201   in
   202     SumTree.mk_sumcases HOLogic.boolT (map brnch branches)
   203   end
   204 
   205 fun mk_induct_rule ctxt R x complete_thms wf_thm ineqss
   206   (IndScheme {T, cases=scases, branches}) =
   207   let
   208     val thy = Proof_Context.theory_of ctxt
   209     val cert = cterm_of thy
   210 
   211     val n = length branches
   212     val scases_idx = map_index I scases
   213 
   214     fun inject i ts =
   215       SumTree.mk_inj T n (i + 1) (foldr1 HOLogic.mk_prod ts)
   216     val P_of = nth (map (fn (SchemeBranch { P, ... }) => P) branches)
   217 
   218     val P_comp = mk_ind_goal ctxt branches
   219 
   220     (* Inductive Hypothesis: !!z. (z,x):R ==> P z *)
   221     val ihyp = Logic.all_const T $ Abs ("z", T,
   222       Logic.mk_implies
   223         (HOLogic.mk_Trueprop (
   224           Const (@{const_name Set.member}, HOLogic.mk_prodT (T, T) --> mk_relT T --> HOLogic.boolT) 
   225           $ (HOLogic.pair_const T T $ Bound 0 $ x)
   226           $ R),
   227          HOLogic.mk_Trueprop (P_comp $ Bound 0)))
   228       |> cert
   229 
   230     val aihyp = Thm.assume ihyp
   231 
   232     (* Rule for case splitting along the sum types *)
   233     val xss = map (fn (SchemeBranch { xs, ... }) => map Free xs) branches
   234     val pats = map_index (uncurry inject) xss
   235     val sum_split_rule =
   236       Pat_Completeness.prove_completeness ctxt [x] (P_comp $ x) xss (map single pats)
   237 
   238     fun prove_branch (bidx, (SchemeBranch { P, xs, ws, Cs, ... }, (complete_thm, pat))) =
   239       let
   240         val fxs = map Free xs
   241         val branch_hyp = Thm.assume (cert (HOLogic.mk_Trueprop (HOLogic.mk_eq (x, pat))))
   242 
   243         val C_hyps = map (cert #> Thm.assume) Cs
   244 
   245         val (relevant_cases, ineqss') =
   246           (scases_idx ~~ ineqss)
   247           |> filter (fn ((_, SchemeCase {bidx=bidx', ...}), _) => bidx' = bidx)
   248           |> split_list
   249 
   250         fun prove_case (cidx, SchemeCase {qs, gs, lhs, rs, ...}) ineq_press =
   251           let
   252             val case_hyps =
   253               map (Thm.assume o cert o HOLogic.mk_Trueprop o HOLogic.mk_eq) (fxs ~~ lhs)
   254 
   255             val cqs = map (cert o Free) qs
   256             val ags = map (Thm.assume o cert) gs
   257 
   258             val replace_x_simpset =
   259               put_simpset HOL_basic_ss ctxt addsimps (branch_hyp :: case_hyps)
   260             val sih = full_simplify replace_x_simpset aihyp
   261 
   262             fun mk_Prec (idx, Gvs, Gas, rcargs) (ineq, pres) =
   263               let
   264                 val cGas = map (Thm.assume o cert) Gas
   265                 val cGvs = map (cert o Free) Gvs
   266                 val import = fold Thm.forall_elim (cqs @ cGvs)
   267                   #> fold Thm.elim_implies (ags @ cGas)
   268                 val ipres = pres
   269                   |> Thm.forall_elim (cert (list_comb (P_of idx, rcargs)))
   270                   |> import
   271               in
   272                 sih
   273                 |> Thm.forall_elim (cert (inject idx rcargs))
   274                 |> Thm.elim_implies (import ineq) (* Psum rcargs *)
   275                 |> Conv.fconv_rule (sum_prod_conv ctxt)
   276                 |> Conv.fconv_rule (ind_rulify ctxt)
   277                 |> (fn th => th COMP ipres) (* P rs *)
   278                 |> fold_rev (Thm.implies_intr o cprop_of) cGas
   279                 |> fold_rev Thm.forall_intr cGvs
   280               end
   281 
   282             val P_recs = map2 mk_Prec rs ineq_press   (*  [P rec1, P rec2, ... ]  *)
   283 
   284             val step = HOLogic.mk_Trueprop (list_comb (P, lhs))
   285               |> fold_rev (curry Logic.mk_implies o prop_of) P_recs
   286               |> fold_rev (curry Logic.mk_implies) gs
   287               |> fold_rev (Logic.all o Free) qs
   288               |> cert
   289 
   290             val Plhs_to_Pxs_conv =
   291               foldl1 (uncurry Conv.combination_conv)
   292                 (Conv.all_conv :: map (fn ch => K (Thm.symmetric (ch RS eq_reflection))) case_hyps)
   293 
   294             val res = Thm.assume step
   295               |> fold Thm.forall_elim cqs
   296               |> fold Thm.elim_implies ags
   297               |> fold Thm.elim_implies P_recs (* P lhs *)
   298               |> Conv.fconv_rule (Conv.arg_conv Plhs_to_Pxs_conv) (* P xs *)
   299               |> fold_rev (Thm.implies_intr o cprop_of) (ags @ case_hyps)
   300               |> fold_rev Thm.forall_intr cqs (* !!qs. Gas ==> xs = lhss ==> P xs *)
   301           in
   302             (res, (cidx, step))
   303           end
   304 
   305         val (cases, steps) = split_list (map2 prove_case relevant_cases ineqss')
   306 
   307         val bstep = complete_thm
   308           |> Thm.forall_elim (cert (list_comb (P, fxs)))
   309           |> fold (Thm.forall_elim o cert) (fxs @ map Free ws)
   310           |> fold Thm.elim_implies C_hyps
   311           |> fold Thm.elim_implies cases (* P xs *)
   312           |> fold_rev (Thm.implies_intr o cprop_of) C_hyps
   313           |> fold_rev (Thm.forall_intr o cert o Free) ws
   314 
   315         val Pxs = cert (HOLogic.mk_Trueprop (P_comp $ x))
   316           |> Goal.init
   317           |> (Simplifier.rewrite_goals_tac ctxt
   318                 (map meta (branch_hyp :: @{thm split_conv} :: @{thms sum.case}))
   319               THEN CONVERSION (ind_rulify ctxt) 1)
   320           |> Seq.hd
   321           |> Thm.elim_implies (Conv.fconv_rule Drule.beta_eta_conversion bstep)
   322           |> Goal.finish ctxt
   323           |> Thm.implies_intr (cprop_of branch_hyp)
   324           |> fold_rev (Thm.forall_intr o cert) fxs
   325       in
   326         (Pxs, steps)
   327       end
   328 
   329     val (branches, steps) =
   330       map_index prove_branch (branches ~~ (complete_thms ~~ pats))
   331       |> split_list |> apsnd flat
   332 
   333     val istep = sum_split_rule
   334       |> fold (fn b => fn th => Drule.compose (b, 1, th)) branches
   335       |> Thm.implies_intr ihyp
   336       |> Thm.forall_intr (cert x) (* "!!x. (!!y<x. P y) ==> P x" *)
   337 
   338     val induct_rule =
   339       @{thm "wf_induct_rule"}
   340       |> (curry op COMP) wf_thm
   341       |> (curry op COMP) istep
   342 
   343     val steps_sorted = map snd (sort (int_ord o pairself fst) steps)
   344   in
   345     (steps_sorted, induct_rule)
   346   end
   347 
   348 
   349 fun mk_ind_tac comp_tac pres_tac term_tac ctxt facts =
   350   (* FIXME proper use of facts!? *)
   351   (ALLGOALS (Method.insert_tac facts)) THEN HEADGOAL (SUBGOAL (fn (t, i) =>
   352   let
   353     val (ctxt', _, cases, concl) = dest_hhf ctxt t
   354     val scheme as IndScheme {T=ST, branches, ...} = mk_scheme' ctxt' cases concl
   355     val ([Rn,xn], ctxt'') = Variable.variant_fixes ["R","x"] ctxt'
   356     val R = Free (Rn, mk_relT ST)
   357     val x = Free (xn, ST)
   358     val cert = cterm_of (Proof_Context.theory_of ctxt)
   359 
   360     val ineqss = mk_ineqs R scheme
   361       |> map (map (pairself (Thm.assume o cert)))
   362     val complete =
   363       map_range (mk_completeness ctxt scheme #> cert #> Thm.assume) (length branches)
   364     val wf_thm = mk_wf R scheme |> cert |> Thm.assume
   365 
   366     val (descent, pres) = split_list (flat ineqss)
   367     val newgoals = complete @ pres @ wf_thm :: descent
   368 
   369     val (steps, indthm) =
   370       mk_induct_rule ctxt'' R x complete wf_thm ineqss scheme
   371 
   372     fun project (i, SchemeBranch {xs, ...}) =
   373       let
   374         val inst = (foldr1 HOLogic.mk_prod (map Free xs))
   375           |> SumTree.mk_inj ST (length branches) (i + 1)
   376           |> cert
   377       in
   378         indthm
   379         |> Drule.instantiate' [] [SOME inst]
   380         |> simplify (put_simpset SumTree.sumcase_split_ss ctxt'')
   381         |> Conv.fconv_rule (ind_rulify ctxt'')
   382       end
   383 
   384     val res = Conjunction.intr_balanced (map_index project branches)
   385       |> fold_rev Thm.implies_intr (map cprop_of newgoals @ steps)
   386       |> Drule.generalize ([], [Rn])
   387 
   388     val nbranches = length branches
   389     val npres = length pres
   390   in
   391     Thm.bicompose {flatten = false, match = false, incremented = false}
   392       (false, res, length newgoals) i
   393     THEN term_tac (i + nbranches + npres)
   394     THEN (EVERY (map (TRY o pres_tac) ((i + nbranches + npres - 1) downto (i + nbranches))))
   395     THEN (EVERY (map (TRY o comp_tac) ((i + nbranches - 1) downto i)))
   396   end))
   397 
   398 
   399 fun induction_schema_tac ctxt =
   400   mk_ind_tac (K all_tac) (assume_tac APPEND' Goal.assume_rule_tac ctxt) (K all_tac) ctxt;
   401 
   402 end