src/Tools/eqsubst.ML
author wenzelm
Fri Mar 06 23:56:43 2015 +0100 (2015-03-06)
changeset 59642 929984c529d3
parent 59621 291934bac95e
child 60358 aebfbcab1eb8
permissions -rw-r--r--
clarified context;
     1 (*  Title:      Tools/eqsubst.ML
     2     Author:     Lucas Dixon, University of Edinburgh
     3 
     4 Perform a substitution using an equation.
     5 *)
     6 
     7 signature EQSUBST =
     8 sig
     9   type match =
    10     ((indexname * (sort * typ)) list (* type instantiations *)
    11       * (indexname * (typ * term)) list) (* term instantiations *)
    12     * (string * typ) list (* fake named type abs env *)
    13     * (string * typ) list (* type abs env *)
    14     * term (* outer term *)
    15 
    16   type searchinfo =
    17     theory
    18     * int (* maxidx *)
    19     * Zipper.T (* focusterm to search under *)
    20 
    21   datatype 'a skipseq = SkipMore of int | SkipSeq of 'a Seq.seq Seq.seq
    22 
    23   val skip_first_asm_occs_search: ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> int -> 'b -> 'c skipseq
    24   val skip_first_occs_search: int -> ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> 'b -> 'c Seq.seq
    25   val skipto_skipseq: int -> 'a Seq.seq Seq.seq -> 'a skipseq
    26 
    27   (* tactics *)
    28   val eqsubst_asm_tac: Proof.context -> int list -> thm list -> int -> tactic
    29   val eqsubst_asm_tac': Proof.context ->
    30     (searchinfo -> int -> term -> match skipseq) -> int -> thm -> int -> tactic
    31   val eqsubst_tac: Proof.context ->
    32     int list -> (* list of occurrences to rewrite, use [0] for any *)
    33     thm list -> int -> tactic
    34   val eqsubst_tac': Proof.context ->
    35     (searchinfo -> term -> match Seq.seq) (* search function *)
    36     -> thm (* equation theorem to rewrite with *)
    37     -> int (* subgoal number in goal theorem *)
    38     -> thm (* goal theorem *)
    39     -> thm Seq.seq (* rewritten goal theorem *)
    40 
    41   (* search for substitutions *)
    42   val valid_match_start: Zipper.T -> bool
    43   val search_lr_all: Zipper.T -> Zipper.T Seq.seq
    44   val search_lr_valid: (Zipper.T -> bool) -> Zipper.T -> Zipper.T Seq.seq
    45   val searchf_lr_unify_all: searchinfo -> term -> match Seq.seq Seq.seq
    46   val searchf_lr_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
    47   val searchf_bt_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
    48 end;
    49 
    50 structure EqSubst: EQSUBST =
    51 struct
    52 
    53 (* changes object "=" to meta "==" which prepares a given rewrite rule *)
    54 fun prep_meta_eq ctxt =
    55   Simplifier.mksimps ctxt #> map Drule.zero_var_indexes;
    56 
    57 (* make free vars into schematic vars with index zero *)
    58 fun unfix_frees frees =
    59    fold (K (Thm.forall_elim_var 0)) frees o Drule.forall_intr_list frees;
    60 
    61 
    62 type match =
    63   ((indexname * (sort * typ)) list (* type instantiations *)
    64    * (indexname * (typ * term)) list) (* term instantiations *)
    65   * (string * typ) list (* fake named type abs env *)
    66   * (string * typ) list (* type abs env *)
    67   * term; (* outer term *)
    68 
    69 type searchinfo =
    70   theory
    71   * int (* maxidx *)
    72   * Zipper.T; (* focusterm to search under *)
    73 
    74 
    75 (* skipping non-empty sub-sequences but when we reach the end
    76    of the seq, remembering how much we have left to skip. *)
    77 datatype 'a skipseq =
    78   SkipMore of int |
    79   SkipSeq of 'a Seq.seq Seq.seq;
    80 
    81 (* given a seqseq, skip the first m non-empty seq's, note deficit *)
    82 fun skipto_skipseq m s =
    83   let
    84     fun skip_occs n sq =
    85       (case Seq.pull sq of
    86         NONE => SkipMore n
    87       | SOME (h, t) =>
    88         (case Seq.pull h of
    89           NONE => skip_occs n t
    90         | SOME _ => if n <= 1 then SkipSeq (Seq.cons h t) else skip_occs (n - 1) t))
    91   in skip_occs m s end;
    92 
    93 (* note: outerterm is the taget with the match replaced by a bound
    94    variable : ie: "P lhs" beocmes "%x. P x"
    95    insts is the types of instantiations of vars in lhs
    96    and typinsts is the type instantiations of types in the lhs
    97    Note: Final rule is the rule lifted into the ontext of the
    98    taget thm. *)
    99 fun mk_foo_match mkuptermfunc Ts t =
   100   let
   101     val ty = Term.type_of t
   102     val bigtype = rev (map snd Ts) ---> ty
   103     fun mk_foo 0 t = t
   104       | mk_foo i t = mk_foo (i - 1) (t $ (Bound (i - 1)))
   105     val num_of_bnds = length Ts
   106     (* foo_term = "fooabs y0 ... yn" where y's are local bounds *)
   107     val foo_term = mk_foo num_of_bnds (Bound num_of_bnds)
   108   in Abs ("fooabs", bigtype, mkuptermfunc foo_term) end;
   109 
   110 (* T is outer bound vars, n is number of locally bound vars *)
   111 (* THINK: is order of Ts correct...? or reversed? *)
   112 fun mk_fake_bound_name n = ":b_" ^ n;
   113 fun fakefree_badbounds Ts t =
   114   let val (FakeTs, Ts, newnames) =
   115     fold_rev (fn (n, ty) => fn (FakeTs, Ts, usednames) =>
   116       let
   117         val newname = singleton (Name.variant_list usednames) n
   118       in
   119         ((mk_fake_bound_name newname, ty) :: FakeTs,
   120           (newname, ty) :: Ts,
   121           newname :: usednames)
   122       end) Ts ([], [], [])
   123   in (FakeTs, Ts, Term.subst_bounds (map Free FakeTs, t)) end;
   124 
   125 (* before matching we need to fake the bound vars that are missing an
   126    abstraction. In this function we additionally construct the
   127    abstraction environment, and an outer context term (with the focus
   128    abstracted out) for use in rewriting with RW_Inst.rw *)
   129 fun prep_zipper_match z =
   130   let
   131     val t = Zipper.trm z
   132     val c = Zipper.ctxt z
   133     val Ts = Zipper.C.nty_ctxt c
   134     val (FakeTs', Ts', t') = fakefree_badbounds Ts t
   135     val absterm = mk_foo_match (Zipper.C.apply c) Ts' t'
   136   in
   137     (t', (FakeTs', Ts', absterm))
   138   end;
   139 
   140 (* Unification with exception handled *)
   141 (* given theory, max var index, pat, tgt; returns Seq of instantiations *)
   142 fun clean_unify thy ix (a as (pat, tgt)) =
   143   let
   144     (* type info will be re-derived, maybe this can be cached
   145        for efficiency? *)
   146     val pat_ty = Term.type_of pat;
   147     val tgt_ty = Term.type_of tgt;
   148     (* FIXME is it OK to ignore the type instantiation info?
   149        or should I be using it? *)
   150     val typs_unify =
   151       SOME (Sign.typ_unify thy (pat_ty, tgt_ty) (Vartab.empty, ix))
   152         handle Type.TUNIFY => NONE;
   153   in
   154     (case typs_unify of
   155       SOME (typinsttab, ix2) =>
   156         let
   157           (* FIXME is it right to throw away the flexes?
   158              or should I be using them somehow? *)
   159           fun mk_insts env =
   160             (Vartab.dest (Envir.type_env env),
   161              Vartab.dest (Envir.term_env env));
   162           val initenv =
   163             Envir.Envir {maxidx = ix2, tenv = Vartab.empty, tyenv = typinsttab};
   164           val useq = Unify.smash_unifiers (Context.Theory thy) [a] initenv
   165             handle ListPair.UnequalLengths => Seq.empty
   166               | Term.TERM _ => Seq.empty;
   167           fun clean_unify' useq () =
   168             (case (Seq.pull useq) of
   169                NONE => NONE
   170              | SOME (h, t) => SOME (mk_insts h, Seq.make (clean_unify' t)))
   171             handle ListPair.UnequalLengths => NONE
   172               | Term.TERM _ => NONE;
   173         in
   174           (Seq.make (clean_unify' useq))
   175         end
   176     | NONE => Seq.empty)
   177   end;
   178 
   179 (* Unification for zippers *)
   180 (* Note: Ts is a modified version of the original names of the outer
   181    bound variables. New names have been introduced to make sure they are
   182    unique w.r.t all names in the term and each other. usednames' is
   183    oldnames + new names. *)
   184 fun clean_unify_z thy maxidx pat z =
   185   let val (t, (FakeTs, Ts, absterm)) = prep_zipper_match z in
   186     Seq.map (fn insts => (insts, FakeTs, Ts, absterm))
   187       (clean_unify thy maxidx (t, pat))
   188   end;
   189 
   190 
   191 fun bot_left_leaf_of (l $ _) = bot_left_leaf_of l
   192   | bot_left_leaf_of (Abs (_, _, t)) = bot_left_leaf_of t
   193   | bot_left_leaf_of x = x;
   194 
   195 (* Avoid considering replacing terms which have a var at the head as
   196    they always succeed trivially, and uninterestingly. *)
   197 fun valid_match_start z =
   198   (case bot_left_leaf_of (Zipper.trm z) of
   199     Var _ => false
   200   | _ => true);
   201 
   202 (* search from top, left to right, then down *)
   203 val search_lr_all = ZipperSearch.all_bl_ur;
   204 
   205 (* search from top, left to right, then down *)
   206 fun search_lr_valid validf =
   207   let
   208     fun sf_valid_td_lr z =
   209       let val here = if validf z then [Zipper.Here z] else [] in
   210         (case Zipper.trm z of
   211           _ $ _ =>
   212             [Zipper.LookIn (Zipper.move_down_left z)] @ here @
   213             [Zipper.LookIn (Zipper.move_down_right z)]
   214         | Abs _ => here @ [Zipper.LookIn (Zipper.move_down_abs z)]
   215         | _ => here)
   216       end;
   217   in Zipper.lzy_search sf_valid_td_lr end;
   218 
   219 (* search from bottom to top, left to right *)
   220 fun search_bt_valid validf =
   221   let
   222     fun sf_valid_td_lr z =
   223       let val here = if validf z then [Zipper.Here z] else [] in
   224         (case Zipper.trm z of
   225           _ $ _ =>
   226             [Zipper.LookIn (Zipper.move_down_left z),
   227              Zipper.LookIn (Zipper.move_down_right z)] @ here
   228         | Abs _ => [Zipper.LookIn (Zipper.move_down_abs z)] @ here
   229         | _ => here)
   230       end;
   231   in Zipper.lzy_search sf_valid_td_lr end;
   232 
   233 fun searchf_unify_gen f (thy, maxidx, z) lhs =
   234   Seq.map (clean_unify_z thy maxidx lhs) (Zipper.limit_apply f z);
   235 
   236 (* search all unifications *)
   237 val searchf_lr_unify_all = searchf_unify_gen search_lr_all;
   238 
   239 (* search only for 'valid' unifiers (non abs subterms and non vars) *)
   240 val searchf_lr_unify_valid = searchf_unify_gen (search_lr_valid valid_match_start);
   241 
   242 val searchf_bt_unify_valid = searchf_unify_gen (search_bt_valid valid_match_start);
   243 
   244 (* apply a substitution in the conclusion of the theorem *)
   245 (* cfvs are certified free var placeholders for goal params *)
   246 (* conclthm is a theorem of for just the conclusion *)
   247 (* m is instantiation/match information *)
   248 (* rule is the equation for substitution *)
   249 fun apply_subst_in_concl ctxt i st (cfvs, conclthm) rule m =
   250   RW_Inst.rw ctxt m rule conclthm
   251   |> unfix_frees cfvs
   252   |> Conv.fconv_rule Drule.beta_eta_conversion
   253   |> (fn r => resolve_tac ctxt [r] i st);
   254 
   255 (* substitute within the conclusion of goal i of gth, using a meta
   256 equation rule. Note that we assume rule has var indicies zero'd *)
   257 fun prep_concl_subst ctxt i gth =
   258   let
   259     val th = Thm.incr_indexes 1 gth;
   260     val tgt_term = Thm.prop_of th;
   261 
   262     val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   263     val cfvs = rev (map (Thm.cterm_of ctxt) fvs);
   264 
   265     val conclterm = Logic.strip_imp_concl fixedbody;
   266     val conclthm = Thm.trivial (Thm.cterm_of ctxt conclterm);
   267     val maxidx = Thm.maxidx_of th;
   268     val ft =
   269       (Zipper.move_down_right (* ==> *)
   270        o Zipper.move_down_left (* Trueprop *)
   271        o Zipper.mktop
   272        o Thm.prop_of) conclthm
   273   in
   274     ((cfvs, conclthm), (Proof_Context.theory_of ctxt, maxidx, ft))
   275   end;
   276 
   277 (* substitute using an object or meta level equality *)
   278 fun eqsubst_tac' ctxt searchf instepthm i st =
   279   let
   280     val (cvfsconclthm, searchinfo) = prep_concl_subst ctxt i st;
   281     val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   282     fun rewrite_with_thm r =
   283       let val (lhs,_) = Logic.dest_equals (Thm.concl_of r) in
   284         searchf searchinfo lhs
   285         |> Seq.maps (apply_subst_in_concl ctxt i st cvfsconclthm r)
   286       end;
   287   in stepthms |> Seq.maps rewrite_with_thm end;
   288 
   289 
   290 (* General substitution of multiple occurrences using one of
   291    the given theorems *)
   292 
   293 fun skip_first_occs_search occ srchf sinfo lhs =
   294   (case skipto_skipseq occ (srchf sinfo lhs) of
   295     SkipMore _ => Seq.empty
   296   | SkipSeq ss => Seq.flat ss);
   297 
   298 (* The "occs" argument is a list of integers indicating which occurrence
   299 w.r.t. the search order, to rewrite. Backtracking will also find later
   300 occurrences, but all earlier ones are skipped. Thus you can use [0] to
   301 just find all rewrites. *)
   302 
   303 fun eqsubst_tac ctxt occs thms i st =
   304   let val nprems = Thm.nprems_of st in
   305     if nprems < i then Seq.empty else
   306     let
   307       val thmseq = Seq.of_list thms;
   308       fun apply_occ occ st =
   309         thmseq |> Seq.maps (fn r =>
   310           eqsubst_tac' ctxt
   311             (skip_first_occs_search occ searchf_lr_unify_valid) r
   312             (i + (Thm.nprems_of st - nprems)) st);
   313       val sorted_occs = Library.sort (rev_order o int_ord) occs;
   314     in
   315       Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
   316     end
   317   end;
   318 
   319 
   320 (* apply a substitution inside assumption j, keeps asm in the same place *)
   321 fun apply_subst_in_asm ctxt i st rule ((cfvs, j, _, pth),m) =
   322   let
   323     val st2 = Thm.rotate_rule (j - 1) i st; (* put premice first *)
   324     val preelimrule =
   325       RW_Inst.rw ctxt m rule pth
   326       |> (Seq.hd o prune_params_tac ctxt)
   327       |> Thm.permute_prems 0 ~1 (* put old asm first *)
   328       |> unfix_frees cfvs (* unfix any global params *)
   329       |> Conv.fconv_rule Drule.beta_eta_conversion; (* normal form *)
   330   in
   331     (* ~j because new asm starts at back, thus we subtract 1 *)
   332     Seq.map (Thm.rotate_rule (~ j) (Thm.nprems_of rule + i))
   333       (dresolve_tac ctxt [preelimrule] i st2)
   334   end;
   335 
   336 
   337 (* prepare to substitute within the j'th premise of subgoal i of gth,
   338 using a meta-level equation. Note that we assume rule has var indicies
   339 zero'd. Note that we also assume that premt is the j'th premice of
   340 subgoal i of gth. Note the repetition of work done for each
   341 assumption, i.e. this can be made more efficient for search over
   342 multiple assumptions.  *)
   343 fun prep_subst_in_asm ctxt i gth j =
   344   let
   345     val th = Thm.incr_indexes 1 gth;
   346     val tgt_term = Thm.prop_of th;
   347 
   348     val thy = Thm.theory_of_thm th;
   349     val cert = Thm.global_cterm_of thy;
   350 
   351     val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   352     val cfvs = rev (map cert fvs);
   353 
   354     val asmt = nth (Logic.strip_imp_prems fixedbody) (j - 1);
   355     val asm_nprems = length (Logic.strip_imp_prems asmt);
   356 
   357     val pth = Thm.trivial (cert asmt);
   358     val maxidx = Thm.maxidx_of th;
   359 
   360     val ft =
   361       (Zipper.move_down_right (* trueprop *)
   362          o Zipper.mktop
   363          o Thm.prop_of) pth
   364   in ((cfvs, j, asm_nprems, pth), (thy, maxidx, ft)) end;
   365 
   366 (* prepare subst in every possible assumption *)
   367 fun prep_subst_in_asms ctxt i gth =
   368   map (prep_subst_in_asm ctxt i gth)
   369     ((fn l => Library.upto (1, length l))
   370       (Logic.prems_of_goal (Thm.prop_of gth) i));
   371 
   372 
   373 (* substitute in an assumption using an object or meta level equality *)
   374 fun eqsubst_asm_tac' ctxt searchf skipocc instepthm i st =
   375   let
   376     val asmpreps = prep_subst_in_asms ctxt i st;
   377     val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   378     fun rewrite_with_thm r =
   379       let
   380         val (lhs,_) = Logic.dest_equals (Thm.concl_of r);
   381         fun occ_search occ [] = Seq.empty
   382           | occ_search occ ((asminfo, searchinfo)::moreasms) =
   383               (case searchf searchinfo occ lhs of
   384                 SkipMore i => occ_search i moreasms
   385               | SkipSeq ss =>
   386                   Seq.append (Seq.map (Library.pair asminfo) (Seq.flat ss))
   387                     (occ_search 1 moreasms)) (* find later substs also *)
   388       in
   389         occ_search skipocc asmpreps |> Seq.maps (apply_subst_in_asm ctxt i st r)
   390       end;
   391   in stepthms |> Seq.maps rewrite_with_thm end;
   392 
   393 
   394 fun skip_first_asm_occs_search searchf sinfo occ lhs =
   395   skipto_skipseq occ (searchf sinfo lhs);
   396 
   397 fun eqsubst_asm_tac ctxt occs thms i st =
   398   let val nprems = Thm.nprems_of st in
   399     if nprems < i then Seq.empty
   400     else
   401       let
   402         val thmseq = Seq.of_list thms;
   403         fun apply_occ occ st =
   404           thmseq |> Seq.maps (fn r =>
   405             eqsubst_asm_tac' ctxt
   406               (skip_first_asm_occs_search searchf_lr_unify_valid) occ r
   407               (i + (Thm.nprems_of st - nprems)) st);
   408         val sorted_occs = Library.sort (rev_order o int_ord) occs;
   409       in
   410         Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
   411       end
   412   end;
   413 
   414 (* combination method that takes a flag (true indicates that subst
   415    should be done to an assumption, false = apply to the conclusion of
   416    the goal) as well as the theorems to use *)
   417 val _ =
   418   Theory.setup
   419     (Method.setup @{binding subst}
   420       (Scan.lift (Args.mode "asm" -- Scan.optional (Args.parens (Scan.repeat Parse.nat)) [0]) --
   421         Attrib.thms >> (fn ((asm, occs), inthms) => fn ctxt =>
   422           SIMPLE_METHOD' ((if asm then eqsubst_asm_tac else eqsubst_tac) ctxt occs inthms)))
   423       "single-step substitution");
   424 
   425 end;