src/Tools/eqsubst.ML
author wenzelm
Tue Jun 02 09:16:19 2015 +0200 (2015-06-02)
changeset 60358 aebfbcab1eb8
parent 59642 929984c529d3
child 67149 e61557884799
permissions -rw-r--r--
clarified context;
     1 (*  Title:      Tools/eqsubst.ML
     2     Author:     Lucas Dixon, University of Edinburgh
     3 
     4 Perform a substitution using an equation.
     5 *)
     6 
     7 signature EQSUBST =
     8 sig
     9   type match =
    10     ((indexname * (sort * typ)) list (* type instantiations *)
    11       * (indexname * (typ * term)) list) (* term instantiations *)
    12     * (string * typ) list (* fake named type abs env *)
    13     * (string * typ) list (* type abs env *)
    14     * term (* outer term *)
    15 
    16   type searchinfo =
    17     Proof.context
    18     * int (* maxidx *)
    19     * Zipper.T (* focusterm to search under *)
    20 
    21   datatype 'a skipseq = SkipMore of int | SkipSeq of 'a Seq.seq Seq.seq
    22 
    23   val skip_first_asm_occs_search: ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> int -> 'b -> 'c skipseq
    24   val skip_first_occs_search: int -> ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> 'b -> 'c Seq.seq
    25   val skipto_skipseq: int -> 'a Seq.seq Seq.seq -> 'a skipseq
    26 
    27   (* tactics *)
    28   val eqsubst_asm_tac: Proof.context -> int list -> thm list -> int -> tactic
    29   val eqsubst_asm_tac': Proof.context ->
    30     (searchinfo -> int -> term -> match skipseq) -> int -> thm -> int -> tactic
    31   val eqsubst_tac: Proof.context ->
    32     int list -> (* list of occurrences to rewrite, use [0] for any *)
    33     thm list -> int -> tactic
    34   val eqsubst_tac': Proof.context ->
    35     (searchinfo -> term -> match Seq.seq) (* search function *)
    36     -> thm (* equation theorem to rewrite with *)
    37     -> int (* subgoal number in goal theorem *)
    38     -> thm (* goal theorem *)
    39     -> thm Seq.seq (* rewritten goal theorem *)
    40 
    41   (* search for substitutions *)
    42   val valid_match_start: Zipper.T -> bool
    43   val search_lr_all: Zipper.T -> Zipper.T Seq.seq
    44   val search_lr_valid: (Zipper.T -> bool) -> Zipper.T -> Zipper.T Seq.seq
    45   val searchf_lr_unify_all: searchinfo -> term -> match Seq.seq Seq.seq
    46   val searchf_lr_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
    47   val searchf_bt_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
    48 end;
    49 
    50 structure EqSubst: EQSUBST =
    51 struct
    52 
    53 (* changes object "=" to meta "==" which prepares a given rewrite rule *)
    54 fun prep_meta_eq ctxt =
    55   Simplifier.mksimps ctxt #> map Drule.zero_var_indexes;
    56 
    57 (* make free vars into schematic vars with index zero *)
    58 fun unfix_frees frees =
    59    fold (K (Thm.forall_elim_var 0)) frees o Drule.forall_intr_list frees;
    60 
    61 
    62 type match =
    63   ((indexname * (sort * typ)) list (* type instantiations *)
    64    * (indexname * (typ * term)) list) (* term instantiations *)
    65   * (string * typ) list (* fake named type abs env *)
    66   * (string * typ) list (* type abs env *)
    67   * term; (* outer term *)
    68 
    69 type searchinfo =
    70   Proof.context
    71   * int (* maxidx *)
    72   * Zipper.T; (* focusterm to search under *)
    73 
    74 
    75 (* skipping non-empty sub-sequences but when we reach the end
    76    of the seq, remembering how much we have left to skip. *)
    77 datatype 'a skipseq =
    78   SkipMore of int |
    79   SkipSeq of 'a Seq.seq Seq.seq;
    80 
    81 (* given a seqseq, skip the first m non-empty seq's, note deficit *)
    82 fun skipto_skipseq m s =
    83   let
    84     fun skip_occs n sq =
    85       (case Seq.pull sq of
    86         NONE => SkipMore n
    87       | SOME (h, t) =>
    88         (case Seq.pull h of
    89           NONE => skip_occs n t
    90         | SOME _ => if n <= 1 then SkipSeq (Seq.cons h t) else skip_occs (n - 1) t))
    91   in skip_occs m s end;
    92 
    93 (* note: outerterm is the taget with the match replaced by a bound
    94    variable : ie: "P lhs" beocmes "%x. P x"
    95    insts is the types of instantiations of vars in lhs
    96    and typinsts is the type instantiations of types in the lhs
    97    Note: Final rule is the rule lifted into the ontext of the
    98    taget thm. *)
    99 fun mk_foo_match mkuptermfunc Ts t =
   100   let
   101     val ty = Term.type_of t
   102     val bigtype = rev (map snd Ts) ---> ty
   103     fun mk_foo 0 t = t
   104       | mk_foo i t = mk_foo (i - 1) (t $ (Bound (i - 1)))
   105     val num_of_bnds = length Ts
   106     (* foo_term = "fooabs y0 ... yn" where y's are local bounds *)
   107     val foo_term = mk_foo num_of_bnds (Bound num_of_bnds)
   108   in Abs ("fooabs", bigtype, mkuptermfunc foo_term) end;
   109 
   110 (* T is outer bound vars, n is number of locally bound vars *)
   111 (* THINK: is order of Ts correct...? or reversed? *)
   112 fun mk_fake_bound_name n = ":b_" ^ n;
   113 fun fakefree_badbounds Ts t =
   114   let val (FakeTs, Ts, newnames) =
   115     fold_rev (fn (n, ty) => fn (FakeTs, Ts, usednames) =>
   116       let
   117         val newname = singleton (Name.variant_list usednames) n
   118       in
   119         ((mk_fake_bound_name newname, ty) :: FakeTs,
   120           (newname, ty) :: Ts,
   121           newname :: usednames)
   122       end) Ts ([], [], [])
   123   in (FakeTs, Ts, Term.subst_bounds (map Free FakeTs, t)) end;
   124 
   125 (* before matching we need to fake the bound vars that are missing an
   126    abstraction. In this function we additionally construct the
   127    abstraction environment, and an outer context term (with the focus
   128    abstracted out) for use in rewriting with RW_Inst.rw *)
   129 fun prep_zipper_match z =
   130   let
   131     val t = Zipper.trm z
   132     val c = Zipper.ctxt z
   133     val Ts = Zipper.C.nty_ctxt c
   134     val (FakeTs', Ts', t') = fakefree_badbounds Ts t
   135     val absterm = mk_foo_match (Zipper.C.apply c) Ts' t'
   136   in
   137     (t', (FakeTs', Ts', absterm))
   138   end;
   139 
   140 (* Unification with exception handled *)
   141 (* given context, max var index, pat, tgt; returns Seq of instantiations *)
   142 fun clean_unify ctxt ix (a as (pat, tgt)) =
   143   let
   144     (* type info will be re-derived, maybe this can be cached
   145        for efficiency? *)
   146     val pat_ty = Term.type_of pat;
   147     val tgt_ty = Term.type_of tgt;
   148     (* FIXME is it OK to ignore the type instantiation info?
   149        or should I be using it? *)
   150     val typs_unify =
   151       SOME (Sign.typ_unify (Proof_Context.theory_of ctxt) (pat_ty, tgt_ty) (Vartab.empty, ix))
   152         handle Type.TUNIFY => NONE;
   153   in
   154     (case typs_unify of
   155       SOME (typinsttab, ix2) =>
   156         let
   157           (* FIXME is it right to throw away the flexes?
   158              or should I be using them somehow? *)
   159           fun mk_insts env =
   160             (Vartab.dest (Envir.type_env env),
   161              Vartab.dest (Envir.term_env env));
   162           val initenv =
   163             Envir.Envir {maxidx = ix2, tenv = Vartab.empty, tyenv = typinsttab};
   164           val useq = Unify.smash_unifiers (Context.Proof ctxt) [a] initenv
   165             handle ListPair.UnequalLengths => Seq.empty
   166               | Term.TERM _ => Seq.empty;
   167           fun clean_unify' useq () =
   168             (case (Seq.pull useq) of
   169                NONE => NONE
   170              | SOME (h, t) => SOME (mk_insts h, Seq.make (clean_unify' t)))
   171             handle ListPair.UnequalLengths => NONE
   172               | Term.TERM _ => NONE;
   173         in
   174           (Seq.make (clean_unify' useq))
   175         end
   176     | NONE => Seq.empty)
   177   end;
   178 
   179 (* Unification for zippers *)
   180 (* Note: Ts is a modified version of the original names of the outer
   181    bound variables. New names have been introduced to make sure they are
   182    unique w.r.t all names in the term and each other. usednames' is
   183    oldnames + new names. *)
   184 fun clean_unify_z ctxt maxidx pat z =
   185   let val (t, (FakeTs, Ts, absterm)) = prep_zipper_match z in
   186     Seq.map (fn insts => (insts, FakeTs, Ts, absterm))
   187       (clean_unify ctxt maxidx (t, pat))
   188   end;
   189 
   190 
   191 fun bot_left_leaf_of (l $ _) = bot_left_leaf_of l
   192   | bot_left_leaf_of (Abs (_, _, t)) = bot_left_leaf_of t
   193   | bot_left_leaf_of x = x;
   194 
   195 (* Avoid considering replacing terms which have a var at the head as
   196    they always succeed trivially, and uninterestingly. *)
   197 fun valid_match_start z =
   198   (case bot_left_leaf_of (Zipper.trm z) of
   199     Var _ => false
   200   | _ => true);
   201 
   202 (* search from top, left to right, then down *)
   203 val search_lr_all = ZipperSearch.all_bl_ur;
   204 
   205 (* search from top, left to right, then down *)
   206 fun search_lr_valid validf =
   207   let
   208     fun sf_valid_td_lr z =
   209       let val here = if validf z then [Zipper.Here z] else [] in
   210         (case Zipper.trm z of
   211           _ $ _ =>
   212             [Zipper.LookIn (Zipper.move_down_left z)] @ here @
   213             [Zipper.LookIn (Zipper.move_down_right z)]
   214         | Abs _ => here @ [Zipper.LookIn (Zipper.move_down_abs z)]
   215         | _ => here)
   216       end;
   217   in Zipper.lzy_search sf_valid_td_lr end;
   218 
   219 (* search from bottom to top, left to right *)
   220 fun search_bt_valid validf =
   221   let
   222     fun sf_valid_td_lr z =
   223       let val here = if validf z then [Zipper.Here z] else [] in
   224         (case Zipper.trm z of
   225           _ $ _ =>
   226             [Zipper.LookIn (Zipper.move_down_left z),
   227              Zipper.LookIn (Zipper.move_down_right z)] @ here
   228         | Abs _ => [Zipper.LookIn (Zipper.move_down_abs z)] @ here
   229         | _ => here)
   230       end;
   231   in Zipper.lzy_search sf_valid_td_lr end;
   232 
   233 fun searchf_unify_gen f (ctxt, maxidx, z) lhs =
   234   Seq.map (clean_unify_z ctxt maxidx lhs) (Zipper.limit_apply f z);
   235 
   236 (* search all unifications *)
   237 val searchf_lr_unify_all = searchf_unify_gen search_lr_all;
   238 
   239 (* search only for 'valid' unifiers (non abs subterms and non vars) *)
   240 val searchf_lr_unify_valid = searchf_unify_gen (search_lr_valid valid_match_start);
   241 
   242 val searchf_bt_unify_valid = searchf_unify_gen (search_bt_valid valid_match_start);
   243 
   244 (* apply a substitution in the conclusion of the theorem *)
   245 (* cfvs are certified free var placeholders for goal params *)
   246 (* conclthm is a theorem of for just the conclusion *)
   247 (* m is instantiation/match information *)
   248 (* rule is the equation for substitution *)
   249 fun apply_subst_in_concl ctxt i st (cfvs, conclthm) rule m =
   250   RW_Inst.rw ctxt m rule conclthm
   251   |> unfix_frees cfvs
   252   |> Conv.fconv_rule Drule.beta_eta_conversion
   253   |> (fn r => resolve_tac ctxt [r] i st);
   254 
   255 (* substitute within the conclusion of goal i of gth, using a meta
   256 equation rule. Note that we assume rule has var indicies zero'd *)
   257 fun prep_concl_subst ctxt i gth =
   258   let
   259     val th = Thm.incr_indexes 1 gth;
   260     val tgt_term = Thm.prop_of th;
   261 
   262     val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   263     val cfvs = rev (map (Thm.cterm_of ctxt) fvs);
   264 
   265     val conclterm = Logic.strip_imp_concl fixedbody;
   266     val conclthm = Thm.trivial (Thm.cterm_of ctxt conclterm);
   267     val maxidx = Thm.maxidx_of th;
   268     val ft =
   269       (Zipper.move_down_right (* ==> *)
   270        o Zipper.move_down_left (* Trueprop *)
   271        o Zipper.mktop
   272        o Thm.prop_of) conclthm
   273   in
   274     ((cfvs, conclthm), (ctxt, maxidx, ft))
   275   end;
   276 
   277 (* substitute using an object or meta level equality *)
   278 fun eqsubst_tac' ctxt searchf instepthm i st =
   279   let
   280     val (cvfsconclthm, searchinfo) = prep_concl_subst ctxt i st;
   281     val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   282     fun rewrite_with_thm r =
   283       let val (lhs,_) = Logic.dest_equals (Thm.concl_of r) in
   284         searchf searchinfo lhs
   285         |> Seq.maps (apply_subst_in_concl ctxt i st cvfsconclthm r)
   286       end;
   287   in stepthms |> Seq.maps rewrite_with_thm end;
   288 
   289 
   290 (* General substitution of multiple occurrences using one of
   291    the given theorems *)
   292 
   293 fun skip_first_occs_search occ srchf sinfo lhs =
   294   (case skipto_skipseq occ (srchf sinfo lhs) of
   295     SkipMore _ => Seq.empty
   296   | SkipSeq ss => Seq.flat ss);
   297 
   298 (* The "occs" argument is a list of integers indicating which occurrence
   299 w.r.t. the search order, to rewrite. Backtracking will also find later
   300 occurrences, but all earlier ones are skipped. Thus you can use [0] to
   301 just find all rewrites. *)
   302 
   303 fun eqsubst_tac ctxt occs thms i st =
   304   let val nprems = Thm.nprems_of st in
   305     if nprems < i then Seq.empty else
   306     let
   307       val thmseq = Seq.of_list thms;
   308       fun apply_occ occ st =
   309         thmseq |> Seq.maps (fn r =>
   310           eqsubst_tac' ctxt
   311             (skip_first_occs_search occ searchf_lr_unify_valid) r
   312             (i + (Thm.nprems_of st - nprems)) st);
   313       val sorted_occs = Library.sort (rev_order o int_ord) occs;
   314     in
   315       Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
   316     end
   317   end;
   318 
   319 
   320 (* apply a substitution inside assumption j, keeps asm in the same place *)
   321 fun apply_subst_in_asm ctxt i st rule ((cfvs, j, _, pth),m) =
   322   let
   323     val st2 = Thm.rotate_rule (j - 1) i st; (* put premice first *)
   324     val preelimrule =
   325       RW_Inst.rw ctxt m rule pth
   326       |> (Seq.hd o prune_params_tac ctxt)
   327       |> Thm.permute_prems 0 ~1 (* put old asm first *)
   328       |> unfix_frees cfvs (* unfix any global params *)
   329       |> Conv.fconv_rule Drule.beta_eta_conversion; (* normal form *)
   330   in
   331     (* ~j because new asm starts at back, thus we subtract 1 *)
   332     Seq.map (Thm.rotate_rule (~ j) (Thm.nprems_of rule + i))
   333       (dresolve_tac ctxt [preelimrule] i st2)
   334   end;
   335 
   336 
   337 (* prepare to substitute within the j'th premise of subgoal i of gth,
   338 using a meta-level equation. Note that we assume rule has var indicies
   339 zero'd. Note that we also assume that premt is the j'th premice of
   340 subgoal i of gth. Note the repetition of work done for each
   341 assumption, i.e. this can be made more efficient for search over
   342 multiple assumptions.  *)
   343 fun prep_subst_in_asm ctxt i gth j =
   344   let
   345     val th = Thm.incr_indexes 1 gth;
   346     val tgt_term = Thm.prop_of th;
   347 
   348     val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   349     val cfvs = rev (map (Thm.cterm_of ctxt) fvs);
   350 
   351     val asmt = nth (Logic.strip_imp_prems fixedbody) (j - 1);
   352     val asm_nprems = length (Logic.strip_imp_prems asmt);
   353 
   354     val pth = Thm.trivial ((Thm.cterm_of ctxt) asmt);
   355     val maxidx = Thm.maxidx_of th;
   356 
   357     val ft =
   358       (Zipper.move_down_right (* trueprop *)
   359          o Zipper.mktop
   360          o Thm.prop_of) pth
   361   in ((cfvs, j, asm_nprems, pth), (ctxt, maxidx, ft)) end;
   362 
   363 (* prepare subst in every possible assumption *)
   364 fun prep_subst_in_asms ctxt i gth =
   365   map (prep_subst_in_asm ctxt i gth)
   366     ((fn l => Library.upto (1, length l))
   367       (Logic.prems_of_goal (Thm.prop_of gth) i));
   368 
   369 
   370 (* substitute in an assumption using an object or meta level equality *)
   371 fun eqsubst_asm_tac' ctxt searchf skipocc instepthm i st =
   372   let
   373     val asmpreps = prep_subst_in_asms ctxt i st;
   374     val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   375     fun rewrite_with_thm r =
   376       let
   377         val (lhs,_) = Logic.dest_equals (Thm.concl_of r);
   378         fun occ_search occ [] = Seq.empty
   379           | occ_search occ ((asminfo, searchinfo)::moreasms) =
   380               (case searchf searchinfo occ lhs of
   381                 SkipMore i => occ_search i moreasms
   382               | SkipSeq ss =>
   383                   Seq.append (Seq.map (Library.pair asminfo) (Seq.flat ss))
   384                     (occ_search 1 moreasms)) (* find later substs also *)
   385       in
   386         occ_search skipocc asmpreps |> Seq.maps (apply_subst_in_asm ctxt i st r)
   387       end;
   388   in stepthms |> Seq.maps rewrite_with_thm end;
   389 
   390 
   391 fun skip_first_asm_occs_search searchf sinfo occ lhs =
   392   skipto_skipseq occ (searchf sinfo lhs);
   393 
   394 fun eqsubst_asm_tac ctxt occs thms i st =
   395   let val nprems = Thm.nprems_of st in
   396     if nprems < i then Seq.empty
   397     else
   398       let
   399         val thmseq = Seq.of_list thms;
   400         fun apply_occ occ st =
   401           thmseq |> Seq.maps (fn r =>
   402             eqsubst_asm_tac' ctxt
   403               (skip_first_asm_occs_search searchf_lr_unify_valid) occ r
   404               (i + (Thm.nprems_of st - nprems)) st);
   405         val sorted_occs = Library.sort (rev_order o int_ord) occs;
   406       in
   407         Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
   408       end
   409   end;
   410 
   411 (* combination method that takes a flag (true indicates that subst
   412    should be done to an assumption, false = apply to the conclusion of
   413    the goal) as well as the theorems to use *)
   414 val _ =
   415   Theory.setup
   416     (Method.setup @{binding subst}
   417       (Scan.lift (Args.mode "asm" -- Scan.optional (Args.parens (Scan.repeat Parse.nat)) [0]) --
   418         Attrib.thms >> (fn ((asm, occs), inthms) => fn ctxt =>
   419           SIMPLE_METHOD' ((if asm then eqsubst_asm_tac else eqsubst_tac) ctxt occs inthms)))
   420       "single-step substitution");
   421 
   422 end;