src/HOL/Tools/induct_method.ML
author wenzelm
Sat Jul 01 19:55:22 2000 +0200 (2000-07-01)
changeset 9230 17ae63f82ad8
parent 9066 b1e874e38dab
child 9299 c5cda71de65d
permissions -rw-r--r--
GPLed;
     1 (*  Title:      HOL/Tools/induct_method.ML
     2     ID:         $Id$
     3     Author:     Markus Wenzel, TU Muenchen
     4     License:    GPL (GNU GENERAL PUBLIC LICENSE)
     5 
     6 Proof by cases and induction on types and sets.
     7 *)
     8 
     9 signature INDUCT_METHOD =
    10 sig
    11   val dest_global_rules: theory ->
    12     {type_cases: (string * thm) list, set_cases: (string * thm) list,
    13       type_induct: (string * thm) list, set_induct: (string * thm) list}
    14   val print_global_rules: theory -> unit
    15   val dest_local_rules: Proof.context ->
    16     {type_cases: (string * thm) list, set_cases: (string * thm) list,
    17       type_induct: (string * thm) list, set_induct: (string * thm) list}
    18   val print_local_rules: Proof.context -> unit
    19   val vars_of: term -> term list
    20   val concls_of: thm -> term list
    21   val cases_type_global: string -> theory attribute
    22   val cases_set_global: string -> theory attribute
    23   val cases_type_local: string -> Proof.context attribute
    24   val cases_set_local: string -> Proof.context attribute
    25   val induct_type_global: string -> theory attribute
    26   val induct_set_global: string -> theory attribute
    27   val induct_type_local: string -> Proof.context attribute
    28   val induct_set_local: string -> Proof.context attribute
    29   val con_elim_tac: simpset -> tactic
    30   val con_elim_solved_tac: simpset -> tactic
    31   val setup: (theory -> theory) list
    32 end;
    33 
    34 structure InductMethod: INDUCT_METHOD =
    35 struct
    36 
    37 
    38 (** global and local induct data **)
    39 
    40 (* rules *)
    41 
    42 type rules = (string * thm) NetRules.T;
    43 
    44 fun eq_rule ((s1:string, th1), (s2, th2)) = s1 = s2 andalso Thm.eq_thm (th1, th2);
    45 
    46 val type_rules = NetRules.init eq_rule (Thm.concl_of o #2);
    47 val set_rules = NetRules.init eq_rule (Thm.major_prem_of o #2);
    48 
    49 fun lookup_rule (rs:rules) name = Library.assoc (NetRules.rules rs, name);
    50 
    51 fun print_rules kind rs =
    52   let val thms = map snd (NetRules.rules rs)
    53   in Pretty.writeln (Pretty.big_list kind (map Display.pretty_thm thms)) end;
    54 
    55 
    56 (* theory data kind 'HOL/induct_method' *)
    57 
    58 structure GlobalInductArgs =
    59 struct
    60   val name = "HOL/induct_method";
    61   type T = (rules * rules) * (rules * rules);
    62 
    63   val empty = ((type_rules, set_rules), (type_rules, set_rules));
    64   val copy = I;
    65   val prep_ext = I;
    66   fun merge (((casesT1, casesS1), (inductT1, inductS1)),
    67       ((casesT2, casesS2), (inductT2, inductS2))) =
    68     ((NetRules.merge (casesT1, casesT2), NetRules.merge (casesS1, casesS2)),
    69       (NetRules.merge (inductT1, inductT2), NetRules.merge (inductS1, inductS2)));
    70 
    71   fun print _ ((casesT, casesS), (inductT, inductS)) =
    72     (print_rules "type cases:" casesT;
    73       print_rules "set cases:" casesS;
    74       print_rules "type induct:" inductT;
    75       print_rules "set induct:" inductS);
    76 
    77   fun dest ((casesT, casesS), (inductT, inductS)) =
    78     {type_cases = NetRules.rules casesT,
    79      set_cases = NetRules.rules casesS,
    80      type_induct = NetRules.rules inductT,
    81      set_induct = NetRules.rules inductS};
    82 end;
    83 
    84 structure GlobalInduct = TheoryDataFun(GlobalInductArgs);
    85 val print_global_rules = GlobalInduct.print;
    86 val dest_global_rules = GlobalInductArgs.dest o GlobalInduct.get;
    87 
    88 
    89 (* proof data kind 'HOL/induct_method' *)
    90 
    91 structure LocalInductArgs =
    92 struct
    93   val name = "HOL/induct_method";
    94   type T = GlobalInductArgs.T;
    95 
    96   fun init thy = GlobalInduct.get thy;
    97   fun print x = GlobalInductArgs.print x;
    98 end;
    99 
   100 structure LocalInduct = ProofDataFun(LocalInductArgs);
   101 val print_local_rules = LocalInduct.print;
   102 val dest_local_rules = GlobalInductArgs.dest o LocalInduct.get;
   103 
   104 
   105 (* access rules *)
   106 
   107 val get_cases = #1 o LocalInduct.get;
   108 val get_induct = #2 o LocalInduct.get;
   109 
   110 val lookup_casesT = lookup_rule o #1 o get_cases;
   111 val lookup_casesS = lookup_rule o #2 o get_cases;
   112 val lookup_inductT = lookup_rule o #1 o get_induct;
   113 val lookup_inductS = lookup_rule o #2 o get_induct;
   114 
   115 
   116 
   117 (** attributes **)
   118 
   119 local
   120 
   121 fun mk_att f g name (x, thm) = (f (g (name, thm)) x, thm);
   122 
   123 fun add_casesT rule x = apfst (apfst (NetRules.insert rule)) x;
   124 fun add_casesS rule x = apfst (apsnd (NetRules.insert rule)) x;
   125 fun add_inductT rule x = apsnd (apfst (NetRules.insert rule)) x;
   126 fun add_inductS rule x = apsnd (apsnd (NetRules.insert rule)) x;
   127 
   128 in
   129 
   130 val cases_type_global = mk_att GlobalInduct.map add_casesT;
   131 val cases_set_global = mk_att GlobalInduct.map add_casesS;
   132 val induct_type_global = mk_att GlobalInduct.map add_inductT;
   133 val induct_set_global = mk_att GlobalInduct.map add_inductS;
   134 
   135 val cases_type_local = mk_att LocalInduct.map add_casesT;
   136 val cases_set_local = mk_att LocalInduct.map add_casesS;
   137 val induct_type_local = mk_att LocalInduct.map add_inductT;
   138 val induct_set_local = mk_att LocalInduct.map add_inductS;
   139 
   140 end;
   141 
   142 
   143 
   144 (** misc utils **)
   145 
   146 (* thms and terms *)
   147 
   148 val concls_of = HOLogic.dest_conj o HOLogic.dest_Trueprop o Thm.concl_of;
   149 
   150 fun vars_of tm =        (*ordered left-to-right, preferring right!*)
   151   Term.foldl_aterms (fn (ts, t as Var _) => t :: ts | (ts, _) => ts) ([], tm)
   152   |> Library.distinct |> rev;
   153 
   154 fun type_name t =
   155   #1 (Term.dest_Type (Term.type_of t))
   156     handle TYPE _ => raise TERM ("Bad type of term argument", [t]);
   157 
   158 
   159 (* simplifying cases rules *)
   160 
   161 local
   162 
   163 (*delete needless equality assumptions*)
   164 val refl_thin = prove_goal HOL.thy "!!P. [| a=a;  P |] ==> P"
   165      (fn _ => [assume_tac 1]);
   166 
   167 val elim_rls = [asm_rl, FalseE, refl_thin, conjE, exE, Pair_inject];
   168 
   169 val elim_tac = REPEAT o Tactic.eresolve_tac elim_rls;
   170 
   171 fun simp_case_tac ss = 
   172   EVERY' [elim_tac, asm_full_simp_tac ss, elim_tac, REPEAT o bound_hyp_subst_tac];
   173 
   174 in
   175 
   176 fun con_elim_tac ss = ALLGOALS (simp_case_tac ss) THEN prune_params_tac;
   177 
   178 fun con_elim_solved_tac ss =
   179   ALLGOALS (fn i => TRY (simp_case_tac ss i THEN_MAYBE no_tac)) THEN prune_params_tac;
   180 
   181 end;
   182 
   183 
   184 
   185 (** cases method **)
   186 
   187 (*
   188   rule selection:
   189         cases         - classical case split
   190         cases t       - datatype exhaustion
   191   <x:A> cases ...     - set elimination
   192   ...   cases ... R   - explicit rule
   193 *)
   194 
   195 val case_split = RuleCases.name ["True", "False"] case_split_thm;
   196 
   197 local
   198 
   199 fun cases_var thm =
   200   (case try (hd o vars_of o hd o Logic.strip_assums_hyp o Library.last_elem o Thm.prems_of) thm of
   201     None => raise THM ("Malformed cases rule", 0, [thm])
   202   | Some x => x);
   203 
   204 fun simplify_cases ctxt =
   205   Tactic.rule_by_tactic (con_elim_solved_tac (Simplifier.get_local_simpset ctxt));
   206 
   207 fun cases_tac (ctxt, (simplified, args)) facts =
   208   let
   209     val sg = ProofContext.sign_of ctxt;
   210     val cert = Thm.cterm_of sg;
   211 
   212     fun inst_rule t thm =
   213       Drule.cterm_instantiate [(cert (cases_var thm), cert t)] thm;
   214 
   215     val cond_simp = if simplified then simplify_cases ctxt else I;
   216 
   217     fun find_cases th =
   218       NetRules.may_unify (#2 (get_cases ctxt))
   219         (Logic.strip_assums_concl (#prop (Thm.rep_thm th)));
   220 
   221     val rules =
   222       (case (args, facts) of
   223         ((None, None), []) => [RuleCases.add case_split]
   224       | ((Some t, None), []) =>
   225           let val name = type_name t in
   226             (case lookup_casesT ctxt name of
   227               None => error ("No cases rule for type: " ^ quote name)
   228             | Some thm => [(inst_rule t thm, RuleCases.get thm)])
   229           end
   230       | ((None, None), th :: _) => map (RuleCases.add o #2) (find_cases th)
   231       | ((Some t, None), th :: _) =>
   232           (case find_cases th of	(*may instantiate first rule only!*)
   233             (_, thm) :: _ => [(inst_rule t thm, RuleCases.get thm)]
   234           | [] => [])
   235       | ((None, Some thm), _) => [RuleCases.add thm]
   236       | ((Some t, Some thm), _) => [(inst_rule t thm, RuleCases.get thm)]);
   237 
   238     fun prep_rule (thm, cases) =
   239       Seq.map (rpair cases o cond_simp) (Method.multi_resolves facts [thm]);
   240   in Method.resolveq_cases_tac (Seq.flat (Seq.map prep_rule (Seq.of_list rules))) end;
   241 
   242 in
   243 
   244 val cases_meth = Method.METHOD_CASES o (HEADGOAL oo cases_tac);
   245 
   246 end;
   247 
   248 
   249 
   250 (** induct method **)
   251 
   252 (*
   253   rule selection:
   254         induct         - mathematical induction
   255         induct x       - datatype induction
   256   <x:A> induct ...     - set induction
   257   ...   induct ... R   - explicit rule
   258 *)
   259 
   260 local
   261 
   262 infix 1 THEN_ALL_NEW_CASES;
   263 
   264 fun (tac1 THEN_ALL_NEW_CASES tac2) i st =
   265   st |> Seq.THEN (tac1 i, (fn (st', cases) =>
   266     Seq.map (rpair cases) (Seq.INTERVAL tac2 i (i + nprems_of st' - nprems_of st) st')));
   267 
   268 
   269 fun induct_rule ctxt t =
   270   let val name = type_name t in
   271     (case lookup_inductT ctxt name of
   272       None => error ("No induct rule for type: " ^ quote name)
   273     | Some thm => (name, thm))
   274   end;
   275 
   276 fun join_rules [(_, thm)] = thm
   277   | join_rules raw_thms =
   278       let
   279         val thms = (map (apsnd Drule.freeze_all) raw_thms);
   280         fun eq_prems ((_, th1), (_, th2)) =
   281           Term.aconvs (Thm.prems_of th1, Thm.prems_of th2);
   282       in
   283         (case Library.gen_distinct eq_prems thms of
   284           [(_, thm)] =>
   285             let
   286               val cprems = Drule.cprems_of thm;
   287               val asms = map Thm.assume cprems;
   288               fun strip (_, th) = Drule.implies_elim_list th asms;
   289             in
   290               foldr1 (fn (th, th') => [th, th'] MRS conjI) (map strip thms)
   291               |> Drule.implies_intr_list cprems
   292               |> Drule.standard
   293             end
   294         | [] => error "No rule given"
   295         | bads => error ("Incompatible rules for " ^ commas_quote (map #1 bads)))
   296       end;
   297 
   298 
   299 fun induct_tac (ctxt, (stripped, args)) facts =
   300   let
   301     val sg = ProofContext.sign_of ctxt;
   302     val cert = Thm.cterm_of sg;
   303 
   304     fun prep_var (x, Some t) = Some (cert x, cert t)
   305       | prep_var (_, None) = None;
   306 
   307     fun prep_inst (concl, ts) =
   308       let val xs = vars_of concl; val n = length xs - length ts in
   309         if n < 0 then error "More variables than given than in induction rule"
   310         else mapfilter prep_var (Library.drop (n, xs) ~~ ts)
   311       end;
   312 
   313     fun inst_rule insts thm =
   314       let val concls = concls_of thm in
   315         if length concls < length insts then
   316           error "More arguments than given than in induction rule"
   317         else Drule.cterm_instantiate (flat (map prep_inst (concls ~~ insts))) thm
   318       end;
   319 
   320     fun find_induct th =
   321       NetRules.may_unify (#2 (get_induct ctxt))
   322         (Logic.strip_assums_concl (#prop (Thm.rep_thm th)));
   323 
   324     val rules =
   325       (case (args, facts) of
   326         (([], None), []) => []
   327       | ((insts, None), []) =>
   328           let val thms = map (induct_rule ctxt o last_elem o mapfilter I) insts
   329             handle Library.LIST _ => error "Unable to figure out type induction rule"
   330           in [(inst_rule insts (join_rules thms), RuleCases.get (#2 (hd thms)))] end
   331       | (([], None), th :: _) => map (RuleCases.add o #2) (find_induct th)
   332       | ((insts, None), th :: _) =>
   333           (case find_induct th of	(*may instantiate first rule only!*)
   334 	    (_, thm) :: _ => [(inst_rule insts thm, RuleCases.get thm)]
   335           | [] => [])
   336       | (([], Some thm), _) => [RuleCases.add thm]
   337       | ((insts, Some thm), _) => [(inst_rule insts thm, RuleCases.get thm)]);
   338 
   339     fun prep_rule (thm, cases) =
   340       Seq.map (rpair cases) (Method.multi_resolves facts [thm]);
   341     val tac = Method.resolveq_cases_tac (Seq.flat (Seq.map prep_rule (Seq.of_list rules)));
   342   in
   343     if stripped then tac THEN_ALL_NEW_CASES (REPEAT o Tactic.match_tac [impI, allI, ballI])
   344     else tac
   345   end;
   346 
   347 in
   348 
   349 val induct_meth = Method.METHOD_CASES o (HEADGOAL oo induct_tac);
   350 
   351 end;
   352 
   353 
   354 
   355 (** concrete syntax **)
   356 
   357 val casesN = "cases";
   358 val inductN = "induct";
   359 
   360 val simplifiedN = "simplified";
   361 val strippedN = "stripped";
   362 
   363 val typeN = "type";
   364 val setN = "set";
   365 val ruleN = "rule";
   366 
   367 
   368 (* attributes *)
   369 
   370 fun spec k = (Args.$$$ k -- Args.colon) |-- Args.!!! Args.name;
   371 
   372 fun attrib sign_of add_type add_set = Scan.depend (fn x =>
   373   let val sg = sign_of x in
   374     spec typeN >> (add_type o Sign.intern_tycon sg) ||
   375     spec setN  >> (add_set o Sign.intern_const sg)
   376   end >> pair x);
   377 
   378 val cases_attr =
   379   (Attrib.syntax (attrib Theory.sign_of cases_type_global cases_set_global),
   380    Attrib.syntax (attrib ProofContext.sign_of cases_type_local cases_set_local));
   381 
   382 val induct_attr =
   383   (Attrib.syntax (attrib Theory.sign_of induct_type_global induct_set_global),
   384    Attrib.syntax (attrib ProofContext.sign_of induct_type_local induct_set_local));
   385 
   386 
   387 (* methods *)
   388 
   389 local
   390 
   391 fun err k get name =
   392   (case get name of Some x => x
   393   | None => error ("No rule for " ^ k ^ " " ^ quote name));
   394 
   395 fun rule get_type get_set =
   396   Scan.depend (fn ctxt =>
   397     let val sg = ProofContext.sign_of ctxt in
   398       spec typeN >> (err typeN (get_type ctxt) o Sign.intern_tycon sg) ||
   399       spec setN >> (err setN (get_set ctxt) o Sign.intern_const sg)
   400     end >> pair ctxt) ||
   401   Scan.lift (Args.$$$ ruleN -- Args.colon) |-- Attrib.local_thm;
   402 
   403 val cases_rule = rule lookup_casesT lookup_casesS;
   404 val induct_rule = rule lookup_inductT lookup_inductS;
   405 
   406 val kind = (Args.$$$ typeN || Args.$$$ setN || Args.$$$ ruleN) -- Args.colon;
   407 val term = Scan.unless (Scan.lift kind) Args.local_term;
   408 val term_dummy = Scan.unless (Scan.lift kind)
   409   (Scan.lift (Args.$$$ "_") >> K None || Args.local_term >> Some);
   410 
   411 fun mode name =
   412   Scan.lift (Scan.optional (Args.parens (Args.$$$ name) >> K true) false);
   413 
   414 in
   415 
   416 val cases_args = Method.syntax (mode simplifiedN -- (Scan.option term -- Scan.option cases_rule));
   417 val induct_args = Method.syntax
   418   (mode strippedN -- (Args.and_list (Scan.repeat term_dummy) -- Scan.option induct_rule));
   419 
   420 end;
   421 
   422 
   423 
   424 (** theory setup **)
   425 
   426 val setup =
   427   [GlobalInduct.init, LocalInduct.init,
   428    Attrib.add_attributes
   429     [(casesN, cases_attr, "cases rule for type or set"),
   430      (inductN, induct_attr, "induction rule for type or set")],
   431    Method.add_methods
   432     [("cases", cases_meth oo cases_args, "case analysis on types or sets"),
   433      ("induct", induct_meth oo induct_args, "induction on types or sets")],
   434    (#1 o PureThy.add_thms [(("case_split", case_split), [])])];
   435 
   436 end;