src/Tools/eqsubst.ML
author wenzelm
Wed May 29 18:25:11 2013 +0200 (2013-05-29)
changeset 52223 5bb6ae8acb87
parent 51717 9e7d1c139569
child 52234 6ffcce211047
permissions -rw-r--r--
tuned signature -- more explicit flags for low-level Thm.bicompose;
     1 (*  Title:      Tools/eqsubst.ML
     2     Author:     Lucas Dixon, University of Edinburgh
     3 
     4 A proof method to perform a substiution using an equation.
     5 *)
     6 
     7 signature EQSUBST =
     8 sig
     9   (* a type abbreviation for match information *)
    10   type match =
    11        ((indexname * (sort * typ)) list (* type instantiations *)
    12         * (indexname * (typ * term)) list) (* term instantiations *)
    13        * (string * typ) list (* fake named type abs env *)
    14        * (string * typ) list (* type abs env *)
    15        * term (* outer term *)
    16 
    17   type searchinfo =
    18        theory
    19        * int (* maxidx *)
    20        * Zipper.T (* focusterm to search under *)
    21 
    22   datatype 'a skipseq = SkipMore of int | SkipSeq of 'a Seq.seq Seq.seq
    23 
    24   val skip_first_asm_occs_search :
    25      ('a -> 'b -> 'c Seq.seq Seq.seq) ->
    26      'a -> int -> 'b -> 'c skipseq
    27   val skip_first_occs_search :
    28      int -> ('a -> 'b -> 'c Seq.seq Seq.seq) -> 'a -> 'b -> 'c Seq.seq
    29   val skipto_skipseq : int -> 'a Seq.seq Seq.seq -> 'a skipseq
    30 
    31   (* tactics *)
    32   val eqsubst_asm_tac :
    33      Proof.context ->
    34      int list -> thm list -> int -> tactic
    35   val eqsubst_asm_tac' :
    36      Proof.context ->
    37      (searchinfo -> int -> term -> match skipseq) ->
    38      int -> thm -> int -> tactic
    39   val eqsubst_tac :
    40      Proof.context ->
    41      int list -> (* list of occurences to rewrite, use [0] for any *)
    42      thm list -> int -> tactic
    43   val eqsubst_tac' :
    44      Proof.context -> (* proof context *)
    45      (searchinfo -> term -> match Seq.seq) (* search function *)
    46      -> thm (* equation theorem to rewrite with *)
    47      -> int (* subgoal number in goal theorem *)
    48      -> thm (* goal theorem *)
    49      -> thm Seq.seq (* rewritten goal theorem *)
    50 
    51   (* search for substitutions *)
    52   val valid_match_start : Zipper.T -> bool
    53   val search_lr_all : Zipper.T -> Zipper.T Seq.seq
    54   val search_lr_valid : (Zipper.T -> bool) -> Zipper.T -> Zipper.T Seq.seq
    55   val searchf_lr_unify_all :
    56      searchinfo -> term -> match Seq.seq Seq.seq
    57   val searchf_lr_unify_valid :
    58      searchinfo -> term -> match Seq.seq Seq.seq
    59   val searchf_bt_unify_valid :
    60      searchinfo -> term -> match Seq.seq Seq.seq
    61 
    62   val setup : theory -> theory
    63 
    64 end;
    65 
    66 structure EqSubst: EQSUBST =
    67 struct
    68 
    69 (* changes object "=" to meta "==" which prepares a given rewrite rule *)
    70 fun prep_meta_eq ctxt =
    71   Simplifier.mksimps ctxt #> map Drule.zero_var_indexes;
    72 
    73 
    74   (* a type abriviation for match information *)
    75   type match =
    76        ((indexname * (sort * typ)) list (* type instantiations *)
    77         * (indexname * (typ * term)) list) (* term instantiations *)
    78        * (string * typ) list (* fake named type abs env *)
    79        * (string * typ) list (* type abs env *)
    80        * term (* outer term *)
    81 
    82   type searchinfo =
    83        theory
    84        * int (* maxidx *)
    85        * Zipper.T (* focusterm to search under *)
    86 
    87 
    88 (* skipping non-empty sub-sequences but when we reach the end
    89    of the seq, remembering how much we have left to skip. *)
    90 datatype 'a skipseq = SkipMore of int
    91   | SkipSeq of 'a Seq.seq Seq.seq;
    92 (* given a seqseq, skip the first m non-empty seq's, note deficit *)
    93 fun skipto_skipseq m s =
    94     let
    95       fun skip_occs n sq =
    96           case Seq.pull sq of
    97             NONE => SkipMore n
    98           | SOME (h,t) =>
    99             (case Seq.pull h of NONE => skip_occs n t
   100              | SOME _ => if n <= 1 then SkipSeq (Seq.cons h t)
   101                          else skip_occs (n - 1) t)
   102     in (skip_occs m s) end;
   103 
   104 (* note: outerterm is the taget with the match replaced by a bound
   105          variable : ie: "P lhs" beocmes "%x. P x"
   106          insts is the types of instantiations of vars in lhs
   107          and typinsts is the type instantiations of types in the lhs
   108          Note: Final rule is the rule lifted into the ontext of the
   109          taget thm. *)
   110 fun mk_foo_match mkuptermfunc Ts t =
   111     let
   112       val ty = Term.type_of t
   113       val bigtype = (rev (map snd Ts)) ---> ty
   114       fun mk_foo 0 t = t
   115         | mk_foo i t = mk_foo (i - 1) (t $ (Bound (i - 1)))
   116       val num_of_bnds = (length Ts)
   117       (* foo_term = "fooabs y0 ... yn" where y's are local bounds *)
   118       val foo_term = mk_foo num_of_bnds (Bound num_of_bnds)
   119     in Abs("fooabs", bigtype, mkuptermfunc foo_term) end;
   120 
   121 (* T is outer bound vars, n is number of locally bound vars *)
   122 (* THINK: is order of Ts correct...? or reversed? *)
   123 fun mk_fake_bound_name n = ":b_" ^ n;
   124 fun fakefree_badbounds Ts t =
   125     let val (FakeTs,Ts,newnames) =
   126             List.foldr (fn ((n,ty),(FakeTs,Ts,usednames)) =>
   127                            let val newname = singleton (Name.variant_list usednames) n
   128                            in ((mk_fake_bound_name newname,ty)::FakeTs,
   129                                (newname,ty)::Ts,
   130                                newname::usednames) end)
   131                        ([],[],[])
   132                        Ts
   133     in (FakeTs, Ts, Term.subst_bounds (map Free FakeTs, t)) end;
   134 
   135 (* before matching we need to fake the bound vars that are missing an
   136 abstraction. In this function we additionally construct the
   137 abstraction environment, and an outer context term (with the focus
   138 abstracted out) for use in rewriting with RWInst.rw *)
   139 fun prep_zipper_match z =
   140     let
   141       val t = Zipper.trm z
   142       val c = Zipper.ctxt z
   143       val Ts = Zipper.C.nty_ctxt c
   144       val (FakeTs', Ts', t') = fakefree_badbounds Ts t
   145       val absterm = mk_foo_match (Zipper.C.apply c) Ts' t'
   146     in
   147       (t', (FakeTs', Ts', absterm))
   148     end;
   149 
   150 (* Unification with exception handled *)
   151 (* given theory, max var index, pat, tgt; returns Seq of instantiations *)
   152 fun clean_unify thry ix (a as (pat, tgt)) =
   153     let
   154       (* type info will be re-derived, maybe this can be cached
   155          for efficiency? *)
   156       val pat_ty = Term.type_of pat;
   157       val tgt_ty = Term.type_of tgt;
   158       (* is it OK to ignore the type instantiation info?
   159          or should I be using it? *)
   160       val typs_unify =
   161           SOME (Sign.typ_unify thry (pat_ty, tgt_ty) (Vartab.empty, ix))
   162             handle Type.TUNIFY => NONE;
   163     in
   164       case typs_unify of
   165         SOME (typinsttab, ix2) =>
   166         let
   167       (* is it right to throw away the flexes?
   168          or should I be using them somehow? *)
   169           fun mk_insts env =
   170             (Vartab.dest (Envir.type_env env),
   171              Vartab.dest (Envir.term_env env));
   172           val initenv =
   173             Envir.Envir {maxidx = ix2, tenv = Vartab.empty, tyenv = typinsttab};
   174           val useq = Unify.smash_unifiers thry [a] initenv
   175               handle ListPair.UnequalLengths => Seq.empty
   176                    | Term.TERM _ => Seq.empty;
   177           fun clean_unify' useq () =
   178               (case (Seq.pull useq) of
   179                  NONE => NONE
   180                | SOME (h,t) => SOME (mk_insts h, Seq.make (clean_unify' t)))
   181               handle ListPair.UnequalLengths => NONE
   182                    | Term.TERM _ => NONE
   183         in
   184           (Seq.make (clean_unify' useq))
   185         end
   186       | NONE => Seq.empty
   187     end;
   188 
   189 (* Unification for zippers *)
   190 (* Note: Ts is a modified version of the original names of the outer
   191 bound variables. New names have been introduced to make sure they are
   192 unique w.r.t all names in the term and each other. usednames' is
   193 oldnames + new names. *)
   194 (* ix = max var index *)
   195 fun clean_unify_z sgn ix pat z =
   196     let val (t, (FakeTs, Ts,absterm)) = prep_zipper_match z in
   197     Seq.map (fn insts => (insts, FakeTs, Ts, absterm))
   198             (clean_unify sgn ix (t, pat)) end;
   199 
   200 
   201 fun bot_left_leaf_of (l $ r) = bot_left_leaf_of l
   202   | bot_left_leaf_of (Abs(s,ty,t)) = bot_left_leaf_of t
   203   | bot_left_leaf_of x = x;
   204 
   205 (* Avoid considering replacing terms which have a var at the head as
   206    they always succeed trivially, and uninterestingly. *)
   207 fun valid_match_start z =
   208     (case bot_left_leaf_of (Zipper.trm z) of
   209       Var _ => false
   210       | _ => true);
   211 
   212 (* search from top, left to right, then down *)
   213 val search_lr_all = ZipperSearch.all_bl_ur;
   214 
   215 (* search from top, left to right, then down *)
   216 fun search_lr_valid validf =
   217     let
   218       fun sf_valid_td_lr z =
   219           let val here = if validf z then [Zipper.Here z] else [] in
   220             case Zipper.trm z
   221              of _ $ _ => [Zipper.LookIn (Zipper.move_down_left z)]
   222                          @ here
   223                          @ [Zipper.LookIn (Zipper.move_down_right z)]
   224               | Abs _ => here @ [Zipper.LookIn (Zipper.move_down_abs z)]
   225               | _ => here
   226           end;
   227     in Zipper.lzy_search sf_valid_td_lr end;
   228 
   229 (* search from bottom to top, left to right *)
   230 
   231 fun search_bt_valid validf =
   232     let
   233       fun sf_valid_td_lr z =
   234           let val here = if validf z then [Zipper.Here z] else [] in
   235             case Zipper.trm z
   236              of _ $ _ => [Zipper.LookIn (Zipper.move_down_left z),
   237                           Zipper.LookIn (Zipper.move_down_right z)] @ here
   238               | Abs _ => [Zipper.LookIn (Zipper.move_down_abs z)] @ here
   239               | _ => here
   240           end;
   241     in Zipper.lzy_search sf_valid_td_lr end;
   242 
   243 fun searchf_unify_gen f (sgn, maxidx, z) lhs =
   244     Seq.map (clean_unify_z sgn maxidx lhs)
   245             (Zipper.limit_apply f z);
   246 
   247 (* search all unifications *)
   248 val searchf_lr_unify_all =
   249     searchf_unify_gen search_lr_all;
   250 
   251 (* search only for 'valid' unifiers (non abs subterms and non vars) *)
   252 val searchf_lr_unify_valid =
   253     searchf_unify_gen (search_lr_valid valid_match_start);
   254 
   255 val searchf_bt_unify_valid =
   256     searchf_unify_gen (search_bt_valid valid_match_start);
   257 
   258 (* apply a substitution in the conclusion of the theorem th *)
   259 (* cfvs are certified free var placeholders for goal params *)
   260 (* conclthm is a theorem of for just the conclusion *)
   261 (* m is instantiation/match information *)
   262 (* rule is the equation for substitution *)
   263 fun apply_subst_in_concl ctxt i th (cfvs, conclthm) rule m =
   264     (RWInst.rw ctxt m rule conclthm)
   265       |> IsaND.unfix_frees cfvs
   266       |> RWInst.beta_eta_contract
   267       |> (fn r => Tactic.rtac r i th);
   268 
   269 (* substitute within the conclusion of goal i of gth, using a meta
   270 equation rule. Note that we assume rule has var indicies zero'd *)
   271 fun prep_concl_subst ctxt i gth =
   272     let
   273       val th = Thm.incr_indexes 1 gth;
   274       val tgt_term = Thm.prop_of th;
   275 
   276       val sgn = Thm.theory_of_thm th;
   277       val ctermify = Thm.cterm_of sgn;
   278       val trivify = Thm.trivial o ctermify;
   279 
   280       val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   281       val cfvs = rev (map ctermify fvs);
   282 
   283       val conclterm = Logic.strip_imp_concl fixedbody;
   284       val conclthm = trivify conclterm;
   285       val maxidx = Thm.maxidx_of th;
   286       val ft = ((Zipper.move_down_right (* ==> *)
   287                  o Zipper.move_down_left (* Trueprop *)
   288                  o Zipper.mktop
   289                  o Thm.prop_of) conclthm)
   290     in
   291       ((cfvs, conclthm), (sgn, maxidx, ft))
   292     end;
   293 
   294 (* substitute using an object or meta level equality *)
   295 fun eqsubst_tac' ctxt searchf instepthm i th =
   296     let
   297       val (cvfsconclthm, searchinfo) = prep_concl_subst ctxt i th;
   298       val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   299       fun rewrite_with_thm r =
   300           let val (lhs,_) = Logic.dest_equals (Thm.concl_of r);
   301           in searchf searchinfo lhs
   302              |> Seq.maps (apply_subst_in_concl ctxt i th cvfsconclthm r) end;
   303     in stepthms |> Seq.maps rewrite_with_thm end;
   304 
   305 
   306 (* distinct subgoals *)
   307 fun distinct_subgoals th =
   308   the_default th (SINGLE distinct_subgoals_tac th);
   309 
   310 (* General substitution of multiple occurances using one of
   311    the given theorems*)
   312 
   313 
   314 fun skip_first_occs_search occ srchf sinfo lhs =
   315     case (skipto_skipseq occ (srchf sinfo lhs)) of
   316       SkipMore _ => Seq.empty
   317     | SkipSeq ss => Seq.flat ss;
   318 
   319 (* The occL is a list of integers indicating which occurence
   320 w.r.t. the search order, to rewrite. Backtracking will also find later
   321 occurences, but all earlier ones are skipped. Thus you can use [0] to
   322 just find all rewrites. *)
   323 
   324 fun eqsubst_tac ctxt occL thms i th =
   325     let val nprems = Thm.nprems_of th in
   326       if nprems < i then Seq.empty else
   327       let val thmseq = (Seq.of_list thms)
   328         fun apply_occ occ th =
   329             thmseq |> Seq.maps
   330                     (fn r => eqsubst_tac'
   331                                ctxt
   332                                (skip_first_occs_search
   333                                   occ searchf_lr_unify_valid) r
   334                                  (i + ((Thm.nprems_of th) - nprems))
   335                                  th);
   336         val sortedoccL =
   337             Library.sort (Library.rev_order o Library.int_ord) occL;
   338       in
   339         Seq.map distinct_subgoals (Seq.EVERY (map apply_occ sortedoccL) th)
   340       end
   341     end;
   342 
   343 
   344 (* inthms are the given arguments in Isar, and treated as eqstep with
   345    the first one, then the second etc *)
   346 fun eqsubst_meth ctxt occL inthms =
   347     SIMPLE_METHOD' (eqsubst_tac ctxt occL inthms);
   348 
   349 (* apply a substitution inside assumption j, keeps asm in the same place *)
   350 fun apply_subst_in_asm ctxt i th rule ((cfvs, j, ngoalprems, pth),m) =
   351     let
   352       val th2 = Thm.rotate_rule (j - 1) i th; (* put premice first *)
   353       val preelimrule =
   354           (RWInst.rw ctxt m rule pth)
   355             |> (Seq.hd o prune_params_tac)
   356             |> Thm.permute_prems 0 ~1 (* put old asm first *)
   357             |> IsaND.unfix_frees cfvs (* unfix any global params *)
   358             |> RWInst.beta_eta_contract; (* normal form *)
   359   (*    val elimrule =
   360           preelimrule
   361             |> Tactic.make_elim (* make into elim rule *)
   362             |> Thm.lift_rule (th2, i); (* lift into context *)
   363    *)
   364     in
   365       (* ~j because new asm starts at back, thus we subtract 1 *)
   366       Seq.map (Thm.rotate_rule (~j) ((Thm.nprems_of rule) + i))
   367       (Tactic.dtac preelimrule i th2)
   368     end;
   369 
   370 
   371 (* prepare to substitute within the j'th premise of subgoal i of gth,
   372 using a meta-level equation. Note that we assume rule has var indicies
   373 zero'd. Note that we also assume that premt is the j'th premice of
   374 subgoal i of gth. Note the repetition of work done for each
   375 assumption, i.e. this can be made more efficient for search over
   376 multiple assumptions.  *)
   377 fun prep_subst_in_asm ctxt i gth j =
   378     let
   379       val th = Thm.incr_indexes 1 gth;
   380       val tgt_term = Thm.prop_of th;
   381 
   382       val sgn = Thm.theory_of_thm th;
   383       val ctermify = Thm.cterm_of sgn;
   384       val trivify = Thm.trivial o ctermify;
   385 
   386       val (fixedbody, fvs) = IsaND.fix_alls_term ctxt i tgt_term;
   387       val cfvs = rev (map ctermify fvs);
   388 
   389       val asmt = nth (Logic.strip_imp_prems fixedbody) (j - 1);
   390       val asm_nprems = length (Logic.strip_imp_prems asmt);
   391 
   392       val pth = trivify asmt;
   393       val maxidx = Thm.maxidx_of th;
   394 
   395       val ft = ((Zipper.move_down_right (* trueprop *)
   396                  o Zipper.mktop
   397                  o Thm.prop_of) pth)
   398     in ((cfvs, j, asm_nprems, pth), (sgn, maxidx, ft)) end;
   399 
   400 (* prepare subst in every possible assumption *)
   401 fun prep_subst_in_asms ctxt i gth =
   402     map (prep_subst_in_asm ctxt i gth)
   403         ((fn l => Library.upto (1, length l))
   404            (Logic.prems_of_goal (Thm.prop_of gth) i));
   405 
   406 
   407 (* substitute in an assumption using an object or meta level equality *)
   408 fun eqsubst_asm_tac' ctxt searchf skipocc instepthm i th =
   409     let
   410       val asmpreps = prep_subst_in_asms ctxt i th;
   411       val stepthms = Seq.of_list (prep_meta_eq ctxt instepthm);
   412       fun rewrite_with_thm r =
   413           let val (lhs,_) = Logic.dest_equals (Thm.concl_of r)
   414             fun occ_search occ [] = Seq.empty
   415               | occ_search occ ((asminfo, searchinfo)::moreasms) =
   416                 (case searchf searchinfo occ lhs of
   417                    SkipMore i => occ_search i moreasms
   418                  | SkipSeq ss =>
   419                    Seq.append (Seq.map (Library.pair asminfo) (Seq.flat ss))
   420                                (occ_search 1 moreasms))
   421                               (* find later substs also *)
   422           in
   423             occ_search skipocc asmpreps |> Seq.maps (apply_subst_in_asm ctxt i th r)
   424           end;
   425     in stepthms |> Seq.maps rewrite_with_thm end;
   426 
   427 
   428 fun skip_first_asm_occs_search searchf sinfo occ lhs =
   429     skipto_skipseq occ (searchf sinfo lhs);
   430 
   431 fun eqsubst_asm_tac ctxt occL thms i th =
   432     let val nprems = Thm.nprems_of th
   433     in
   434       if nprems < i then Seq.empty else
   435       let val thmseq = (Seq.of_list thms)
   436         fun apply_occ occK th =
   437             thmseq |> Seq.maps
   438                     (fn r =>
   439                         eqsubst_asm_tac' ctxt (skip_first_asm_occs_search
   440                                             searchf_lr_unify_valid) occK r
   441                                          (i + ((Thm.nprems_of th) - nprems))
   442                                          th);
   443         val sortedoccs =
   444             Library.sort (Library.rev_order o Library.int_ord) occL
   445       in
   446         Seq.map distinct_subgoals
   447                 (Seq.EVERY (map apply_occ sortedoccs) th)
   448       end
   449     end;
   450 
   451 (* inthms are the given arguments in Isar, and treated as eqstep with
   452    the first one, then the second etc *)
   453 fun eqsubst_asm_meth ctxt occL inthms =
   454     SIMPLE_METHOD' (eqsubst_asm_tac ctxt occL inthms);
   455 
   456 (* combination method that takes a flag (true indicates that subst
   457    should be done to an assumption, false = apply to the conclusion of
   458    the goal) as well as the theorems to use *)
   459 val setup =
   460   Method.setup @{binding subst}
   461     (Args.mode "asm" -- Scan.lift (Scan.optional (Args.parens (Scan.repeat Parse.nat)) [0]) --
   462         Attrib.thms >>
   463       (fn ((asm, occL), inthms) => fn ctxt =>
   464         (if asm then eqsubst_asm_meth else eqsubst_meth) ctxt occL inthms))
   465     "single-step substitution";
   466 
   467 end;