src/Tools/eqsubst.ML
author wenzelm
Fri Mar 06 15:58:56 2015 +0100 (2015-03-06)
changeset 59621 291934bac95e
parent 59498 50b60f501b05
child 59642 929984c529d3
permissions -rw-r--r--
Thm.cterm_of and Thm.ctyp_of operate on local context;
     1 (*  Title:      Tools/eqsubst.ML
     2     Author:     Lucas Dixon, University of Edinburgh
     3 
     4 Perform a substitution using an equation.
     5 *)
     6 
     7 signature EQSUBST =
     8 sig
     9   type match =
    10     ((indexname * (sort * typ)) list (* type instantiations *)
    11       * (indexname * (typ * term)) list) (* term instantiations *)
    12     * (string * typ) list (* fake named type abs env *)
    13     * (string * typ) list (* type abs env *)
    14     * term (* outer term *)
    15 
    16   type searchinfo =
    17     theory
    18     * int (* maxidx *)
    19     * Zipper.T (* focusterm to search under *)
    20 
    21   datatype 'a skipseq = SkipMore of int | SkipSeq of 'a Seq.seq Seq.seq
    22 
    23   val skip_first_asm_occs_search: ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> int -> 'b -> 'c skipseq
    24   val skip_first_occs_search: int -> ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> 'b -> 'c Seq.seq
    25   val skipto_skipseq: int -> 'a Seq.seq Seq.seq -> 'a skipseq
    26 
    27   (* tactics *)
    28   val eqsubst_asm_tac: Proof.context -> int list -> thm list -> int -> tactic
    29   val eqsubst_asm_tac': Proof.context ->
    30     (searchinfo -> int -> term -> match skipseq) -> int -> thm -> int -> tactic
    31   val eqsubst_tac: Proof.context ->
    32     int list -> (* list of occurrences to rewrite, use [0] for any *)
    33     thm list -> int -> tactic
    34   val eqsubst_tac': Proof.context ->
    35     (searchinfo -> term -> match Seq.seq) (* search function *)
    36     -> thm (* equation theorem to rewrite with *)
    37     -> int (* subgoal number in goal theorem *)
    38     -> thm (* goal theorem *)
    39     -> thm Seq.seq (* rewritten goal theorem *)
    40 
    41   (* search for substitutions *)
    42   val valid_match_start: Zipper.T -> bool
    43   val search_lr_all: Zipper.T -> Zipper.T Seq.seq
    44   val search_lr_valid: (Zipper.T -> bool) -> Zipper.T -> Zipper.T Seq.seq
    45   val searchf_lr_unify_all: searchinfo -> term -> match Seq.seq Seq.seq
    46   val searchf_lr_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
    47   val searchf_bt_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
    48 end;
    49 
    50 structure EqSubst: EQSUBST =
    51 struct
    52 
    53 (* changes object "=" to meta "==" which prepares a given rewrite rule *)
    54 fun prep_meta_eq ctxt =
    55   Simplifier.mksimps ctxt #> map Drule.zero_var_indexes;
    56 
    57 (* make free vars into schematic vars with index zero *)
    58 fun unfix_frees frees =
    59    fold (K (Thm.forall_elim_var 0)) frees o Drule.forall_intr_list frees;
    60 
    61 
    62 type match =
    63   ((indexname * (sort * typ)) list (* type instantiations *)
    64    * (indexname * (typ * term)) list) (* term instantiations *)
    65   * (string * typ) list (* fake named type abs env *)
    66   * (string * typ) list (* type abs env *)
    67   * term; (* outer term *)
    68 
    69 type searchinfo =
    70   theory
    71   * int (* maxidx *)
    72   * Zipper.T; (* focusterm to search under *)
    73 
    74 
    75 (* skipping non-empty sub-sequences but when we reach the end
    76    of the seq, remembering how much we have left to skip. *)
    77 datatype 'a skipseq =
    78   SkipMore of int |
    79   SkipSeq of 'a Seq.seq Seq.seq;
    80 
    81 (* given a seqseq, skip the first m non-empty seq's, note deficit *)
    82 fun skipto_skipseq m s =
    83   let
    84     fun skip_occs n sq =
    85       (case Seq.pull sq of
    86         NONE => SkipMore n
    87       | SOME (h, t) =>
    88         (case Seq.pull h of
    89           NONE => skip_occs n t
    90         | SOME _ => if n <= 1 then SkipSeq (Seq.cons h t) else skip_occs (n - 1) t))
    91   in skip_occs m s end;
    92 
    93 (* note: outerterm is the taget with the match replaced by a bound
    94    variable : ie: "P lhs" beocmes "%x. P x"
    95    insts is the types of instantiations of vars in lhs
    96    and typinsts is the type instantiations of types in the lhs
    97    Note: Final rule is the rule lifted into the ontext of the
    98    taget thm. *)
    99 fun mk_foo_match mkuptermfunc Ts t =
   100   let
   101     val ty = Term.type_of t
   102     val bigtype = rev (map snd Ts) ---> ty
   103     fun mk_foo 0 t = t
   104       | mk_foo i t = mk_foo (i - 1) (t $ (Bound (i - 1)))
   105     val num_of_bnds = length Ts
   106     (* foo_term = "fooabs y0 ... yn" where y's are local bounds *)
   107     val foo_term = mk_foo num_of_bnds (Bound num_of_bnds)
   108   in Abs ("fooabs", bigtype, mkuptermfunc foo_term) end;
   109 
   110 (* T is outer bound vars, n is number of locally bound vars *)
   111 (* THINK: is order of Ts correct...? or reversed? *)
   112 fun mk_fake_bound_name n = ":b_" ^ n;
   113 fun fakefree_badbounds Ts t =
   114   let val (FakeTs, Ts, newnames) =
   115     fold_rev (fn (n, ty) => fn (FakeTs, Ts, usednames) =>
   116       let
   117         val newname = singleton (Name.variant_list usednames) n
   118       in
   119         ((mk_fake_bound_name newname, ty) :: FakeTs,
   120           (newname, ty) :: Ts,
   121           newname :: usednames)
   122       end) Ts ([], [], [])
   123   in (FakeTs, Ts, Term.subst_bounds (map Free FakeTs, t)) end;
   124 
   125 (* before matching we need to fake the bound vars that are missing an
   126    abstraction. In this function we additionally construct the
   127    abstraction environment, and an outer context term (with the focus
   128    abstracted out) for use in rewriting with RW_Inst.rw *)
   129 fun prep_zipper_match z =
   130   let
   131     val t = Zipper.trm z
   132     val c = Zipper.ctxt z
   133     val Ts = Zipper.C.nty_ctxt c
   134     val (FakeTs', Ts', t') = fakefree_badbounds Ts t
   135     val absterm = mk_foo_match (Zipper.C.apply c) Ts' t'
   136   in
   137     (t', (FakeTs', Ts', absterm))
   138   end;
   139 
   140 (* Unification with exception handled *)
   141 (* given theory, max var index, pat, tgt; returns Seq of instantiations *)
   142 fun clean_unify thy ix (a as (pat, tgt)) =
   143   let
   144     (* type info will be re-derived, maybe this can be cached
   145        for efficiency? *)
   146     val pat_ty = Term.type_of pat;
   147     val tgt_ty = Term.type_of tgt;
   148     (* FIXME is it OK to ignore the type instantiation info?
   149        or should I be using it? *)
   150     val typs_unify =
   151       SOME (Sign.typ_unify thy (pat_ty, tgt_ty) (Vartab.empty, ix))
   152         handle Type.TUNIFY => NONE;
   153   in
   154     (case typs_unify of
   155       SOME (typinsttab, ix2) =>
   156         let
   157           (* FIXME is it right to throw away the flexes?
   158              or should I be using them somehow? *)
   159           fun mk_insts env =
   160             (Vartab.dest (Envir.type_env env),
   161              Vartab.dest (Envir.term_env env));
   162           val initenv =
   163             Envir.Envir {maxidx = ix2, tenv = Vartab.empty, tyenv = typinsttab};
   164           val useq = Unify.smash_unifiers (Context.Theory thy) [a] initenv
   165             handle ListPair.UnequalLengths => Seq.empty
   166               | Term.TERM _ => Seq.empty;
   167           fun clean_unify' useq () =
   168             (case (Seq.pull useq) of
   169                NONE => NONE
   170              | SOME (h, t) => SOME (mk_insts h, Seq.make (clean_unify' t)))
   171             handle ListPair.UnequalLengths => NONE
   172               | Term.TERM _ => NONE;
   173         in
   174           (Seq.make (clean_unify' useq))
   175         end
   176     | NONE => Seq.empty)
   177   end;
   178 
   179 (* Unification for zippers *)
   180 (* Note: Ts is a modified version of the original names of the outer
   181    bound variables. New names have been introduced to make sure they are
   182    unique w.r.t all names in the term and each other. usednames' is
   183    oldnames + new names. *)
   184 fun clean_unify_z thy maxidx pat z =
   185   let val (t, (FakeTs, Ts, absterm)) = prep_zipper_match z in
   186     Seq.map (fn insts => (insts, FakeTs, Ts, absterm))
   187       (clean_unify thy maxidx (t, pat))
   188   end;
   189 
   190 
   191 fun bot_left_leaf_of (l $ _) = bot_left_leaf_of l
   192   | bot_left_leaf_of (Abs (_, _, t)) = bot_left_leaf_of t
   193   | bot_left_leaf_of x = x;
   194 
   195 (* Avoid considering replacing terms which have a var at the head as
   196    they always succeed trivially, and uninterestingly. *)
   197 fun valid_match_start z =
   198   (case bot_left_leaf_of (Zipper.trm z) of
   199     Var _ => false
   200   | _ => true);
   201 
   202 (* search from top, left to right, then down *)
   203 val search_lr_all = ZipperSearch.all_bl_ur;
   204 
   205 (* search from top, left to right, then down *)
   206 fun search_lr_valid validf =
   207   let
   208     fun sf_valid_td_lr z =
   209       let val here = if validf z then [Zipper.Here z] else [] in
   210         (case Zipper.trm z of
   211           _ $ _ =>
   212             [Zipper.LookIn (Zipper.move_down_left z)] @ here @
   213             [Zipper.LookIn (Zipper.move_down_right z)]
   214         | Abs _ => here @ [Zipper.LookIn (Zipper.move_down_abs z)]
   215         | _ => here)
   216       end;
   217   in Zipper.lzy_search sf_valid_td_lr end;
   218 
   219 (* search from bottom to top, left to right *)
   220 fun search_bt_valid validf =
   221   let
   222     fun sf_valid_td_lr z =
   223       let val here = if validf z then [Zipper.Here z] else [] in
   224         (case Zipper.trm z of
   225           _ $ _ =>
   226             [Zipper.LookIn (Zipper.move_down_left z),
   227              Zipper.LookIn (Zipper.move_down_right z)] @ here
   228         | Abs _ => [Zipper.LookIn (Zipper.move_down_abs z)] @ here
   229         | _ => here)
   230       end;
   231   in Zipper.lzy_search sf_valid_td_lr end;
   232 
   233 fun searchf_unify_gen f (thy, maxidx, z) lhs =
   234   Seq.map (clean_unify_z thy maxidx lhs) (Zipper.limit_apply f z);
   235 
   236 (* search all unifications *)
   237 val searchf_lr_unify_all = searchf_unify_gen search_lr_all;
   238 
   239 (* search only for 'valid' unifiers (non abs subterms and non vars) *)
   240 val searchf_lr_unify_valid = searchf_unify_gen (search_lr_valid valid_match_start);
   241 
   242 val searchf_bt_unify_valid = searchf_unify_gen (search_bt_valid valid_match_start);
   243 
   244 (* apply a substitution in the conclusion of the theorem *)
   245 (* cfvs are certified free var placeholders for goal params *)
   246 (* conclthm is a theorem of for just the conclusion *)
   247 (* m is instantiation/match information *)
   248 (* rule is the equation for substitution *)
   249 fun apply_subst_in_concl ctxt i st (cfvs, conclthm) rule m =
   250   RW_Inst.rw ctxt m rule conclthm
   251   |> unfix_frees cfvs
   252   |> Conv.fconv_rule Drule.beta_eta_conversion
   253   |> (fn r => resolve_tac ctxt [r] i st);
   254 
   255 (* substitute within the conclusion of goal i of gth, using a meta
   256 equation rule. Note that we assume rule has var indicies zero'd *)
   257 fun prep_concl_subst ctxt i gth =
   258   let
   259     val th = Thm.incr_indexes 1 gth;
   260     val tgt_term = Thm.prop_of th;
   261 
   262     val thy = Thm.theory_of_thm th;
   263     val cert = Thm.global_cterm_of thy;
   264 
   265     val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   266     val cfvs = rev (map cert fvs);
   267 
   268     val conclterm = Logic.strip_imp_concl fixedbody;
   269     val conclthm = Thm.trivial (cert conclterm);
   270     val maxidx = Thm.maxidx_of th;
   271     val ft =
   272       (Zipper.move_down_right (* ==> *)
   273        o Zipper.move_down_left (* Trueprop *)
   274        o Zipper.mktop
   275        o Thm.prop_of) conclthm
   276   in
   277     ((cfvs, conclthm), (thy, maxidx, ft))
   278   end;
   279 
   280 (* substitute using an object or meta level equality *)
   281 fun eqsubst_tac' ctxt searchf instepthm i st =
   282   let
   283     val (cvfsconclthm, searchinfo) = prep_concl_subst ctxt i st;
   284     val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   285     fun rewrite_with_thm r =
   286       let val (lhs,_) = Logic.dest_equals (Thm.concl_of r) in
   287         searchf searchinfo lhs
   288         |> Seq.maps (apply_subst_in_concl ctxt i st cvfsconclthm r)
   289       end;
   290   in stepthms |> Seq.maps rewrite_with_thm end;
   291 
   292 
   293 (* General substitution of multiple occurrences using one of
   294    the given theorems *)
   295 
   296 fun skip_first_occs_search occ srchf sinfo lhs =
   297   (case skipto_skipseq occ (srchf sinfo lhs) of
   298     SkipMore _ => Seq.empty
   299   | SkipSeq ss => Seq.flat ss);
   300 
   301 (* The "occs" argument is a list of integers indicating which occurrence
   302 w.r.t. the search order, to rewrite. Backtracking will also find later
   303 occurrences, but all earlier ones are skipped. Thus you can use [0] to
   304 just find all rewrites. *)
   305 
   306 fun eqsubst_tac ctxt occs thms i st =
   307   let val nprems = Thm.nprems_of st in
   308     if nprems < i then Seq.empty else
   309     let
   310       val thmseq = Seq.of_list thms;
   311       fun apply_occ occ st =
   312         thmseq |> Seq.maps (fn r =>
   313           eqsubst_tac' ctxt
   314             (skip_first_occs_search occ searchf_lr_unify_valid) r
   315             (i + (Thm.nprems_of st - nprems)) st);
   316       val sorted_occs = Library.sort (rev_order o int_ord) occs;
   317     in
   318       Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
   319     end
   320   end;
   321 
   322 
   323 (* apply a substitution inside assumption j, keeps asm in the same place *)
   324 fun apply_subst_in_asm ctxt i st rule ((cfvs, j, _, pth),m) =
   325   let
   326     val st2 = Thm.rotate_rule (j - 1) i st; (* put premice first *)
   327     val preelimrule =
   328       RW_Inst.rw ctxt m rule pth
   329       |> (Seq.hd o prune_params_tac ctxt)
   330       |> Thm.permute_prems 0 ~1 (* put old asm first *)
   331       |> unfix_frees cfvs (* unfix any global params *)
   332       |> Conv.fconv_rule Drule.beta_eta_conversion; (* normal form *)
   333   in
   334     (* ~j because new asm starts at back, thus we subtract 1 *)
   335     Seq.map (Thm.rotate_rule (~ j) (Thm.nprems_of rule + i))
   336       (dresolve_tac ctxt [preelimrule] i st2)
   337   end;
   338 
   339 
   340 (* prepare to substitute within the j'th premise of subgoal i of gth,
   341 using a meta-level equation. Note that we assume rule has var indicies
   342 zero'd. Note that we also assume that premt is the j'th premice of
   343 subgoal i of gth. Note the repetition of work done for each
   344 assumption, i.e. this can be made more efficient for search over
   345 multiple assumptions.  *)
   346 fun prep_subst_in_asm ctxt i gth j =
   347   let
   348     val th = Thm.incr_indexes 1 gth;
   349     val tgt_term = Thm.prop_of th;
   350 
   351     val thy = Thm.theory_of_thm th;
   352     val cert = Thm.global_cterm_of thy;
   353 
   354     val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   355     val cfvs = rev (map cert fvs);
   356 
   357     val asmt = nth (Logic.strip_imp_prems fixedbody) (j - 1);
   358     val asm_nprems = length (Logic.strip_imp_prems asmt);
   359 
   360     val pth = Thm.trivial (cert asmt);
   361     val maxidx = Thm.maxidx_of th;
   362 
   363     val ft =
   364       (Zipper.move_down_right (* trueprop *)
   365          o Zipper.mktop
   366          o Thm.prop_of) pth
   367   in ((cfvs, j, asm_nprems, pth), (thy, maxidx, ft)) end;
   368 
   369 (* prepare subst in every possible assumption *)
   370 fun prep_subst_in_asms ctxt i gth =
   371   map (prep_subst_in_asm ctxt i gth)
   372     ((fn l => Library.upto (1, length l))
   373       (Logic.prems_of_goal (Thm.prop_of gth) i));
   374 
   375 
   376 (* substitute in an assumption using an object or meta level equality *)
   377 fun eqsubst_asm_tac' ctxt searchf skipocc instepthm i st =
   378   let
   379     val asmpreps = prep_subst_in_asms ctxt i st;
   380     val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   381     fun rewrite_with_thm r =
   382       let
   383         val (lhs,_) = Logic.dest_equals (Thm.concl_of r);
   384         fun occ_search occ [] = Seq.empty
   385           | occ_search occ ((asminfo, searchinfo)::moreasms) =
   386               (case searchf searchinfo occ lhs of
   387                 SkipMore i => occ_search i moreasms
   388               | SkipSeq ss =>
   389                   Seq.append (Seq.map (Library.pair asminfo) (Seq.flat ss))
   390                     (occ_search 1 moreasms)) (* find later substs also *)
   391       in
   392         occ_search skipocc asmpreps |> Seq.maps (apply_subst_in_asm ctxt i st r)
   393       end;
   394   in stepthms |> Seq.maps rewrite_with_thm end;
   395 
   396 
   397 fun skip_first_asm_occs_search searchf sinfo occ lhs =
   398   skipto_skipseq occ (searchf sinfo lhs);
   399 
   400 fun eqsubst_asm_tac ctxt occs thms i st =
   401   let val nprems = Thm.nprems_of st in
   402     if nprems < i then Seq.empty
   403     else
   404       let
   405         val thmseq = Seq.of_list thms;
   406         fun apply_occ occ st =
   407           thmseq |> Seq.maps (fn r =>
   408             eqsubst_asm_tac' ctxt
   409               (skip_first_asm_occs_search searchf_lr_unify_valid) occ r
   410               (i + (Thm.nprems_of st - nprems)) st);
   411         val sorted_occs = Library.sort (rev_order o int_ord) occs;
   412       in
   413         Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
   414       end
   415   end;
   416 
   417 (* combination method that takes a flag (true indicates that subst
   418    should be done to an assumption, false = apply to the conclusion of
   419    the goal) as well as the theorems to use *)
   420 val _ =
   421   Theory.setup
   422     (Method.setup @{binding subst}
   423       (Scan.lift (Args.mode "asm" -- Scan.optional (Args.parens (Scan.repeat Parse.nat)) [0]) --
   424         Attrib.thms >> (fn ((asm, occs), inthms) => fn ctxt =>
   425           SIMPLE_METHOD' ((if asm then eqsubst_asm_tac else eqsubst_tac) ctxt occs inthms)))
   426       "single-step substitution");
   427 
   428 end;