src/Tools/eqsubst.ML
author blanchet
Thu Sep 11 21:11:03 2014 +0200 (2014-09-11)
changeset 58318 f95754ca7082
parent 54742 7a86358a3c0b
child 58826 2ed2eaabe3df
permissions -rw-r--r--
fixed some spelling mistakes
     1 (*  Title:      Tools/eqsubst.ML
     2     Author:     Lucas Dixon, University of Edinburgh
     3 
     4 Perform a substitution using an equation.
     5 *)
     6 
     7 signature EQSUBST =
     8 sig
     9   type match =
    10     ((indexname * (sort * typ)) list (* type instantiations *)
    11       * (indexname * (typ * term)) list) (* term instantiations *)
    12     * (string * typ) list (* fake named type abs env *)
    13     * (string * typ) list (* type abs env *)
    14     * term (* outer term *)
    15 
    16   type searchinfo =
    17     theory
    18     * int (* maxidx *)
    19     * Zipper.T (* focusterm to search under *)
    20 
    21   datatype 'a skipseq = SkipMore of int | SkipSeq of 'a Seq.seq Seq.seq
    22 
    23   val skip_first_asm_occs_search: ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> int -> 'b -> 'c skipseq
    24   val skip_first_occs_search: int -> ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> 'b -> 'c Seq.seq
    25   val skipto_skipseq: int -> 'a Seq.seq Seq.seq -> 'a skipseq
    26 
    27   (* tactics *)
    28   val eqsubst_asm_tac: Proof.context -> int list -> thm list -> int -> tactic
    29   val eqsubst_asm_tac': Proof.context ->
    30     (searchinfo -> int -> term -> match skipseq) -> int -> thm -> int -> tactic
    31   val eqsubst_tac: Proof.context ->
    32     int list -> (* list of occurrences to rewrite, use [0] for any *)
    33     thm list -> int -> tactic
    34   val eqsubst_tac': Proof.context ->
    35     (searchinfo -> term -> match Seq.seq) (* search function *)
    36     -> thm (* equation theorem to rewrite with *)
    37     -> int (* subgoal number in goal theorem *)
    38     -> thm (* goal theorem *)
    39     -> thm Seq.seq (* rewritten goal theorem *)
    40 
    41   (* search for substitutions *)
    42   val valid_match_start: Zipper.T -> bool
    43   val search_lr_all: Zipper.T -> Zipper.T Seq.seq
    44   val search_lr_valid: (Zipper.T -> bool) -> Zipper.T -> Zipper.T Seq.seq
    45   val searchf_lr_unify_all: searchinfo -> term -> match Seq.seq Seq.seq
    46   val searchf_lr_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
    47   val searchf_bt_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
    48 
    49   val setup : theory -> theory
    50 end;
    51 
    52 structure EqSubst: EQSUBST =
    53 struct
    54 
    55 (* changes object "=" to meta "==" which prepares a given rewrite rule *)
    56 fun prep_meta_eq ctxt =
    57   Simplifier.mksimps ctxt #> map Drule.zero_var_indexes;
    58 
    59 (* make free vars into schematic vars with index zero *)
    60 fun unfix_frees frees =
    61    fold (K (Thm.forall_elim_var 0)) frees o Drule.forall_intr_list frees;
    62 
    63 
    64 type match =
    65   ((indexname * (sort * typ)) list (* type instantiations *)
    66    * (indexname * (typ * term)) list) (* term instantiations *)
    67   * (string * typ) list (* fake named type abs env *)
    68   * (string * typ) list (* type abs env *)
    69   * term; (* outer term *)
    70 
    71 type searchinfo =
    72   theory
    73   * int (* maxidx *)
    74   * Zipper.T; (* focusterm to search under *)
    75 
    76 
    77 (* skipping non-empty sub-sequences but when we reach the end
    78    of the seq, remembering how much we have left to skip. *)
    79 datatype 'a skipseq =
    80   SkipMore of int |
    81   SkipSeq of 'a Seq.seq Seq.seq;
    82 
    83 (* given a seqseq, skip the first m non-empty seq's, note deficit *)
    84 fun skipto_skipseq m s =
    85   let
    86     fun skip_occs n sq =
    87       (case Seq.pull sq of
    88         NONE => SkipMore n
    89       | SOME (h, t) =>
    90         (case Seq.pull h of
    91           NONE => skip_occs n t
    92         | SOME _ => if n <= 1 then SkipSeq (Seq.cons h t) else skip_occs (n - 1) t))
    93   in skip_occs m s end;
    94 
    95 (* note: outerterm is the taget with the match replaced by a bound
    96    variable : ie: "P lhs" beocmes "%x. P x"
    97    insts is the types of instantiations of vars in lhs
    98    and typinsts is the type instantiations of types in the lhs
    99    Note: Final rule is the rule lifted into the ontext of the
   100    taget thm. *)
   101 fun mk_foo_match mkuptermfunc Ts t =
   102   let
   103     val ty = Term.type_of t
   104     val bigtype = rev (map snd Ts) ---> ty
   105     fun mk_foo 0 t = t
   106       | mk_foo i t = mk_foo (i - 1) (t $ (Bound (i - 1)))
   107     val num_of_bnds = length Ts
   108     (* foo_term = "fooabs y0 ... yn" where y's are local bounds *)
   109     val foo_term = mk_foo num_of_bnds (Bound num_of_bnds)
   110   in Abs ("fooabs", bigtype, mkuptermfunc foo_term) end;
   111 
   112 (* T is outer bound vars, n is number of locally bound vars *)
   113 (* THINK: is order of Ts correct...? or reversed? *)
   114 fun mk_fake_bound_name n = ":b_" ^ n;
   115 fun fakefree_badbounds Ts t =
   116   let val (FakeTs, Ts, newnames) =
   117     fold_rev (fn (n, ty) => fn (FakeTs, Ts, usednames) =>
   118       let
   119         val newname = singleton (Name.variant_list usednames) n
   120       in
   121         ((mk_fake_bound_name newname, ty) :: FakeTs,
   122           (newname, ty) :: Ts,
   123           newname :: usednames)
   124       end) Ts ([], [], [])
   125   in (FakeTs, Ts, Term.subst_bounds (map Free FakeTs, t)) end;
   126 
   127 (* before matching we need to fake the bound vars that are missing an
   128    abstraction. In this function we additionally construct the
   129    abstraction environment, and an outer context term (with the focus
   130    abstracted out) for use in rewriting with RW_Inst.rw *)
   131 fun prep_zipper_match z =
   132   let
   133     val t = Zipper.trm z
   134     val c = Zipper.ctxt z
   135     val Ts = Zipper.C.nty_ctxt c
   136     val (FakeTs', Ts', t') = fakefree_badbounds Ts t
   137     val absterm = mk_foo_match (Zipper.C.apply c) Ts' t'
   138   in
   139     (t', (FakeTs', Ts', absterm))
   140   end;
   141 
   142 (* Unification with exception handled *)
   143 (* given theory, max var index, pat, tgt; returns Seq of instantiations *)
   144 fun clean_unify thy ix (a as (pat, tgt)) =
   145   let
   146     (* type info will be re-derived, maybe this can be cached
   147        for efficiency? *)
   148     val pat_ty = Term.type_of pat;
   149     val tgt_ty = Term.type_of tgt;
   150     (* FIXME is it OK to ignore the type instantiation info?
   151        or should I be using it? *)
   152     val typs_unify =
   153       SOME (Sign.typ_unify thy (pat_ty, tgt_ty) (Vartab.empty, ix))
   154         handle Type.TUNIFY => NONE;
   155   in
   156     (case typs_unify of
   157       SOME (typinsttab, ix2) =>
   158         let
   159           (* FIXME is it right to throw away the flexes?
   160              or should I be using them somehow? *)
   161           fun mk_insts env =
   162             (Vartab.dest (Envir.type_env env),
   163              Vartab.dest (Envir.term_env env));
   164           val initenv =
   165             Envir.Envir {maxidx = ix2, tenv = Vartab.empty, tyenv = typinsttab};
   166           val useq = Unify.smash_unifiers thy [a] initenv
   167             handle ListPair.UnequalLengths => Seq.empty
   168               | Term.TERM _ => Seq.empty;
   169           fun clean_unify' useq () =
   170             (case (Seq.pull useq) of
   171                NONE => NONE
   172              | SOME (h, t) => SOME (mk_insts h, Seq.make (clean_unify' t)))
   173             handle ListPair.UnequalLengths => NONE
   174               | Term.TERM _ => NONE;
   175         in
   176           (Seq.make (clean_unify' useq))
   177         end
   178     | NONE => Seq.empty)
   179   end;
   180 
   181 (* Unification for zippers *)
   182 (* Note: Ts is a modified version of the original names of the outer
   183    bound variables. New names have been introduced to make sure they are
   184    unique w.r.t all names in the term and each other. usednames' is
   185    oldnames + new names. *)
   186 fun clean_unify_z thy maxidx pat z =
   187   let val (t, (FakeTs, Ts, absterm)) = prep_zipper_match z in
   188     Seq.map (fn insts => (insts, FakeTs, Ts, absterm))
   189       (clean_unify thy maxidx (t, pat))
   190   end;
   191 
   192 
   193 fun bot_left_leaf_of (l $ _) = bot_left_leaf_of l
   194   | bot_left_leaf_of (Abs (_, _, t)) = bot_left_leaf_of t
   195   | bot_left_leaf_of x = x;
   196 
   197 (* Avoid considering replacing terms which have a var at the head as
   198    they always succeed trivially, and uninterestingly. *)
   199 fun valid_match_start z =
   200   (case bot_left_leaf_of (Zipper.trm z) of
   201     Var _ => false
   202   | _ => true);
   203 
   204 (* search from top, left to right, then down *)
   205 val search_lr_all = ZipperSearch.all_bl_ur;
   206 
   207 (* search from top, left to right, then down *)
   208 fun search_lr_valid validf =
   209   let
   210     fun sf_valid_td_lr z =
   211       let val here = if validf z then [Zipper.Here z] else [] in
   212         (case Zipper.trm z of
   213           _ $ _ =>
   214             [Zipper.LookIn (Zipper.move_down_left z)] @ here @
   215             [Zipper.LookIn (Zipper.move_down_right z)]
   216         | Abs _ => here @ [Zipper.LookIn (Zipper.move_down_abs z)]
   217         | _ => here)
   218       end;
   219   in Zipper.lzy_search sf_valid_td_lr end;
   220 
   221 (* search from bottom to top, left to right *)
   222 fun search_bt_valid validf =
   223   let
   224     fun sf_valid_td_lr z =
   225       let val here = if validf z then [Zipper.Here z] else [] in
   226         (case Zipper.trm z of
   227           _ $ _ =>
   228             [Zipper.LookIn (Zipper.move_down_left z),
   229              Zipper.LookIn (Zipper.move_down_right z)] @ here
   230         | Abs _ => [Zipper.LookIn (Zipper.move_down_abs z)] @ here
   231         | _ => here)
   232       end;
   233   in Zipper.lzy_search sf_valid_td_lr end;
   234 
   235 fun searchf_unify_gen f (thy, maxidx, z) lhs =
   236   Seq.map (clean_unify_z thy maxidx lhs) (Zipper.limit_apply f z);
   237 
   238 (* search all unifications *)
   239 val searchf_lr_unify_all = searchf_unify_gen search_lr_all;
   240 
   241 (* search only for 'valid' unifiers (non abs subterms and non vars) *)
   242 val searchf_lr_unify_valid = searchf_unify_gen (search_lr_valid valid_match_start);
   243 
   244 val searchf_bt_unify_valid = searchf_unify_gen (search_bt_valid valid_match_start);
   245 
   246 (* apply a substitution in the conclusion of the theorem *)
   247 (* cfvs are certified free var placeholders for goal params *)
   248 (* conclthm is a theorem of for just the conclusion *)
   249 (* m is instantiation/match information *)
   250 (* rule is the equation for substitution *)
   251 fun apply_subst_in_concl ctxt i st (cfvs, conclthm) rule m =
   252   RW_Inst.rw ctxt m rule conclthm
   253   |> unfix_frees cfvs
   254   |> Conv.fconv_rule Drule.beta_eta_conversion
   255   |> (fn r => rtac r i st);
   256 
   257 (* substitute within the conclusion of goal i of gth, using a meta
   258 equation rule. Note that we assume rule has var indicies zero'd *)
   259 fun prep_concl_subst ctxt i gth =
   260   let
   261     val th = Thm.incr_indexes 1 gth;
   262     val tgt_term = Thm.prop_of th;
   263 
   264     val thy = Thm.theory_of_thm th;
   265     val cert = Thm.cterm_of thy;
   266 
   267     val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   268     val cfvs = rev (map cert fvs);
   269 
   270     val conclterm = Logic.strip_imp_concl fixedbody;
   271     val conclthm = Thm.trivial (cert conclterm);
   272     val maxidx = Thm.maxidx_of th;
   273     val ft =
   274       (Zipper.move_down_right (* ==> *)
   275        o Zipper.move_down_left (* Trueprop *)
   276        o Zipper.mktop
   277        o Thm.prop_of) conclthm
   278   in
   279     ((cfvs, conclthm), (thy, maxidx, ft))
   280   end;
   281 
   282 (* substitute using an object or meta level equality *)
   283 fun eqsubst_tac' ctxt searchf instepthm i st =
   284   let
   285     val (cvfsconclthm, searchinfo) = prep_concl_subst ctxt i st;
   286     val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   287     fun rewrite_with_thm r =
   288       let val (lhs,_) = Logic.dest_equals (Thm.concl_of r) in
   289         searchf searchinfo lhs
   290         |> Seq.maps (apply_subst_in_concl ctxt i st cvfsconclthm r)
   291       end;
   292   in stepthms |> Seq.maps rewrite_with_thm end;
   293 
   294 
   295 (* General substitution of multiple occurrences using one of
   296    the given theorems *)
   297 
   298 fun skip_first_occs_search occ srchf sinfo lhs =
   299   (case skipto_skipseq occ (srchf sinfo lhs) of
   300     SkipMore _ => Seq.empty
   301   | SkipSeq ss => Seq.flat ss);
   302 
   303 (* The "occs" argument is a list of integers indicating which occurrence
   304 w.r.t. the search order, to rewrite. Backtracking will also find later
   305 occurrences, but all earlier ones are skipped. Thus you can use [0] to
   306 just find all rewrites. *)
   307 
   308 fun eqsubst_tac ctxt occs thms i st =
   309   let val nprems = Thm.nprems_of st in
   310     if nprems < i then Seq.empty else
   311     let
   312       val thmseq = Seq.of_list thms;
   313       fun apply_occ occ st =
   314         thmseq |> Seq.maps (fn r =>
   315           eqsubst_tac' ctxt
   316             (skip_first_occs_search occ searchf_lr_unify_valid) r
   317             (i + (Thm.nprems_of st - nprems)) st);
   318       val sorted_occs = Library.sort (rev_order o int_ord) occs;
   319     in
   320       Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
   321     end
   322   end;
   323 
   324 
   325 (* apply a substitution inside assumption j, keeps asm in the same place *)
   326 fun apply_subst_in_asm ctxt i st rule ((cfvs, j, _, pth),m) =
   327   let
   328     val st2 = Thm.rotate_rule (j - 1) i st; (* put premice first *)
   329     val preelimrule =
   330       RW_Inst.rw ctxt m rule pth
   331       |> (Seq.hd o prune_params_tac ctxt)
   332       |> Thm.permute_prems 0 ~1 (* put old asm first *)
   333       |> unfix_frees cfvs (* unfix any global params *)
   334       |> Conv.fconv_rule Drule.beta_eta_conversion; (* normal form *)
   335   in
   336     (* ~j because new asm starts at back, thus we subtract 1 *)
   337     Seq.map (Thm.rotate_rule (~ j) (Thm.nprems_of rule + i)) (dtac preelimrule i st2)
   338   end;
   339 
   340 
   341 (* prepare to substitute within the j'th premise of subgoal i of gth,
   342 using a meta-level equation. Note that we assume rule has var indicies
   343 zero'd. Note that we also assume that premt is the j'th premice of
   344 subgoal i of gth. Note the repetition of work done for each
   345 assumption, i.e. this can be made more efficient for search over
   346 multiple assumptions.  *)
   347 fun prep_subst_in_asm ctxt i gth j =
   348   let
   349     val th = Thm.incr_indexes 1 gth;
   350     val tgt_term = Thm.prop_of th;
   351 
   352     val thy = Thm.theory_of_thm th;
   353     val cert = Thm.cterm_of thy;
   354 
   355     val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   356     val cfvs = rev (map cert fvs);
   357 
   358     val asmt = nth (Logic.strip_imp_prems fixedbody) (j - 1);
   359     val asm_nprems = length (Logic.strip_imp_prems asmt);
   360 
   361     val pth = Thm.trivial (cert asmt);
   362     val maxidx = Thm.maxidx_of th;
   363 
   364     val ft =
   365       (Zipper.move_down_right (* trueprop *)
   366          o Zipper.mktop
   367          o Thm.prop_of) pth
   368   in ((cfvs, j, asm_nprems, pth), (thy, maxidx, ft)) end;
   369 
   370 (* prepare subst in every possible assumption *)
   371 fun prep_subst_in_asms ctxt i gth =
   372   map (prep_subst_in_asm ctxt i gth)
   373     ((fn l => Library.upto (1, length l))
   374       (Logic.prems_of_goal (Thm.prop_of gth) i));
   375 
   376 
   377 (* substitute in an assumption using an object or meta level equality *)
   378 fun eqsubst_asm_tac' ctxt searchf skipocc instepthm i st =
   379   let
   380     val asmpreps = prep_subst_in_asms ctxt i st;
   381     val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   382     fun rewrite_with_thm r =
   383       let
   384         val (lhs,_) = Logic.dest_equals (Thm.concl_of r);
   385         fun occ_search occ [] = Seq.empty
   386           | occ_search occ ((asminfo, searchinfo)::moreasms) =
   387               (case searchf searchinfo occ lhs of
   388                 SkipMore i => occ_search i moreasms
   389               | SkipSeq ss =>
   390                   Seq.append (Seq.map (Library.pair asminfo) (Seq.flat ss))
   391                     (occ_search 1 moreasms)) (* find later substs also *)
   392       in
   393         occ_search skipocc asmpreps |> Seq.maps (apply_subst_in_asm ctxt i st r)
   394       end;
   395   in stepthms |> Seq.maps rewrite_with_thm end;
   396 
   397 
   398 fun skip_first_asm_occs_search searchf sinfo occ lhs =
   399   skipto_skipseq occ (searchf sinfo lhs);
   400 
   401 fun eqsubst_asm_tac ctxt occs thms i st =
   402   let val nprems = Thm.nprems_of st in
   403     if nprems < i then Seq.empty
   404     else
   405       let
   406         val thmseq = Seq.of_list thms;
   407         fun apply_occ occ st =
   408           thmseq |> Seq.maps (fn r =>
   409             eqsubst_asm_tac' ctxt
   410               (skip_first_asm_occs_search searchf_lr_unify_valid) occ r
   411               (i + (Thm.nprems_of st - nprems)) st);
   412         val sorted_occs = Library.sort (rev_order o int_ord) occs;
   413       in
   414         Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
   415       end
   416   end;
   417 
   418 (* combination method that takes a flag (true indicates that subst
   419    should be done to an assumption, false = apply to the conclusion of
   420    the goal) as well as the theorems to use *)
   421 val setup =
   422   Method.setup @{binding subst}
   423     (Scan.lift (Args.mode "asm" -- Scan.optional (Args.parens (Scan.repeat Parse.nat)) [0]) --
   424         Attrib.thms >>
   425       (fn ((asm, occs), inthms) => fn ctxt =>
   426         SIMPLE_METHOD' ((if asm then eqsubst_asm_tac else eqsubst_tac) ctxt occs inthms)))
   427     "single-step substitution";
   428 
   429 end;