src/Tools/eqsubst.ML
author wenzelm
Thu May 30 14:17:56 2013 +0200 (2013-05-30)
changeset 52235 6aff6b8bec13
parent 52234 6ffcce211047
child 52236 fb82b42eb498
permissions -rw-r--r--
tuned;
     1 (*  Title:      Tools/eqsubst.ML
     2     Author:     Lucas Dixon, University of Edinburgh
     3 
     4 Perform a substitution using an equation.
     5 *)
     6 
     7 signature EQSUBST =
     8 sig
     9   type match =
    10     ((indexname * (sort * typ)) list (* type instantiations *)
    11       * (indexname * (typ * term)) list) (* term instantiations *)
    12     * (string * typ) list (* fake named type abs env *)
    13     * (string * typ) list (* type abs env *)
    14     * term (* outer term *)
    15 
    16   type searchinfo =
    17     theory
    18     * int (* maxidx *)
    19     * Zipper.T (* focusterm to search under *)
    20 
    21   datatype 'a skipseq = SkipMore of int | SkipSeq of 'a Seq.seq Seq.seq
    22 
    23   val skip_first_asm_occs_search: ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> int -> 'b -> 'c skipseq
    24   val skip_first_occs_search: int -> ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> 'b -> 'c Seq.seq
    25   val skipto_skipseq: int -> 'a Seq.seq Seq.seq -> 'a skipseq
    26 
    27   (* tactics *)
    28   val eqsubst_asm_tac: Proof.context -> int list -> thm list -> int -> tactic
    29   val eqsubst_asm_tac': Proof.context ->
    30     (searchinfo -> int -> term -> match skipseq) -> int -> thm -> int -> tactic
    31   val eqsubst_tac: Proof.context ->
    32     int list -> (* list of occurences to rewrite, use [0] for any *)
    33     thm list -> int -> tactic
    34   val eqsubst_tac': Proof.context ->
    35     (searchinfo -> term -> match Seq.seq) (* search function *)
    36     -> thm (* equation theorem to rewrite with *)
    37     -> int (* subgoal number in goal theorem *)
    38     -> thm (* goal theorem *)
    39     -> thm Seq.seq (* rewritten goal theorem *)
    40 
    41   (* search for substitutions *)
    42   val valid_match_start: Zipper.T -> bool
    43   val search_lr_all: Zipper.T -> Zipper.T Seq.seq
    44   val search_lr_valid: (Zipper.T -> bool) -> Zipper.T -> Zipper.T Seq.seq
    45   val searchf_lr_unify_all: searchinfo -> term -> match Seq.seq Seq.seq
    46   val searchf_lr_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
    47   val searchf_bt_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
    48 
    49   val setup : theory -> theory
    50 end;
    51 
    52 structure EqSubst: EQSUBST =
    53 struct
    54 
    55 (* changes object "=" to meta "==" which prepares a given rewrite rule *)
    56 fun prep_meta_eq ctxt =
    57   Simplifier.mksimps ctxt #> map Drule.zero_var_indexes;
    58 
    59 
    60 type match =
    61   ((indexname * (sort * typ)) list (* type instantiations *)
    62    * (indexname * (typ * term)) list) (* term instantiations *)
    63   * (string * typ) list (* fake named type abs env *)
    64   * (string * typ) list (* type abs env *)
    65   * term; (* outer term *)
    66 
    67 type searchinfo =
    68   theory
    69   * int (* maxidx *)
    70   * Zipper.T; (* focusterm to search under *)
    71 
    72 
    73 (* skipping non-empty sub-sequences but when we reach the end
    74    of the seq, remembering how much we have left to skip. *)
    75 datatype 'a skipseq =
    76   SkipMore of int |
    77   SkipSeq of 'a Seq.seq Seq.seq;
    78 
    79 (* given a seqseq, skip the first m non-empty seq's, note deficit *)
    80 fun skipto_skipseq m s =
    81   let
    82     fun skip_occs n sq =
    83       (case Seq.pull sq of
    84         NONE => SkipMore n
    85       | SOME (h,t) =>
    86         (case Seq.pull h of
    87           NONE => skip_occs n t
    88         | SOME _ => if n <= 1 then SkipSeq (Seq.cons h t) else skip_occs (n - 1) t))
    89   in skip_occs m s end;
    90 
    91 (* note: outerterm is the taget with the match replaced by a bound
    92    variable : ie: "P lhs" beocmes "%x. P x"
    93    insts is the types of instantiations of vars in lhs
    94    and typinsts is the type instantiations of types in the lhs
    95    Note: Final rule is the rule lifted into the ontext of the
    96    taget thm. *)
    97 fun mk_foo_match mkuptermfunc Ts t =
    98   let
    99     val ty = Term.type_of t
   100     val bigtype = rev (map snd Ts) ---> ty
   101     fun mk_foo 0 t = t
   102       | mk_foo i t = mk_foo (i - 1) (t $ (Bound (i - 1)))
   103     val num_of_bnds = length Ts
   104     (* foo_term = "fooabs y0 ... yn" where y's are local bounds *)
   105     val foo_term = mk_foo num_of_bnds (Bound num_of_bnds)
   106   in Abs ("fooabs", bigtype, mkuptermfunc foo_term) end;
   107 
   108 (* T is outer bound vars, n is number of locally bound vars *)
   109 (* THINK: is order of Ts correct...? or reversed? *)
   110 fun mk_fake_bound_name n = ":b_" ^ n;
   111 fun fakefree_badbounds Ts t =
   112   let val (FakeTs, Ts, newnames) =
   113     List.foldr (fn ((n, ty), (FakeTs, Ts, usednames)) =>
   114       let
   115         val newname = singleton (Name.variant_list usednames) n
   116       in
   117         ((mk_fake_bound_name newname, ty) :: FakeTs,
   118           (newname, ty) :: Ts,
   119           newname :: usednames)
   120       end) ([], [], []) Ts
   121   in (FakeTs, Ts, Term.subst_bounds (map Free FakeTs, t)) end;
   122 
   123 (* before matching we need to fake the bound vars that are missing an
   124    abstraction. In this function we additionally construct the
   125    abstraction environment, and an outer context term (with the focus
   126    abstracted out) for use in rewriting with RWInst.rw *)
   127 fun prep_zipper_match z =
   128   let
   129     val t = Zipper.trm z
   130     val c = Zipper.ctxt z
   131     val Ts = Zipper.C.nty_ctxt c
   132     val (FakeTs', Ts', t') = fakefree_badbounds Ts t
   133     val absterm = mk_foo_match (Zipper.C.apply c) Ts' t'
   134   in
   135     (t', (FakeTs', Ts', absterm))
   136   end;
   137 
   138 (* Unification with exception handled *)
   139 (* given theory, max var index, pat, tgt; returns Seq of instantiations *)
   140 fun clean_unify thy ix (a as (pat, tgt)) =
   141   let
   142     (* type info will be re-derived, maybe this can be cached
   143        for efficiency? *)
   144     val pat_ty = Term.type_of pat;
   145     val tgt_ty = Term.type_of tgt;
   146     (* FIXME is it OK to ignore the type instantiation info?
   147        or should I be using it? *)
   148     val typs_unify =
   149       SOME (Sign.typ_unify thy (pat_ty, tgt_ty) (Vartab.empty, ix))
   150         handle Type.TUNIFY => NONE;
   151   in
   152     (case typs_unify of
   153       SOME (typinsttab, ix2) =>
   154         let
   155           (* FIXME is it right to throw away the flexes?
   156              or should I be using them somehow? *)
   157           fun mk_insts env =
   158             (Vartab.dest (Envir.type_env env),
   159              Vartab.dest (Envir.term_env env));
   160           val initenv =
   161             Envir.Envir {maxidx = ix2, tenv = Vartab.empty, tyenv = typinsttab};
   162           val useq = Unify.smash_unifiers thy [a] initenv
   163             handle ListPair.UnequalLengths => Seq.empty
   164               | Term.TERM _ => Seq.empty;
   165           fun clean_unify' useq () =
   166             (case (Seq.pull useq) of
   167                NONE => NONE
   168              | SOME (h, t) => SOME (mk_insts h, Seq.make (clean_unify' t)))
   169             handle ListPair.UnequalLengths => NONE
   170               | Term.TERM _ => NONE;
   171         in
   172           (Seq.make (clean_unify' useq))
   173         end
   174     | NONE => Seq.empty)
   175   end;
   176 
   177 (* Unification for zippers *)
   178 (* Note: Ts is a modified version of the original names of the outer
   179    bound variables. New names have been introduced to make sure they are
   180    unique w.r.t all names in the term and each other. usednames' is
   181    oldnames + new names. *)
   182 fun clean_unify_z thy maxidx pat z =
   183   let val (t, (FakeTs, Ts, absterm)) = prep_zipper_match z in
   184     Seq.map (fn insts => (insts, FakeTs, Ts, absterm))
   185       (clean_unify thy maxidx (t, pat))
   186   end;
   187 
   188 
   189 fun bot_left_leaf_of (l $ _) = bot_left_leaf_of l
   190   | bot_left_leaf_of (Abs (_, _, t)) = bot_left_leaf_of t
   191   | bot_left_leaf_of x = x;
   192 
   193 (* Avoid considering replacing terms which have a var at the head as
   194    they always succeed trivially, and uninterestingly. *)
   195 fun valid_match_start z =
   196   (case bot_left_leaf_of (Zipper.trm z) of
   197     Var _ => false
   198   | _ => true);
   199 
   200 (* search from top, left to right, then down *)
   201 val search_lr_all = ZipperSearch.all_bl_ur;
   202 
   203 (* search from top, left to right, then down *)
   204 fun search_lr_valid validf =
   205   let
   206     fun sf_valid_td_lr z =
   207       let val here = if validf z then [Zipper.Here z] else [] in
   208         (case Zipper.trm z of
   209           _ $ _ =>
   210             [Zipper.LookIn (Zipper.move_down_left z)] @ here @
   211             [Zipper.LookIn (Zipper.move_down_right z)]
   212         | Abs _ => here @ [Zipper.LookIn (Zipper.move_down_abs z)]
   213         | _ => here)
   214       end;
   215   in Zipper.lzy_search sf_valid_td_lr end;
   216 
   217 (* search from bottom to top, left to right *)
   218 fun search_bt_valid validf =
   219   let
   220     fun sf_valid_td_lr z =
   221       let val here = if validf z then [Zipper.Here z] else [] in
   222         (case Zipper.trm z of
   223           _ $ _ =>
   224             [Zipper.LookIn (Zipper.move_down_left z),
   225              Zipper.LookIn (Zipper.move_down_right z)] @ here
   226         | Abs _ => [Zipper.LookIn (Zipper.move_down_abs z)] @ here
   227         | _ => here)
   228       end;
   229   in Zipper.lzy_search sf_valid_td_lr end;
   230 
   231 fun searchf_unify_gen f (thy, maxidx, z) lhs =
   232   Seq.map (clean_unify_z thy maxidx lhs) (Zipper.limit_apply f z);
   233 
   234 (* search all unifications *)
   235 val searchf_lr_unify_all = searchf_unify_gen search_lr_all;
   236 
   237 (* search only for 'valid' unifiers (non abs subterms and non vars) *)
   238 val searchf_lr_unify_valid = searchf_unify_gen (search_lr_valid valid_match_start);
   239 
   240 val searchf_bt_unify_valid = searchf_unify_gen (search_bt_valid valid_match_start);
   241 
   242 (* apply a substitution in the conclusion of the theorem th *)
   243 (* cfvs are certified free var placeholders for goal params *)
   244 (* conclthm is a theorem of for just the conclusion *)
   245 (* m is instantiation/match information *)
   246 (* rule is the equation for substitution *)
   247 fun apply_subst_in_concl ctxt i th (cfvs, conclthm) rule m =
   248   RWInst.rw ctxt m rule conclthm
   249   |> IsaND.unfix_frees cfvs
   250   |> RWInst.beta_eta_contract
   251   |> (fn r => Tactic.rtac r i th);
   252 
   253 (* substitute within the conclusion of goal i of gth, using a meta
   254 equation rule. Note that we assume rule has var indicies zero'd *)
   255 fun prep_concl_subst ctxt i gth =
   256   let
   257     val th = Thm.incr_indexes 1 gth;
   258     val tgt_term = Thm.prop_of th;
   259 
   260     val thy = Thm.theory_of_thm th;
   261     val cert = Thm.cterm_of thy;
   262 
   263     val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   264     val cfvs = rev (map cert fvs);
   265 
   266     val conclterm = Logic.strip_imp_concl fixedbody;
   267     val conclthm = Thm.trivial (cert conclterm);
   268     val maxidx = Thm.maxidx_of th;
   269     val ft =
   270       (Zipper.move_down_right (* ==> *)
   271        o Zipper.move_down_left (* Trueprop *)
   272        o Zipper.mktop
   273        o Thm.prop_of) conclthm
   274   in
   275     ((cfvs, conclthm), (thy, maxidx, ft))
   276   end;
   277 
   278 (* substitute using an object or meta level equality *)
   279 fun eqsubst_tac' ctxt searchf instepthm i th =
   280   let
   281     val (cvfsconclthm, searchinfo) = prep_concl_subst ctxt i th;
   282     val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   283     fun rewrite_with_thm r =
   284       let val (lhs,_) = Logic.dest_equals (Thm.concl_of r) in
   285         searchf searchinfo lhs
   286         |> Seq.maps (apply_subst_in_concl ctxt i th cvfsconclthm r)
   287       end;
   288   in stepthms |> Seq.maps rewrite_with_thm end;
   289 
   290 
   291 (* distinct subgoals *)
   292 fun distinct_subgoals th = the_default th (SINGLE distinct_subgoals_tac th);
   293 
   294 
   295 (* General substitution of multiple occurances using one of
   296    the given theorems *)
   297 
   298 fun skip_first_occs_search occ srchf sinfo lhs =
   299   (case (skipto_skipseq occ (srchf sinfo lhs)) of
   300     SkipMore _ => Seq.empty
   301   | SkipSeq ss => Seq.flat ss);
   302 
   303 (* The occL is a list of integers indicating which occurence
   304 w.r.t. the search order, to rewrite. Backtracking will also find later
   305 occurences, but all earlier ones are skipped. Thus you can use [0] to
   306 just find all rewrites. *)
   307 
   308 fun eqsubst_tac ctxt occL thms i th =
   309   let val nprems = Thm.nprems_of th in
   310     if nprems < i then Seq.empty else
   311     let
   312       val thmseq = (Seq.of_list thms);
   313       fun apply_occ occ th =
   314         thmseq |> Seq.maps (fn r =>
   315           eqsubst_tac' ctxt
   316             (skip_first_occs_search occ searchf_lr_unify_valid) r
   317             (i + (Thm.nprems_of th - nprems)) th);
   318       val sortedoccL = Library.sort (rev_order o int_ord) occL;
   319     in
   320       Seq.map distinct_subgoals (Seq.EVERY (map apply_occ sortedoccL) th)
   321     end
   322   end;
   323 
   324 
   325 (* inthms are the given arguments in Isar, and treated as eqstep with
   326    the first one, then the second etc *)
   327 fun eqsubst_meth ctxt occL inthms = SIMPLE_METHOD' (eqsubst_tac ctxt occL inthms);
   328 
   329 (* apply a substitution inside assumption j, keeps asm in the same place *)
   330 fun apply_subst_in_asm ctxt i th rule ((cfvs, j, _, pth),m) =
   331   let
   332     val th2 = Thm.rotate_rule (j - 1) i th; (* put premice first *)
   333     val preelimrule =
   334       RWInst.rw ctxt m rule pth
   335       |> (Seq.hd o prune_params_tac)
   336       |> Thm.permute_prems 0 ~1 (* put old asm first *)
   337       |> IsaND.unfix_frees cfvs (* unfix any global params *)
   338       |> RWInst.beta_eta_contract; (* normal form *)
   339   in
   340     (* ~j because new asm starts at back, thus we subtract 1 *)
   341     Seq.map (Thm.rotate_rule (~ j) (Thm.nprems_of rule + i))
   342       (Tactic.dtac preelimrule i th2)
   343   end;
   344 
   345 
   346 (* prepare to substitute within the j'th premise of subgoal i of gth,
   347 using a meta-level equation. Note that we assume rule has var indicies
   348 zero'd. Note that we also assume that premt is the j'th premice of
   349 subgoal i of gth. Note the repetition of work done for each
   350 assumption, i.e. this can be made more efficient for search over
   351 multiple assumptions.  *)
   352 fun prep_subst_in_asm ctxt i gth j =
   353   let
   354     val th = Thm.incr_indexes 1 gth;
   355     val tgt_term = Thm.prop_of th;
   356 
   357     val thy = Thm.theory_of_thm th;
   358     val cert = Thm.cterm_of thy;
   359 
   360     val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   361     val cfvs = rev (map cert fvs);
   362 
   363     val asmt = nth (Logic.strip_imp_prems fixedbody) (j - 1);
   364     val asm_nprems = length (Logic.strip_imp_prems asmt);
   365 
   366     val pth = Thm.trivial (cert asmt);
   367     val maxidx = Thm.maxidx_of th;
   368 
   369     val ft =
   370       (Zipper.move_down_right (* trueprop *)
   371          o Zipper.mktop
   372          o Thm.prop_of) pth
   373   in ((cfvs, j, asm_nprems, pth), (thy, maxidx, ft)) end;
   374 
   375 (* prepare subst in every possible assumption *)
   376 fun prep_subst_in_asms ctxt i gth =
   377   map (prep_subst_in_asm ctxt i gth)
   378     ((fn l => Library.upto (1, length l))
   379       (Logic.prems_of_goal (Thm.prop_of gth) i));
   380 
   381 
   382 (* substitute in an assumption using an object or meta level equality *)
   383 fun eqsubst_asm_tac' ctxt searchf skipocc instepthm i th =
   384   let
   385     val asmpreps = prep_subst_in_asms ctxt i th;
   386     val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   387     fun rewrite_with_thm r =
   388       let
   389         val (lhs,_) = Logic.dest_equals (Thm.concl_of r);
   390         fun occ_search occ [] = Seq.empty
   391           | occ_search occ ((asminfo, searchinfo)::moreasms) =
   392               (case searchf searchinfo occ lhs of
   393                 SkipMore i => occ_search i moreasms
   394               | SkipSeq ss =>
   395                   Seq.append (Seq.map (Library.pair asminfo) (Seq.flat ss))
   396                     (occ_search 1 moreasms)) (* find later substs also *)
   397       in
   398         occ_search skipocc asmpreps |> Seq.maps (apply_subst_in_asm ctxt i th r)
   399       end;
   400   in stepthms |> Seq.maps rewrite_with_thm end;
   401 
   402 
   403 fun skip_first_asm_occs_search searchf sinfo occ lhs =
   404   skipto_skipseq occ (searchf sinfo lhs);
   405 
   406 fun eqsubst_asm_tac ctxt occL thms i th =
   407   let val nprems = Thm.nprems_of th in
   408     if nprems < i then Seq.empty
   409     else
   410       let
   411         val thmseq = Seq.of_list thms;
   412         fun apply_occ occK th =
   413           thmseq |> Seq.maps (fn r =>
   414             eqsubst_asm_tac' ctxt
   415               (skip_first_asm_occs_search searchf_lr_unify_valid) occK r
   416               (i + (Thm.nprems_of th - nprems)) th);
   417         val sortedoccs = Library.sort (rev_order o int_ord) occL;
   418       in
   419         Seq.map distinct_subgoals (Seq.EVERY (map apply_occ sortedoccs) th)
   420       end
   421   end;
   422 
   423 (* inthms are the given arguments in Isar, and treated as eqstep with
   424    the first one, then the second etc *)
   425 fun eqsubst_asm_meth ctxt occL inthms =
   426   SIMPLE_METHOD' (eqsubst_asm_tac ctxt occL inthms);
   427 
   428 (* combination method that takes a flag (true indicates that subst
   429    should be done to an assumption, false = apply to the conclusion of
   430    the goal) as well as the theorems to use *)
   431 val setup =
   432   Method.setup @{binding subst}
   433     (Args.mode "asm" -- Scan.lift (Scan.optional (Args.parens (Scan.repeat Parse.nat)) [0]) --
   434         Attrib.thms >>
   435       (fn ((asm, occL), inthms) => fn ctxt =>
   436         (if asm then eqsubst_asm_meth else eqsubst_meth) ctxt occL inthms))
   437     "single-step substitution";
   438 
   439 end;