src/HOL/Library/rewrite.ML
author noschinl
Mon Apr 13 15:32:32 2015 +0200 (2015-04-13)
changeset 60053 0e9895ffab1d
parent 60052 616a17640229
child 60054 ef4878146485
permissions -rw-r--r--
rewrite: do not descend into conclusion of premise with asm pattern
     1 (*  Title:      HOL/Library/rewrite.ML
     2     Author:     Christoph Traut, Lars Noschinski, TU Muenchen
     3 
     4 This is a rewrite method that supports subterm-selection based on patterns.
     5 
     6 The patterns accepted by rewrite are of the following form:
     7   <atom>    ::= <term> | "concl" | "asm" | "for" "(" <names> ")"
     8   <pattern> ::= (in <atom> | at <atom>) [<pattern>]
     9   <args>    ::= [<pattern>] ("to" <term>) <thms>
    10 
    11 This syntax was clearly inspired by Gonthier's and Tassi's language of
    12 patterns but has diverged significantly during its development.
    13 
    14 We also allow introduction of identifiers for bound variables,
    15 which can then be used to match arbitrary subterms inside abstractions.
    16 *)
    17 
    18 signature REWRITE =
    19 sig
    20   (* FIXME proper ML interface!? *)
    21 end
    22 
    23 structure Rewrite : REWRITE =
    24 struct
    25 
    26 datatype ('a, 'b) pattern = At | In | Term of 'a | Concl | Asm | For of 'b list
    27 
    28 fun map_term_pattern f (Term x) = f x
    29   | map_term_pattern _ (For ss) = (For ss)
    30   | map_term_pattern _ At = At
    31   | map_term_pattern _ In = In
    32   | map_term_pattern _ Concl = Concl
    33   | map_term_pattern _ Asm = Asm
    34 
    35 
    36 exception NO_TO_MATCH
    37 
    38 fun SEQ_CONCAT (tacq : tactic Seq.seq) : tactic = fn st => Seq.maps (fn tac => tac st) tacq
    39 
    40 (* We rewrite subterms using rewrite conversions. These are conversions
    41    that also take a context and a list of identifiers for bound variables
    42    as parameters. *)
    43 type rewrite_conv = Proof.context -> (string * term) list -> conv
    44 
    45 (* To apply such a rewrite conversion to a subterm of our goal, we use
    46    subterm positions, which are just functions that map a rewrite conversion,
    47    working on the top level, to a new rewrite conversion, working on
    48    a specific subterm.
    49 
    50    During substitution, we are traversing the goal to find subterms that
    51    we can rewrite. For each of these subterms, a subterm position is
    52    created and later used in creating a conversion that we use to try and
    53    rewrite this subterm. *)
    54 type subterm_position = rewrite_conv -> rewrite_conv
    55 
    56 (* A focusterm represents a subterm. It is a tuple (t, p), consisting
    57   of the subterm t itself and its subterm position p. *)
    58 type focusterm = Type.tyenv * term * subterm_position
    59 
    60 val dummyN = Name.internal "__dummy"
    61 val holeN = Name.internal "_hole"
    62 
    63 fun prep_meta_eq ctxt =
    64   Simplifier.mksimps ctxt #> map Drule.zero_var_indexes
    65 
    66 
    67 (* rewrite conversions *)
    68 
    69 fun abs_rewr_cconv ident : subterm_position =
    70   let
    71     fun add_ident NONE _ l = l
    72       | add_ident (SOME name) ct l = (name, Thm.term_of ct) :: l
    73     fun inner rewr ctxt idents =
    74       CConv.abs_cconv (fn (ct, ctxt) => rewr ctxt (add_ident ident ct idents)) ctxt
    75   in inner end
    76 
    77 val fun_rewr_cconv : subterm_position = fn rewr => CConv.fun_cconv oo rewr
    78 val arg_rewr_cconv : subterm_position = fn rewr => CConv.arg_cconv oo rewr
    79 val imp_rewr_cconv : subterm_position = fn rewr => CConv.concl_cconv 1 oo rewr
    80 
    81 
    82 (* focus terms *)
    83 
    84 fun ft_abs ctxt (s,T) (tyenv, u, pos) =
    85   case try (fastype_of #> dest_funT) u of
    86     NONE => raise TERM ("ft_abs: no function type", [u])
    87   | SOME (U, _) =>
    88       let
    89         val tyenv' =
    90           if T = dummyT then tyenv
    91           else Sign.typ_match (Proof_Context.theory_of ctxt) (T, U) tyenv
    92         val x = Free (the_default (Name.internal dummyN) s, Envir.norm_type tyenv' T)
    93         val eta_expand_cconv = CConv.rewr_cconv @{thm eta_expand}
    94         fun eta_expand rewr ctxt bounds = eta_expand_cconv then_conv rewr ctxt bounds
    95         val (u', pos') =
    96           case u of
    97             Abs (_,_,t') => (subst_bound (x, t'), pos o abs_rewr_cconv s)
    98           | _ => (u $ x, pos o eta_expand o abs_rewr_cconv s)
    99       in (tyenv', u', pos') end
   100       handle Pattern.MATCH => raise TYPE ("ft_abs: types don't match", [T,U], [u])
   101 
   102 fun ft_fun _ (tyenv, l $ _, pos) = (tyenv, l, pos o fun_rewr_cconv)
   103   | ft_fun ctxt (ft as (_, Abs (_, T, _ $ Bound 0), _)) = (ft_fun ctxt o ft_abs ctxt (NONE, T)) ft
   104   | ft_fun _ (_, t, _) = raise TERM ("ft_fun", [t])
   105 
   106 local
   107 
   108 fun ft_arg_gen cconv _ (tyenv, _ $ r, pos) = (tyenv, r, pos o cconv)
   109   | ft_arg_gen cconv ctxt (ft as (_, Abs (_, T, _ $ Bound 0), _)) = (ft_arg_gen cconv ctxt o ft_abs ctxt (NONE, T)) ft
   110   | ft_arg_gen _ _ (_, t, _) = raise TERM ("ft_arg", [t])
   111 
   112 in
   113 
   114 val ft_arg = ft_arg_gen arg_rewr_cconv
   115 val ft_imp = ft_arg_gen imp_rewr_cconv
   116 
   117 end
   118 
   119 (* Move to B in !!x_1 ... x_n. B. Do not eta-expand *)
   120 fun ft_params ctxt (ft as (_, t, _) : focusterm) =
   121   case t of
   122     Const (@{const_name "Pure.all"}, _) $ Abs (_,T,_) =>
   123       (ft_params ctxt o ft_abs ctxt (NONE, T) o ft_arg ctxt) ft
   124   | Const (@{const_name "Pure.all"}, _) =>
   125       (ft_params ctxt o ft_arg ctxt) ft
   126   | _ => ft
   127 
   128 fun ft_all ctxt ident (ft as (_, Const (@{const_name "Pure.all"}, T) $ _, _) : focusterm) =
   129     let
   130       val def_U = T |> dest_funT |> fst |> dest_funT |> fst
   131       val ident' = apsnd (the_default (def_U)) ident
   132     in (ft_abs ctxt ident' o ft_arg ctxt) ft end
   133   | ft_all _ _ (_, t, _) = raise TERM ("ft_all", [t])
   134 
   135 fun ft_for ctxt idents (ft as (_, t, _) : focusterm) =
   136   let
   137     fun f rev_idents (Const (@{const_name "Pure.all"}, _) $ t) =
   138         let
   139          val (rev_idents', desc) = f rev_idents (case t of Abs (_,_,u) => u | _ => t)
   140         in
   141           case rev_idents' of
   142             [] => ([], desc o ft_all ctxt (NONE, NONE))
   143           | (x :: xs) => (xs , desc o ft_all ctxt x)
   144         end
   145       | f rev_idents _ = (rev_idents, I)
   146   in
   147     case f (rev idents) t of
   148       ([], ft') => SOME (ft' ft)
   149     | _ => NONE
   150   end
   151 
   152 fun ft_concl ctxt (ft as (_, t, _) : focusterm) =
   153   case t of
   154     (Const (@{const_name "Pure.imp"}, _) $ _) $ _ => (ft_concl ctxt o ft_imp ctxt) ft
   155   | _ => ft
   156 
   157 fun ft_assm ctxt (ft as (_, t, _) : focusterm) =
   158   case t of
   159     (Const (@{const_name "Pure.imp"}, _) $ _) $ _ => (ft_arg ctxt o ft_fun ctxt) ft
   160   | _ => raise TERM ("ft_assm", [t])
   161 
   162 fun ft_judgment ctxt (ft as (_, t, _) : focusterm) =
   163   if Object_Logic.is_judgment ctxt t
   164   then ft_arg ctxt ft
   165   else ft
   166 
   167 
   168 (* Return a lazy sequenze of all subterms of the focusterm for which
   169    the condition holds. *)
   170 fun find_subterms ctxt condition (ft as (_, t, _) : focusterm) =
   171   let
   172     val recurse = find_subterms ctxt condition
   173     val recursive_matches =
   174       case t of
   175         _ $ _ => Seq.append (ft |> ft_fun ctxt |> recurse) (ft |> ft_arg ctxt |> recurse)
   176       | Abs (_,T,_) => ft |> ft_abs ctxt (NONE, T) |> recurse
   177       | _ => Seq.empty
   178   in
   179     (* If the condition is met, then the current focusterm is part of the
   180        sequence of results. Otherwise, only the results of the recursive
   181        application are. *)
   182     if condition ft
   183     then Seq.cons ft recursive_matches
   184     else recursive_matches
   185   end
   186 
   187 (* Find all subterms that might be a valid point to apply a rule. *)
   188 fun valid_match_points ctxt =
   189   let
   190     fun is_valid (l $ _) = is_valid l
   191       | is_valid (Abs (_, _, a)) = is_valid a
   192       | is_valid (Var _) = false
   193       | is_valid (Bound _) = false
   194       | is_valid _ = true
   195   in
   196     find_subterms ctxt (#2 #> is_valid )
   197   end
   198 
   199 fun is_hole (Var ((name, _), _)) = (name = holeN)
   200   | is_hole _ = false
   201 
   202 fun is_hole_const (Const (@{const_name rewrite_HOLE}, _)) = true
   203   | is_hole_const _ = false
   204 
   205 val hole_syntax =
   206   let
   207     (* Modified variant of Term.replace_hole *)
   208     fun replace_hole Ts (Const (@{const_name rewrite_HOLE}, T)) i =
   209           (list_comb (Var ((holeN, i), Ts ---> T), map_range Bound (length Ts)), i + 1)
   210       | replace_hole Ts (Abs (x, T, t)) i =
   211           let val (t', i') = replace_hole (T :: Ts) t i
   212           in (Abs (x, T, t'), i') end
   213       | replace_hole Ts (t $ u) i =
   214           let
   215             val (t', i') = replace_hole Ts t i
   216             val (u', i'') = replace_hole Ts u i'
   217           in (t' $ u', i'') end
   218       | replace_hole _ a i = (a, i)
   219     fun prep_holes ts = #1 (fold_map (replace_hole []) ts 1)
   220   in
   221     Context.proof_map (Syntax_Phases.term_check 101 "hole_expansion" (K prep_holes))
   222     #> Proof_Context.set_mode Proof_Context.mode_pattern
   223   end
   224 
   225 (* Find a subterm of the focusterm matching the pattern. *)
   226 fun find_matches ctxt pattern_list =
   227   let
   228     fun move_term ctxt (t, off) (ft : focusterm) =
   229       let
   230         val thy = Proof_Context.theory_of ctxt
   231 
   232         val eta_expands =
   233           let val (_, ts) = strip_comb t
   234           in map fastype_of (snd (take_suffix is_Var ts)) end
   235 
   236         fun do_match (tyenv, u, pos) =
   237           case try (Pattern.match thy (t,u)) (tyenv, Vartab.empty) of
   238             NONE => NONE
   239           | SOME (tyenv', _) => SOME (off (tyenv', u, pos))
   240 
   241         fun match_argT T u =
   242           let val (U, _) = dest_funT (fastype_of u)
   243           in try (Sign.typ_match thy (T,U)) end
   244           handle TYPE _ => K NONE
   245 
   246         fun desc [] ft = do_match ft
   247           | desc (T :: Ts) (ft as (tyenv , u, pos)) =
   248             case do_match ft of
   249               NONE =>
   250                 (case match_argT T u tyenv of
   251                   NONE => NONE
   252                 | SOME tyenv' => desc Ts (ft_abs ctxt (NONE, T) (tyenv', u, pos)))
   253             | SOME ft => SOME ft
   254       in desc eta_expands ft end
   255 
   256     fun move_assms ctxt (ft: focusterm) =
   257       let
   258         fun f () = case try (ft_assm ctxt) ft of
   259             NONE => NONE
   260           | SOME ft' => SOME (ft', move_assms ctxt (ft_imp ctxt ft))
   261       in Seq.make f end
   262 
   263     fun apply_pat At = Seq.map (ft_judgment ctxt)
   264       | apply_pat In = Seq.maps (valid_match_points ctxt)
   265       | apply_pat Asm = Seq.maps (move_assms ctxt o ft_params ctxt)
   266       | apply_pat Concl = Seq.map (ft_concl ctxt o ft_params ctxt)
   267       | apply_pat (For idents) = Seq.map_filter ((ft_for ctxt (map (apfst SOME) idents)))
   268       | apply_pat (Term x) = Seq.map_filter ( (move_term ctxt x))
   269 
   270     fun apply_pats ft = ft
   271       |> Seq.single
   272       |> fold apply_pat pattern_list
   273   in
   274     apply_pats
   275   end
   276 
   277 fun instantiate_normalize_env ctxt env thm =
   278   let
   279     fun certs f = map (apply2 (f ctxt))
   280     val prop = Thm.prop_of thm
   281     val norm_type = Envir.norm_type o Envir.type_env
   282     val insts = Term.add_vars prop []
   283       |> map (fn x as (s,T) => (Var (s, norm_type env T), Envir.norm_term env (Var x)))
   284       |> certs Thm.cterm_of
   285     val tyinsts = Term.add_tvars prop []
   286       |> map (fn x => (TVar x, norm_type env (TVar x)))
   287       |> certs Thm.ctyp_of
   288   in Drule.instantiate_normalize (tyinsts, insts) thm end
   289 
   290 fun unify_with_rhs context to env thm =
   291   let
   292     val (_, rhs) = thm |> Thm.concl_of |> Logic.dest_equals
   293     val env' = Pattern.unify context (Logic.mk_term to, Logic.mk_term rhs) env
   294       handle Pattern.Unif => raise NO_TO_MATCH
   295   in env' end
   296 
   297 fun inst_thm_to _ (NONE, _) thm = thm
   298   | inst_thm_to (ctxt : Proof.context) (SOME to, env) thm =
   299       instantiate_normalize_env ctxt (unify_with_rhs (Context.Proof ctxt) to env thm) thm
   300 
   301 fun inst_thm ctxt idents (to, tyenv) thm =
   302   let
   303     (* Replace any identifiers with their corresponding bound variables. *)
   304     val maxidx = Term.maxidx_typs (map (snd o snd) (Vartab.dest tyenv)) 0
   305     val env = Envir.Envir {maxidx = maxidx, tenv = Vartab.empty, tyenv = tyenv}
   306     val replace_idents =
   307       let
   308         fun subst ((n1, s)::ss) (t as Free (n2, _)) = if n1 = n2 then s else subst ss t
   309           | subst _ t = t
   310       in Term.map_aterms (subst idents) end
   311 
   312     val maxidx = Envir.maxidx_of env |> fold Term.maxidx_term (the_list to)
   313     val thm' = Thm.incr_indexes (maxidx + 1) thm
   314   in SOME (inst_thm_to ctxt (Option.map replace_idents to, env) thm') end
   315   handle NO_TO_MATCH => NONE
   316 
   317 (* Rewrite in subgoal i. *)
   318 fun rewrite_goal_with_thm ctxt (pattern, (to, orig_ctxt)) rules = SUBGOAL (fn (t,i) =>
   319   let
   320     val matches = find_matches ctxt pattern (Vartab.empty, t, I)
   321 
   322     fun rewrite_conv insty ctxt bounds =
   323       CConv.rewrs_cconv (map_filter (inst_thm ctxt bounds insty) rules)
   324 
   325     val export = singleton (Proof_Context.export ctxt orig_ctxt)
   326 
   327     fun distinct_prems th =
   328       case Seq.pull (distinct_subgoals_tac th) of
   329         NONE => th
   330       | SOME (th', _) => th'
   331 
   332     fun tac (tyenv, _, position) = CCONVERSION
   333       (distinct_prems o export o position (rewrite_conv (to, tyenv)) ctxt []) i
   334   in
   335     SEQ_CONCAT (Seq.map tac matches)
   336   end)
   337 
   338 fun rewrite_tac ctxt pattern thms =
   339   let
   340     val thms' = maps (prep_meta_eq ctxt) thms
   341     val tac = rewrite_goal_with_thm ctxt pattern thms'
   342   in tac end
   343 
   344 val _ =
   345   Theory.setup
   346   let
   347     fun mk_fix s = (Binding.name s, NONE, NoSyn)
   348 
   349     val raw_pattern : (string, binding * string option * mixfix) pattern list parser =
   350       let
   351         val sep = (Args.$$$ "at" >> K At) || (Args.$$$ "in" >> K In)
   352         val atom =  (Args.$$$ "asm" >> K Asm) ||
   353           (Args.$$$ "concl" >> K Concl) ||
   354           (Args.$$$ "for" |-- Args.parens (Scan.optional Parse.fixes []) >> For) ||
   355           (Parse.term >> Term)
   356         val sep_atom = sep -- atom >> (fn (s,a) => [s,a])
   357 
   358         fun append_default [] = [Concl, In]
   359           | append_default (ps as Term _ :: _) = Concl :: In :: ps
   360           | append_default ps = ps
   361 
   362       in Scan.repeat sep_atom >> (flat #> rev #> append_default) end
   363 
   364     fun context_lift (scan : 'a parser) f = fn (context : Context.generic, toks) =>
   365       let
   366         val (r, toks') = scan toks
   367         val (r', context') = Context.map_proof_result (fn ctxt => f ctxt r) context
   368       in (r', (context', toks' : Token.T list)) end
   369 
   370     fun read_fixes fixes ctxt =
   371       let fun read_typ (b, rawT, mx) = (b, Option.map (Syntax.read_typ ctxt) rawT, mx)
   372       in Proof_Context.add_fixes (map read_typ fixes) ctxt end
   373 
   374     fun prep_pats ctxt (ps : (string, binding * string option * mixfix) pattern list) =
   375       let
   376         fun add_constrs ctxt n (Abs (x, T, t)) =
   377             let
   378               val (x', ctxt') = yield_singleton Proof_Context.add_fixes (mk_fix x) ctxt
   379             in
   380               (case add_constrs ctxt' (n+1) t of
   381                 NONE => NONE
   382               | SOME ((ctxt'', n', xs), t') =>
   383                   let
   384                     val U = Type_Infer.mk_param n []
   385                     val u = Type.constraint (U --> dummyT) (Abs (x, T, t'))
   386                   in SOME ((ctxt'', n', (x', U) :: xs), u) end)
   387             end
   388           | add_constrs ctxt n (l $ r) =
   389             (case add_constrs ctxt n l of
   390               SOME (c, l') => SOME (c, l' $ r)
   391             | NONE =>
   392               (case add_constrs ctxt n r of
   393                 SOME (c, r') => SOME (c, l $ r')
   394               | NONE => NONE))
   395           | add_constrs ctxt n t =
   396             if is_hole_const t then SOME ((ctxt, n, []), t) else NONE
   397 
   398         fun prep (Term s) (n, ctxt) =
   399             let
   400               val t = Syntax.parse_term ctxt s
   401               val ((ctxt', n', bs), t') =
   402                 the_default ((ctxt, n, []), t) (add_constrs ctxt (n+1) t)
   403             in (Term (t', bs), (n', ctxt')) end
   404           | prep (For ss) (n, ctxt) =
   405             let val (ns, ctxt') = read_fixes ss ctxt
   406             in (For ns, (n, ctxt')) end
   407           | prep At (n,ctxt) = (At, (n, ctxt))
   408           | prep In (n,ctxt) = (In, (n, ctxt))
   409           | prep Concl (n,ctxt) = (Concl, (n, ctxt))
   410           | prep Asm (n,ctxt) = (Asm, (n, ctxt))
   411 
   412         val (xs, (_, ctxt')) = fold_map prep ps (0, ctxt)
   413 
   414       in (xs, ctxt') end
   415 
   416     fun prep_args ctxt (((raw_pats, raw_to), raw_ths)) =
   417       let
   418 
   419         fun interpret_term_patterns ctxt =
   420           let
   421 
   422             fun descend_hole fixes (Abs (_, _, t)) =
   423                 (case descend_hole fixes t of
   424                   NONE => NONE
   425                 | SOME (fix :: fixes', pos) => SOME (fixes', pos o ft_abs ctxt (apfst SOME fix))
   426                 | SOME ([], _) => raise Match (* XXX -- check phases modified binding *))
   427               | descend_hole fixes (t as l $ r) =
   428                 let val (f, _) = strip_comb t
   429                 in
   430                   if is_hole f
   431                   then SOME (fixes, I)
   432                   else
   433                     (case descend_hole fixes l of
   434                       SOME (fixes', pos) => SOME (fixes', pos o ft_fun ctxt)
   435                     | NONE =>
   436                       (case descend_hole fixes r of
   437                         SOME (fixes', pos) => SOME (fixes', pos o ft_arg ctxt)
   438                       | NONE => NONE))
   439                 end
   440               | descend_hole fixes t =
   441                 if is_hole t then SOME (fixes, I) else NONE
   442 
   443             fun f (t, fixes) = Term (t, (descend_hole (rev fixes) #> the_default ([], I) #> snd) t)
   444 
   445           in map (map_term_pattern f) end
   446 
   447         fun check_terms ctxt ps to =
   448           let
   449             fun safe_chop (0: int) xs = ([], xs)
   450               | safe_chop n (x :: xs) = chop (n - 1) xs |>> cons x
   451               | safe_chop _ _ = raise Match
   452 
   453             fun reinsert_pat _ (Term (_, cs)) (t :: ts) =
   454                 let val (cs', ts') = safe_chop (length cs) ts
   455                 in (Term (t, map dest_Free cs'), ts') end
   456               | reinsert_pat _ (Term _) [] = raise Match
   457               | reinsert_pat ctxt (For ss) ts =
   458                 let val fixes = map (fn s => (s, Variable.default_type ctxt s)) ss
   459                 in (For fixes, ts) end
   460               | reinsert_pat _ At ts = (At, ts)
   461               | reinsert_pat _ In ts = (In, ts)
   462               | reinsert_pat _ Concl ts = (Concl, ts)
   463               | reinsert_pat _ Asm ts = (Asm, ts)
   464 
   465             fun free_constr (s,T) = Type.constraint T (Free (s, dummyT))
   466             fun mk_free_constrs (Term (t, cs)) = t :: map free_constr cs
   467               | mk_free_constrs _ = []
   468 
   469             val ts = maps mk_free_constrs ps @ the_list to
   470               |> Syntax.check_terms (hole_syntax ctxt)
   471             val ctxt' = fold Variable.declare_term ts ctxt
   472             val (ps', (to', ts')) = fold_map (reinsert_pat ctxt') ps ts
   473               ||> (fn xs => case to of NONE => (NONE, xs) | SOME _ => (SOME (hd xs), tl xs))
   474             val _ = case ts' of (_ :: _) => raise Match | [] => ()
   475           in ((ps', to'), ctxt') end
   476 
   477         val (pats, ctxt') = prep_pats ctxt raw_pats
   478 
   479         val ths = Attrib.eval_thms ctxt' raw_ths
   480         val to = Option.map (Syntax.parse_term ctxt') raw_to
   481 
   482         val ((pats', to'), ctxt'') = check_terms ctxt' pats to
   483         val pats'' = interpret_term_patterns ctxt'' pats'
   484 
   485       in ((pats'', ths, (to', ctxt)), ctxt'') end
   486 
   487     val to_parser = Scan.option ((Args.$$$ "to") |-- Parse.term)
   488 
   489     val subst_parser =
   490       let val scan = raw_pattern -- to_parser -- Parse.xthms1
   491       in context_lift scan prep_args end
   492   in
   493     Method.setup @{binding rewrite} (subst_parser >>
   494       (fn (pattern, inthms, inst) => fn ctxt =>
   495         SIMPLE_METHOD' (rewrite_tac ctxt (pattern, inst) inthms)))
   496       "single-step rewriting, allowing subterm selection via patterns."
   497   end
   498 end