src/HOL/Library/rewrite.ML
author noschinl
Wed, 15 Apr 2015 15:10:01 +0200
changeset 60079 ef4fe30e9ef1
parent 60055 aa3d2a6dd99e
child 60088 0a064330a885
permissions -rw-r--r--
rewrite: add ML interface

(*  Title:      HOL/Library/rewrite.ML
    Author:     Christoph Traut, Lars Noschinski, TU Muenchen

This is a rewrite method that supports subterm-selection based on patterns.

The patterns accepted by rewrite are of the following form:
  <atom>    ::= <term> | "concl" | "asm" | "for" "(" <names> ")"
  <pattern> ::= (in <atom> | at <atom>) [<pattern>]
  <args>    ::= [<pattern>] ("to" <term>) <thms>

This syntax was clearly inspired by Gonthier's and Tassi's language of
patterns but has diverged significantly during its development.

We also allow introduction of identifiers for bound variables,
which can then be used to match arbitrary subterms inside abstractions.
*)

signature REWRITE =
sig
  datatype ('a, 'b) pattern = At | In | Term of 'a | Concl | Asm | For of 'b list

  val mk_hole: int -> typ -> term

  val rewrite: Proof.context
    -> (term * (string * typ) list, string * typ option) pattern list * term option
    -> thm list
    -> cterm
    -> thm Seq.seq

  val rewrite_tac: Proof.context
    -> (term * (string * typ) list, string * typ option) pattern list * term option
    -> thm list
    -> int
    -> tactic
end

structure Rewrite : REWRITE =
struct

datatype ('a, 'b) pattern = At | In | Term of 'a | Concl | Asm | For of 'b list

fun map_term_pattern f (Term x) = f x
  | map_term_pattern _ (For ss) = (For ss)
  | map_term_pattern _ At = At
  | map_term_pattern _ In = In
  | map_term_pattern _ Concl = Concl
  | map_term_pattern _ Asm = Asm


exception NO_TO_MATCH

fun SEQ_CONCAT (tacq : tactic Seq.seq) : tactic = fn st => Seq.maps (fn tac => tac st) tacq

(* We rewrite subterms using rewrite conversions. These are conversions
   that also take a context and a list of identifiers for bound variables
   as parameters. *)
type rewrite_conv = Proof.context -> (string * term) list -> conv

(* To apply such a rewrite conversion to a subterm of our goal, we use
   subterm positions, which are just functions that map a rewrite conversion,
   working on the top level, to a new rewrite conversion, working on
   a specific subterm.

   During substitution, we are traversing the goal to find subterms that
   we can rewrite. For each of these subterms, a subterm position is
   created and later used in creating a conversion that we use to try and
   rewrite this subterm. *)
type subterm_position = rewrite_conv -> rewrite_conv

(* A focusterm represents a subterm. It is a tuple (t, p), consisting
  of the subterm t itself and its subterm position p. *)
type focusterm = Type.tyenv * term * subterm_position

val dummyN = Name.internal "__dummy"
val holeN = Name.internal "_hole"

fun prep_meta_eq ctxt =
  Simplifier.mksimps ctxt #> map Drule.zero_var_indexes


(* rewrite conversions *)

fun abs_rewr_cconv ident : subterm_position =
  let
    fun add_ident NONE _ l = l
      | add_ident (SOME name) ct l = (name, Thm.term_of ct) :: l
    fun inner rewr ctxt idents =
      CConv.abs_cconv (fn (ct, ctxt) => rewr ctxt (add_ident ident ct idents)) ctxt
  in inner end

val fun_rewr_cconv : subterm_position = fn rewr => CConv.fun_cconv oo rewr
val arg_rewr_cconv : subterm_position = fn rewr => CConv.arg_cconv oo rewr
val imp_rewr_cconv : subterm_position = fn rewr => CConv.concl_cconv 1 oo rewr
val with_prems_rewr_cconv : subterm_position = fn rewr => CConv.with_prems_cconv ~1 oo rewr


(* focus terms *)

fun ft_abs ctxt (s,T) (tyenv, u, pos) =
  case try (fastype_of #> dest_funT) u of
    NONE => raise TERM ("ft_abs: no function type", [u])
  | SOME (U, _) =>
      let
        val tyenv' =
          if T = dummyT then tyenv
          else Sign.typ_match (Proof_Context.theory_of ctxt) (T, U) tyenv
        val x = Free (the_default (Name.internal dummyN) s, Envir.norm_type tyenv' T)
        val eta_expand_cconv = CConv.rewr_cconv @{thm eta_expand}
        fun eta_expand rewr ctxt bounds = eta_expand_cconv then_conv rewr ctxt bounds
        val (u', pos') =
          case u of
            Abs (_,_,t') => (subst_bound (x, t'), pos o abs_rewr_cconv s)
          | _ => (u $ x, pos o eta_expand o abs_rewr_cconv s)
      in (tyenv', u', pos') end
      handle Pattern.MATCH => raise TYPE ("ft_abs: types don't match", [T,U], [u])

fun ft_fun _ (tyenv, l $ _, pos) = (tyenv, l, pos o fun_rewr_cconv)
  | ft_fun ctxt (ft as (_, Abs (_, T, _ $ Bound 0), _)) = (ft_fun ctxt o ft_abs ctxt (NONE, T)) ft
  | ft_fun _ (_, t, _) = raise TERM ("ft_fun", [t])

local

fun ft_arg_gen cconv _ (tyenv, _ $ r, pos) = (tyenv, r, pos o cconv)
  | ft_arg_gen cconv ctxt (ft as (_, Abs (_, T, _ $ Bound 0), _)) = (ft_arg_gen cconv ctxt o ft_abs ctxt (NONE, T)) ft
  | ft_arg_gen _ _ (_, t, _) = raise TERM ("ft_arg", [t])

in

val ft_arg = ft_arg_gen arg_rewr_cconv
val ft_imp = ft_arg_gen imp_rewr_cconv

end

(* Move to B in !!x_1 ... x_n. B. Do not eta-expand *)
fun ft_params ctxt (ft as (_, t, _) : focusterm) =
  case t of
    Const (@{const_name "Pure.all"}, _) $ Abs (_,T,_) =>
      (ft_params ctxt o ft_abs ctxt (NONE, T) o ft_arg ctxt) ft
  | Const (@{const_name "Pure.all"}, _) =>
      (ft_params ctxt o ft_arg ctxt) ft
  | _ => ft

fun ft_all ctxt ident (ft as (_, Const (@{const_name "Pure.all"}, T) $ _, _) : focusterm) =
    let
      val def_U = T |> dest_funT |> fst |> dest_funT |> fst
      val ident' = apsnd (the_default (def_U)) ident
    in (ft_abs ctxt ident' o ft_arg ctxt) ft end
  | ft_all _ _ (_, t, _) = raise TERM ("ft_all", [t])

fun ft_for ctxt idents (ft as (_, t, _) : focusterm) =
  let
    fun f rev_idents (Const (@{const_name "Pure.all"}, _) $ t) =
        let
         val (rev_idents', desc) = f rev_idents (case t of Abs (_,_,u) => u | _ => t)
        in
          case rev_idents' of
            [] => ([], desc o ft_all ctxt (NONE, NONE))
          | (x :: xs) => (xs , desc o ft_all ctxt x)
        end
      | f rev_idents _ = (rev_idents, I)
  in
    case f (rev idents) t of
      ([], ft') => SOME (ft' ft)
    | _ => NONE
  end

fun ft_concl ctxt (ft as (_, t, _) : focusterm) =
  case t of
    (Const (@{const_name "Pure.imp"}, _) $ _) $ _ => (ft_concl ctxt o ft_imp ctxt) ft
  | _ => ft

fun ft_assm _ (tyenv, (Const (@{const_name "Pure.imp"}, _) $ l) $ _, pos) =
      (tyenv, l, pos o with_prems_rewr_cconv)
  | ft_assm _ (_, t, _) = raise TERM ("ft_assm", [t])

fun ft_judgment ctxt (ft as (_, t, _) : focusterm) =
  if Object_Logic.is_judgment ctxt t
  then ft_arg ctxt ft
  else ft

(* Find all subterms that might be a valid point to apply a rule. *)
fun valid_match_points ctxt (ft : focusterm) =
  let
    fun descend (_, _ $ _, _) = [ft_fun ctxt, ft_arg ctxt]
      | descend (_, Abs (_, T, _), _) = [ft_abs ctxt (NONE, T)]
      | descend _ = []
    fun subseq ft =
      descend ft |> Seq.of_list |> Seq.maps (fn f => ft |> f |> valid_match_points ctxt)
    fun is_valid (l $ _) = is_valid l
      | is_valid (Abs (_, _, a)) = is_valid a
      | is_valid (Var _) = false
      | is_valid (Bound _) = false
      | is_valid _ = true
  in
    Seq.make (fn () => SOME (ft, subseq ft))
    |> Seq.filter (#2 #> is_valid)
  end

fun mk_hole i T = Var ((holeN, i), T)

fun is_hole (Var ((name, _), _)) = (name = holeN)
  | is_hole _ = false

fun is_hole_const (Const (@{const_name rewrite_HOLE}, _)) = true
  | is_hole_const _ = false

val hole_syntax =
  let
    (* Modified variant of Term.replace_hole *)
    fun replace_hole Ts (Const (@{const_name rewrite_HOLE}, T)) i =
          (list_comb (mk_hole i (Ts ---> T), map_range Bound (length Ts)), i + 1)
      | replace_hole Ts (Abs (x, T, t)) i =
          let val (t', i') = replace_hole (T :: Ts) t i
          in (Abs (x, T, t'), i') end
      | replace_hole Ts (t $ u) i =
          let
            val (t', i') = replace_hole Ts t i
            val (u', i'') = replace_hole Ts u i'
          in (t' $ u', i'') end
      | replace_hole _ a i = (a, i)
    fun prep_holes ts = #1 (fold_map (replace_hole []) ts 1)
  in
    Context.proof_map (Syntax_Phases.term_check 101 "hole_expansion" (K prep_holes))
    #> Proof_Context.set_mode Proof_Context.mode_pattern
  end

(* Find a subterm of the focusterm matching the pattern. *)
fun find_matches ctxt pattern_list =
  let
    fun move_term ctxt (t, off) (ft : focusterm) =
      let
        val thy = Proof_Context.theory_of ctxt

        val eta_expands =
          let val (_, ts) = strip_comb t
          in map fastype_of (snd (take_suffix is_Var ts)) end

        fun do_match (tyenv, u, pos) =
          case try (Pattern.match thy (t,u)) (tyenv, Vartab.empty) of
            NONE => NONE
          | SOME (tyenv', _) => SOME (off (tyenv', u, pos))

        fun match_argT T u =
          let val (U, _) = dest_funT (fastype_of u)
          in try (Sign.typ_match thy (T,U)) end
          handle TYPE _ => K NONE

        fun desc [] ft = do_match ft
          | desc (T :: Ts) (ft as (tyenv , u, pos)) =
            case do_match ft of
              NONE =>
                (case match_argT T u tyenv of
                  NONE => NONE
                | SOME tyenv' => desc Ts (ft_abs ctxt (NONE, T) (tyenv', u, pos)))
            | SOME ft => SOME ft
      in desc eta_expands ft end

    fun move_assms ctxt (ft: focusterm) =
      let
        fun f () = case try (ft_assm ctxt) ft of
            NONE => NONE
          | SOME ft' => SOME (ft', move_assms ctxt (ft_imp ctxt ft))
      in Seq.make f end

    fun apply_pat At = Seq.map (ft_judgment ctxt)
      | apply_pat In = Seq.maps (valid_match_points ctxt)
      | apply_pat Asm = Seq.maps (move_assms ctxt o ft_params ctxt)
      | apply_pat Concl = Seq.map (ft_concl ctxt o ft_params ctxt)
      | apply_pat (For idents) = Seq.map_filter ((ft_for ctxt (map (apfst SOME) idents)))
      | apply_pat (Term x) = Seq.map_filter ( (move_term ctxt x))

    fun apply_pats ft = ft
      |> Seq.single
      |> fold apply_pat pattern_list
  in
    apply_pats
  end

fun instantiate_normalize_env ctxt env thm =
  let
    fun certs f = map (apply2 (f ctxt))
    val prop = Thm.prop_of thm
    val norm_type = Envir.norm_type o Envir.type_env
    val insts = Term.add_vars prop []
      |> map (fn x as (s,T) => (Var (s, norm_type env T), Envir.norm_term env (Var x)))
      |> certs Thm.cterm_of
    val tyinsts = Term.add_tvars prop []
      |> map (fn x => (TVar x, norm_type env (TVar x)))
      |> certs Thm.ctyp_of
  in Drule.instantiate_normalize (tyinsts, insts) thm end

fun unify_with_rhs context to env thm =
  let
    val (_, rhs) = thm |> Thm.concl_of |> Logic.dest_equals
    val env' = Pattern.unify context (Logic.mk_term to, Logic.mk_term rhs) env
      handle Pattern.Unif => raise NO_TO_MATCH
  in env' end

fun inst_thm_to _ (NONE, _) thm = thm
  | inst_thm_to (ctxt : Proof.context) (SOME to, env) thm =
      instantiate_normalize_env ctxt (unify_with_rhs (Context.Proof ctxt) to env thm) thm

fun inst_thm ctxt idents (to, tyenv) thm =
  let
    (* Replace any identifiers with their corresponding bound variables. *)
    val maxidx = Term.maxidx_typs (map (snd o snd) (Vartab.dest tyenv)) 0
    val env = Envir.Envir {maxidx = maxidx, tenv = Vartab.empty, tyenv = tyenv}
    val replace_idents =
      let
        fun subst ((n1, s)::ss) (t as Free (n2, _)) = if n1 = n2 then s else subst ss t
          | subst _ t = t
      in Term.map_aterms (subst idents) end

    val maxidx = Envir.maxidx_of env |> fold Term.maxidx_term (the_list to)
    val thm' = Thm.incr_indexes (maxidx + 1) thm
  in SOME (inst_thm_to ctxt (Option.map replace_idents to, env) thm') end
  handle NO_TO_MATCH => NONE

local

fun rewrite_raw ctxt (pattern, to) thms ct =
  let
    fun interpret_term_patterns ctxt =
      let
    
        fun descend_hole fixes (Abs (_, _, t)) =
            (case descend_hole fixes t of
              NONE => NONE
            | SOME (fix :: fixes', pos) => SOME (fixes', pos o ft_abs ctxt (apfst SOME fix))
            | SOME ([], _) => raise Match (* XXX -- check phases modified binding *))
          | descend_hole fixes (t as l $ r) =
            let val (f, _) = strip_comb t
            in
              if is_hole f
              then SOME (fixes, I)
              else
                (case descend_hole fixes l of
                  SOME (fixes', pos) => SOME (fixes', pos o ft_fun ctxt)
                | NONE =>
                  (case descend_hole fixes r of
                    SOME (fixes', pos) => SOME (fixes', pos o ft_arg ctxt)
                  | NONE => NONE))
            end
          | descend_hole fixes t =
            if is_hole t then SOME (fixes, I) else NONE
    
        fun f (t, fixes) = Term (t, (descend_hole (rev fixes) #> the_default ([], I) #> snd) t)
    
      in map (map_term_pattern f) end

    val pattern' = interpret_term_patterns ctxt pattern
    val matches = find_matches ctxt pattern' (Vartab.empty, Thm.term_of ct, I)

    val thms' = maps (prep_meta_eq ctxt) thms

    fun rewrite_conv insty ctxt bounds =
      CConv.rewrs_cconv (map_filter (inst_thm ctxt bounds insty) thms')

    fun distinct_prems th =
      case Seq.pull (distinct_subgoals_tac th) of
        NONE => th
      | SOME (th', _) => th'

    fun conv ((tyenv, _, position) : focusterm) =
      distinct_prems o position (rewrite_conv (to, tyenv)) ctxt []

  in Seq.map (fn ft => conv ft) matches end

in

fun rewrite ctxt pat thms ct =
  rewrite_raw ctxt pat thms ct |> Seq.map_filter (fn cv => try cv ct)

fun rewrite_export_tac ctxt (pat, pat_ctxt) thms =
  let
    val export = case pat_ctxt of
        NONE => I
      | SOME inner => singleton (Proof_Context.export inner ctxt)
    val tac = CSUBGOAL (fn (ct, i) =>
      rewrite_raw ctxt pat thms ct
      |> Seq.map (fn cv => CCONVERSION (export o cv) i)
      |> SEQ_CONCAT)
  in tac end

fun rewrite_tac ctxt pat = rewrite_export_tac ctxt (pat, NONE)

end

val _ =
  Theory.setup
  let
    fun mk_fix s = (Binding.name s, NONE, NoSyn)

    val raw_pattern : (string, binding * string option * mixfix) pattern list parser =
      let
        val sep = (Args.$$$ "at" >> K At) || (Args.$$$ "in" >> K In)
        val atom =  (Args.$$$ "asm" >> K Asm) ||
          (Args.$$$ "concl" >> K Concl) ||
          (Args.$$$ "for" |-- Args.parens (Scan.optional Parse.fixes []) >> For) ||
          (Parse.term >> Term)
        val sep_atom = sep -- atom >> (fn (s,a) => [s,a])

        fun append_default [] = [Concl, In]
          | append_default (ps as Term _ :: _) = Concl :: In :: ps
          | append_default ps = ps

      in Scan.repeat sep_atom >> (flat #> rev #> append_default) end

    fun context_lift (scan : 'a parser) f = fn (context : Context.generic, toks) =>
      let
        val (r, toks') = scan toks
        val (r', context') = Context.map_proof_result (fn ctxt => f ctxt r) context
      in (r', (context', toks' : Token.T list)) end

    fun read_fixes fixes ctxt =
      let fun read_typ (b, rawT, mx) = (b, Option.map (Syntax.read_typ ctxt) rawT, mx)
      in Proof_Context.add_fixes (map read_typ fixes) ctxt end

    fun prep_pats ctxt (ps : (string, binding * string option * mixfix) pattern list) =
      let
        fun add_constrs ctxt n (Abs (x, T, t)) =
            let
              val (x', ctxt') = yield_singleton Proof_Context.add_fixes (mk_fix x) ctxt
            in
              (case add_constrs ctxt' (n+1) t of
                NONE => NONE
              | SOME ((ctxt'', n', xs), t') =>
                  let
                    val U = Type_Infer.mk_param n []
                    val u = Type.constraint (U --> dummyT) (Abs (x, T, t'))
                  in SOME ((ctxt'', n', (x', U) :: xs), u) end)
            end
          | add_constrs ctxt n (l $ r) =
            (case add_constrs ctxt n l of
              SOME (c, l') => SOME (c, l' $ r)
            | NONE =>
              (case add_constrs ctxt n r of
                SOME (c, r') => SOME (c, l $ r')
              | NONE => NONE))
          | add_constrs ctxt n t =
            if is_hole_const t then SOME ((ctxt, n, []), t) else NONE

        fun prep (Term s) (n, ctxt) =
            let
              val t = Syntax.parse_term ctxt s
              val ((ctxt', n', bs), t') =
                the_default ((ctxt, n, []), t) (add_constrs ctxt (n+1) t)
            in (Term (t', bs), (n', ctxt')) end
          | prep (For ss) (n, ctxt) =
            let val (ns, ctxt') = read_fixes ss ctxt
            in (For ns, (n, ctxt')) end
          | prep At (n,ctxt) = (At, (n, ctxt))
          | prep In (n,ctxt) = (In, (n, ctxt))
          | prep Concl (n,ctxt) = (Concl, (n, ctxt))
          | prep Asm (n,ctxt) = (Asm, (n, ctxt))

        val (xs, (_, ctxt')) = fold_map prep ps (0, ctxt)

      in (xs, ctxt') end

    fun prep_args ctxt (((raw_pats, raw_to), raw_ths)) =
      let

        fun check_terms ctxt ps to =
          let
            fun safe_chop (0: int) xs = ([], xs)
              | safe_chop n (x :: xs) = chop (n - 1) xs |>> cons x
              | safe_chop _ _ = raise Match

            fun reinsert_pat _ (Term (_, cs)) (t :: ts) =
                let val (cs', ts') = safe_chop (length cs) ts
                in (Term (t, map dest_Free cs'), ts') end
              | reinsert_pat _ (Term _) [] = raise Match
              | reinsert_pat ctxt (For ss) ts =
                let val fixes = map (fn s => (s, Variable.default_type ctxt s)) ss
                in (For fixes, ts) end
              | reinsert_pat _ At ts = (At, ts)
              | reinsert_pat _ In ts = (In, ts)
              | reinsert_pat _ Concl ts = (Concl, ts)
              | reinsert_pat _ Asm ts = (Asm, ts)

            fun free_constr (s,T) = Type.constraint T (Free (s, dummyT))
            fun mk_free_constrs (Term (t, cs)) = t :: map free_constr cs
              | mk_free_constrs _ = []

            val ts = maps mk_free_constrs ps @ the_list to
              |> Syntax.check_terms (hole_syntax ctxt)
            val ctxt' = fold Variable.declare_term ts ctxt
            val (ps', (to', ts')) = fold_map (reinsert_pat ctxt') ps ts
              ||> (fn xs => case to of NONE => (NONE, xs) | SOME _ => (SOME (hd xs), tl xs))
            val _ = case ts' of (_ :: _) => raise Match | [] => ()
          in ((ps', to'), ctxt') end

        val (pats, ctxt') = prep_pats ctxt raw_pats

        val ths = Attrib.eval_thms ctxt' raw_ths
        val to = Option.map (Syntax.parse_term ctxt') raw_to

        val ((pats', to'), ctxt'') = check_terms ctxt' pats to

      in ((pats', ths, (to', ctxt)), ctxt'') end

    val to_parser = Scan.option ((Args.$$$ "to") |-- Parse.term)

    val subst_parser =
      let val scan = raw_pattern -- to_parser -- Parse.xthms1
      in context_lift scan prep_args end
  in
    Method.setup @{binding rewrite} (subst_parser >>
      (fn (pattern, inthms, (to, pat_ctxt)) => fn orig_ctxt =>
        SIMPLE_METHOD' (rewrite_export_tac orig_ctxt ((pattern, to), SOME pat_ctxt) inthms)))
      "single-step rewriting, allowing subterm selection via patterns."
  end
end