src/HOL/Library/rewrite.ML
author noschinl
Mon Apr 13 14:52:40 2015 +0200 (2015-04-13)
changeset 60052 616a17640229
parent 60051 2a1cab4c9c9d
child 60053 0e9895ffab1d
permissions -rw-r--r--
rewrite: with asm pattern, try all premises for rewriting, not only the first
wenzelm@59975
     1
(*  Title:      HOL/Library/rewrite.ML
wenzelm@59975
     2
    Author:     Christoph Traut, Lars Noschinski, TU Muenchen
noschinl@59739
     3
wenzelm@59975
     4
This is a rewrite method that supports subterm-selection based on patterns.
noschinl@59739
     5
wenzelm@59975
     6
The patterns accepted by rewrite are of the following form:
wenzelm@59975
     7
  <atom>    ::= <term> | "concl" | "asm" | "for" "(" <names> ")"
wenzelm@59975
     8
  <pattern> ::= (in <atom> | at <atom>) [<pattern>]
wenzelm@59975
     9
  <args>    ::= [<pattern>] ("to" <term>) <thms>
noschinl@59739
    10
wenzelm@59975
    11
This syntax was clearly inspired by Gonthier's and Tassi's language of
wenzelm@59975
    12
patterns but has diverged significantly during its development.
noschinl@59739
    13
wenzelm@59975
    14
We also allow introduction of identifiers for bound variables,
wenzelm@59975
    15
which can then be used to match arbitrary subterms inside abstractions.
noschinl@59739
    16
*)
noschinl@59739
    17
wenzelm@59975
    18
signature REWRITE =
wenzelm@59975
    19
sig
wenzelm@59975
    20
  (* FIXME proper ML interface!? *)
noschinl@59739
    21
end
noschinl@59739
    22
wenzelm@59975
    23
structure Rewrite : REWRITE =
noschinl@59739
    24
struct
noschinl@59739
    25
noschinl@59739
    26
datatype ('a, 'b) pattern = At | In | Term of 'a | Concl | Asm | For of 'b list
noschinl@59739
    27
noschinl@59739
    28
fun map_term_pattern f (Term x) = f x
noschinl@59739
    29
  | map_term_pattern _ (For ss) = (For ss)
noschinl@59739
    30
  | map_term_pattern _ At = At
noschinl@59739
    31
  | map_term_pattern _ In = In
noschinl@59739
    32
  | map_term_pattern _ Concl = Concl
noschinl@59739
    33
  | map_term_pattern _ Asm = Asm
noschinl@59739
    34
noschinl@59739
    35
noschinl@59739
    36
exception NO_TO_MATCH
noschinl@59739
    37
noschinl@59739
    38
fun SEQ_CONCAT (tacq : tactic Seq.seq) : tactic = fn st => Seq.maps (fn tac => tac st) tacq
noschinl@59739
    39
noschinl@59739
    40
(* We rewrite subterms using rewrite conversions. These are conversions
noschinl@59739
    41
   that also take a context and a list of identifiers for bound variables
noschinl@59739
    42
   as parameters. *)
noschinl@59739
    43
type rewrite_conv = Proof.context -> (string * term) list -> conv
noschinl@59739
    44
noschinl@59739
    45
(* To apply such a rewrite conversion to a subterm of our goal, we use
noschinl@59739
    46
   subterm positions, which are just functions that map a rewrite conversion,
noschinl@59739
    47
   working on the top level, to a new rewrite conversion, working on
noschinl@59739
    48
   a specific subterm.
noschinl@59739
    49
noschinl@59739
    50
   During substitution, we are traversing the goal to find subterms that
noschinl@59739
    51
   we can rewrite. For each of these subterms, a subterm position is
noschinl@59739
    52
   created and later used in creating a conversion that we use to try and
noschinl@59739
    53
   rewrite this subterm. *)
noschinl@59739
    54
type subterm_position = rewrite_conv -> rewrite_conv
noschinl@59739
    55
noschinl@59739
    56
(* A focusterm represents a subterm. It is a tuple (t, p), consisting
noschinl@59739
    57
  of the subterm t itself and its subterm position p. *)
noschinl@59739
    58
type focusterm = Type.tyenv * term * subterm_position
noschinl@59739
    59
noschinl@59739
    60
val dummyN = Name.internal "__dummy"
noschinl@59739
    61
val holeN = Name.internal "_hole"
noschinl@59739
    62
noschinl@59739
    63
fun prep_meta_eq ctxt =
noschinl@59739
    64
  Simplifier.mksimps ctxt #> map Drule.zero_var_indexes
noschinl@59739
    65
noschinl@59739
    66
noschinl@59739
    67
(* rewrite conversions *)
noschinl@59739
    68
noschinl@59739
    69
fun abs_rewr_cconv ident : subterm_position =
noschinl@59739
    70
  let
noschinl@59739
    71
    fun add_ident NONE _ l = l
noschinl@59739
    72
      | add_ident (SOME name) ct l = (name, Thm.term_of ct) :: l
wenzelm@59975
    73
    fun inner rewr ctxt idents =
wenzelm@59975
    74
      CConv.abs_cconv (fn (ct, ctxt) => rewr ctxt (add_ident ident ct idents)) ctxt
noschinl@59739
    75
  in inner end
wenzelm@59975
    76
noschinl@59739
    77
val fun_rewr_cconv : subterm_position = fn rewr => CConv.fun_cconv oo rewr
noschinl@59739
    78
val arg_rewr_cconv : subterm_position = fn rewr => CConv.arg_cconv oo rewr
noschinl@60050
    79
val imp_rewr_cconv : subterm_position = fn rewr => CConv.concl_cconv 1 oo rewr
noschinl@59739
    80
noschinl@59739
    81
noschinl@59739
    82
(* focus terms *)
noschinl@59739
    83
noschinl@59739
    84
fun ft_abs ctxt (s,T) (tyenv, u, pos) =
noschinl@59739
    85
  case try (fastype_of #> dest_funT) u of
noschinl@59739
    86
    NONE => raise TERM ("ft_abs: no function type", [u])
noschinl@59739
    87
  | SOME (U, _) =>
wenzelm@59975
    88
      let
wenzelm@59975
    89
        val tyenv' =
wenzelm@59975
    90
          if T = dummyT then tyenv
wenzelm@59975
    91
          else Sign.typ_match (Proof_Context.theory_of ctxt) (T, U) tyenv
wenzelm@59975
    92
        val x = Free (the_default (Name.internal dummyN) s, Envir.norm_type tyenv' T)
wenzelm@59975
    93
        val eta_expand_cconv = CConv.rewr_cconv @{thm eta_expand}
wenzelm@59975
    94
        fun eta_expand rewr ctxt bounds = eta_expand_cconv then_conv rewr ctxt bounds
wenzelm@59975
    95
        val (u', pos') =
wenzelm@59975
    96
          case u of
wenzelm@59975
    97
            Abs (_,_,t') => (subst_bound (x, t'), pos o abs_rewr_cconv s)
wenzelm@59975
    98
          | _ => (u $ x, pos o eta_expand o abs_rewr_cconv s)
wenzelm@59975
    99
      in (tyenv', u', pos') end
wenzelm@59975
   100
      handle Pattern.MATCH => raise TYPE ("ft_abs: types don't match", [T,U], [u])
noschinl@59739
   101
noschinl@59739
   102
fun ft_fun _ (tyenv, l $ _, pos) = (tyenv, l, pos o fun_rewr_cconv)
noschinl@59739
   103
  | ft_fun ctxt (ft as (_, Abs (_, T, _ $ Bound 0), _)) = (ft_fun ctxt o ft_abs ctxt (NONE, T)) ft
noschinl@59739
   104
  | ft_fun _ (_, t, _) = raise TERM ("ft_fun", [t])
noschinl@59739
   105
noschinl@60050
   106
local
noschinl@60050
   107
noschinl@60050
   108
fun ft_arg_gen cconv _ (tyenv, _ $ r, pos) = (tyenv, r, pos o cconv)
noschinl@60050
   109
  | ft_arg_gen cconv ctxt (ft as (_, Abs (_, T, _ $ Bound 0), _)) = (ft_arg_gen cconv ctxt o ft_abs ctxt (NONE, T)) ft
noschinl@60050
   110
  | ft_arg_gen _ _ (_, t, _) = raise TERM ("ft_arg", [t])
noschinl@60050
   111
noschinl@60050
   112
in
noschinl@60050
   113
noschinl@60050
   114
val ft_arg = ft_arg_gen arg_rewr_cconv
noschinl@60050
   115
val ft_imp = ft_arg_gen imp_rewr_cconv
noschinl@60050
   116
noschinl@60050
   117
end
noschinl@59739
   118
noschinl@59739
   119
(* Move to B in !!x_1 ... x_n. B. Do not eta-expand *)
noschinl@59739
   120
fun ft_params ctxt (ft as (_, t, _) : focusterm) =
noschinl@59739
   121
  case t of
noschinl@59739
   122
    Const (@{const_name "Pure.all"}, _) $ Abs (_,T,_) =>
noschinl@59739
   123
      (ft_params ctxt o ft_abs ctxt (NONE, T) o ft_arg ctxt) ft
noschinl@59739
   124
  | Const (@{const_name "Pure.all"}, _) =>
noschinl@59739
   125
      (ft_params ctxt o ft_arg ctxt) ft
noschinl@59739
   126
  | _ => ft
noschinl@59739
   127
noschinl@59739
   128
fun ft_all ctxt ident (ft as (_, Const (@{const_name "Pure.all"}, T) $ _, _) : focusterm) =
noschinl@59739
   129
    let
noschinl@59739
   130
      val def_U = T |> dest_funT |> fst |> dest_funT |> fst
noschinl@59739
   131
      val ident' = apsnd (the_default (def_U)) ident
noschinl@59739
   132
    in (ft_abs ctxt ident' o ft_arg ctxt) ft end
noschinl@59739
   133
  | ft_all _ _ (_, t, _) = raise TERM ("ft_all", [t])
noschinl@59739
   134
noschinl@59739
   135
fun ft_for ctxt idents (ft as (_, t, _) : focusterm) =
noschinl@59739
   136
  let
noschinl@59739
   137
    fun f rev_idents (Const (@{const_name "Pure.all"}, _) $ t) =
noschinl@59739
   138
        let
noschinl@59739
   139
         val (rev_idents', desc) = f rev_idents (case t of Abs (_,_,u) => u | _ => t)
noschinl@59739
   140
        in
noschinl@59739
   141
          case rev_idents' of
noschinl@59739
   142
            [] => ([], desc o ft_all ctxt (NONE, NONE))
noschinl@59739
   143
          | (x :: xs) => (xs , desc o ft_all ctxt x)
noschinl@59739
   144
        end
noschinl@59739
   145
      | f rev_idents _ = (rev_idents, I)
wenzelm@59975
   146
  in
wenzelm@59975
   147
    case f (rev idents) t of
noschinl@59739
   148
      ([], ft') => SOME (ft' ft)
noschinl@59739
   149
    | _ => NONE
noschinl@59739
   150
  end
noschinl@59739
   151
noschinl@59739
   152
fun ft_concl ctxt (ft as (_, t, _) : focusterm) =
noschinl@59739
   153
  case t of
noschinl@60050
   154
    (Const (@{const_name "Pure.imp"}, _) $ _) $ _ => (ft_concl ctxt o ft_imp ctxt) ft
noschinl@59739
   155
  | _ => ft
noschinl@59739
   156
noschinl@59739
   157
fun ft_assm ctxt (ft as (_, t, _) : focusterm) =
noschinl@59739
   158
  case t of
noschinl@59739
   159
    (Const (@{const_name "Pure.imp"}, _) $ _) $ _ => (ft_concl ctxt o ft_arg ctxt o ft_fun ctxt) ft
noschinl@59739
   160
  | _ => raise TERM ("ft_assm", [t])
noschinl@59739
   161
noschinl@59739
   162
fun ft_judgment ctxt (ft as (_, t, _) : focusterm) =
wenzelm@59970
   163
  if Object_Logic.is_judgment ctxt t
noschinl@59739
   164
  then ft_arg ctxt ft
noschinl@59739
   165
  else ft
noschinl@59739
   166
noschinl@59739
   167
noschinl@59739
   168
(* Return a lazy sequenze of all subterms of the focusterm for which
noschinl@59739
   169
   the condition holds. *)
noschinl@59739
   170
fun find_subterms ctxt condition (ft as (_, t, _) : focusterm) =
noschinl@59739
   171
  let
noschinl@59739
   172
    val recurse = find_subterms ctxt condition
wenzelm@59975
   173
    val recursive_matches =
wenzelm@59975
   174
      case t of
noschinl@59739
   175
        _ $ _ => Seq.append (ft |> ft_fun ctxt |> recurse) (ft |> ft_arg ctxt |> recurse)
noschinl@59739
   176
      | Abs (_,T,_) => ft |> ft_abs ctxt (NONE, T) |> recurse
noschinl@59739
   177
      | _ => Seq.empty
noschinl@59739
   178
  in
noschinl@59739
   179
    (* If the condition is met, then the current focusterm is part of the
noschinl@59739
   180
       sequence of results. Otherwise, only the results of the recursive
noschinl@59739
   181
       application are. *)
noschinl@59739
   182
    if condition ft
noschinl@59739
   183
    then Seq.cons ft recursive_matches
noschinl@59739
   184
    else recursive_matches
noschinl@59739
   185
  end
noschinl@59739
   186
noschinl@59739
   187
(* Find all subterms that might be a valid point to apply a rule. *)
noschinl@59739
   188
fun valid_match_points ctxt =
noschinl@59739
   189
  let
noschinl@59739
   190
    fun is_valid (l $ _) = is_valid l
noschinl@59739
   191
      | is_valid (Abs (_, _, a)) = is_valid a
noschinl@59739
   192
      | is_valid (Var _) = false
noschinl@59739
   193
      | is_valid (Bound _) = false
noschinl@59739
   194
      | is_valid _ = true
noschinl@59739
   195
  in
noschinl@59739
   196
    find_subterms ctxt (#2 #> is_valid )
noschinl@59739
   197
  end
noschinl@59739
   198
noschinl@59739
   199
fun is_hole (Var ((name, _), _)) = (name = holeN)
noschinl@59739
   200
  | is_hole _ = false
noschinl@59739
   201
noschinl@59739
   202
fun is_hole_const (Const (@{const_name rewrite_HOLE}, _)) = true
noschinl@59739
   203
  | is_hole_const _ = false
noschinl@59739
   204
noschinl@59739
   205
val hole_syntax =
noschinl@59739
   206
  let
noschinl@59739
   207
    (* Modified variant of Term.replace_hole *)
noschinl@59739
   208
    fun replace_hole Ts (Const (@{const_name rewrite_HOLE}, T)) i =
noschinl@59739
   209
          (list_comb (Var ((holeN, i), Ts ---> T), map_range Bound (length Ts)), i + 1)
noschinl@59739
   210
      | replace_hole Ts (Abs (x, T, t)) i =
noschinl@59739
   211
          let val (t', i') = replace_hole (T :: Ts) t i
noschinl@59739
   212
          in (Abs (x, T, t'), i') end
noschinl@59739
   213
      | replace_hole Ts (t $ u) i =
noschinl@59739
   214
          let
noschinl@59739
   215
            val (t', i') = replace_hole Ts t i
noschinl@59739
   216
            val (u', i'') = replace_hole Ts u i'
noschinl@59739
   217
          in (t' $ u', i'') end
noschinl@59739
   218
      | replace_hole _ a i = (a, i)
noschinl@59739
   219
    fun prep_holes ts = #1 (fold_map (replace_hole []) ts 1)
noschinl@59739
   220
  in
noschinl@59739
   221
    Context.proof_map (Syntax_Phases.term_check 101 "hole_expansion" (K prep_holes))
noschinl@59739
   222
    #> Proof_Context.set_mode Proof_Context.mode_pattern
noschinl@59739
   223
  end
noschinl@59739
   224
noschinl@59739
   225
(* Find a subterm of the focusterm matching the pattern. *)
noschinl@59739
   226
fun find_matches ctxt pattern_list =
noschinl@59739
   227
  let
noschinl@59739
   228
    fun move_term ctxt (t, off) (ft : focusterm) =
noschinl@59739
   229
      let
noschinl@59739
   230
        val thy = Proof_Context.theory_of ctxt
noschinl@59739
   231
noschinl@59739
   232
        val eta_expands =
noschinl@59739
   233
          let val (_, ts) = strip_comb t
noschinl@59739
   234
          in map fastype_of (snd (take_suffix is_Var ts)) end
noschinl@59739
   235
noschinl@59739
   236
        fun do_match (tyenv, u, pos) =
noschinl@59739
   237
          case try (Pattern.match thy (t,u)) (tyenv, Vartab.empty) of
noschinl@59739
   238
            NONE => NONE
noschinl@59739
   239
          | SOME (tyenv', _) => SOME (off (tyenv', u, pos))
noschinl@59739
   240
noschinl@59739
   241
        fun match_argT T u =
noschinl@59739
   242
          let val (U, _) = dest_funT (fastype_of u)
noschinl@59739
   243
          in try (Sign.typ_match thy (T,U)) end
noschinl@59739
   244
          handle TYPE _ => K NONE
noschinl@59739
   245
noschinl@59739
   246
        fun desc [] ft = do_match ft
noschinl@59739
   247
          | desc (T :: Ts) (ft as (tyenv , u, pos)) =
noschinl@59739
   248
            case do_match ft of
noschinl@59739
   249
              NONE =>
noschinl@59739
   250
                (case match_argT T u tyenv of
noschinl@59739
   251
                  NONE => NONE
noschinl@59739
   252
                | SOME tyenv' => desc Ts (ft_abs ctxt (NONE, T) (tyenv', u, pos)))
noschinl@59739
   253
            | SOME ft => SOME ft
noschinl@59739
   254
      in desc eta_expands ft end
noschinl@59739
   255
noschinl@60052
   256
    fun move_assms ctxt (ft: focusterm) =
noschinl@60052
   257
      let
noschinl@60052
   258
        fun f () = case try (ft_assm ctxt) ft of
noschinl@60052
   259
            NONE => NONE
noschinl@60052
   260
          | SOME ft' => SOME (ft', move_assms ctxt (ft_imp ctxt ft))
noschinl@60052
   261
      in Seq.make f end
noschinl@59739
   262
noschinl@59739
   263
    fun apply_pat At = Seq.map (ft_judgment ctxt)
noschinl@59739
   264
      | apply_pat In = Seq.maps (valid_match_points ctxt)
noschinl@60052
   265
      | apply_pat Asm = Seq.maps (move_assms ctxt o ft_params ctxt)
noschinl@59739
   266
      | apply_pat Concl = Seq.map (ft_concl ctxt o ft_params ctxt)
noschinl@59739
   267
      | apply_pat (For idents) = Seq.map_filter ((ft_for ctxt (map (apfst SOME) idents)))
noschinl@59739
   268
      | apply_pat (Term x) = Seq.map_filter ( (move_term ctxt x))
noschinl@59739
   269
noschinl@59739
   270
    fun apply_pats ft = ft
noschinl@59739
   271
      |> Seq.single
noschinl@59739
   272
      |> fold apply_pat pattern_list
noschinl@59739
   273
  in
noschinl@59739
   274
    apply_pats
noschinl@59739
   275
  end
noschinl@59739
   276
noschinl@59739
   277
fun instantiate_normalize_env ctxt env thm =
noschinl@59739
   278
  let
noschinl@59739
   279
    fun certs f = map (apply2 (f ctxt))
noschinl@59739
   280
    val prop = Thm.prop_of thm
noschinl@59739
   281
    val norm_type = Envir.norm_type o Envir.type_env
noschinl@59739
   282
    val insts = Term.add_vars prop []
noschinl@59739
   283
      |> map (fn x as (s,T) => (Var (s, norm_type env T), Envir.norm_term env (Var x)))
noschinl@59739
   284
      |> certs Thm.cterm_of
noschinl@59739
   285
    val tyinsts = Term.add_tvars prop []
noschinl@59739
   286
      |> map (fn x => (TVar x, norm_type env (TVar x)))
noschinl@59739
   287
      |> certs Thm.ctyp_of
noschinl@59739
   288
  in Drule.instantiate_normalize (tyinsts, insts) thm end
noschinl@59739
   289
noschinl@59739
   290
fun unify_with_rhs context to env thm =
noschinl@59739
   291
  let
noschinl@59739
   292
    val (_, rhs) = thm |> Thm.concl_of |> Logic.dest_equals
noschinl@59739
   293
    val env' = Pattern.unify context (Logic.mk_term to, Logic.mk_term rhs) env
noschinl@59739
   294
      handle Pattern.Unif => raise NO_TO_MATCH
noschinl@59739
   295
  in env' end
noschinl@59739
   296
noschinl@59739
   297
fun inst_thm_to _ (NONE, _) thm = thm
noschinl@59739
   298
  | inst_thm_to (ctxt : Proof.context) (SOME to, env) thm =
noschinl@59739
   299
      instantiate_normalize_env ctxt (unify_with_rhs (Context.Proof ctxt) to env thm) thm
noschinl@59739
   300
noschinl@59739
   301
fun inst_thm ctxt idents (to, tyenv) thm =
noschinl@59739
   302
  let
noschinl@59739
   303
    (* Replace any identifiers with their corresponding bound variables. *)
noschinl@59739
   304
    val maxidx = Term.maxidx_typs (map (snd o snd) (Vartab.dest tyenv)) 0
noschinl@59739
   305
    val env = Envir.Envir {maxidx = maxidx, tenv = Vartab.empty, tyenv = tyenv}
noschinl@59739
   306
    val replace_idents =
noschinl@59739
   307
      let
noschinl@59739
   308
        fun subst ((n1, s)::ss) (t as Free (n2, _)) = if n1 = n2 then s else subst ss t
noschinl@59739
   309
          | subst _ t = t
noschinl@59739
   310
      in Term.map_aterms (subst idents) end
noschinl@59739
   311
noschinl@60051
   312
    val maxidx = Envir.maxidx_of env |> fold Term.maxidx_term (the_list to)
noschinl@59739
   313
    val thm' = Thm.incr_indexes (maxidx + 1) thm
noschinl@59739
   314
  in SOME (inst_thm_to ctxt (Option.map replace_idents to, env) thm') end
noschinl@59739
   315
  handle NO_TO_MATCH => NONE
noschinl@59739
   316
noschinl@59739
   317
(* Rewrite in subgoal i. *)
noschinl@59739
   318
fun rewrite_goal_with_thm ctxt (pattern, (to, orig_ctxt)) rules = SUBGOAL (fn (t,i) =>
noschinl@59739
   319
  let
noschinl@59739
   320
    val matches = find_matches ctxt pattern (Vartab.empty, t, I)
noschinl@59739
   321
noschinl@59739
   322
    fun rewrite_conv insty ctxt bounds =
noschinl@59739
   323
      CConv.rewrs_cconv (map_filter (inst_thm ctxt bounds insty) rules)
noschinl@59739
   324
noschinl@59739
   325
    val export = singleton (Proof_Context.export ctxt orig_ctxt)
noschinl@59739
   326
noschinl@59739
   327
    fun distinct_prems th =
noschinl@59739
   328
      case Seq.pull (distinct_subgoals_tac th) of
noschinl@59739
   329
        NONE => th
noschinl@59739
   330
      | SOME (th', _) => th'
noschinl@59739
   331
noschinl@59739
   332
    fun tac (tyenv, _, position) = CCONVERSION
noschinl@59739
   333
      (distinct_prems o export o position (rewrite_conv (to, tyenv)) ctxt []) i
noschinl@59739
   334
  in
noschinl@59739
   335
    SEQ_CONCAT (Seq.map tac matches)
noschinl@59739
   336
  end)
noschinl@59739
   337
noschinl@59739
   338
fun rewrite_tac ctxt pattern thms =
noschinl@59739
   339
  let
noschinl@59739
   340
    val thms' = maps (prep_meta_eq ctxt) thms
noschinl@59739
   341
    val tac = rewrite_goal_with_thm ctxt pattern thms'
noschinl@59739
   342
  in tac end
noschinl@59739
   343
wenzelm@59975
   344
val _ =
wenzelm@59975
   345
  Theory.setup
noschinl@59739
   346
  let
noschinl@59739
   347
    fun mk_fix s = (Binding.name s, NONE, NoSyn)
noschinl@59739
   348
noschinl@59739
   349
    val raw_pattern : (string, binding * string option * mixfix) pattern list parser =
noschinl@59739
   350
      let
noschinl@59739
   351
        val sep = (Args.$$$ "at" >> K At) || (Args.$$$ "in" >> K In)
noschinl@59739
   352
        val atom =  (Args.$$$ "asm" >> K Asm) ||
noschinl@59739
   353
          (Args.$$$ "concl" >> K Concl) ||
noschinl@59739
   354
          (Args.$$$ "for" |-- Args.parens (Scan.optional Parse.fixes []) >> For) ||
noschinl@59739
   355
          (Parse.term >> Term)
noschinl@59739
   356
        val sep_atom = sep -- atom >> (fn (s,a) => [s,a])
noschinl@59739
   357
noschinl@59739
   358
        fun append_default [] = [Concl, In]
noschinl@59739
   359
          | append_default (ps as Term _ :: _) = Concl :: In :: ps
noschinl@59739
   360
          | append_default ps = ps
noschinl@59739
   361
noschinl@59739
   362
      in Scan.repeat sep_atom >> (flat #> rev #> append_default) end
noschinl@59739
   363
wenzelm@59975
   364
    fun context_lift (scan : 'a parser) f = fn (context : Context.generic, toks) =>
noschinl@59739
   365
      let
noschinl@59739
   366
        val (r, toks') = scan toks
wenzelm@59975
   367
        val (r', context') = Context.map_proof_result (fn ctxt => f ctxt r) context
wenzelm@59975
   368
      in (r', (context', toks' : Token.T list)) end
noschinl@59739
   369
noschinl@59739
   370
    fun read_fixes fixes ctxt =
noschinl@59739
   371
      let fun read_typ (b, rawT, mx) = (b, Option.map (Syntax.read_typ ctxt) rawT, mx)
noschinl@59739
   372
      in Proof_Context.add_fixes (map read_typ fixes) ctxt end
noschinl@59739
   373
noschinl@59739
   374
    fun prep_pats ctxt (ps : (string, binding * string option * mixfix) pattern list) =
noschinl@59739
   375
      let
noschinl@59739
   376
        fun add_constrs ctxt n (Abs (x, T, t)) =
noschinl@59739
   377
            let
noschinl@59739
   378
              val (x', ctxt') = yield_singleton Proof_Context.add_fixes (mk_fix x) ctxt
noschinl@59739
   379
            in
noschinl@59739
   380
              (case add_constrs ctxt' (n+1) t of
noschinl@59739
   381
                NONE => NONE
noschinl@59739
   382
              | SOME ((ctxt'', n', xs), t') =>
noschinl@59739
   383
                  let
noschinl@59739
   384
                    val U = Type_Infer.mk_param n []
noschinl@59739
   385
                    val u = Type.constraint (U --> dummyT) (Abs (x, T, t'))
noschinl@59739
   386
                  in SOME ((ctxt'', n', (x', U) :: xs), u) end)
noschinl@59739
   387
            end
noschinl@59739
   388
          | add_constrs ctxt n (l $ r) =
noschinl@59739
   389
            (case add_constrs ctxt n l of
noschinl@59739
   390
              SOME (c, l') => SOME (c, l' $ r)
noschinl@59739
   391
            | NONE =>
noschinl@59739
   392
              (case add_constrs ctxt n r of
noschinl@59739
   393
                SOME (c, r') => SOME (c, l $ r')
noschinl@59739
   394
              | NONE => NONE))
noschinl@59739
   395
          | add_constrs ctxt n t =
noschinl@59739
   396
            if is_hole_const t then SOME ((ctxt, n, []), t) else NONE
noschinl@59739
   397
noschinl@59739
   398
        fun prep (Term s) (n, ctxt) =
noschinl@59739
   399
            let
noschinl@59739
   400
              val t = Syntax.parse_term ctxt s
noschinl@59739
   401
              val ((ctxt', n', bs), t') =
noschinl@59739
   402
                the_default ((ctxt, n, []), t) (add_constrs ctxt (n+1) t)
noschinl@59739
   403
            in (Term (t', bs), (n', ctxt')) end
noschinl@59739
   404
          | prep (For ss) (n, ctxt) =
noschinl@59739
   405
            let val (ns, ctxt') = read_fixes ss ctxt
noschinl@59739
   406
            in (For ns, (n, ctxt')) end
noschinl@59739
   407
          | prep At (n,ctxt) = (At, (n, ctxt))
noschinl@59739
   408
          | prep In (n,ctxt) = (In, (n, ctxt))
noschinl@59739
   409
          | prep Concl (n,ctxt) = (Concl, (n, ctxt))
noschinl@59739
   410
          | prep Asm (n,ctxt) = (Asm, (n, ctxt))
noschinl@59739
   411
noschinl@59739
   412
        val (xs, (_, ctxt')) = fold_map prep ps (0, ctxt)
noschinl@59739
   413
noschinl@59739
   414
      in (xs, ctxt') end
noschinl@59739
   415
noschinl@59739
   416
    fun prep_args ctxt (((raw_pats, raw_to), raw_ths)) =
noschinl@59739
   417
      let
noschinl@59739
   418
noschinl@59739
   419
        fun interpret_term_patterns ctxt =
noschinl@59739
   420
          let
noschinl@59739
   421
noschinl@59739
   422
            fun descend_hole fixes (Abs (_, _, t)) =
noschinl@59739
   423
                (case descend_hole fixes t of
noschinl@59739
   424
                  NONE => NONE
noschinl@59739
   425
                | SOME (fix :: fixes', pos) => SOME (fixes', pos o ft_abs ctxt (apfst SOME fix))
noschinl@59739
   426
                | SOME ([], _) => raise Match (* XXX -- check phases modified binding *))
noschinl@59739
   427
              | descend_hole fixes (t as l $ r) =
noschinl@59739
   428
                let val (f, _) = strip_comb t
noschinl@59739
   429
                in
noschinl@59739
   430
                  if is_hole f
noschinl@59739
   431
                  then SOME (fixes, I)
noschinl@59739
   432
                  else
noschinl@59739
   433
                    (case descend_hole fixes l of
noschinl@59739
   434
                      SOME (fixes', pos) => SOME (fixes', pos o ft_fun ctxt)
noschinl@59739
   435
                    | NONE =>
noschinl@59739
   436
                      (case descend_hole fixes r of
noschinl@59739
   437
                        SOME (fixes', pos) => SOME (fixes', pos o ft_arg ctxt)
noschinl@59739
   438
                      | NONE => NONE))
noschinl@59739
   439
                end
noschinl@59739
   440
              | descend_hole fixes t =
noschinl@59739
   441
                if is_hole t then SOME (fixes, I) else NONE
noschinl@59739
   442
noschinl@59739
   443
            fun f (t, fixes) = Term (t, (descend_hole (rev fixes) #> the_default ([], I) #> snd) t)
noschinl@59739
   444
noschinl@59739
   445
          in map (map_term_pattern f) end
noschinl@59739
   446
noschinl@59739
   447
        fun check_terms ctxt ps to =
noschinl@59739
   448
          let
noschinl@59739
   449
            fun safe_chop (0: int) xs = ([], xs)
noschinl@59739
   450
              | safe_chop n (x :: xs) = chop (n - 1) xs |>> cons x
noschinl@59739
   451
              | safe_chop _ _ = raise Match
noschinl@59739
   452
noschinl@59739
   453
            fun reinsert_pat _ (Term (_, cs)) (t :: ts) =
noschinl@59739
   454
                let val (cs', ts') = safe_chop (length cs) ts
noschinl@59739
   455
                in (Term (t, map dest_Free cs'), ts') end
noschinl@59739
   456
              | reinsert_pat _ (Term _) [] = raise Match
noschinl@59739
   457
              | reinsert_pat ctxt (For ss) ts =
noschinl@59739
   458
                let val fixes = map (fn s => (s, Variable.default_type ctxt s)) ss
noschinl@59739
   459
                in (For fixes, ts) end
noschinl@59739
   460
              | reinsert_pat _ At ts = (At, ts)
noschinl@59739
   461
              | reinsert_pat _ In ts = (In, ts)
noschinl@59739
   462
              | reinsert_pat _ Concl ts = (Concl, ts)
noschinl@59739
   463
              | reinsert_pat _ Asm ts = (Asm, ts)
noschinl@59739
   464
noschinl@59739
   465
            fun free_constr (s,T) = Type.constraint T (Free (s, dummyT))
noschinl@59739
   466
            fun mk_free_constrs (Term (t, cs)) = t :: map free_constr cs
noschinl@59739
   467
              | mk_free_constrs _ = []
noschinl@59739
   468
noschinl@60051
   469
            val ts = maps mk_free_constrs ps @ the_list to
noschinl@59739
   470
              |> Syntax.check_terms (hole_syntax ctxt)
noschinl@59739
   471
            val ctxt' = fold Variable.declare_term ts ctxt
noschinl@59739
   472
            val (ps', (to', ts')) = fold_map (reinsert_pat ctxt') ps ts
noschinl@59739
   473
              ||> (fn xs => case to of NONE => (NONE, xs) | SOME _ => (SOME (hd xs), tl xs))
noschinl@59739
   474
            val _ = case ts' of (_ :: _) => raise Match | [] => ()
noschinl@59739
   475
          in ((ps', to'), ctxt') end
noschinl@59739
   476
noschinl@59739
   477
        val (pats, ctxt') = prep_pats ctxt raw_pats
noschinl@59739
   478
noschinl@59739
   479
        val ths = Attrib.eval_thms ctxt' raw_ths
noschinl@59739
   480
        val to = Option.map (Syntax.parse_term ctxt') raw_to
noschinl@59739
   481
noschinl@59739
   482
        val ((pats', to'), ctxt'') = check_terms ctxt' pats to
noschinl@59739
   483
        val pats'' = interpret_term_patterns ctxt'' pats'
noschinl@59739
   484
noschinl@59739
   485
      in ((pats'', ths, (to', ctxt)), ctxt'') end
noschinl@59739
   486
noschinl@59739
   487
    val to_parser = Scan.option ((Args.$$$ "to") |-- Parse.term)
noschinl@59739
   488
noschinl@59739
   489
    val subst_parser =
noschinl@59739
   490
      let val scan = raw_pattern -- to_parser -- Parse.xthms1
wenzelm@59975
   491
      in context_lift scan prep_args end
noschinl@59739
   492
  in
noschinl@59739
   493
    Method.setup @{binding rewrite} (subst_parser >>
wenzelm@59975
   494
      (fn (pattern, inthms, inst) => fn ctxt =>
wenzelm@59975
   495
        SIMPLE_METHOD' (rewrite_tac ctxt (pattern, inst) inthms)))
noschinl@59739
   496
      "single-step rewriting, allowing subterm selection via patterns."
noschinl@59739
   497
  end
noschinl@59739
   498
end