src/Tools/eqsubst.ML
author wenzelm
Thu May 30 15:51:55 2013 +0200 (2013-05-30)
changeset 52238 d84ff5a93764
parent 52237 ab3ba550cbe7
child 52239 6a6033fa507c
permissions -rw-r--r--
more standard names;
     1 (*  Title:      Tools/eqsubst.ML
     2     Author:     Lucas Dixon, University of Edinburgh
     3 
     4 Perform a substitution using an equation.
     5 *)
     6 
     7 signature EQSUBST =
     8 sig
     9   type match =
    10     ((indexname * (sort * typ)) list (* type instantiations *)
    11       * (indexname * (typ * term)) list) (* term instantiations *)
    12     * (string * typ) list (* fake named type abs env *)
    13     * (string * typ) list (* type abs env *)
    14     * term (* outer term *)
    15 
    16   type searchinfo =
    17     theory
    18     * int (* maxidx *)
    19     * Zipper.T (* focusterm to search under *)
    20 
    21   datatype 'a skipseq = SkipMore of int | SkipSeq of 'a Seq.seq Seq.seq
    22 
    23   val skip_first_asm_occs_search: ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> int -> 'b -> 'c skipseq
    24   val skip_first_occs_search: int -> ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> 'b -> 'c Seq.seq
    25   val skipto_skipseq: int -> 'a Seq.seq Seq.seq -> 'a skipseq
    26 
    27   (* tactics *)
    28   val eqsubst_asm_tac: Proof.context -> int list -> thm list -> int -> tactic
    29   val eqsubst_asm_tac': Proof.context ->
    30     (searchinfo -> int -> term -> match skipseq) -> int -> thm -> int -> tactic
    31   val eqsubst_tac: Proof.context ->
    32     int list -> (* list of occurences to rewrite, use [0] for any *)
    33     thm list -> int -> tactic
    34   val eqsubst_tac': Proof.context ->
    35     (searchinfo -> term -> match Seq.seq) (* search function *)
    36     -> thm (* equation theorem to rewrite with *)
    37     -> int (* subgoal number in goal theorem *)
    38     -> thm (* goal theorem *)
    39     -> thm Seq.seq (* rewritten goal theorem *)
    40 
    41   (* search for substitutions *)
    42   val valid_match_start: Zipper.T -> bool
    43   val search_lr_all: Zipper.T -> Zipper.T Seq.seq
    44   val search_lr_valid: (Zipper.T -> bool) -> Zipper.T -> Zipper.T Seq.seq
    45   val searchf_lr_unify_all: searchinfo -> term -> match Seq.seq Seq.seq
    46   val searchf_lr_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
    47   val searchf_bt_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
    48 
    49   val setup : theory -> theory
    50 end;
    51 
    52 structure EqSubst: EQSUBST =
    53 struct
    54 
    55 (* changes object "=" to meta "==" which prepares a given rewrite rule *)
    56 fun prep_meta_eq ctxt =
    57   Simplifier.mksimps ctxt #> map Drule.zero_var_indexes;
    58 
    59 
    60 type match =
    61   ((indexname * (sort * typ)) list (* type instantiations *)
    62    * (indexname * (typ * term)) list) (* term instantiations *)
    63   * (string * typ) list (* fake named type abs env *)
    64   * (string * typ) list (* type abs env *)
    65   * term; (* outer term *)
    66 
    67 type searchinfo =
    68   theory
    69   * int (* maxidx *)
    70   * Zipper.T; (* focusterm to search under *)
    71 
    72 
    73 (* skipping non-empty sub-sequences but when we reach the end
    74    of the seq, remembering how much we have left to skip. *)
    75 datatype 'a skipseq =
    76   SkipMore of int |
    77   SkipSeq of 'a Seq.seq Seq.seq;
    78 
    79 (* given a seqseq, skip the first m non-empty seq's, note deficit *)
    80 fun skipto_skipseq m s =
    81   let
    82     fun skip_occs n sq =
    83       (case Seq.pull sq of
    84         NONE => SkipMore n
    85       | SOME (h, t) =>
    86         (case Seq.pull h of
    87           NONE => skip_occs n t
    88         | SOME _ => if n <= 1 then SkipSeq (Seq.cons h t) else skip_occs (n - 1) t))
    89   in skip_occs m s end;
    90 
    91 (* note: outerterm is the taget with the match replaced by a bound
    92    variable : ie: "P lhs" beocmes "%x. P x"
    93    insts is the types of instantiations of vars in lhs
    94    and typinsts is the type instantiations of types in the lhs
    95    Note: Final rule is the rule lifted into the ontext of the
    96    taget thm. *)
    97 fun mk_foo_match mkuptermfunc Ts t =
    98   let
    99     val ty = Term.type_of t
   100     val bigtype = rev (map snd Ts) ---> ty
   101     fun mk_foo 0 t = t
   102       | mk_foo i t = mk_foo (i - 1) (t $ (Bound (i - 1)))
   103     val num_of_bnds = length Ts
   104     (* foo_term = "fooabs y0 ... yn" where y's are local bounds *)
   105     val foo_term = mk_foo num_of_bnds (Bound num_of_bnds)
   106   in Abs ("fooabs", bigtype, mkuptermfunc foo_term) end;
   107 
   108 (* T is outer bound vars, n is number of locally bound vars *)
   109 (* THINK: is order of Ts correct...? or reversed? *)
   110 fun mk_fake_bound_name n = ":b_" ^ n;
   111 fun fakefree_badbounds Ts t =
   112   let val (FakeTs, Ts, newnames) =
   113     List.foldr (fn ((n, ty), (FakeTs, Ts, usednames)) =>
   114       let
   115         val newname = singleton (Name.variant_list usednames) n
   116       in
   117         ((mk_fake_bound_name newname, ty) :: FakeTs,
   118           (newname, ty) :: Ts,
   119           newname :: usednames)
   120       end) ([], [], []) Ts
   121   in (FakeTs, Ts, Term.subst_bounds (map Free FakeTs, t)) end;
   122 
   123 (* before matching we need to fake the bound vars that are missing an
   124    abstraction. In this function we additionally construct the
   125    abstraction environment, and an outer context term (with the focus
   126    abstracted out) for use in rewriting with RWInst.rw *)
   127 fun prep_zipper_match z =
   128   let
   129     val t = Zipper.trm z
   130     val c = Zipper.ctxt z
   131     val Ts = Zipper.C.nty_ctxt c
   132     val (FakeTs', Ts', t') = fakefree_badbounds Ts t
   133     val absterm = mk_foo_match (Zipper.C.apply c) Ts' t'
   134   in
   135     (t', (FakeTs', Ts', absterm))
   136   end;
   137 
   138 (* Unification with exception handled *)
   139 (* given theory, max var index, pat, tgt; returns Seq of instantiations *)
   140 fun clean_unify thy ix (a as (pat, tgt)) =
   141   let
   142     (* type info will be re-derived, maybe this can be cached
   143        for efficiency? *)
   144     val pat_ty = Term.type_of pat;
   145     val tgt_ty = Term.type_of tgt;
   146     (* FIXME is it OK to ignore the type instantiation info?
   147        or should I be using it? *)
   148     val typs_unify =
   149       SOME (Sign.typ_unify thy (pat_ty, tgt_ty) (Vartab.empty, ix))
   150         handle Type.TUNIFY => NONE;
   151   in
   152     (case typs_unify of
   153       SOME (typinsttab, ix2) =>
   154         let
   155           (* FIXME is it right to throw away the flexes?
   156              or should I be using them somehow? *)
   157           fun mk_insts env =
   158             (Vartab.dest (Envir.type_env env),
   159              Vartab.dest (Envir.term_env env));
   160           val initenv =
   161             Envir.Envir {maxidx = ix2, tenv = Vartab.empty, tyenv = typinsttab};
   162           val useq = Unify.smash_unifiers thy [a] initenv
   163             handle ListPair.UnequalLengths => Seq.empty
   164               | Term.TERM _ => Seq.empty;
   165           fun clean_unify' useq () =
   166             (case (Seq.pull useq) of
   167                NONE => NONE
   168              | SOME (h, t) => SOME (mk_insts h, Seq.make (clean_unify' t)))
   169             handle ListPair.UnequalLengths => NONE
   170               | Term.TERM _ => NONE;
   171         in
   172           (Seq.make (clean_unify' useq))
   173         end
   174     | NONE => Seq.empty)
   175   end;
   176 
   177 (* Unification for zippers *)
   178 (* Note: Ts is a modified version of the original names of the outer
   179    bound variables. New names have been introduced to make sure they are
   180    unique w.r.t all names in the term and each other. usednames' is
   181    oldnames + new names. *)
   182 fun clean_unify_z thy maxidx pat z =
   183   let val (t, (FakeTs, Ts, absterm)) = prep_zipper_match z in
   184     Seq.map (fn insts => (insts, FakeTs, Ts, absterm))
   185       (clean_unify thy maxidx (t, pat))
   186   end;
   187 
   188 
   189 fun bot_left_leaf_of (l $ _) = bot_left_leaf_of l
   190   | bot_left_leaf_of (Abs (_, _, t)) = bot_left_leaf_of t
   191   | bot_left_leaf_of x = x;
   192 
   193 (* Avoid considering replacing terms which have a var at the head as
   194    they always succeed trivially, and uninterestingly. *)
   195 fun valid_match_start z =
   196   (case bot_left_leaf_of (Zipper.trm z) of
   197     Var _ => false
   198   | _ => true);
   199 
   200 (* search from top, left to right, then down *)
   201 val search_lr_all = ZipperSearch.all_bl_ur;
   202 
   203 (* search from top, left to right, then down *)
   204 fun search_lr_valid validf =
   205   let
   206     fun sf_valid_td_lr z =
   207       let val here = if validf z then [Zipper.Here z] else [] in
   208         (case Zipper.trm z of
   209           _ $ _ =>
   210             [Zipper.LookIn (Zipper.move_down_left z)] @ here @
   211             [Zipper.LookIn (Zipper.move_down_right z)]
   212         | Abs _ => here @ [Zipper.LookIn (Zipper.move_down_abs z)]
   213         | _ => here)
   214       end;
   215   in Zipper.lzy_search sf_valid_td_lr end;
   216 
   217 (* search from bottom to top, left to right *)
   218 fun search_bt_valid validf =
   219   let
   220     fun sf_valid_td_lr z =
   221       let val here = if validf z then [Zipper.Here z] else [] in
   222         (case Zipper.trm z of
   223           _ $ _ =>
   224             [Zipper.LookIn (Zipper.move_down_left z),
   225              Zipper.LookIn (Zipper.move_down_right z)] @ here
   226         | Abs _ => [Zipper.LookIn (Zipper.move_down_abs z)] @ here
   227         | _ => here)
   228       end;
   229   in Zipper.lzy_search sf_valid_td_lr end;
   230 
   231 fun searchf_unify_gen f (thy, maxidx, z) lhs =
   232   Seq.map (clean_unify_z thy maxidx lhs) (Zipper.limit_apply f z);
   233 
   234 (* search all unifications *)
   235 val searchf_lr_unify_all = searchf_unify_gen search_lr_all;
   236 
   237 (* search only for 'valid' unifiers (non abs subterms and non vars) *)
   238 val searchf_lr_unify_valid = searchf_unify_gen (search_lr_valid valid_match_start);
   239 
   240 val searchf_bt_unify_valid = searchf_unify_gen (search_bt_valid valid_match_start);
   241 
   242 (* apply a substitution in the conclusion of the theorem *)
   243 (* cfvs are certified free var placeholders for goal params *)
   244 (* conclthm is a theorem of for just the conclusion *)
   245 (* m is instantiation/match information *)
   246 (* rule is the equation for substitution *)
   247 fun apply_subst_in_concl ctxt i st (cfvs, conclthm) rule m =
   248   RWInst.rw ctxt m rule conclthm
   249   |> IsaND.unfix_frees cfvs
   250   |> RWInst.beta_eta_contract
   251   |> (fn r => Tactic.rtac r i st);
   252 
   253 (* substitute within the conclusion of goal i of gth, using a meta
   254 equation rule. Note that we assume rule has var indicies zero'd *)
   255 fun prep_concl_subst ctxt i gth =
   256   let
   257     val th = Thm.incr_indexes 1 gth;
   258     val tgt_term = Thm.prop_of th;
   259 
   260     val thy = Thm.theory_of_thm th;
   261     val cert = Thm.cterm_of thy;
   262 
   263     val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   264     val cfvs = rev (map cert fvs);
   265 
   266     val conclterm = Logic.strip_imp_concl fixedbody;
   267     val conclthm = Thm.trivial (cert conclterm);
   268     val maxidx = Thm.maxidx_of th;
   269     val ft =
   270       (Zipper.move_down_right (* ==> *)
   271        o Zipper.move_down_left (* Trueprop *)
   272        o Zipper.mktop
   273        o Thm.prop_of) conclthm
   274   in
   275     ((cfvs, conclthm), (thy, maxidx, ft))
   276   end;
   277 
   278 (* substitute using an object or meta level equality *)
   279 fun eqsubst_tac' ctxt searchf instepthm i st =
   280   let
   281     val (cvfsconclthm, searchinfo) = prep_concl_subst ctxt i st;
   282     val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   283     fun rewrite_with_thm r =
   284       let val (lhs,_) = Logic.dest_equals (Thm.concl_of r) in
   285         searchf searchinfo lhs
   286         |> Seq.maps (apply_subst_in_concl ctxt i st cvfsconclthm r)
   287       end;
   288   in stepthms |> Seq.maps rewrite_with_thm end;
   289 
   290 
   291 (* General substitution of multiple occurances using one of
   292    the given theorems *)
   293 
   294 fun skip_first_occs_search occ srchf sinfo lhs =
   295   (case skipto_skipseq occ (srchf sinfo lhs) of
   296     SkipMore _ => Seq.empty
   297   | SkipSeq ss => Seq.flat ss);
   298 
   299 (* The "occs" argument is a list of integers indicating which occurence
   300 w.r.t. the search order, to rewrite. Backtracking will also find later
   301 occurences, but all earlier ones are skipped. Thus you can use [0] to
   302 just find all rewrites. *)
   303 
   304 fun eqsubst_tac ctxt occs thms i st =
   305   let val nprems = Thm.nprems_of st in
   306     if nprems < i then Seq.empty else
   307     let
   308       val thmseq = Seq.of_list thms;
   309       fun apply_occ occ st =
   310         thmseq |> Seq.maps (fn r =>
   311           eqsubst_tac' ctxt
   312             (skip_first_occs_search occ searchf_lr_unify_valid) r
   313             (i + (Thm.nprems_of st - nprems)) st);
   314       val sorted_occs = Library.sort (rev_order o int_ord) occs;
   315     in
   316       Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
   317     end
   318   end;
   319 
   320 
   321 (* apply a substitution inside assumption j, keeps asm in the same place *)
   322 fun apply_subst_in_asm ctxt i st rule ((cfvs, j, _, pth),m) =
   323   let
   324     val st2 = Thm.rotate_rule (j - 1) i st; (* put premice first *)
   325     val preelimrule =
   326       RWInst.rw ctxt m rule pth
   327       |> (Seq.hd o prune_params_tac)
   328       |> Thm.permute_prems 0 ~1 (* put old asm first *)
   329       |> IsaND.unfix_frees cfvs (* unfix any global params *)
   330       |> RWInst.beta_eta_contract; (* normal form *)
   331   in
   332     (* ~j because new asm starts at back, thus we subtract 1 *)
   333     Seq.map (Thm.rotate_rule (~ j) (Thm.nprems_of rule + i))
   334       (Tactic.dtac preelimrule i st2)
   335   end;
   336 
   337 
   338 (* prepare to substitute within the j'th premise of subgoal i of gth,
   339 using a meta-level equation. Note that we assume rule has var indicies
   340 zero'd. Note that we also assume that premt is the j'th premice of
   341 subgoal i of gth. Note the repetition of work done for each
   342 assumption, i.e. this can be made more efficient for search over
   343 multiple assumptions.  *)
   344 fun prep_subst_in_asm ctxt i gth j =
   345   let
   346     val th = Thm.incr_indexes 1 gth;
   347     val tgt_term = Thm.prop_of th;
   348 
   349     val thy = Thm.theory_of_thm th;
   350     val cert = Thm.cterm_of thy;
   351 
   352     val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   353     val cfvs = rev (map cert fvs);
   354 
   355     val asmt = nth (Logic.strip_imp_prems fixedbody) (j - 1);
   356     val asm_nprems = length (Logic.strip_imp_prems asmt);
   357 
   358     val pth = Thm.trivial (cert asmt);
   359     val maxidx = Thm.maxidx_of th;
   360 
   361     val ft =
   362       (Zipper.move_down_right (* trueprop *)
   363          o Zipper.mktop
   364          o Thm.prop_of) pth
   365   in ((cfvs, j, asm_nprems, pth), (thy, maxidx, ft)) end;
   366 
   367 (* prepare subst in every possible assumption *)
   368 fun prep_subst_in_asms ctxt i gth =
   369   map (prep_subst_in_asm ctxt i gth)
   370     ((fn l => Library.upto (1, length l))
   371       (Logic.prems_of_goal (Thm.prop_of gth) i));
   372 
   373 
   374 (* substitute in an assumption using an object or meta level equality *)
   375 fun eqsubst_asm_tac' ctxt searchf skipocc instepthm i st =
   376   let
   377     val asmpreps = prep_subst_in_asms ctxt i st;
   378     val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   379     fun rewrite_with_thm r =
   380       let
   381         val (lhs,_) = Logic.dest_equals (Thm.concl_of r);
   382         fun occ_search occ [] = Seq.empty
   383           | occ_search occ ((asminfo, searchinfo)::moreasms) =
   384               (case searchf searchinfo occ lhs of
   385                 SkipMore i => occ_search i moreasms
   386               | SkipSeq ss =>
   387                   Seq.append (Seq.map (Library.pair asminfo) (Seq.flat ss))
   388                     (occ_search 1 moreasms)) (* find later substs also *)
   389       in
   390         occ_search skipocc asmpreps |> Seq.maps (apply_subst_in_asm ctxt i st r)
   391       end;
   392   in stepthms |> Seq.maps rewrite_with_thm end;
   393 
   394 
   395 fun skip_first_asm_occs_search searchf sinfo occ lhs =
   396   skipto_skipseq occ (searchf sinfo lhs);
   397 
   398 fun eqsubst_asm_tac ctxt occs thms i st =
   399   let val nprems = Thm.nprems_of st in
   400     if nprems < i then Seq.empty
   401     else
   402       let
   403         val thmseq = Seq.of_list thms;
   404         fun apply_occ occ st =
   405           thmseq |> Seq.maps (fn r =>
   406             eqsubst_asm_tac' ctxt
   407               (skip_first_asm_occs_search searchf_lr_unify_valid) occ r
   408               (i + (Thm.nprems_of st - nprems)) st);
   409         val sorted_occs = Library.sort (rev_order o int_ord) occs;
   410       in
   411         Seq.maps distinct_subgoals_tac (Seq.EVERY (map apply_occ sorted_occs) st)
   412       end
   413   end;
   414 
   415 (* combination method that takes a flag (true indicates that subst
   416    should be done to an assumption, false = apply to the conclusion of
   417    the goal) as well as the theorems to use *)
   418 val setup =
   419   Method.setup @{binding subst}
   420     (Args.mode "asm" -- Scan.lift (Scan.optional (Args.parens (Scan.repeat Parse.nat)) [0]) --
   421         Attrib.thms >>
   422       (fn ((asm, occs), inthms) => fn ctxt =>
   423         SIMPLE_METHOD' ((if asm then eqsubst_asm_tac else eqsubst_tac) ctxt occs inthms)))
   424     "single-step substitution";
   425 
   426 end;