src/Tools/eqsubst.ML
author wenzelm
Thu May 30 13:59:38 2013 +0200 (2013-05-30)
changeset 52234 6ffcce211047
parent 52223 5bb6ae8acb87
child 52235 6aff6b8bec13
permissions -rw-r--r--
misc tuning;
     1 (*  Title:      Tools/eqsubst.ML
     2     Author:     Lucas Dixon, University of Edinburgh
     3 
     4 A proof method to perform a substiution using an equation.
     5 *)
     6 
     7 signature EQSUBST =
     8 sig
     9   type match =
    10     ((indexname * (sort * typ)) list (* type instantiations *)
    11       * (indexname * (typ * term)) list) (* term instantiations *)
    12     * (string * typ) list (* fake named type abs env *)
    13     * (string * typ) list (* type abs env *)
    14     * term (* outer term *)
    15 
    16   type searchinfo =
    17     theory
    18     * int (* maxidx *)
    19     * Zipper.T (* focusterm to search under *)
    20 
    21   datatype 'a skipseq = SkipMore of int | SkipSeq of 'a Seq.seq Seq.seq
    22 
    23   val skip_first_asm_occs_search: ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> int -> 'b -> 'c skipseq
    24   val skip_first_occs_search: int -> ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> 'b -> 'c Seq.seq
    25   val skipto_skipseq: int -> 'a Seq.seq Seq.seq -> 'a skipseq
    26 
    27   (* tactics *)
    28   val eqsubst_asm_tac: Proof.context -> int list -> thm list -> int -> tactic
    29   val eqsubst_asm_tac': Proof.context ->
    30     (searchinfo -> int -> term -> match skipseq) -> int -> thm -> int -> tactic
    31   val eqsubst_tac: Proof.context ->
    32     int list -> (* list of occurences to rewrite, use [0] for any *)
    33     thm list -> int -> tactic
    34   val eqsubst_tac': Proof.context ->
    35     (searchinfo -> term -> match Seq.seq) (* search function *)
    36     -> thm (* equation theorem to rewrite with *)
    37     -> int (* subgoal number in goal theorem *)
    38     -> thm (* goal theorem *)
    39     -> thm Seq.seq (* rewritten goal theorem *)
    40 
    41   (* search for substitutions *)
    42   val valid_match_start: Zipper.T -> bool
    43   val search_lr_all: Zipper.T -> Zipper.T Seq.seq
    44   val search_lr_valid: (Zipper.T -> bool) -> Zipper.T -> Zipper.T Seq.seq
    45   val searchf_lr_unify_all: searchinfo -> term -> match Seq.seq Seq.seq
    46   val searchf_lr_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
    47   val searchf_bt_unify_valid: searchinfo -> term -> match Seq.seq Seq.seq
    48 
    49   val setup : theory -> theory
    50 end;
    51 
    52 structure EqSubst: EQSUBST =
    53 struct
    54 
    55 (* changes object "=" to meta "==" which prepares a given rewrite rule *)
    56 fun prep_meta_eq ctxt =
    57   Simplifier.mksimps ctxt #> map Drule.zero_var_indexes;
    58 
    59 
    60 type match =
    61      ((indexname * (sort * typ)) list (* type instantiations *)
    62       * (indexname * (typ * term)) list) (* term instantiations *)
    63      * (string * typ) list (* fake named type abs env *)
    64      * (string * typ) list (* type abs env *)
    65      * term; (* outer term *)
    66 
    67 type searchinfo =
    68      theory
    69      * int (* maxidx *)
    70      * Zipper.T; (* focusterm to search under *)
    71 
    72 
    73 (* skipping non-empty sub-sequences but when we reach the end
    74    of the seq, remembering how much we have left to skip. *)
    75 datatype 'a skipseq =
    76   SkipMore of int |
    77   SkipSeq of 'a Seq.seq Seq.seq;
    78 
    79 (* given a seqseq, skip the first m non-empty seq's, note deficit *)
    80 fun skipto_skipseq m s =
    81   let
    82     fun skip_occs n sq =
    83       (case Seq.pull sq of
    84         NONE => SkipMore n
    85       | SOME (h,t) =>
    86         (case Seq.pull h of
    87           NONE => skip_occs n t
    88         | SOME _ => if n <= 1 then SkipSeq (Seq.cons h t) else skip_occs (n - 1) t))
    89   in skip_occs m s end;
    90 
    91 (* note: outerterm is the taget with the match replaced by a bound
    92    variable : ie: "P lhs" beocmes "%x. P x"
    93    insts is the types of instantiations of vars in lhs
    94    and typinsts is the type instantiations of types in the lhs
    95    Note: Final rule is the rule lifted into the ontext of the
    96    taget thm. *)
    97 fun mk_foo_match mkuptermfunc Ts t =
    98   let
    99     val ty = Term.type_of t
   100     val bigtype = (rev (map snd Ts)) ---> ty
   101     fun mk_foo 0 t = t
   102       | mk_foo i t = mk_foo (i - 1) (t $ (Bound (i - 1)))
   103     val num_of_bnds = (length Ts)
   104     (* foo_term = "fooabs y0 ... yn" where y's are local bounds *)
   105     val foo_term = mk_foo num_of_bnds (Bound num_of_bnds)
   106   in Abs ("fooabs", bigtype, mkuptermfunc foo_term) end;
   107 
   108 (* T is outer bound vars, n is number of locally bound vars *)
   109 (* THINK: is order of Ts correct...? or reversed? *)
   110 fun mk_fake_bound_name n = ":b_" ^ n;
   111 fun fakefree_badbounds Ts t =
   112   let val (FakeTs, Ts, newnames) =
   113     List.foldr (fn ((n, ty), (FakeTs, Ts, usednames)) =>
   114       let
   115         val newname = singleton (Name.variant_list usednames) n
   116       in
   117         ((mk_fake_bound_name newname, ty) :: FakeTs,
   118           (newname, ty) :: Ts,
   119           newname :: usednames)
   120       end) ([], [], []) Ts
   121   in (FakeTs, Ts, Term.subst_bounds (map Free FakeTs, t)) end;
   122 
   123 (* before matching we need to fake the bound vars that are missing an
   124 abstraction. In this function we additionally construct the
   125 abstraction environment, and an outer context term (with the focus
   126 abstracted out) for use in rewriting with RWInst.rw *)
   127 fun prep_zipper_match z =
   128   let
   129     val t = Zipper.trm z
   130     val c = Zipper.ctxt z
   131     val Ts = Zipper.C.nty_ctxt c
   132     val (FakeTs', Ts', t') = fakefree_badbounds Ts t
   133     val absterm = mk_foo_match (Zipper.C.apply c) Ts' t'
   134   in
   135     (t', (FakeTs', Ts', absterm))
   136   end;
   137 
   138 (* Unification with exception handled *)
   139 (* given theory, max var index, pat, tgt; returns Seq of instantiations *)
   140 fun clean_unify thry ix (a as (pat, tgt)) =
   141   let
   142     (* type info will be re-derived, maybe this can be cached
   143        for efficiency? *)
   144     val pat_ty = Term.type_of pat;
   145     val tgt_ty = Term.type_of tgt;
   146     (* is it OK to ignore the type instantiation info?
   147        or should I be using it? *)
   148     val typs_unify =
   149       SOME (Sign.typ_unify thry (pat_ty, tgt_ty) (Vartab.empty, ix))
   150         handle Type.TUNIFY => NONE;
   151   in
   152     (case typs_unify of
   153       SOME (typinsttab, ix2) =>
   154         let
   155           (* FIXME is it right to throw away the flexes?
   156              or should I be using them somehow? *)
   157           fun mk_insts env =
   158             (Vartab.dest (Envir.type_env env),
   159              Vartab.dest (Envir.term_env env));
   160           val initenv =
   161             Envir.Envir {maxidx = ix2, tenv = Vartab.empty, tyenv = typinsttab};
   162           val useq = Unify.smash_unifiers thry [a] initenv
   163             handle ListPair.UnequalLengths => Seq.empty
   164               | Term.TERM _ => Seq.empty;
   165           fun clean_unify' useq () =
   166             (case (Seq.pull useq) of
   167                NONE => NONE
   168              | SOME (h, t) => SOME (mk_insts h, Seq.make (clean_unify' t)))
   169             handle ListPair.UnequalLengths => NONE
   170               | Term.TERM _ => NONE;
   171         in
   172           (Seq.make (clean_unify' useq))
   173         end
   174     | NONE => Seq.empty)
   175   end;
   176 
   177 (* Unification for zippers *)
   178 (* Note: Ts is a modified version of the original names of the outer
   179 bound variables. New names have been introduced to make sure they are
   180 unique w.r.t all names in the term and each other. usednames' is
   181 oldnames + new names. *)
   182 (* ix = max var index *)
   183 fun clean_unify_z sgn ix pat z =
   184   let val (t, (FakeTs, Ts,absterm)) = prep_zipper_match z in
   185     Seq.map (fn insts => (insts, FakeTs, Ts, absterm))
   186       (clean_unify sgn ix (t, pat))
   187   end;
   188 
   189 
   190 fun bot_left_leaf_of (l $ _) = bot_left_leaf_of l
   191   | bot_left_leaf_of (Abs (_, _, t)) = bot_left_leaf_of t
   192   | bot_left_leaf_of x = x;
   193 
   194 (* Avoid considering replacing terms which have a var at the head as
   195    they always succeed trivially, and uninterestingly. *)
   196 fun valid_match_start z =
   197   (case bot_left_leaf_of (Zipper.trm z) of
   198     Var _ => false
   199   | _ => true);
   200 
   201 (* search from top, left to right, then down *)
   202 val search_lr_all = ZipperSearch.all_bl_ur;
   203 
   204 (* search from top, left to right, then down *)
   205 fun search_lr_valid validf =
   206   let
   207     fun sf_valid_td_lr z =
   208       let val here = if validf z then [Zipper.Here z] else [] in
   209         (case Zipper.trm z of
   210           _ $ _ =>
   211             [Zipper.LookIn (Zipper.move_down_left z)] @ here @
   212             [Zipper.LookIn (Zipper.move_down_right z)]
   213         | Abs _ => here @ [Zipper.LookIn (Zipper.move_down_abs z)]
   214         | _ => here)
   215       end;
   216   in Zipper.lzy_search sf_valid_td_lr end;
   217 
   218 (* search from bottom to top, left to right *)
   219 fun search_bt_valid validf =
   220   let
   221     fun sf_valid_td_lr z =
   222       let val here = if validf z then [Zipper.Here z] else [] in
   223         (case Zipper.trm z of
   224           _ $ _ =>
   225             [Zipper.LookIn (Zipper.move_down_left z),
   226              Zipper.LookIn (Zipper.move_down_right z)] @ here
   227         | Abs _ => [Zipper.LookIn (Zipper.move_down_abs z)] @ here
   228         | _ => here)
   229       end;
   230   in Zipper.lzy_search sf_valid_td_lr end;
   231 
   232 fun searchf_unify_gen f (sgn, maxidx, z) lhs =
   233   Seq.map (clean_unify_z sgn maxidx lhs) (Zipper.limit_apply f z);
   234 
   235 (* search all unifications *)
   236 val searchf_lr_unify_all = searchf_unify_gen search_lr_all;
   237 
   238 (* search only for 'valid' unifiers (non abs subterms and non vars) *)
   239 val searchf_lr_unify_valid = searchf_unify_gen (search_lr_valid valid_match_start);
   240 
   241 val searchf_bt_unify_valid = searchf_unify_gen (search_bt_valid valid_match_start);
   242 
   243 (* apply a substitution in the conclusion of the theorem th *)
   244 (* cfvs are certified free var placeholders for goal params *)
   245 (* conclthm is a theorem of for just the conclusion *)
   246 (* m is instantiation/match information *)
   247 (* rule is the equation for substitution *)
   248 fun apply_subst_in_concl ctxt i th (cfvs, conclthm) rule m =
   249   RWInst.rw ctxt m rule conclthm
   250   |> IsaND.unfix_frees cfvs
   251   |> RWInst.beta_eta_contract
   252   |> (fn r => Tactic.rtac r i th);
   253 
   254 (* substitute within the conclusion of goal i of gth, using a meta
   255 equation rule. Note that we assume rule has var indicies zero'd *)
   256 fun prep_concl_subst ctxt i gth =
   257   let
   258     val th = Thm.incr_indexes 1 gth;
   259     val tgt_term = Thm.prop_of th;
   260 
   261     val sgn = Thm.theory_of_thm th;
   262     val ctermify = Thm.cterm_of sgn;
   263     val trivify = Thm.trivial o ctermify;
   264 
   265     val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   266     val cfvs = rev (map ctermify fvs);
   267 
   268     val conclterm = Logic.strip_imp_concl fixedbody;
   269     val conclthm = trivify conclterm;
   270     val maxidx = Thm.maxidx_of th;
   271     val ft =
   272       (Zipper.move_down_right (* ==> *)
   273        o Zipper.move_down_left (* Trueprop *)
   274        o Zipper.mktop
   275        o Thm.prop_of) conclthm
   276   in
   277     ((cfvs, conclthm), (sgn, maxidx, ft))
   278   end;
   279 
   280 (* substitute using an object or meta level equality *)
   281 fun eqsubst_tac' ctxt searchf instepthm i th =
   282   let
   283     val (cvfsconclthm, searchinfo) = prep_concl_subst ctxt i th;
   284     val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   285     fun rewrite_with_thm r =
   286       let val (lhs,_) = Logic.dest_equals (Thm.concl_of r) in
   287         searchf searchinfo lhs
   288         |> Seq.maps (apply_subst_in_concl ctxt i th cvfsconclthm r)
   289       end;
   290   in stepthms |> Seq.maps rewrite_with_thm end;
   291 
   292 
   293 (* distinct subgoals *)
   294 fun distinct_subgoals th = the_default th (SINGLE distinct_subgoals_tac th);
   295 
   296 
   297 (* General substitution of multiple occurances using one of
   298    the given theorems*)
   299 
   300 fun skip_first_occs_search occ srchf sinfo lhs =
   301   (case (skipto_skipseq occ (srchf sinfo lhs)) of
   302     SkipMore _ => Seq.empty
   303   | SkipSeq ss => Seq.flat ss);
   304 
   305 (* The occL is a list of integers indicating which occurence
   306 w.r.t. the search order, to rewrite. Backtracking will also find later
   307 occurences, but all earlier ones are skipped. Thus you can use [0] to
   308 just find all rewrites. *)
   309 
   310 fun eqsubst_tac ctxt occL thms i th =
   311   let val nprems = Thm.nprems_of th in
   312     if nprems < i then Seq.empty else
   313     let
   314       val thmseq = (Seq.of_list thms);
   315       fun apply_occ occ th =
   316         thmseq |> Seq.maps (fn r =>
   317           eqsubst_tac' ctxt
   318             (skip_first_occs_search occ searchf_lr_unify_valid) r
   319             (i + ((Thm.nprems_of th) - nprems)) th);
   320       val sortedoccL = Library.sort (rev_order o int_ord) occL;
   321     in
   322       Seq.map distinct_subgoals (Seq.EVERY (map apply_occ sortedoccL) th)
   323     end
   324   end;
   325 
   326 
   327 (* inthms are the given arguments in Isar, and treated as eqstep with
   328    the first one, then the second etc *)
   329 fun eqsubst_meth ctxt occL inthms = SIMPLE_METHOD' (eqsubst_tac ctxt occL inthms);
   330 
   331 (* apply a substitution inside assumption j, keeps asm in the same place *)
   332 fun apply_subst_in_asm ctxt i th rule ((cfvs, j, _, pth),m) =
   333   let
   334     val th2 = Thm.rotate_rule (j - 1) i th; (* put premice first *)
   335     val preelimrule =
   336       RWInst.rw ctxt m rule pth
   337       |> (Seq.hd o prune_params_tac)
   338       |> Thm.permute_prems 0 ~1 (* put old asm first *)
   339       |> IsaND.unfix_frees cfvs (* unfix any global params *)
   340       |> RWInst.beta_eta_contract; (* normal form *)
   341   in
   342     (* ~j because new asm starts at back, thus we subtract 1 *)
   343     Seq.map (Thm.rotate_rule (~ j) ((Thm.nprems_of rule) + i))
   344       (Tactic.dtac preelimrule i th2)
   345   end;
   346 
   347 
   348 (* prepare to substitute within the j'th premise of subgoal i of gth,
   349 using a meta-level equation. Note that we assume rule has var indicies
   350 zero'd. Note that we also assume that premt is the j'th premice of
   351 subgoal i of gth. Note the repetition of work done for each
   352 assumption, i.e. this can be made more efficient for search over
   353 multiple assumptions.  *)
   354 fun prep_subst_in_asm ctxt i gth j =
   355   let
   356     val th = Thm.incr_indexes 1 gth;
   357     val tgt_term = Thm.prop_of th;
   358 
   359     val sgn = Thm.theory_of_thm th;
   360     val ctermify = Thm.cterm_of sgn;
   361     val trivify = Thm.trivial o ctermify;
   362 
   363     val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   364     val cfvs = rev (map ctermify fvs);
   365 
   366     val asmt = nth (Logic.strip_imp_prems fixedbody) (j - 1);
   367     val asm_nprems = length (Logic.strip_imp_prems asmt);
   368 
   369     val pth = trivify asmt;
   370     val maxidx = Thm.maxidx_of th;
   371 
   372     val ft =
   373       (Zipper.move_down_right (* trueprop *)
   374          o Zipper.mktop
   375          o Thm.prop_of) pth
   376   in ((cfvs, j, asm_nprems, pth), (sgn, maxidx, ft)) end;
   377 
   378 (* prepare subst in every possible assumption *)
   379 fun prep_subst_in_asms ctxt i gth =
   380   map (prep_subst_in_asm ctxt i gth)
   381     ((fn l => Library.upto (1, length l))
   382       (Logic.prems_of_goal (Thm.prop_of gth) i));
   383 
   384 
   385 (* substitute in an assumption using an object or meta level equality *)
   386 fun eqsubst_asm_tac' ctxt searchf skipocc instepthm i th =
   387   let
   388     val asmpreps = prep_subst_in_asms ctxt i th;
   389     val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   390     fun rewrite_with_thm r =
   391       let
   392         val (lhs,_) = Logic.dest_equals (Thm.concl_of r);
   393         fun occ_search occ [] = Seq.empty
   394           | occ_search occ ((asminfo, searchinfo)::moreasms) =
   395               (case searchf searchinfo occ lhs of
   396                 SkipMore i => occ_search i moreasms
   397               | SkipSeq ss =>
   398                   Seq.append (Seq.map (Library.pair asminfo) (Seq.flat ss))
   399                     (occ_search 1 moreasms)) (* find later substs also *)
   400       in
   401         occ_search skipocc asmpreps |> Seq.maps (apply_subst_in_asm ctxt i th r)
   402       end;
   403   in stepthms |> Seq.maps rewrite_with_thm end;
   404 
   405 
   406 fun skip_first_asm_occs_search searchf sinfo occ lhs =
   407   skipto_skipseq occ (searchf sinfo lhs);
   408 
   409 fun eqsubst_asm_tac ctxt occL thms i th =
   410   let val nprems = Thm.nprems_of th in
   411     if nprems < i then Seq.empty
   412     else
   413       let
   414         val thmseq = Seq.of_list thms;
   415         fun apply_occ occK th =
   416           thmseq |> Seq.maps (fn r =>
   417             eqsubst_asm_tac' ctxt
   418               (skip_first_asm_occs_search searchf_lr_unify_valid) occK r
   419               (i + ((Thm.nprems_of th) - nprems)) th);
   420         val sortedoccs = Library.sort (rev_order o int_ord) occL;
   421       in
   422         Seq.map distinct_subgoals (Seq.EVERY (map apply_occ sortedoccs) th)
   423       end
   424   end;
   425 
   426 (* inthms are the given arguments in Isar, and treated as eqstep with
   427    the first one, then the second etc *)
   428 fun eqsubst_asm_meth ctxt occL inthms =
   429   SIMPLE_METHOD' (eqsubst_asm_tac ctxt occL inthms);
   430 
   431 (* combination method that takes a flag (true indicates that subst
   432    should be done to an assumption, false = apply to the conclusion of
   433    the goal) as well as the theorems to use *)
   434 val setup =
   435   Method.setup @{binding subst}
   436     (Args.mode "asm" -- Scan.lift (Scan.optional (Args.parens (Scan.repeat Parse.nat)) [0]) --
   437         Attrib.thms >>
   438       (fn ((asm, occL), inthms) => fn ctxt =>
   439         (if asm then eqsubst_asm_meth else eqsubst_meth) ctxt occL inthms))
   440     "single-step substitution";
   441 
   442 end;