src/Tools/eqsubst.ML
author wenzelm
Wed Sep 12 23:18:26 2012 +0200 (2012-09-12)
changeset 49340 25fc6e0da459
parent 49339 d1fcb4de8349
child 51717 9e7d1c139569
permissions -rw-r--r--
observe context more carefully when producing "fresh" variables -- for increased chances that method "subst" works in local context (including that of forked proofs);
     1 (*  Title:      Tools/eqsubst.ML
     2     Author:     Lucas Dixon, University of Edinburgh
     3 
     4 A proof method to perform a substiution using an equation.
     5 *)
     6 
     7 signature EQSUBST =
     8 sig
     9   (* a type abbreviation for match information *)
    10   type match =
    11        ((indexname * (sort * typ)) list (* type instantiations *)
    12         * (indexname * (typ * term)) list) (* term instantiations *)
    13        * (string * typ) list (* fake named type abs env *)
    14        * (string * typ) list (* type abs env *)
    15        * term (* outer term *)
    16 
    17   type searchinfo =
    18        theory
    19        * int (* maxidx *)
    20        * Zipper.T (* focusterm to search under *)
    21 
    22   datatype 'a skipseq = SkipMore of int | SkipSeq of 'a Seq.seq Seq.seq
    23 
    24   val skip_first_asm_occs_search :
    25      ('a -> 'b -> 'c Seq.seq Seq.seq) ->
    26      'a -> int -> 'b -> 'c skipseq
    27   val skip_first_occs_search :
    28      int -> ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> 'b -> 'c Seq.seq
    29   val skipto_skipseq : int -> 'a Seq.seq Seq.seq -> 'a skipseq
    30 
    31   (* tactics *)
    32   val eqsubst_asm_tac :
    33      Proof.context ->
    34      int list -> thm list -> int -> tactic
    35   val eqsubst_asm_tac' :
    36      Proof.context ->
    37      (searchinfo -> int -> term -> match skipseq) ->
    38      int -> thm -> int -> tactic
    39   val eqsubst_tac :
    40      Proof.context ->
    41      int list -> (* list of occurences to rewrite, use [0] for any *)
    42      thm list -> int -> tactic
    43   val eqsubst_tac' :
    44      Proof.context -> (* proof context *)
    45      (searchinfo -> term -> match Seq.seq) (* search function *)
    46      -> thm (* equation theorem to rewrite with *)
    47      -> int (* subgoal number in goal theorem *)
    48      -> thm (* goal theorem *)
    49      -> thm Seq.seq (* rewritten goal theorem *)
    50 
    51   (* search for substitutions *)
    52   val valid_match_start : Zipper.T -> bool
    53   val search_lr_all : Zipper.T -> Zipper.T Seq.seq
    54   val search_lr_valid : (Zipper.T -> bool) -> Zipper.T -> Zipper.T Seq.seq
    55   val searchf_lr_unify_all :
    56      searchinfo -> term -> match Seq.seq Seq.seq
    57   val searchf_lr_unify_valid :
    58      searchinfo -> term -> match Seq.seq Seq.seq
    59   val searchf_bt_unify_valid :
    60      searchinfo -> term -> match Seq.seq Seq.seq
    61 
    62   val setup : theory -> theory
    63 
    64 end;
    65 
    66 structure EqSubst: EQSUBST =
    67 struct
    68 
    69 (* changes object "=" to meta "==" which prepares a given rewrite rule *)
    70 fun prep_meta_eq ctxt =
    71   Simplifier.mksimps (simpset_of ctxt) #> map Drule.zero_var_indexes;
    72 
    73 
    74   (* a type abriviation for match information *)
    75   type match =
    76        ((indexname * (sort * typ)) list (* type instantiations *)
    77         * (indexname * (typ * term)) list) (* term instantiations *)
    78        * (string * typ) list (* fake named type abs env *)
    79        * (string * typ) list (* type abs env *)
    80        * term (* outer term *)
    81 
    82   type searchinfo =
    83        theory
    84        * int (* maxidx *)
    85        * Zipper.T (* focusterm to search under *)
    86 
    87 
    88 (* skipping non-empty sub-sequences but when we reach the end
    89    of the seq, remembering how much we have left to skip. *)
    90 datatype 'a skipseq = SkipMore of int
    91   | SkipSeq of 'a Seq.seq Seq.seq;
    92 (* given a seqseq, skip the first m non-empty seq's, note deficit *)
    93 fun skipto_skipseq m s =
    94     let
    95       fun skip_occs n sq =
    96           case Seq.pull sq of
    97             NONE => SkipMore n
    98           | SOME (h,t) =>
    99             (case Seq.pull h of NONE => skip_occs n t
   100              | SOME _ => if n <= 1 then SkipSeq (Seq.cons h t)
   101                          else skip_occs (n - 1) t)
   102     in (skip_occs m s) end;
   103 
   104 (* note: outerterm is the taget with the match replaced by a bound
   105          variable : ie: "P lhs" beocmes "%x. P x"
   106          insts is the types of instantiations of vars in lhs
   107          and typinsts is the type instantiations of types in the lhs
   108          Note: Final rule is the rule lifted into the ontext of the
   109          taget thm. *)
   110 fun mk_foo_match mkuptermfunc Ts t =
   111     let
   112       val ty = Term.type_of t
   113       val bigtype = (rev (map snd Ts)) ---> ty
   114       fun mk_foo 0 t = t
   115         | mk_foo i t = mk_foo (i - 1) (t $ (Bound (i - 1)))
   116       val num_of_bnds = (length Ts)
   117       (* foo_term = "fooabs y0 ... yn" where y's are local bounds *)
   118       val foo_term = mk_foo num_of_bnds (Bound num_of_bnds)
   119     in Abs("fooabs", bigtype, mkuptermfunc foo_term) end;
   120 
   121 (* T is outer bound vars, n is number of locally bound vars *)
   122 (* THINK: is order of Ts correct...? or reversed? *)
   123 fun mk_fake_bound_name n = ":b_" ^ n;
   124 fun fakefree_badbounds Ts t =
   125     let val (FakeTs,Ts,newnames) =
   126             List.foldr (fn ((n,ty),(FakeTs,Ts,usednames)) =>
   127                            let val newname = singleton (Name.variant_list usednames) n
   128                            in ((mk_fake_bound_name newname,ty)::FakeTs,
   129                                (newname,ty)::Ts,
   130                                newname::usednames) end)
   131                        ([],[],[])
   132                        Ts
   133     in (FakeTs, Ts, Term.subst_bounds (map Free FakeTs, t)) end;
   134 
   135 (* before matching we need to fake the bound vars that are missing an
   136 abstraction. In this function we additionally construct the
   137 abstraction environment, and an outer context term (with the focus
   138 abstracted out) for use in rewriting with RWInst.rw *)
   139 fun prep_zipper_match z =
   140     let
   141       val t = Zipper.trm z
   142       val c = Zipper.ctxt z
   143       val Ts = Zipper.C.nty_ctxt c
   144       val (FakeTs', Ts', t') = fakefree_badbounds Ts t
   145       val absterm = mk_foo_match (Zipper.C.apply c) Ts' t'
   146     in
   147       (t', (FakeTs', Ts', absterm))
   148     end;
   149 
   150 (* Unification with exception handled *)
   151 (* given theory, max var index, pat, tgt; returns Seq of instantiations *)
   152 fun clean_unify thry ix (a as (pat, tgt)) =
   153     let
   154       (* type info will be re-derived, maybe this can be cached
   155          for efficiency? *)
   156       val pat_ty = Term.type_of pat;
   157       val tgt_ty = Term.type_of tgt;
   158       (* is it OK to ignore the type instantiation info?
   159          or should I be using it? *)
   160       val typs_unify =
   161           SOME (Sign.typ_unify thry (pat_ty, tgt_ty) (Vartab.empty, ix))
   162             handle Type.TUNIFY => NONE;
   163     in
   164       case typs_unify of
   165         SOME (typinsttab, ix2) =>
   166         let
   167       (* is it right to throw away the flexes?
   168          or should I be using them somehow? *)
   169           fun mk_insts env =
   170             (Vartab.dest (Envir.type_env env),
   171              Vartab.dest (Envir.term_env env));
   172           val initenv =
   173             Envir.Envir {maxidx = ix2, tenv = Vartab.empty, tyenv = typinsttab};
   174           val useq = Unify.smash_unifiers thry [a] initenv
   175               handle ListPair.UnequalLengths => Seq.empty
   176                    | Term.TERM _ => Seq.empty;
   177           fun clean_unify' useq () =
   178               (case (Seq.pull useq) of
   179                  NONE => NONE
   180                | SOME (h,t) => SOME (mk_insts h, Seq.make (clean_unify' t)))
   181               handle ListPair.UnequalLengths => NONE
   182                    | Term.TERM _ => NONE
   183         in
   184           (Seq.make (clean_unify' useq))
   185         end
   186       | NONE => Seq.empty
   187     end;
   188 
   189 (* Unification for zippers *)
   190 (* Note: Ts is a modified version of the original names of the outer
   191 bound variables. New names have been introduced to make sure they are
   192 unique w.r.t all names in the term and each other. usednames' is
   193 oldnames + new names. *)
   194 (* ix = max var index *)
   195 fun clean_unify_z sgn ix pat z =
   196     let val (t, (FakeTs, Ts,absterm)) = prep_zipper_match z in
   197     Seq.map (fn insts => (insts, FakeTs, Ts, absterm))
   198             (clean_unify sgn ix (t, pat)) end;
   199 
   200 
   201 fun bot_left_leaf_of (l $ r) = bot_left_leaf_of l
   202   | bot_left_leaf_of (Abs(s,ty,t)) = bot_left_leaf_of t
   203   | bot_left_leaf_of x = x;
   204 
   205 (* Avoid considering replacing terms which have a var at the head as
   206    they always succeed trivially, and uninterestingly. *)
   207 fun valid_match_start z =
   208     (case bot_left_leaf_of (Zipper.trm z) of
   209       Var _ => false
   210       | _ => true);
   211 
   212 (* search from top, left to right, then down *)
   213 val search_lr_all = ZipperSearch.all_bl_ur;
   214 
   215 (* search from top, left to right, then down *)
   216 fun search_lr_valid validf =
   217     let
   218       fun sf_valid_td_lr z =
   219           let val here = if validf z then [Zipper.Here z] else [] in
   220             case Zipper.trm z
   221              of _ $ _ => [Zipper.LookIn (Zipper.move_down_left z)]
   222                          @ here
   223                          @ [Zipper.LookIn (Zipper.move_down_right z)]
   224               | Abs _ => here @ [Zipper.LookIn (Zipper.move_down_abs z)]
   225               | _ => here
   226           end;
   227     in Zipper.lzy_search sf_valid_td_lr end;
   228 
   229 (* search from bottom to top, left to right *)
   230 
   231 fun search_bt_valid validf =
   232     let
   233       fun sf_valid_td_lr z =
   234           let val here = if validf z then [Zipper.Here z] else [] in
   235             case Zipper.trm z
   236              of _ $ _ => [Zipper.LookIn (Zipper.move_down_left z),
   237                           Zipper.LookIn (Zipper.move_down_right z)] @ here
   238               | Abs _ => [Zipper.LookIn (Zipper.move_down_abs z)] @ here
   239               | _ => here
   240           end;
   241     in Zipper.lzy_search sf_valid_td_lr end;
   242 
   243 fun searchf_unify_gen f (sgn, maxidx, z) lhs =
   244     Seq.map (clean_unify_z sgn maxidx lhs)
   245             (Zipper.limit_apply f z);
   246 
   247 (* search all unifications *)
   248 val searchf_lr_unify_all =
   249     searchf_unify_gen search_lr_all;
   250 
   251 (* search only for 'valid' unifiers (non abs subterms and non vars) *)
   252 val searchf_lr_unify_valid =
   253     searchf_unify_gen (search_lr_valid valid_match_start);
   254 
   255 val searchf_bt_unify_valid =
   256     searchf_unify_gen (search_bt_valid valid_match_start);
   257 
   258 (* apply a substitution in the conclusion of the theorem th *)
   259 (* cfvs are certified free var placeholders for goal params *)
   260 (* conclthm is a theorem of for just the conclusion *)
   261 (* m is instantiation/match information *)
   262 (* rule is the equation for substitution *)
   263 fun apply_subst_in_concl ctxt i th (cfvs, conclthm) rule m =
   264     (RWInst.rw ctxt m rule conclthm)
   265       |> IsaND.unfix_frees cfvs
   266       |> RWInst.beta_eta_contract
   267       |> (fn r => Tactic.rtac r i th);
   268 
   269 (* substitute within the conclusion of goal i of gth, using a meta
   270 equation rule. Note that we assume rule has var indicies zero'd *)
   271 fun prep_concl_subst ctxt i gth =
   272     let
   273       val th = Thm.incr_indexes 1 gth;
   274       val tgt_term = Thm.prop_of th;
   275 
   276       val sgn = Thm.theory_of_thm th;
   277       val ctermify = Thm.cterm_of sgn;
   278       val trivify = Thm.trivial o ctermify;
   279 
   280       val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   281       val cfvs = rev (map ctermify fvs);
   282 
   283       val conclterm = Logic.strip_imp_concl fixedbody;
   284       val conclthm = trivify conclterm;
   285       val maxidx = Thm.maxidx_of th;
   286       val ft = ((Zipper.move_down_right (* ==> *)
   287                  o Zipper.move_down_left (* Trueprop *)
   288                  o Zipper.mktop
   289                  o Thm.prop_of) conclthm)
   290     in
   291       ((cfvs, conclthm), (sgn, maxidx, ft))
   292     end;
   293 
   294 (* substitute using an object or meta level equality *)
   295 fun eqsubst_tac' ctxt searchf instepthm i th =
   296     let
   297       val (cvfsconclthm, searchinfo) = prep_concl_subst ctxt i th;
   298       val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   299       fun rewrite_with_thm r =
   300           let val (lhs,_) = Logic.dest_equals (Thm.concl_of r);
   301           in searchf searchinfo lhs
   302              |> Seq.maps (apply_subst_in_concl ctxt i th cvfsconclthm r) end;
   303     in stepthms |> Seq.maps rewrite_with_thm end;
   304 
   305 
   306 (* distinct subgoals *)
   307 fun distinct_subgoals th =
   308   the_default th (SINGLE distinct_subgoals_tac th);
   309 
   310 (* General substitution of multiple occurances using one of
   311    the given theorems*)
   312 
   313 
   314 fun skip_first_occs_search occ srchf sinfo lhs =
   315     case (skipto_skipseq occ (srchf sinfo lhs)) of
   316       SkipMore _ => Seq.empty
   317     | SkipSeq ss => Seq.flat ss;
   318 
   319 (* The occL is a list of integers indicating which occurence
   320 w.r.t. the search order, to rewrite. Backtracking will also find later
   321 occurences, but all earlier ones are skipped. Thus you can use [0] to
   322 just find all rewrites. *)
   323 
   324 fun eqsubst_tac ctxt occL thms i th =
   325     let val nprems = Thm.nprems_of th in
   326       if nprems < i then Seq.empty else
   327       let val thmseq = (Seq.of_list thms)
   328         fun apply_occ occ th =
   329             thmseq |> Seq.maps
   330                     (fn r => eqsubst_tac'
   331                                ctxt
   332                                (skip_first_occs_search
   333                                   occ searchf_lr_unify_valid) r
   334                                  (i + ((Thm.nprems_of th) - nprems))
   335                                  th);
   336         val sortedoccL =
   337             Library.sort (Library.rev_order o Library.int_ord) occL;
   338       in
   339         Seq.map distinct_subgoals (Seq.EVERY (map apply_occ sortedoccL) th)
   340       end
   341     end;
   342 
   343 
   344 (* inthms are the given arguments in Isar, and treated as eqstep with
   345    the first one, then the second etc *)
   346 fun eqsubst_meth ctxt occL inthms =
   347     SIMPLE_METHOD' (eqsubst_tac ctxt occL inthms);
   348 
   349 (* apply a substitution inside assumption j, keeps asm in the same place *)
   350 fun apply_subst_in_asm ctxt i th rule ((cfvs, j, ngoalprems, pth),m) =
   351     let
   352       val th2 = Thm.rotate_rule (j - 1) i th; (* put premice first *)
   353       val preelimrule =
   354           (RWInst.rw ctxt m rule pth)
   355             |> (Seq.hd o prune_params_tac)
   356             |> Thm.permute_prems 0 ~1 (* put old asm first *)
   357             |> IsaND.unfix_frees cfvs (* unfix any global params *)
   358             |> RWInst.beta_eta_contract; (* normal form *)
   359   (*    val elimrule =
   360           preelimrule
   361             |> Tactic.make_elim (* make into elim rule *)
   362             |> Thm.lift_rule (th2, i); (* lift into context *)
   363    *)
   364     in
   365       (* ~j because new asm starts at back, thus we subtract 1 *)
   366       Seq.map (Thm.rotate_rule (~j) ((Thm.nprems_of rule) + i))
   367       (Tactic.dtac preelimrule i th2)
   368 
   369       (* (Thm.bicompose
   370                  false (* use unification *)
   371                  (true, (* elim resolution *)
   372                   elimrule, (2 + (Thm.nprems_of rule)) - ngoalprems)
   373                  i th2) *)
   374     end;
   375 
   376 
   377 (* prepare to substitute within the j'th premise of subgoal i of gth,
   378 using a meta-level equation. Note that we assume rule has var indicies
   379 zero'd. Note that we also assume that premt is the j'th premice of
   380 subgoal i of gth. Note the repetition of work done for each
   381 assumption, i.e. this can be made more efficient for search over
   382 multiple assumptions.  *)
   383 fun prep_subst_in_asm ctxt i gth j =
   384     let
   385       val th = Thm.incr_indexes 1 gth;
   386       val tgt_term = Thm.prop_of th;
   387 
   388       val sgn = Thm.theory_of_thm th;
   389       val ctermify = Thm.cterm_of sgn;
   390       val trivify = Thm.trivial o ctermify;
   391 
   392       val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   393       val cfvs = rev (map ctermify fvs);
   394 
   395       val asmt = nth (Logic.strip_imp_prems fixedbody) (j - 1);
   396       val asm_nprems = length (Logic.strip_imp_prems asmt);
   397 
   398       val pth = trivify asmt;
   399       val maxidx = Thm.maxidx_of th;
   400 
   401       val ft = ((Zipper.move_down_right (* trueprop *)
   402                  o Zipper.mktop
   403                  o Thm.prop_of) pth)
   404     in ((cfvs, j, asm_nprems, pth), (sgn, maxidx, ft)) end;
   405 
   406 (* prepare subst in every possible assumption *)
   407 fun prep_subst_in_asms ctxt i gth =
   408     map (prep_subst_in_asm ctxt i gth)
   409         ((fn l => Library.upto (1, length l))
   410            (Logic.prems_of_goal (Thm.prop_of gth) i));
   411 
   412 
   413 (* substitute in an assumption using an object or meta level equality *)
   414 fun eqsubst_asm_tac' ctxt searchf skipocc instepthm i th =
   415     let
   416       val asmpreps = prep_subst_in_asms ctxt i th;
   417       val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   418       fun rewrite_with_thm r =
   419           let val (lhs,_) = Logic.dest_equals (Thm.concl_of r)
   420             fun occ_search occ [] = Seq.empty
   421               | occ_search occ ((asminfo, searchinfo)::moreasms) =
   422                 (case searchf searchinfo occ lhs of
   423                    SkipMore i => occ_search i moreasms
   424                  | SkipSeq ss =>
   425                    Seq.append (Seq.map (Library.pair asminfo) (Seq.flat ss))
   426                                (occ_search 1 moreasms))
   427                               (* find later substs also *)
   428           in
   429             occ_search skipocc asmpreps |> Seq.maps (apply_subst_in_asm ctxt i th r)
   430           end;
   431     in stepthms |> Seq.maps rewrite_with_thm end;
   432 
   433 
   434 fun skip_first_asm_occs_search searchf sinfo occ lhs =
   435     skipto_skipseq occ (searchf sinfo lhs);
   436 
   437 fun eqsubst_asm_tac ctxt occL thms i th =
   438     let val nprems = Thm.nprems_of th
   439     in
   440       if nprems < i then Seq.empty else
   441       let val thmseq = (Seq.of_list thms)
   442         fun apply_occ occK th =
   443             thmseq |> Seq.maps
   444                     (fn r =>
   445                         eqsubst_asm_tac' ctxt (skip_first_asm_occs_search
   446                                             searchf_lr_unify_valid) occK r
   447                                          (i + ((Thm.nprems_of th) - nprems))
   448                                          th);
   449         val sortedoccs =
   450             Library.sort (Library.rev_order o Library.int_ord) occL
   451       in
   452         Seq.map distinct_subgoals
   453                 (Seq.EVERY (map apply_occ sortedoccs) th)
   454       end
   455     end;
   456 
   457 (* inthms are the given arguments in Isar, and treated as eqstep with
   458    the first one, then the second etc *)
   459 fun eqsubst_asm_meth ctxt occL inthms =
   460     SIMPLE_METHOD' (eqsubst_asm_tac ctxt occL inthms);
   461 
   462 (* combination method that takes a flag (true indicates that subst
   463    should be done to an assumption, false = apply to the conclusion of
   464    the goal) as well as the theorems to use *)
   465 val setup =
   466   Method.setup @{binding subst}
   467     (Args.mode "asm" -- Scan.lift (Scan.optional (Args.parens (Scan.repeat Parse.nat)) [0]) --
   468         Attrib.thms >>
   469       (fn ((asm, occL), inthms) => fn ctxt =>
   470         (if asm then eqsubst_asm_meth else eqsubst_meth) ctxt occL inthms))
   471     "single-step substitution";
   472 
   473 end;