src/HOL/Tools/Function/induction_schema.ML
changeset 33471 5aef13872723
child 33697 7d6793ce0a26
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/Tools/Function/induction_schema.ML	Fri Nov 06 14:42:42 2009 +0100
     1.3 @@ -0,0 +1,405 @@
     1.4 +(*  Title:      HOL/Tools/Function/induction_schema.ML
     1.5 +    Author:     Alexander Krauss, TU Muenchen
     1.6 +
     1.7 +A method to prove induction schemas.
     1.8 +*)
     1.9 +
    1.10 +signature INDUCTION_SCHEMA =
    1.11 +sig
    1.12 +  val mk_ind_tac : (int -> tactic) -> (int -> tactic) -> (int -> tactic)
    1.13 +                   -> Proof.context -> thm list -> tactic
    1.14 +  val induction_schema_tac : Proof.context -> thm list -> tactic
    1.15 +  val setup : theory -> theory
    1.16 +end
    1.17 +
    1.18 +
    1.19 +structure Induction_Schema : INDUCTION_SCHEMA =
    1.20 +struct
    1.21 +
    1.22 +open Function_Lib
    1.23 +
    1.24 +
    1.25 +type rec_call_info = int * (string * typ) list * term list * term list
    1.26 +
    1.27 +datatype scheme_case =
    1.28 +  SchemeCase of
    1.29 +  {
    1.30 +   bidx : int,
    1.31 +   qs: (string * typ) list,
    1.32 +   oqnames: string list,
    1.33 +   gs: term list,
    1.34 +   lhs: term list,
    1.35 +   rs: rec_call_info list
    1.36 +  }
    1.37 +
    1.38 +datatype scheme_branch = 
    1.39 +  SchemeBranch of
    1.40 +  {
    1.41 +   P : term,
    1.42 +   xs: (string * typ) list,
    1.43 +   ws: (string * typ) list,
    1.44 +   Cs: term list
    1.45 +  }
    1.46 +
    1.47 +datatype ind_scheme =
    1.48 +  IndScheme of
    1.49 +  {
    1.50 +   T: typ, (* sum of products *)
    1.51 +   branches: scheme_branch list,
    1.52 +   cases: scheme_case list
    1.53 +  }
    1.54 +
    1.55 +val ind_atomize = MetaSimplifier.rewrite true @{thms induct_atomize}
    1.56 +val ind_rulify = MetaSimplifier.rewrite true @{thms induct_rulify}
    1.57 +
    1.58 +fun meta thm = thm RS eq_reflection
    1.59 +
    1.60 +val sum_prod_conv = MetaSimplifier.rewrite true 
    1.61 +                    (map meta (@{thm split_conv} :: @{thms sum.cases}))
    1.62 +
    1.63 +fun term_conv thy cv t = 
    1.64 +    cv (cterm_of thy t)
    1.65 +    |> prop_of |> Logic.dest_equals |> snd
    1.66 +
    1.67 +fun mk_relT T = HOLogic.mk_setT (HOLogic.mk_prodT (T, T))
    1.68 +
    1.69 +fun dest_hhf ctxt t = 
    1.70 +    let 
    1.71 +      val (ctxt', vars, imp) = dest_all_all_ctx ctxt t
    1.72 +    in
    1.73 +      (ctxt', vars, Logic.strip_imp_prems imp, Logic.strip_imp_concl imp)
    1.74 +    end
    1.75 +
    1.76 +
    1.77 +fun mk_scheme' ctxt cases concl =
    1.78 +    let
    1.79 +      fun mk_branch concl =
    1.80 +          let
    1.81 +            val (ctxt', ws, Cs, _ $ Pxs) = dest_hhf ctxt concl
    1.82 +            val (P, xs) = strip_comb Pxs
    1.83 +          in
    1.84 +            SchemeBranch { P=P, xs=map dest_Free xs, ws=ws, Cs=Cs }
    1.85 +          end
    1.86 +
    1.87 +      val (branches, cases') = (* correction *)
    1.88 +          case Logic.dest_conjunction_list concl of
    1.89 +            [conc] => 
    1.90 +            let 
    1.91 +              val _ $ Pxs = Logic.strip_assums_concl conc
    1.92 +              val (P, _) = strip_comb Pxs
    1.93 +              val (cases', conds) = take_prefix (Term.exists_subterm (curry op aconv P)) cases
    1.94 +              val concl' = fold_rev (curry Logic.mk_implies) conds conc
    1.95 +            in
    1.96 +              ([mk_branch concl'], cases')
    1.97 +            end
    1.98 +          | concls => (map mk_branch concls, cases)
    1.99 +
   1.100 +      fun mk_case premise =
   1.101 +          let
   1.102 +            val (ctxt', qs, prems, _ $ Plhs) = dest_hhf ctxt premise
   1.103 +            val (P, lhs) = strip_comb Plhs
   1.104 +                                
   1.105 +            fun bidx Q = find_index (fn SchemeBranch {P=P',...} => Q aconv P') branches
   1.106 +
   1.107 +            fun mk_rcinfo pr =
   1.108 +                let
   1.109 +                  val (ctxt'', Gvs, Gas, _ $ Phyp) = dest_hhf ctxt' pr
   1.110 +                  val (P', rcs) = strip_comb Phyp
   1.111 +                in
   1.112 +                  (bidx P', Gvs, Gas, rcs)
   1.113 +                end
   1.114 +                
   1.115 +            fun is_pred v = exists (fn SchemeBranch {P,...} => v aconv P) branches
   1.116 +
   1.117 +            val (gs, rcprs) = 
   1.118 +                take_prefix (not o Term.exists_subterm is_pred) prems
   1.119 +          in
   1.120 +            SchemeCase {bidx=bidx P, qs=qs, oqnames=map fst qs(*FIXME*), gs=gs, lhs=lhs, rs=map mk_rcinfo rcprs}
   1.121 +          end
   1.122 +
   1.123 +      fun PT_of (SchemeBranch { xs, ...}) =
   1.124 +            foldr1 HOLogic.mk_prodT (map snd xs)
   1.125 +
   1.126 +      val ST = Balanced_Tree.make (uncurry SumTree.mk_sumT) (map PT_of branches)
   1.127 +    in
   1.128 +      IndScheme {T=ST, cases=map mk_case cases', branches=branches }
   1.129 +    end
   1.130 +
   1.131 +
   1.132 +
   1.133 +fun mk_completeness ctxt (IndScheme {cases, branches, ...}) bidx =
   1.134 +    let
   1.135 +      val SchemeBranch { xs, ws, Cs, ... } = nth branches bidx
   1.136 +      val relevant_cases = filter (fn SchemeCase {bidx=bidx', ...} => bidx' = bidx) cases
   1.137 +
   1.138 +      val allqnames = fold (fn SchemeCase {qs, ...} => fold (insert (op =) o Free) qs) relevant_cases []
   1.139 +      val (Pbool :: xs') = map Free (Variable.variant_frees ctxt allqnames (("P", HOLogic.boolT) :: xs))
   1.140 +      val Cs' = map (Pattern.rewrite_term (ProofContext.theory_of ctxt) (filter_out (op aconv) (map Free xs ~~ xs')) []) Cs
   1.141 +                       
   1.142 +      fun mk_case (SchemeCase {qs, oqnames, gs, lhs, ...}) =
   1.143 +          HOLogic.mk_Trueprop Pbool
   1.144 +                     |> fold_rev (fn x_l => curry Logic.mk_implies (HOLogic.mk_Trueprop(HOLogic.mk_eq x_l)))
   1.145 +                                 (xs' ~~ lhs)
   1.146 +                     |> fold_rev (curry Logic.mk_implies) gs
   1.147 +                     |> fold_rev mk_forall_rename (oqnames ~~ map Free qs)
   1.148 +    in
   1.149 +      HOLogic.mk_Trueprop Pbool
   1.150 +       |> fold_rev (curry Logic.mk_implies o mk_case) relevant_cases
   1.151 +       |> fold_rev (curry Logic.mk_implies) Cs'
   1.152 +       |> fold_rev (Logic.all o Free) ws
   1.153 +       |> fold_rev mk_forall_rename (map fst xs ~~ xs')
   1.154 +       |> mk_forall_rename ("P", Pbool)
   1.155 +    end
   1.156 +
   1.157 +fun mk_wf ctxt R (IndScheme {T, ...}) =
   1.158 +    HOLogic.Trueprop $ (Const (@{const_name wf}, mk_relT T --> HOLogic.boolT) $ R)
   1.159 +
   1.160 +fun mk_ineqs R (IndScheme {T, cases, branches}) =
   1.161 +    let
   1.162 +      fun inject i ts =
   1.163 +          SumTree.mk_inj T (length branches) (i + 1) (foldr1 HOLogic.mk_prod ts)
   1.164 +
   1.165 +      val thesis = Free ("thesis", HOLogic.boolT) (* FIXME *)
   1.166 +
   1.167 +      fun mk_pres bdx args = 
   1.168 +          let
   1.169 +            val SchemeBranch { xs, ws, Cs, ... } = nth branches bdx
   1.170 +            fun replace (x, v) t = betapply (lambda (Free x) t, v)
   1.171 +            val Cs' = map (fold replace (xs ~~ args)) Cs
   1.172 +            val cse = 
   1.173 +                HOLogic.mk_Trueprop thesis
   1.174 +                |> fold_rev (curry Logic.mk_implies) Cs'
   1.175 +                |> fold_rev (Logic.all o Free) ws
   1.176 +          in
   1.177 +            Logic.mk_implies (cse, HOLogic.mk_Trueprop thesis)
   1.178 +          end
   1.179 +
   1.180 +      fun f (SchemeCase {bidx, qs, oqnames, gs, lhs, rs, ...}) = 
   1.181 +          let
   1.182 +            fun g (bidx', Gvs, Gas, rcarg) =
   1.183 +                let val export = 
   1.184 +                         fold_rev (curry Logic.mk_implies) Gas
   1.185 +                         #> fold_rev (curry Logic.mk_implies) gs
   1.186 +                         #> fold_rev (Logic.all o Free) Gvs
   1.187 +                         #> fold_rev mk_forall_rename (oqnames ~~ map Free qs)
   1.188 +                in
   1.189 +                (HOLogic.mk_mem (HOLogic.mk_prod (inject bidx' rcarg, inject bidx lhs), R)
   1.190 +                 |> HOLogic.mk_Trueprop
   1.191 +                 |> export,
   1.192 +                 mk_pres bidx' rcarg
   1.193 +                 |> export
   1.194 +                 |> Logic.all thesis)
   1.195 +                end
   1.196 +          in
   1.197 +            map g rs
   1.198 +          end
   1.199 +    in
   1.200 +      map f cases
   1.201 +    end
   1.202 +
   1.203 +
   1.204 +fun mk_hol_imp a b = HOLogic.imp $ a $ b
   1.205 +
   1.206 +fun mk_ind_goal thy branches =
   1.207 +    let
   1.208 +      fun brnch (SchemeBranch { P, xs, ws, Cs, ... }) =
   1.209 +          HOLogic.mk_Trueprop (list_comb (P, map Free xs))
   1.210 +          |> fold_rev (curry Logic.mk_implies) Cs
   1.211 +          |> fold_rev (Logic.all o Free) ws
   1.212 +          |> term_conv thy ind_atomize
   1.213 +          |> ObjectLogic.drop_judgment thy
   1.214 +          |> tupled_lambda (foldr1 HOLogic.mk_prod (map Free xs))
   1.215 +    in
   1.216 +      SumTree.mk_sumcases HOLogic.boolT (map brnch branches)
   1.217 +    end
   1.218 +
   1.219 +
   1.220 +fun mk_induct_rule ctxt R x complete_thms wf_thm ineqss (IndScheme {T, cases=scases, branches}) =
   1.221 +    let
   1.222 +      val n = length branches
   1.223 +
   1.224 +      val scases_idx = map_index I scases
   1.225 +
   1.226 +      fun inject i ts =
   1.227 +          SumTree.mk_inj T n (i + 1) (foldr1 HOLogic.mk_prod ts)
   1.228 +      val P_of = nth (map (fn (SchemeBranch { P, ... }) => P) branches)
   1.229 +
   1.230 +      val thy = ProofContext.theory_of ctxt
   1.231 +      val cert = cterm_of thy 
   1.232 +
   1.233 +      val P_comp = mk_ind_goal thy branches
   1.234 +
   1.235 +      (* Inductive Hypothesis: !!z. (z,x):R ==> P z *)
   1.236 +      val ihyp = Term.all T $ Abs ("z", T, 
   1.237 +               Logic.mk_implies
   1.238 +                 (HOLogic.mk_Trueprop (
   1.239 +                  Const ("op :", HOLogic.mk_prodT (T, T) --> mk_relT T --> HOLogic.boolT) 
   1.240 +                    $ (HOLogic.pair_const T T $ Bound 0 $ x) 
   1.241 +                    $ R),
   1.242 +                   HOLogic.mk_Trueprop (P_comp $ Bound 0)))
   1.243 +           |> cert
   1.244 +
   1.245 +      val aihyp = assume ihyp
   1.246 +
   1.247 +     (* Rule for case splitting along the sum types *)
   1.248 +      val xss = map (fn (SchemeBranch { xs, ... }) => map Free xs) branches
   1.249 +      val pats = map_index (uncurry inject) xss
   1.250 +      val sum_split_rule = Pat_Completeness.prove_completeness thy [x] (P_comp $ x) xss (map single pats)
   1.251 +
   1.252 +      fun prove_branch (bidx, (SchemeBranch { P, xs, ws, Cs, ... }, (complete_thm, pat))) =
   1.253 +          let
   1.254 +            val fxs = map Free xs
   1.255 +            val branch_hyp = assume (cert (HOLogic.mk_Trueprop (HOLogic.mk_eq (x, pat))))
   1.256 +                             
   1.257 +            val C_hyps = map (cert #> assume) Cs
   1.258 +
   1.259 +            val (relevant_cases, ineqss') = filter (fn ((_, SchemeCase {bidx=bidx', ...}), _) => bidx' = bidx) (scases_idx ~~ ineqss)
   1.260 +                                            |> split_list
   1.261 +                           
   1.262 +            fun prove_case (cidx, SchemeCase {qs, oqnames, gs, lhs, rs, ...}) ineq_press =
   1.263 +                let
   1.264 +                  val case_hyps = map (assume o cert o HOLogic.mk_Trueprop o HOLogic.mk_eq) (fxs ~~ lhs)
   1.265 +                           
   1.266 +                  val cqs = map (cert o Free) qs
   1.267 +                  val ags = map (assume o cert) gs
   1.268 +                            
   1.269 +                  val replace_x_ss = HOL_basic_ss addsimps (branch_hyp :: case_hyps)
   1.270 +                  val sih = full_simplify replace_x_ss aihyp
   1.271 +                            
   1.272 +                  fun mk_Prec (idx, Gvs, Gas, rcargs) (ineq, pres) =
   1.273 +                      let
   1.274 +                        val cGas = map (assume o cert) Gas
   1.275 +                        val cGvs = map (cert o Free) Gvs
   1.276 +                        val import = fold forall_elim (cqs @ cGvs)
   1.277 +                                     #> fold Thm.elim_implies (ags @ cGas)
   1.278 +                        val ipres = pres
   1.279 +                                     |> forall_elim (cert (list_comb (P_of idx, rcargs)))
   1.280 +                                     |> import
   1.281 +                      in
   1.282 +                        sih |> forall_elim (cert (inject idx rcargs))
   1.283 +                            |> Thm.elim_implies (import ineq) (* Psum rcargs *)
   1.284 +                            |> Conv.fconv_rule sum_prod_conv
   1.285 +                            |> Conv.fconv_rule ind_rulify
   1.286 +                            |> (fn th => th COMP ipres) (* P rs *)
   1.287 +                            |> fold_rev (implies_intr o cprop_of) cGas
   1.288 +                            |> fold_rev forall_intr cGvs
   1.289 +                      end
   1.290 +                      
   1.291 +                  val P_recs = map2 mk_Prec rs ineq_press   (*  [P rec1, P rec2, ... ]  *)
   1.292 +                               
   1.293 +                  val step = HOLogic.mk_Trueprop (list_comb (P, lhs))
   1.294 +                             |> fold_rev (curry Logic.mk_implies o prop_of) P_recs
   1.295 +                             |> fold_rev (curry Logic.mk_implies) gs
   1.296 +                             |> fold_rev (Logic.all o Free) qs
   1.297 +                             |> cert
   1.298 +                             
   1.299 +                  val Plhs_to_Pxs_conv = 
   1.300 +                      foldl1 (uncurry Conv.combination_conv) 
   1.301 +                      (Conv.all_conv :: map (fn ch => K (Thm.symmetric (ch RS eq_reflection))) case_hyps)
   1.302 +
   1.303 +                  val res = assume step
   1.304 +                                   |> fold forall_elim cqs
   1.305 +                                   |> fold Thm.elim_implies ags
   1.306 +                                   |> fold Thm.elim_implies P_recs (* P lhs *) 
   1.307 +                                   |> Conv.fconv_rule (Conv.arg_conv Plhs_to_Pxs_conv) (* P xs *)
   1.308 +                                   |> fold_rev (implies_intr o cprop_of) (ags @ case_hyps)
   1.309 +                                   |> fold_rev forall_intr cqs (* !!qs. Gas ==> xs = lhss ==> P xs *)
   1.310 +                in
   1.311 +                  (res, (cidx, step))
   1.312 +                end
   1.313 +
   1.314 +            val (cases, steps) = split_list (map2 prove_case relevant_cases ineqss')
   1.315 +
   1.316 +            val bstep = complete_thm
   1.317 +                |> forall_elim (cert (list_comb (P, fxs)))
   1.318 +                |> fold (forall_elim o cert) (fxs @ map Free ws)
   1.319 +                |> fold Thm.elim_implies C_hyps             (* FIXME: optimization using rotate_prems *)
   1.320 +                |> fold Thm.elim_implies cases (* P xs *)
   1.321 +                |> fold_rev (implies_intr o cprop_of) C_hyps
   1.322 +                |> fold_rev (forall_intr o cert o Free) ws
   1.323 +
   1.324 +            val Pxs = cert (HOLogic.mk_Trueprop (P_comp $ x))
   1.325 +                     |> Goal.init
   1.326 +                     |> (MetaSimplifier.rewrite_goals_tac (map meta (branch_hyp :: @{thm split_conv} :: @{thms sum.cases}))
   1.327 +                         THEN CONVERSION ind_rulify 1)
   1.328 +                     |> Seq.hd
   1.329 +                     |> Thm.elim_implies (Conv.fconv_rule Drule.beta_eta_conversion bstep)
   1.330 +                     |> Goal.finish ctxt
   1.331 +                     |> implies_intr (cprop_of branch_hyp)
   1.332 +                     |> fold_rev (forall_intr o cert) fxs
   1.333 +          in
   1.334 +            (Pxs, steps)
   1.335 +          end
   1.336 +
   1.337 +      val (branches, steps) = split_list (map_index prove_branch (branches ~~ (complete_thms ~~ pats)))
   1.338 +                              |> apsnd flat
   1.339 +                           
   1.340 +      val istep = sum_split_rule
   1.341 +                |> fold (fn b => fn th => Drule.compose_single (b, 1, th)) branches
   1.342 +                |> implies_intr ihyp
   1.343 +                |> forall_intr (cert x) (* "!!x. (!!y<x. P y) ==> P x" *)
   1.344 +         
   1.345 +      val induct_rule =
   1.346 +          @{thm "wf_induct_rule"}
   1.347 +            |> (curry op COMP) wf_thm 
   1.348 +            |> (curry op COMP) istep
   1.349 +
   1.350 +      val steps_sorted = map snd (sort (int_ord o pairself fst) steps)
   1.351 +    in
   1.352 +      (steps_sorted, induct_rule)
   1.353 +    end
   1.354 +
   1.355 +
   1.356 +fun mk_ind_tac comp_tac pres_tac term_tac ctxt facts = (ALLGOALS (Method.insert_tac facts)) THEN HEADGOAL 
   1.357 +(SUBGOAL (fn (t, i) =>
   1.358 +  let
   1.359 +    val (ctxt', _, cases, concl) = dest_hhf ctxt t
   1.360 +    val scheme as IndScheme {T=ST, branches, ...} = mk_scheme' ctxt' cases concl
   1.361 +(*     val _ = tracing (makestring scheme)*)
   1.362 +    val ([Rn,xn], ctxt'') = Variable.variant_fixes ["R","x"] ctxt'
   1.363 +    val R = Free (Rn, mk_relT ST)
   1.364 +    val x = Free (xn, ST)
   1.365 +    val cert = cterm_of (ProofContext.theory_of ctxt)
   1.366 +
   1.367 +    val ineqss = mk_ineqs R scheme
   1.368 +                   |> map (map (pairself (assume o cert)))
   1.369 +    val complete = map_range (mk_completeness ctxt scheme #> cert #> assume) (length branches)
   1.370 +    val wf_thm = mk_wf ctxt R scheme |> cert |> assume
   1.371 +
   1.372 +    val (descent, pres) = split_list (flat ineqss)
   1.373 +    val newgoals = complete @ pres @ wf_thm :: descent 
   1.374 +
   1.375 +    val (steps, indthm) = mk_induct_rule ctxt'' R x complete wf_thm ineqss scheme
   1.376 +
   1.377 +    fun project (i, SchemeBranch {xs, ...}) =
   1.378 +        let
   1.379 +          val inst = cert (SumTree.mk_inj ST (length branches) (i + 1) (foldr1 HOLogic.mk_prod (map Free xs)))
   1.380 +        in
   1.381 +          indthm |> Drule.instantiate' [] [SOME inst]
   1.382 +                 |> simplify SumTree.sumcase_split_ss
   1.383 +                 |> Conv.fconv_rule ind_rulify
   1.384 +(*                 |> (fn thm => (tracing (makestring thm); thm))*)
   1.385 +        end                  
   1.386 +
   1.387 +    val res = Conjunction.intr_balanced (map_index project branches)
   1.388 +                 |> fold_rev implies_intr (map cprop_of newgoals @ steps)
   1.389 +                 |> (fn thm => Thm.generalize ([], [Rn]) (Thm.maxidx_of thm + 1) thm)
   1.390 +
   1.391 +    val nbranches = length branches
   1.392 +    val npres = length pres
   1.393 +  in
   1.394 +    Thm.compose_no_flatten false (res, length newgoals) i
   1.395 +    THEN term_tac (i + nbranches + npres)
   1.396 +    THEN (EVERY (map (TRY o pres_tac) ((i + nbranches + npres - 1) downto (i + nbranches))))
   1.397 +    THEN (EVERY (map (TRY o comp_tac) ((i + nbranches - 1) downto i)))
   1.398 +  end))
   1.399 +
   1.400 +
   1.401 +fun induction_schema_tac ctxt =
   1.402 +  mk_ind_tac (K all_tac) (assume_tac APPEND' Goal.assume_rule_tac ctxt) (K all_tac) ctxt;
   1.403 +
   1.404 +val setup =
   1.405 +  Method.setup @{binding induction_schema} (Scan.succeed (RAW_METHOD o induction_schema_tac))
   1.406 +    "proves an induction principle"
   1.407 +
   1.408 +end