src/Tools/eqsubst.ML
author wenzelm
Thu Oct 30 16:55:29 2014 +0100 (2014-10-30)
changeset 58838 59203adfc33f
parent 58826 2ed2eaabe3df
child 58950 d07464875dd4
permissions -rw-r--r--
eliminated aliases;
     1 (*  Title:      Tools/eqsubst.ML
     2     Author:     Lucas Dixon, University of Edinburgh
     3 
     4 Perform a substitution using an equation.
     5 *)
     6 
     7 signature EQSUBST =
     8 sig
     9   type match =
    10     ((indexname * (sort * typ)) list (* type instantiations *)
    11       * (indexname * (typ * term)) list) (* term instantiations *)
    12     * (string * typ) list (* fake named type abs env *)
    13     * (string * typ) list (* type abs env *)
    14     * term (* outer term *)
    15 
    16   type searchinfo =
    17     theory
    18     * int (* maxidx *)
    19     * Zipper.T (* focusterm to search under *)
    20 
    21   datatype 'a skipseq = SkipMore of int | SkipSeq of 'a Seq.seq Seq.seq
    22 
    23   val skip_first_asm_occs_search: ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> int -> 'b -> 'c skipseq
    24   val skip_first_occs_search: int -> ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> 'b -> 'c Seq.seq
    25   val skipto_skipseq: int -> 'a Seq.seq Seq.seq -> 'a skipseq
    26 
    27   (* tactics *)
    28   val eqsubst_asm_tac: Proof.context -> int list -> thm list -> int -> tactic
    29   val eqsubst_asm_tac': Proof.context ->
    30     (searchinfo -> int -> term -> match skipseq) -> int -> thm -> int -> tactic
    31   val eqsubst_tac: Proof.context ->
    32     int list -> (* list of occurrences to rewrite, use [0] for any *)
    33     thm list -> int -> tactic
    34   val eqsubst_tac': Proof.context ->
    35     (searchinfo -> term -> match Seq.seq) (* search function *)
    36     -> thm (* equation theorem to rewrite with *)
    37     -> int (* subgoal number in goal theorem *)
    38     -> thm (* goal theorem *)
    39     -> thm Seq.seq (* rewritten goal theorem *)
    40 
    41   (* search for substitutions *)
    42   val valid_match_start: Zipper.T -> bool
    43   val search_lr_all: Zipper.T -> Zipper.T Seq.seq
    44   val search_lr_valid: (Zipper.T -> bool) -> Zipper.T -> Zipper.T Seq.seq
    45   val searchf_lr_unify_all: searchinfo -> term -> match Seq.seq Seq.seq
    46   val searchf_lr_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
    47   val searchf_bt_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
    48 end;
    49 
    50 structure EqSubst: EQSUBST =
    51 struct
    52 
    53 (* changes object "=" to meta "==" which prepares a given rewrite rule *)
    54 fun prep_meta_eq ctxt =
    55   Simplifier.mksimps ctxt #> map Drule.zero_var_indexes;
    56 
    57 (* make free vars into schematic vars with index zero *)
    58 fun unfix_frees frees =
    59    fold (K (Thm.forall_elim_var 0)) frees o Drule.forall_intr_list frees;
    60 
    61 
    62 type match =
    63   ((indexname * (sort * typ)) list (* type instantiations *)
    64    * (indexname * (typ * term)) list) (* term instantiations *)
    65   * (string * typ) list (* fake named type abs env *)
    66   * (string * typ) list (* type abs env *)
    67   * term; (* outer term *)
    68 
    69 type searchinfo =
    70   theory
    71   * int (* maxidx *)
    72   * Zipper.T; (* focusterm to search under *)
    73 
    74 
    75 (* skipping non-empty sub-sequences but when we reach the end
    76    of the seq, remembering how much we have left to skip. *)
    77 datatype 'a skipseq =
    78   SkipMore of int |
    79   SkipSeq of 'a Seq.seq Seq.seq;
    80 
    81 (* given a seqseq, skip the first m non-empty seq's, note deficit *)
    82 fun skipto_skipseq m s =
    83   let
    84     fun skip_occs n sq =
    85       (case Seq.pull sq of
    86         NONE => SkipMore n
    87       | SOME (h, t) =>
    88         (case Seq.pull h of
    89           NONE => skip_occs n t
    90         | SOME _ => if n <= 1 then SkipSeq (Seq.cons h t) else skip_occs (n - 1) t))
    91   in skip_occs m s end;
    92 
    93 (* note: outerterm is the taget with the match replaced by a bound
    94    variable : ie: "P lhs" beocmes "%x. P x"
    95    insts is the types of instantiations of vars in lhs
    96    and typinsts is the type instantiations of types in the lhs
    97    Note: Final rule is the rule lifted into the ontext of the
    98    taget thm. *)
    99 fun mk_foo_match mkuptermfunc Ts t =
   100   let
   101     val ty = Term.type_of t
   102     val bigtype = rev (map snd Ts) ---> ty
   103     fun mk_foo 0 t = t
   104       | mk_foo i t = mk_foo (i - 1) (t $ (Bound (i - 1)))
   105     val num_of_bnds = length Ts
   106     (* foo_term = "fooabs y0 ... yn" where y's are local bounds *)
   107     val foo_term = mk_foo num_of_bnds (Bound num_of_bnds)
   108   in Abs ("fooabs", bigtype, mkuptermfunc foo_term) end;
   109 
   110 (* T is outer bound vars, n is number of locally bound vars *)
   111 (* THINK: is order of Ts correct...? or reversed? *)
   112 fun mk_fake_bound_name n = ":b_" ^ n;
   113 fun fakefree_badbounds Ts t =
   114   let val (FakeTs, Ts, newnames) =
   115     fold_rev (fn (n, ty) => fn (FakeTs, Ts, usednames) =>
   116       let
   117         val newname = singleton (Name.variant_list usednames) n
   118       in
   119         ((mk_fake_bound_name newname, ty) :: FakeTs,
   120           (newname, ty) :: Ts,
   121           newname :: usednames)
   122       end) Ts ([], [], [])
   123   in (FakeTs, Ts, Term.subst_bounds (map Free FakeTs, t)) end;
   124 
   125 (* before matching we need to fake the bound vars that are missing an
   126    abstraction. In this function we additionally construct the
   127    abstraction environment, and an outer context term (with the focus
   128    abstracted out) for use in rewriting with RW_Inst.rw *)
   129 fun prep_zipper_match z =
   130   let
   131     val t = Zipper.trm z
   132     val c = Zipper.ctxt z
   133     val Ts = Zipper.C.nty_ctxt c
   134     val (FakeTs', Ts', t') = fakefree_badbounds Ts t
   135     val absterm = mk_foo_match (Zipper.C.apply c) Ts' t'
   136   in
   137     (t', (FakeTs', Ts', absterm))
   138   end;
   139 
   140 (* Unification with exception handled *)
   141 (* given theory, max var index, pat, tgt; returns Seq of instantiations *)
   142 fun clean_unify thy ix (a as (pat, tgt)) =
   143   let
   144     (* type info will be re-derived, maybe this can be cached
   145        for efficiency? *)
   146     val pat_ty = Term.type_of pat;
   147     val tgt_ty = Term.type_of tgt;
   148     (* FIXME is it OK to ignore the type instantiation info?
   149        or should I be using it? *)
   150     val typs_unify =
   151       SOME (Sign.typ_unify thy (pat_ty, tgt_ty) (Vartab.empty, ix))
   152         handle Type.TUNIFY => NONE;
   153   in
   154     (case typs_unify of
   155       SOME (typinsttab, ix2) =>
   156         let
   157           (* FIXME is it right to throw away the flexes?
   158              or should I be using them somehow? *)
   159           fun mk_insts env =
   160             (Vartab.dest (Envir.type_env env),
   161              Vartab.dest (Envir.term_env env));
   162           val initenv =
   163             Envir.Envir {maxidx = ix2, tenv = Vartab.empty, tyenv = typinsttab};
   164           val useq = Unify.smash_unifiers thy [a] initenv
   165             handle ListPair.UnequalLengths => Seq.empty
   166               | Term.TERM _ => Seq.empty;
   167           fun clean_unify' useq () =
   168             (case (Seq.pull useq) of
   169                NONE => NONE
   170              | SOME (h, t) => SOME (mk_insts h, Seq.make (clean_unify' t)))
   171             handle ListPair.UnequalLengths => NONE
   172               | Term.TERM _ => NONE;
   173         in
   174           (Seq.make (clean_unify' useq))
   175         end
   176     | NONE => Seq.empty)
   177   end;
   178 
   179 (* Unification for zippers *)
   180 (* Note: Ts is a modified version of the original names of the outer
   181    bound variables. New names have been introduced to make sure they are
   182    unique w.r.t all names in the term and each other. usednames' is
   183    oldnames + new names. *)
   184 fun clean_unify_z thy maxidx pat z =
   185   let val (t, (FakeTs, Ts, absterm)) = prep_zipper_match z in
   186     Seq.map (fn insts => (insts, FakeTs, Ts, absterm))
   187       (clean_unify thy maxidx (t, pat))
   188   end;
   189 
   190 
   191 fun bot_left_leaf_of (l $ _) = bot_left_leaf_of l
   192   | bot_left_leaf_of (Abs (_, _, t)) = bot_left_leaf_of t
   193   | bot_left_leaf_of x = x;
   194 
   195 (* Avoid considering replacing terms which have a var at the head as
   196    they always succeed trivially, and uninterestingly. *)
   197 fun valid_match_start z =
   198   (case bot_left_leaf_of (Zipper.trm z) of
   199     Var _ => false
   200   | _ => true);
   201 
   202 (* search from top, left to right, then down *)
   203 val search_lr_all = ZipperSearch.all_bl_ur;
   204 
   205 (* search from top, left to right, then down *)
   206 fun search_lr_valid validf =
   207   let
   208     fun sf_valid_td_lr z =
   209       let val here = if validf z then [Zipper.Here z] else [] in
   210         (case Zipper.trm z of
   211           _ $ _ =>
   212             [Zipper.LookIn (Zipper.move_down_left z)] @ here @
   213             [Zipper.LookIn (Zipper.move_down_right z)]
   214         | Abs _ => here @ [Zipper.LookIn (Zipper.move_down_abs z)]
   215         | _ => here)
   216       end;
   217   in Zipper.lzy_search sf_valid_td_lr end;
   218 
   219 (* search from bottom to top, left to right *)
   220 fun search_bt_valid validf =
   221   let
   222     fun sf_valid_td_lr z =
   223       let val here = if validf z then [Zipper.Here z] else [] in
   224         (case Zipper.trm z of
   225           _ $ _ =>
   226             [Zipper.LookIn (Zipper.move_down_left z),
   227              Zipper.LookIn (Zipper.move_down_right z)] @ here
   228         | Abs _ => [Zipper.LookIn (Zipper.move_down_abs z)] @ here
   229         | _ => here)
   230       end;
   231   in Zipper.lzy_search sf_valid_td_lr end;
   232 
   233 fun searchf_unify_gen f (thy, maxidx, z) lhs =
   234   Seq.map (clean_unify_z thy maxidx lhs) (Zipper.limit_apply f z);
   235 
   236 (* search all unifications *)
   237 val searchf_lr_unify_all = searchf_unify_gen search_lr_all;
   238 
   239 (* search only for 'valid' unifiers (non abs subterms and non vars) *)
   240 val searchf_lr_unify_valid = searchf_unify_gen (search_lr_valid valid_match_start);
   241 
   242 val searchf_bt_unify_valid = searchf_unify_gen (search_bt_valid valid_match_start);
   243 
   244 (* apply a substitution in the conclusion of the theorem *)
   245 (* cfvs are certified free var placeholders for goal params *)
   246 (* conclthm is a theorem of for just the conclusion *)
   247 (* m is instantiation/match information *)
   248 (* rule is the equation for substitution *)
   249 fun apply_subst_in_concl ctxt i st (cfvs, conclthm) rule m =
   250   RW_Inst.rw ctxt m rule conclthm
   251   |> unfix_frees cfvs
   252   |> Conv.fconv_rule Drule.beta_eta_conversion
   253   |> (fn r => resolve_tac [r] i st);
   254 
   255 (* substitute within the conclusion of goal i of gth, using a meta
   256 equation rule. Note that we assume rule has var indicies zero'd *)
   257 fun prep_concl_subst ctxt i gth =
   258   let
   259     val th = Thm.incr_indexes 1 gth;
   260     val tgt_term = Thm.prop_of th;
   261 
   262     val thy = Thm.theory_of_thm th;
   263     val cert = Thm.cterm_of thy;
   264 
   265     val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   266     val cfvs = rev (map cert fvs);
   267 
   268     val conclterm = Logic.strip_imp_concl fixedbody;
   269     val conclthm = Thm.trivial (cert conclterm);
   270     val maxidx = Thm.maxidx_of th;
   271     val ft =
   272       (Zipper.move_down_right (* ==> *)
   273        o Zipper.move_down_left (* Trueprop *)
   274        o Zipper.mktop
   275        o Thm.prop_of) conclthm
   276   in
   277     ((cfvs, conclthm), (thy, maxidx, ft))
   278   end;
   279 
   280 (* substitute using an object or meta level equality *)
   281 fun eqsubst_tac' ctxt searchf instepthm i st =
   282   let
   283     val (cvfsconclthm, searchinfo) = prep_concl_subst ctxt i st;
   284     val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   285     fun rewrite_with_thm r =
   286       let val (lhs,_) = Logic.dest_equals (Thm.concl_of r) in
   287         searchf searchinfo lhs
   288         |> Seq.maps (apply_subst_in_concl ctxt i st cvfsconclthm r)
   289       end;
   290   in stepthms |> Seq.maps rewrite_with_thm end;
   291 
   292 
   293 (* General substitution of multiple occurrences using one of
   294    the given theorems *)
   295 
   296 fun skip_first_occs_search occ srchf sinfo lhs =
   297   (case skipto_skipseq occ (srchf sinfo lhs) of
   298     SkipMore _ => Seq.empty
   299   | SkipSeq ss => Seq.flat ss);
   300 
   301 (* The "occs" argument is a list of integers indicating which occurrence
   302 w.r.t. the search order, to rewrite. Backtracking will also find later
   303 occurrences, but all earlier ones are skipped. Thus you can use [0] to
   304 just find all rewrites. *)
   305 
   306 fun eqsubst_tac ctxt occs thms i st =
   307   let val nprems = Thm.nprems_of st in
   308     if nprems < i then Seq.empty else
   309     let
   310       val thmseq = Seq.of_list thms;
   311       fun apply_occ occ st =
   312         thmseq |> Seq.maps (fn r =>
   313           eqsubst_tac' ctxt
   314             (skip_first_occs_search occ searchf_lr_unify_valid) r
   315             (i + (Thm.nprems_of st - nprems)) st);
   316       val sorted_occs = Library.sort (rev_order o int_ord) occs;
   317     in
   318       Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
   319     end
   320   end;
   321 
   322 
   323 (* apply a substitution inside assumption j, keeps asm in the same place *)
   324 fun apply_subst_in_asm ctxt i st rule ((cfvs, j, _, pth),m) =
   325   let
   326     val st2 = Thm.rotate_rule (j - 1) i st; (* put premice first *)
   327     val preelimrule =
   328       RW_Inst.rw ctxt m rule pth
   329       |> (Seq.hd o prune_params_tac ctxt)
   330       |> Thm.permute_prems 0 ~1 (* put old asm first *)
   331       |> unfix_frees cfvs (* unfix any global params *)
   332       |> Conv.fconv_rule Drule.beta_eta_conversion; (* normal form *)
   333   in
   334     (* ~j because new asm starts at back, thus we subtract 1 *)
   335     Seq.map (Thm.rotate_rule (~ j) (Thm.nprems_of rule + i)) (dresolve_tac [preelimrule] i st2)
   336   end;
   337 
   338 
   339 (* prepare to substitute within the j'th premise of subgoal i of gth,
   340 using a meta-level equation. Note that we assume rule has var indicies
   341 zero'd. Note that we also assume that premt is the j'th premice of
   342 subgoal i of gth. Note the repetition of work done for each
   343 assumption, i.e. this can be made more efficient for search over
   344 multiple assumptions.  *)
   345 fun prep_subst_in_asm ctxt i gth j =
   346   let
   347     val th = Thm.incr_indexes 1 gth;
   348     val tgt_term = Thm.prop_of th;
   349 
   350     val thy = Thm.theory_of_thm th;
   351     val cert = Thm.cterm_of thy;
   352 
   353     val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   354     val cfvs = rev (map cert fvs);
   355 
   356     val asmt = nth (Logic.strip_imp_prems fixedbody) (j - 1);
   357     val asm_nprems = length (Logic.strip_imp_prems asmt);
   358 
   359     val pth = Thm.trivial (cert asmt);
   360     val maxidx = Thm.maxidx_of th;
   361 
   362     val ft =
   363       (Zipper.move_down_right (* trueprop *)
   364          o Zipper.mktop
   365          o Thm.prop_of) pth
   366   in ((cfvs, j, asm_nprems, pth), (thy, maxidx, ft)) end;
   367 
   368 (* prepare subst in every possible assumption *)
   369 fun prep_subst_in_asms ctxt i gth =
   370   map (prep_subst_in_asm ctxt i gth)
   371     ((fn l => Library.upto (1, length l))
   372       (Logic.prems_of_goal (Thm.prop_of gth) i));
   373 
   374 
   375 (* substitute in an assumption using an object or meta level equality *)
   376 fun eqsubst_asm_tac' ctxt searchf skipocc instepthm i st =
   377   let
   378     val asmpreps = prep_subst_in_asms ctxt i st;
   379     val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   380     fun rewrite_with_thm r =
   381       let
   382         val (lhs,_) = Logic.dest_equals (Thm.concl_of r);
   383         fun occ_search occ [] = Seq.empty
   384           | occ_search occ ((asminfo, searchinfo)::moreasms) =
   385               (case searchf searchinfo occ lhs of
   386                 SkipMore i => occ_search i moreasms
   387               | SkipSeq ss =>
   388                   Seq.append (Seq.map (Library.pair asminfo) (Seq.flat ss))
   389                     (occ_search 1 moreasms)) (* find later substs also *)
   390       in
   391         occ_search skipocc asmpreps |> Seq.maps (apply_subst_in_asm ctxt i st r)
   392       end;
   393   in stepthms |> Seq.maps rewrite_with_thm end;
   394 
   395 
   396 fun skip_first_asm_occs_search searchf sinfo occ lhs =
   397   skipto_skipseq occ (searchf sinfo lhs);
   398 
   399 fun eqsubst_asm_tac ctxt occs thms i st =
   400   let val nprems = Thm.nprems_of st in
   401     if nprems < i then Seq.empty
   402     else
   403       let
   404         val thmseq = Seq.of_list thms;
   405         fun apply_occ occ st =
   406           thmseq |> Seq.maps (fn r =>
   407             eqsubst_asm_tac' ctxt
   408               (skip_first_asm_occs_search searchf_lr_unify_valid) occ r
   409               (i + (Thm.nprems_of st - nprems)) st);
   410         val sorted_occs = Library.sort (rev_order o int_ord) occs;
   411       in
   412         Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
   413       end
   414   end;
   415 
   416 (* combination method that takes a flag (true indicates that subst
   417    should be done to an assumption, false = apply to the conclusion of
   418    the goal) as well as the theorems to use *)
   419 val _ =
   420   Theory.setup
   421     (Method.setup @{binding subst}
   422       (Scan.lift (Args.mode "asm" -- Scan.optional (Args.parens (Scan.repeat Parse.nat)) [0]) --
   423         Attrib.thms >> (fn ((asm, occs), inthms) => fn ctxt =>
   424           SIMPLE_METHOD' ((if asm then eqsubst_asm_tac else eqsubst_tac) ctxt occs inthms)))
   425       "single-step substitution");
   426 
   427 end;