(* Title: HOL/Library/rewrite.ML
Author: Christoph Traut, Lars Noschinski, TU Muenchen
This is a rewrite method that supports subterm-selection based on patterns.
The patterns accepted by rewrite are of the following form:
<atom> ::= <term> | "concl" | "asm" | "for" "(" <names> ")"
<pattern> ::= (in <atom> | at <atom>) [<pattern>]
<args> ::= [<pattern>] ("to" <term>) <thms>
This syntax was clearly inspired by Gonthier's and Tassi's language of
patterns but has diverged significantly during its development.
We also allow introduction of identifiers for bound variables,
which can then be used to match arbitrary subterms inside abstractions.
*)
signature REWRITE =
sig
datatype ('a, 'b) pattern = At | In | Term of 'a | Concl | Asm | For of 'b list
val mk_hole: int -> typ -> term
val rewrite: Proof.context
-> (term * (string * typ) list, string * typ option) pattern list * term option
-> thm list
-> cterm
-> thm Seq.seq
val rewrite_tac: Proof.context
-> (term * (string * typ) list, string * typ option) pattern list * term option
-> thm list
-> int
-> tactic
end
structure Rewrite : REWRITE =
struct
datatype ('a, 'b) pattern = At | In | Term of 'a | Concl | Asm | For of 'b list
fun map_term_pattern f (Term x) = f x
| map_term_pattern _ (For ss) = (For ss)
| map_term_pattern _ At = At
| map_term_pattern _ In = In
| map_term_pattern _ Concl = Concl
| map_term_pattern _ Asm = Asm
exception NO_TO_MATCH
fun SEQ_CONCAT (tacq : tactic Seq.seq) : tactic = fn st => Seq.maps (fn tac => tac st) tacq
(* We rewrite subterms using rewrite conversions. These are conversions
that also take a context and a list of identifiers for bound variables
as parameters. *)
type rewrite_conv = Proof.context -> (string * term) list -> conv
(* To apply such a rewrite conversion to a subterm of our goal, we use
subterm positions, which are just functions that map a rewrite conversion,
working on the top level, to a new rewrite conversion, working on
a specific subterm.
During substitution, we are traversing the goal to find subterms that
we can rewrite. For each of these subterms, a subterm position is
created and later used in creating a conversion that we use to try and
rewrite this subterm. *)
type subterm_position = rewrite_conv -> rewrite_conv
(* A focusterm represents a subterm. It is a tuple (t, p), consisting
of the subterm t itself and its subterm position p. *)
type focusterm = Type.tyenv * term * subterm_position
val dummyN = Name.internal "__dummy"
val holeN = Name.internal "_hole"
fun prep_meta_eq ctxt =
Simplifier.mksimps ctxt #> map Drule.zero_var_indexes
(* rewrite conversions *)
fun abs_rewr_cconv ident : subterm_position =
let
fun add_ident NONE _ l = l
| add_ident (SOME name) ct l = (name, Thm.term_of ct) :: l
fun inner rewr ctxt idents =
CConv.abs_cconv (fn (ct, ctxt) => rewr ctxt (add_ident ident ct idents)) ctxt
in inner end
val fun_rewr_cconv : subterm_position = fn rewr => CConv.fun_cconv oo rewr
val arg_rewr_cconv : subterm_position = fn rewr => CConv.arg_cconv oo rewr
val imp_rewr_cconv : subterm_position = fn rewr => CConv.concl_cconv 1 oo rewr
val with_prems_rewr_cconv : subterm_position = fn rewr => CConv.with_prems_cconv ~1 oo rewr
(* focus terms *)
fun ft_abs ctxt (s,T) (tyenv, u, pos) =
case try (fastype_of #> dest_funT) u of
NONE => raise TERM ("ft_abs: no function type", [u])
| SOME (U, _) =>
let
val tyenv' =
if T = dummyT then tyenv
else Sign.typ_match (Proof_Context.theory_of ctxt) (T, U) tyenv
val x = Free (the_default (Name.internal dummyN) s, Envir.norm_type tyenv' T)
val eta_expand_cconv = CConv.rewr_cconv @{thm eta_expand}
fun eta_expand rewr ctxt bounds = eta_expand_cconv then_conv rewr ctxt bounds
val (u', pos') =
case u of
Abs (_,_,t') => (subst_bound (x, t'), pos o abs_rewr_cconv s)
| _ => (u $ x, pos o eta_expand o abs_rewr_cconv s)
in (tyenv', u', pos') end
handle Pattern.MATCH => raise TYPE ("ft_abs: types don't match", [T,U], [u])
fun ft_fun _ (tyenv, l $ _, pos) = (tyenv, l, pos o fun_rewr_cconv)
| ft_fun ctxt (ft as (_, Abs (_, T, _ $ Bound 0), _)) = (ft_fun ctxt o ft_abs ctxt (NONE, T)) ft
| ft_fun _ (_, t, _) = raise TERM ("ft_fun", [t])
local
fun ft_arg_gen cconv _ (tyenv, _ $ r, pos) = (tyenv, r, pos o cconv)
| ft_arg_gen cconv ctxt (ft as (_, Abs (_, T, _ $ Bound 0), _)) = (ft_arg_gen cconv ctxt o ft_abs ctxt (NONE, T)) ft
| ft_arg_gen _ _ (_, t, _) = raise TERM ("ft_arg", [t])
in
val ft_arg = ft_arg_gen arg_rewr_cconv
val ft_imp = ft_arg_gen imp_rewr_cconv
end
(* Move to B in !!x_1 ... x_n. B. Do not eta-expand *)
fun ft_params ctxt (ft as (_, t, _) : focusterm) =
case t of
Const (@{const_name "Pure.all"}, _) $ Abs (_,T,_) =>
(ft_params ctxt o ft_abs ctxt (NONE, T) o ft_arg ctxt) ft
| Const (@{const_name "Pure.all"}, _) =>
(ft_params ctxt o ft_arg ctxt) ft
| _ => ft
fun ft_all ctxt ident (ft as (_, Const (@{const_name "Pure.all"}, T) $ _, _) : focusterm) =
let
val def_U = T |> dest_funT |> fst |> dest_funT |> fst
val ident' = apsnd (the_default (def_U)) ident
in (ft_abs ctxt ident' o ft_arg ctxt) ft end
| ft_all _ _ (_, t, _) = raise TERM ("ft_all", [t])
fun ft_for ctxt idents (ft as (_, t, _) : focusterm) =
let
fun f rev_idents (Const (@{const_name "Pure.all"}, _) $ t) =
let
val (rev_idents', desc) = f rev_idents (case t of Abs (_,_,u) => u | _ => t)
in
case rev_idents' of
[] => ([], desc o ft_all ctxt (NONE, NONE))
| (x :: xs) => (xs , desc o ft_all ctxt x)
end
| f rev_idents _ = (rev_idents, I)
in
case f (rev idents) t of
([], ft') => SOME (ft' ft)
| _ => NONE
end
fun ft_concl ctxt (ft as (_, t, _) : focusterm) =
case t of
(Const (@{const_name "Pure.imp"}, _) $ _) $ _ => (ft_concl ctxt o ft_imp ctxt) ft
| _ => ft
fun ft_assm _ (tyenv, (Const (@{const_name "Pure.imp"}, _) $ l) $ _, pos) =
(tyenv, l, pos o with_prems_rewr_cconv)
| ft_assm _ (_, t, _) = raise TERM ("ft_assm", [t])
fun ft_judgment ctxt (ft as (_, t, _) : focusterm) =
if Object_Logic.is_judgment ctxt t
then ft_arg ctxt ft
else ft
(* Find all subterms that might be a valid point to apply a rule. *)
fun valid_match_points ctxt (ft : focusterm) =
let
fun descend (_, _ $ _, _) = [ft_fun ctxt, ft_arg ctxt]
| descend (_, Abs (_, T, _), _) = [ft_abs ctxt (NONE, T)]
| descend _ = []
fun subseq ft =
descend ft |> Seq.of_list |> Seq.maps (fn f => ft |> f |> valid_match_points ctxt)
fun is_valid (l $ _) = is_valid l
| is_valid (Abs (_, _, a)) = is_valid a
| is_valid (Var _) = false
| is_valid (Bound _) = false
| is_valid _ = true
in
Seq.make (fn () => SOME (ft, subseq ft))
|> Seq.filter (#2 #> is_valid)
end
fun mk_hole i T = Var ((holeN, i), T)
fun is_hole (Var ((name, _), _)) = (name = holeN)
| is_hole _ = false
fun is_hole_const (Const (@{const_name rewrite_HOLE}, _)) = true
| is_hole_const _ = false
val hole_syntax =
let
(* Modified variant of Term.replace_hole *)
fun replace_hole Ts (Const (@{const_name rewrite_HOLE}, T)) i =
(list_comb (mk_hole i (Ts ---> T), map_range Bound (length Ts)), i + 1)
| replace_hole Ts (Abs (x, T, t)) i =
let val (t', i') = replace_hole (T :: Ts) t i
in (Abs (x, T, t'), i') end
| replace_hole Ts (t $ u) i =
let
val (t', i') = replace_hole Ts t i
val (u', i'') = replace_hole Ts u i'
in (t' $ u', i'') end
| replace_hole _ a i = (a, i)
fun prep_holes ts = #1 (fold_map (replace_hole []) ts 1)
in
Context.proof_map (Syntax_Phases.term_check 101 "hole_expansion" (K prep_holes))
#> Proof_Context.set_mode Proof_Context.mode_pattern
end
(* Find a subterm of the focusterm matching the pattern. *)
fun find_matches ctxt pattern_list =
let
fun move_term ctxt (t, off) (ft : focusterm) =
let
val thy = Proof_Context.theory_of ctxt
val eta_expands =
let val (_, ts) = strip_comb t
in map fastype_of (snd (take_suffix is_Var ts)) end
fun do_match (tyenv, u, pos) =
case try (Pattern.match thy (t,u)) (tyenv, Vartab.empty) of
NONE => NONE
| SOME (tyenv', _) => SOME (off (tyenv', u, pos))
fun match_argT T u =
let val (U, _) = dest_funT (fastype_of u)
in try (Sign.typ_match thy (T,U)) end
handle TYPE _ => K NONE
fun desc [] ft = do_match ft
| desc (T :: Ts) (ft as (tyenv , u, pos)) =
case do_match ft of
NONE =>
(case match_argT T u tyenv of
NONE => NONE
| SOME tyenv' => desc Ts (ft_abs ctxt (NONE, T) (tyenv', u, pos)))
| SOME ft => SOME ft
in desc eta_expands ft end
fun move_assms ctxt (ft: focusterm) =
let
fun f () = case try (ft_assm ctxt) ft of
NONE => NONE
| SOME ft' => SOME (ft', move_assms ctxt (ft_imp ctxt ft))
in Seq.make f end
fun apply_pat At = Seq.map (ft_judgment ctxt)
| apply_pat In = Seq.maps (valid_match_points ctxt)
| apply_pat Asm = Seq.maps (move_assms ctxt o ft_params ctxt)
| apply_pat Concl = Seq.map (ft_concl ctxt o ft_params ctxt)
| apply_pat (For idents) = Seq.map_filter ((ft_for ctxt (map (apfst SOME) idents)))
| apply_pat (Term x) = Seq.map_filter ( (move_term ctxt x))
fun apply_pats ft = ft
|> Seq.single
|> fold apply_pat pattern_list
in
apply_pats
end
fun instantiate_normalize_env ctxt env thm =
let
fun certs f = map (apply2 (f ctxt))
val prop = Thm.prop_of thm
val norm_type = Envir.norm_type o Envir.type_env
val insts = Term.add_vars prop []
|> map (fn x as (s,T) => (Var (s, norm_type env T), Envir.norm_term env (Var x)))
|> certs Thm.cterm_of
val tyinsts = Term.add_tvars prop []
|> map (fn x => (TVar x, norm_type env (TVar x)))
|> certs Thm.ctyp_of
in Drule.instantiate_normalize (tyinsts, insts) thm end
fun unify_with_rhs context to env thm =
let
val (_, rhs) = thm |> Thm.concl_of |> Logic.dest_equals
val env' = Pattern.unify context (Logic.mk_term to, Logic.mk_term rhs) env
handle Pattern.Unif => raise NO_TO_MATCH
in env' end
fun inst_thm_to _ (NONE, _) thm = thm
| inst_thm_to (ctxt : Proof.context) (SOME to, env) thm =
instantiate_normalize_env ctxt (unify_with_rhs (Context.Proof ctxt) to env thm) thm
fun inst_thm ctxt idents (to, tyenv) thm =
let
(* Replace any identifiers with their corresponding bound variables. *)
val maxidx = Term.maxidx_typs (map (snd o snd) (Vartab.dest tyenv)) 0
val env = Envir.Envir {maxidx = maxidx, tenv = Vartab.empty, tyenv = tyenv}
val replace_idents =
let
fun subst ((n1, s)::ss) (t as Free (n2, _)) = if n1 = n2 then s else subst ss t
| subst _ t = t
in Term.map_aterms (subst idents) end
val maxidx = Envir.maxidx_of env |> fold Term.maxidx_term (the_list to)
val thm' = Thm.incr_indexes (maxidx + 1) thm
in SOME (inst_thm_to ctxt (Option.map replace_idents to, env) thm') end
handle NO_TO_MATCH => NONE
local
fun rewrite_raw ctxt (pattern, to) thms ct =
let
fun interpret_term_patterns ctxt =
let
fun descend_hole fixes (Abs (_, _, t)) =
(case descend_hole fixes t of
NONE => NONE
| SOME (fix :: fixes', pos) => SOME (fixes', pos o ft_abs ctxt (apfst SOME fix))
| SOME ([], _) => raise Match (* XXX -- check phases modified binding *))
| descend_hole fixes (t as l $ r) =
let val (f, _) = strip_comb t
in
if is_hole f
then SOME (fixes, I)
else
(case descend_hole fixes l of
SOME (fixes', pos) => SOME (fixes', pos o ft_fun ctxt)
| NONE =>
(case descend_hole fixes r of
SOME (fixes', pos) => SOME (fixes', pos o ft_arg ctxt)
| NONE => NONE))
end
| descend_hole fixes t =
if is_hole t then SOME (fixes, I) else NONE
fun f (t, fixes) = Term (t, (descend_hole (rev fixes) #> the_default ([], I) #> snd) t)
in map (map_term_pattern f) end
val pattern' = interpret_term_patterns ctxt pattern
val matches = find_matches ctxt pattern' (Vartab.empty, Thm.term_of ct, I)
val thms' = maps (prep_meta_eq ctxt) thms
fun rewrite_conv insty ctxt bounds =
CConv.rewrs_cconv (map_filter (inst_thm ctxt bounds insty) thms')
fun distinct_prems th =
case Seq.pull (distinct_subgoals_tac th) of
NONE => th
| SOME (th', _) => th'
fun conv ((tyenv, _, position) : focusterm) =
distinct_prems o position (rewrite_conv (to, tyenv)) ctxt []
in Seq.map (fn ft => conv ft) matches end
in
fun rewrite ctxt pat thms ct =
rewrite_raw ctxt pat thms ct |> Seq.map_filter (fn cv => try cv ct)
fun rewrite_export_tac ctxt (pat, pat_ctxt) thms =
let
val export = case pat_ctxt of
NONE => I
| SOME inner => singleton (Proof_Context.export inner ctxt)
val tac = CSUBGOAL (fn (ct, i) =>
rewrite_raw ctxt pat thms ct
|> Seq.map (fn cv => CCONVERSION (export o cv) i)
|> SEQ_CONCAT)
in tac end
fun rewrite_tac ctxt pat = rewrite_export_tac ctxt (pat, NONE)
end
val _ =
Theory.setup
let
fun mk_fix s = (Binding.name s, NONE, NoSyn)
val raw_pattern : (string, binding * string option * mixfix) pattern list parser =
let
val sep = (Args.$$$ "at" >> K At) || (Args.$$$ "in" >> K In)
val atom = (Args.$$$ "asm" >> K Asm) ||
(Args.$$$ "concl" >> K Concl) ||
(Args.$$$ "for" |-- Args.parens (Scan.optional Parse.fixes []) >> For) ||
(Parse.term >> Term)
val sep_atom = sep -- atom >> (fn (s,a) => [s,a])
fun append_default [] = [Concl, In]
| append_default (ps as Term _ :: _) = Concl :: In :: ps
| append_default ps = ps
in Scan.repeat sep_atom >> (flat #> rev #> append_default) end
fun context_lift (scan : 'a parser) f = fn (context : Context.generic, toks) =>
let
val (r, toks') = scan toks
val (r', context') = Context.map_proof_result (fn ctxt => f ctxt r) context
in (r', (context', toks' : Token.T list)) end
fun read_fixes fixes ctxt =
let fun read_typ (b, rawT, mx) = (b, Option.map (Syntax.read_typ ctxt) rawT, mx)
in Proof_Context.add_fixes (map read_typ fixes) ctxt end
fun prep_pats ctxt (ps : (string, binding * string option * mixfix) pattern list) =
let
fun add_constrs ctxt n (Abs (x, T, t)) =
let
val (x', ctxt') = yield_singleton Proof_Context.add_fixes (mk_fix x) ctxt
in
(case add_constrs ctxt' (n+1) t of
NONE => NONE
| SOME ((ctxt'', n', xs), t') =>
let
val U = Type_Infer.mk_param n []
val u = Type.constraint (U --> dummyT) (Abs (x, T, t'))
in SOME ((ctxt'', n', (x', U) :: xs), u) end)
end
| add_constrs ctxt n (l $ r) =
(case add_constrs ctxt n l of
SOME (c, l') => SOME (c, l' $ r)
| NONE =>
(case add_constrs ctxt n r of
SOME (c, r') => SOME (c, l $ r')
| NONE => NONE))
| add_constrs ctxt n t =
if is_hole_const t then SOME ((ctxt, n, []), t) else NONE
fun prep (Term s) (n, ctxt) =
let
val t = Syntax.parse_term ctxt s
val ((ctxt', n', bs), t') =
the_default ((ctxt, n, []), t) (add_constrs ctxt (n+1) t)
in (Term (t', bs), (n', ctxt')) end
| prep (For ss) (n, ctxt) =
let val (ns, ctxt') = read_fixes ss ctxt
in (For ns, (n, ctxt')) end
| prep At (n,ctxt) = (At, (n, ctxt))
| prep In (n,ctxt) = (In, (n, ctxt))
| prep Concl (n,ctxt) = (Concl, (n, ctxt))
| prep Asm (n,ctxt) = (Asm, (n, ctxt))
val (xs, (_, ctxt')) = fold_map prep ps (0, ctxt)
in (xs, ctxt') end
fun prep_args ctxt (((raw_pats, raw_to), raw_ths)) =
let
fun check_terms ctxt ps to =
let
fun safe_chop (0: int) xs = ([], xs)
| safe_chop n (x :: xs) = chop (n - 1) xs |>> cons x
| safe_chop _ _ = raise Match
fun reinsert_pat _ (Term (_, cs)) (t :: ts) =
let val (cs', ts') = safe_chop (length cs) ts
in (Term (t, map dest_Free cs'), ts') end
| reinsert_pat _ (Term _) [] = raise Match
| reinsert_pat ctxt (For ss) ts =
let val fixes = map (fn s => (s, Variable.default_type ctxt s)) ss
in (For fixes, ts) end
| reinsert_pat _ At ts = (At, ts)
| reinsert_pat _ In ts = (In, ts)
| reinsert_pat _ Concl ts = (Concl, ts)
| reinsert_pat _ Asm ts = (Asm, ts)
fun free_constr (s,T) = Type.constraint T (Free (s, dummyT))
fun mk_free_constrs (Term (t, cs)) = t :: map free_constr cs
| mk_free_constrs _ = []
val ts = maps mk_free_constrs ps @ the_list to
|> Syntax.check_terms (hole_syntax ctxt)
val ctxt' = fold Variable.declare_term ts ctxt
val (ps', (to', ts')) = fold_map (reinsert_pat ctxt') ps ts
||> (fn xs => case to of NONE => (NONE, xs) | SOME _ => (SOME (hd xs), tl xs))
val _ = case ts' of (_ :: _) => raise Match | [] => ()
in ((ps', to'), ctxt') end
val (pats, ctxt') = prep_pats ctxt raw_pats
val ths = Attrib.eval_thms ctxt' raw_ths
val to = Option.map (Syntax.parse_term ctxt') raw_to
val ((pats', to'), ctxt'') = check_terms ctxt' pats to
in ((pats', ths, (to', ctxt)), ctxt'') end
val to_parser = Scan.option ((Args.$$$ "to") |-- Parse.term)
val subst_parser =
let val scan = raw_pattern -- to_parser -- Parse.xthms1
in context_lift scan prep_args end
in
Method.setup @{binding rewrite} (subst_parser >>
(fn (pattern, inthms, (to, pat_ctxt)) => fn orig_ctxt =>
SIMPLE_METHOD' (rewrite_export_tac orig_ctxt ((pattern, to), SOME pat_ctxt) inthms)))
"single-step rewriting, allowing subterm selection via patterns."
end
end