src/HOL/Library/rewrite.ML
author noschinl
Tue Apr 14 08:42:16 2015 +0200 (2015-04-14)
changeset 60055 aa3d2a6dd99e
parent 60054 ef4878146485
child 60079 ef4fe30e9ef1
permissions -rw-r--r--
rewrite: tuned code, no semantic changes
     1 (*  Title:      HOL/Library/rewrite.ML
     2     Author:     Christoph Traut, Lars Noschinski, TU Muenchen
     3 
     4 This is a rewrite method that supports subterm-selection based on patterns.
     5 
     6 The patterns accepted by rewrite are of the following form:
     7   <atom>    ::= <term> | "concl" | "asm" | "for" "(" <names> ")"
     8   <pattern> ::= (in <atom> | at <atom>) [<pattern>]
     9   <args>    ::= [<pattern>] ("to" <term>) <thms>
    10 
    11 This syntax was clearly inspired by Gonthier's and Tassi's language of
    12 patterns but has diverged significantly during its development.
    13 
    14 We also allow introduction of identifiers for bound variables,
    15 which can then be used to match arbitrary subterms inside abstractions.
    16 *)
    17 
    18 signature REWRITE =
    19 sig
    20   (* FIXME proper ML interface!? *)
    21 end
    22 
    23 structure Rewrite : REWRITE =
    24 struct
    25 
    26 datatype ('a, 'b) pattern = At | In | Term of 'a | Concl | Asm | For of 'b list
    27 
    28 fun map_term_pattern f (Term x) = f x
    29   | map_term_pattern _ (For ss) = (For ss)
    30   | map_term_pattern _ At = At
    31   | map_term_pattern _ In = In
    32   | map_term_pattern _ Concl = Concl
    33   | map_term_pattern _ Asm = Asm
    34 
    35 
    36 exception NO_TO_MATCH
    37 
    38 fun SEQ_CONCAT (tacq : tactic Seq.seq) : tactic = fn st => Seq.maps (fn tac => tac st) tacq
    39 
    40 (* We rewrite subterms using rewrite conversions. These are conversions
    41    that also take a context and a list of identifiers for bound variables
    42    as parameters. *)
    43 type rewrite_conv = Proof.context -> (string * term) list -> conv
    44 
    45 (* To apply such a rewrite conversion to a subterm of our goal, we use
    46    subterm positions, which are just functions that map a rewrite conversion,
    47    working on the top level, to a new rewrite conversion, working on
    48    a specific subterm.
    49 
    50    During substitution, we are traversing the goal to find subterms that
    51    we can rewrite. For each of these subterms, a subterm position is
    52    created and later used in creating a conversion that we use to try and
    53    rewrite this subterm. *)
    54 type subterm_position = rewrite_conv -> rewrite_conv
    55 
    56 (* A focusterm represents a subterm. It is a tuple (t, p), consisting
    57   of the subterm t itself and its subterm position p. *)
    58 type focusterm = Type.tyenv * term * subterm_position
    59 
    60 val dummyN = Name.internal "__dummy"
    61 val holeN = Name.internal "_hole"
    62 
    63 fun prep_meta_eq ctxt =
    64   Simplifier.mksimps ctxt #> map Drule.zero_var_indexes
    65 
    66 
    67 (* rewrite conversions *)
    68 
    69 fun abs_rewr_cconv ident : subterm_position =
    70   let
    71     fun add_ident NONE _ l = l
    72       | add_ident (SOME name) ct l = (name, Thm.term_of ct) :: l
    73     fun inner rewr ctxt idents =
    74       CConv.abs_cconv (fn (ct, ctxt) => rewr ctxt (add_ident ident ct idents)) ctxt
    75   in inner end
    76 
    77 val fun_rewr_cconv : subterm_position = fn rewr => CConv.fun_cconv oo rewr
    78 val arg_rewr_cconv : subterm_position = fn rewr => CConv.arg_cconv oo rewr
    79 val imp_rewr_cconv : subterm_position = fn rewr => CConv.concl_cconv 1 oo rewr
    80 val with_prems_rewr_cconv : subterm_position = fn rewr => CConv.with_prems_cconv ~1 oo rewr
    81 
    82 
    83 (* focus terms *)
    84 
    85 fun ft_abs ctxt (s,T) (tyenv, u, pos) =
    86   case try (fastype_of #> dest_funT) u of
    87     NONE => raise TERM ("ft_abs: no function type", [u])
    88   | SOME (U, _) =>
    89       let
    90         val tyenv' =
    91           if T = dummyT then tyenv
    92           else Sign.typ_match (Proof_Context.theory_of ctxt) (T, U) tyenv
    93         val x = Free (the_default (Name.internal dummyN) s, Envir.norm_type tyenv' T)
    94         val eta_expand_cconv = CConv.rewr_cconv @{thm eta_expand}
    95         fun eta_expand rewr ctxt bounds = eta_expand_cconv then_conv rewr ctxt bounds
    96         val (u', pos') =
    97           case u of
    98             Abs (_,_,t') => (subst_bound (x, t'), pos o abs_rewr_cconv s)
    99           | _ => (u $ x, pos o eta_expand o abs_rewr_cconv s)
   100       in (tyenv', u', pos') end
   101       handle Pattern.MATCH => raise TYPE ("ft_abs: types don't match", [T,U], [u])
   102 
   103 fun ft_fun _ (tyenv, l $ _, pos) = (tyenv, l, pos o fun_rewr_cconv)
   104   | ft_fun ctxt (ft as (_, Abs (_, T, _ $ Bound 0), _)) = (ft_fun ctxt o ft_abs ctxt (NONE, T)) ft
   105   | ft_fun _ (_, t, _) = raise TERM ("ft_fun", [t])
   106 
   107 local
   108 
   109 fun ft_arg_gen cconv _ (tyenv, _ $ r, pos) = (tyenv, r, pos o cconv)
   110   | ft_arg_gen cconv ctxt (ft as (_, Abs (_, T, _ $ Bound 0), _)) = (ft_arg_gen cconv ctxt o ft_abs ctxt (NONE, T)) ft
   111   | ft_arg_gen _ _ (_, t, _) = raise TERM ("ft_arg", [t])
   112 
   113 in
   114 
   115 val ft_arg = ft_arg_gen arg_rewr_cconv
   116 val ft_imp = ft_arg_gen imp_rewr_cconv
   117 
   118 end
   119 
   120 (* Move to B in !!x_1 ... x_n. B. Do not eta-expand *)
   121 fun ft_params ctxt (ft as (_, t, _) : focusterm) =
   122   case t of
   123     Const (@{const_name "Pure.all"}, _) $ Abs (_,T,_) =>
   124       (ft_params ctxt o ft_abs ctxt (NONE, T) o ft_arg ctxt) ft
   125   | Const (@{const_name "Pure.all"}, _) =>
   126       (ft_params ctxt o ft_arg ctxt) ft
   127   | _ => ft
   128 
   129 fun ft_all ctxt ident (ft as (_, Const (@{const_name "Pure.all"}, T) $ _, _) : focusterm) =
   130     let
   131       val def_U = T |> dest_funT |> fst |> dest_funT |> fst
   132       val ident' = apsnd (the_default (def_U)) ident
   133     in (ft_abs ctxt ident' o ft_arg ctxt) ft end
   134   | ft_all _ _ (_, t, _) = raise TERM ("ft_all", [t])
   135 
   136 fun ft_for ctxt idents (ft as (_, t, _) : focusterm) =
   137   let
   138     fun f rev_idents (Const (@{const_name "Pure.all"}, _) $ t) =
   139         let
   140          val (rev_idents', desc) = f rev_idents (case t of Abs (_,_,u) => u | _ => t)
   141         in
   142           case rev_idents' of
   143             [] => ([], desc o ft_all ctxt (NONE, NONE))
   144           | (x :: xs) => (xs , desc o ft_all ctxt x)
   145         end
   146       | f rev_idents _ = (rev_idents, I)
   147   in
   148     case f (rev idents) t of
   149       ([], ft') => SOME (ft' ft)
   150     | _ => NONE
   151   end
   152 
   153 fun ft_concl ctxt (ft as (_, t, _) : focusterm) =
   154   case t of
   155     (Const (@{const_name "Pure.imp"}, _) $ _) $ _ => (ft_concl ctxt o ft_imp ctxt) ft
   156   | _ => ft
   157 
   158 fun ft_assm _ (tyenv, (Const (@{const_name "Pure.imp"}, _) $ l) $ _, pos) =
   159       (tyenv, l, pos o with_prems_rewr_cconv)
   160   | ft_assm _ (_, t, _) = raise TERM ("ft_assm", [t])
   161 
   162 fun ft_judgment ctxt (ft as (_, t, _) : focusterm) =
   163   if Object_Logic.is_judgment ctxt t
   164   then ft_arg ctxt ft
   165   else ft
   166 
   167 (* Find all subterms that might be a valid point to apply a rule. *)
   168 fun valid_match_points ctxt (ft : focusterm) =
   169   let
   170     fun descend (_, _ $ _, _) = [ft_fun ctxt, ft_arg ctxt]
   171       | descend (_, Abs (_, T, _), _) = [ft_abs ctxt (NONE, T)]
   172       | descend _ = []
   173     fun subseq ft =
   174       descend ft |> Seq.of_list |> Seq.maps (fn f => ft |> f |> valid_match_points ctxt)
   175     fun is_valid (l $ _) = is_valid l
   176       | is_valid (Abs (_, _, a)) = is_valid a
   177       | is_valid (Var _) = false
   178       | is_valid (Bound _) = false
   179       | is_valid _ = true
   180   in
   181     Seq.make (fn () => SOME (ft, subseq ft))
   182     |> Seq.filter (#2 #> is_valid)
   183   end
   184 
   185 fun is_hole (Var ((name, _), _)) = (name = holeN)
   186   | is_hole _ = false
   187 
   188 fun is_hole_const (Const (@{const_name rewrite_HOLE}, _)) = true
   189   | is_hole_const _ = false
   190 
   191 val hole_syntax =
   192   let
   193     (* Modified variant of Term.replace_hole *)
   194     fun replace_hole Ts (Const (@{const_name rewrite_HOLE}, T)) i =
   195           (list_comb (Var ((holeN, i), Ts ---> T), map_range Bound (length Ts)), i + 1)
   196       | replace_hole Ts (Abs (x, T, t)) i =
   197           let val (t', i') = replace_hole (T :: Ts) t i
   198           in (Abs (x, T, t'), i') end
   199       | replace_hole Ts (t $ u) i =
   200           let
   201             val (t', i') = replace_hole Ts t i
   202             val (u', i'') = replace_hole Ts u i'
   203           in (t' $ u', i'') end
   204       | replace_hole _ a i = (a, i)
   205     fun prep_holes ts = #1 (fold_map (replace_hole []) ts 1)
   206   in
   207     Context.proof_map (Syntax_Phases.term_check 101 "hole_expansion" (K prep_holes))
   208     #> Proof_Context.set_mode Proof_Context.mode_pattern
   209   end
   210 
   211 (* Find a subterm of the focusterm matching the pattern. *)
   212 fun find_matches ctxt pattern_list =
   213   let
   214     fun move_term ctxt (t, off) (ft : focusterm) =
   215       let
   216         val thy = Proof_Context.theory_of ctxt
   217 
   218         val eta_expands =
   219           let val (_, ts) = strip_comb t
   220           in map fastype_of (snd (take_suffix is_Var ts)) end
   221 
   222         fun do_match (tyenv, u, pos) =
   223           case try (Pattern.match thy (t,u)) (tyenv, Vartab.empty) of
   224             NONE => NONE
   225           | SOME (tyenv', _) => SOME (off (tyenv', u, pos))
   226 
   227         fun match_argT T u =
   228           let val (U, _) = dest_funT (fastype_of u)
   229           in try (Sign.typ_match thy (T,U)) end
   230           handle TYPE _ => K NONE
   231 
   232         fun desc [] ft = do_match ft
   233           | desc (T :: Ts) (ft as (tyenv , u, pos)) =
   234             case do_match ft of
   235               NONE =>
   236                 (case match_argT T u tyenv of
   237                   NONE => NONE
   238                 | SOME tyenv' => desc Ts (ft_abs ctxt (NONE, T) (tyenv', u, pos)))
   239             | SOME ft => SOME ft
   240       in desc eta_expands ft end
   241 
   242     fun move_assms ctxt (ft: focusterm) =
   243       let
   244         fun f () = case try (ft_assm ctxt) ft of
   245             NONE => NONE
   246           | SOME ft' => SOME (ft', move_assms ctxt (ft_imp ctxt ft))
   247       in Seq.make f end
   248 
   249     fun apply_pat At = Seq.map (ft_judgment ctxt)
   250       | apply_pat In = Seq.maps (valid_match_points ctxt)
   251       | apply_pat Asm = Seq.maps (move_assms ctxt o ft_params ctxt)
   252       | apply_pat Concl = Seq.map (ft_concl ctxt o ft_params ctxt)
   253       | apply_pat (For idents) = Seq.map_filter ((ft_for ctxt (map (apfst SOME) idents)))
   254       | apply_pat (Term x) = Seq.map_filter ( (move_term ctxt x))
   255 
   256     fun apply_pats ft = ft
   257       |> Seq.single
   258       |> fold apply_pat pattern_list
   259   in
   260     apply_pats
   261   end
   262 
   263 fun instantiate_normalize_env ctxt env thm =
   264   let
   265     fun certs f = map (apply2 (f ctxt))
   266     val prop = Thm.prop_of thm
   267     val norm_type = Envir.norm_type o Envir.type_env
   268     val insts = Term.add_vars prop []
   269       |> map (fn x as (s,T) => (Var (s, norm_type env T), Envir.norm_term env (Var x)))
   270       |> certs Thm.cterm_of
   271     val tyinsts = Term.add_tvars prop []
   272       |> map (fn x => (TVar x, norm_type env (TVar x)))
   273       |> certs Thm.ctyp_of
   274   in Drule.instantiate_normalize (tyinsts, insts) thm end
   275 
   276 fun unify_with_rhs context to env thm =
   277   let
   278     val (_, rhs) = thm |> Thm.concl_of |> Logic.dest_equals
   279     val env' = Pattern.unify context (Logic.mk_term to, Logic.mk_term rhs) env
   280       handle Pattern.Unif => raise NO_TO_MATCH
   281   in env' end
   282 
   283 fun inst_thm_to _ (NONE, _) thm = thm
   284   | inst_thm_to (ctxt : Proof.context) (SOME to, env) thm =
   285       instantiate_normalize_env ctxt (unify_with_rhs (Context.Proof ctxt) to env thm) thm
   286 
   287 fun inst_thm ctxt idents (to, tyenv) thm =
   288   let
   289     (* Replace any identifiers with their corresponding bound variables. *)
   290     val maxidx = Term.maxidx_typs (map (snd o snd) (Vartab.dest tyenv)) 0
   291     val env = Envir.Envir {maxidx = maxidx, tenv = Vartab.empty, tyenv = tyenv}
   292     val replace_idents =
   293       let
   294         fun subst ((n1, s)::ss) (t as Free (n2, _)) = if n1 = n2 then s else subst ss t
   295           | subst _ t = t
   296       in Term.map_aterms (subst idents) end
   297 
   298     val maxidx = Envir.maxidx_of env |> fold Term.maxidx_term (the_list to)
   299     val thm' = Thm.incr_indexes (maxidx + 1) thm
   300   in SOME (inst_thm_to ctxt (Option.map replace_idents to, env) thm') end
   301   handle NO_TO_MATCH => NONE
   302 
   303 (* Rewrite in subgoal i. *)
   304 fun rewrite_goal_with_thm ctxt (pattern, (to, orig_ctxt)) rules = SUBGOAL (fn (t,i) =>
   305   let
   306     val matches = find_matches ctxt pattern (Vartab.empty, t, I)
   307 
   308     fun rewrite_conv insty ctxt bounds =
   309       CConv.rewrs_cconv (map_filter (inst_thm ctxt bounds insty) rules)
   310 
   311     val export = singleton (Proof_Context.export ctxt orig_ctxt)
   312 
   313     fun distinct_prems th =
   314       case Seq.pull (distinct_subgoals_tac th) of
   315         NONE => th
   316       | SOME (th', _) => th'
   317 
   318     fun tac (tyenv, _, position) = CCONVERSION
   319       (distinct_prems o export o position (rewrite_conv (to, tyenv)) ctxt []) i
   320   in
   321     SEQ_CONCAT (Seq.map tac matches)
   322   end)
   323 
   324 fun rewrite_tac ctxt pattern thms =
   325   let
   326     val thms' = maps (prep_meta_eq ctxt) thms
   327     val tac = rewrite_goal_with_thm ctxt pattern thms'
   328   in tac end
   329 
   330 val _ =
   331   Theory.setup
   332   let
   333     fun mk_fix s = (Binding.name s, NONE, NoSyn)
   334 
   335     val raw_pattern : (string, binding * string option * mixfix) pattern list parser =
   336       let
   337         val sep = (Args.$$$ "at" >> K At) || (Args.$$$ "in" >> K In)
   338         val atom =  (Args.$$$ "asm" >> K Asm) ||
   339           (Args.$$$ "concl" >> K Concl) ||
   340           (Args.$$$ "for" |-- Args.parens (Scan.optional Parse.fixes []) >> For) ||
   341           (Parse.term >> Term)
   342         val sep_atom = sep -- atom >> (fn (s,a) => [s,a])
   343 
   344         fun append_default [] = [Concl, In]
   345           | append_default (ps as Term _ :: _) = Concl :: In :: ps
   346           | append_default ps = ps
   347 
   348       in Scan.repeat sep_atom >> (flat #> rev #> append_default) end
   349 
   350     fun context_lift (scan : 'a parser) f = fn (context : Context.generic, toks) =>
   351       let
   352         val (r, toks') = scan toks
   353         val (r', context') = Context.map_proof_result (fn ctxt => f ctxt r) context
   354       in (r', (context', toks' : Token.T list)) end
   355 
   356     fun read_fixes fixes ctxt =
   357       let fun read_typ (b, rawT, mx) = (b, Option.map (Syntax.read_typ ctxt) rawT, mx)
   358       in Proof_Context.add_fixes (map read_typ fixes) ctxt end
   359 
   360     fun prep_pats ctxt (ps : (string, binding * string option * mixfix) pattern list) =
   361       let
   362         fun add_constrs ctxt n (Abs (x, T, t)) =
   363             let
   364               val (x', ctxt') = yield_singleton Proof_Context.add_fixes (mk_fix x) ctxt
   365             in
   366               (case add_constrs ctxt' (n+1) t of
   367                 NONE => NONE
   368               | SOME ((ctxt'', n', xs), t') =>
   369                   let
   370                     val U = Type_Infer.mk_param n []
   371                     val u = Type.constraint (U --> dummyT) (Abs (x, T, t'))
   372                   in SOME ((ctxt'', n', (x', U) :: xs), u) end)
   373             end
   374           | add_constrs ctxt n (l $ r) =
   375             (case add_constrs ctxt n l of
   376               SOME (c, l') => SOME (c, l' $ r)
   377             | NONE =>
   378               (case add_constrs ctxt n r of
   379                 SOME (c, r') => SOME (c, l $ r')
   380               | NONE => NONE))
   381           | add_constrs ctxt n t =
   382             if is_hole_const t then SOME ((ctxt, n, []), t) else NONE
   383 
   384         fun prep (Term s) (n, ctxt) =
   385             let
   386               val t = Syntax.parse_term ctxt s
   387               val ((ctxt', n', bs), t') =
   388                 the_default ((ctxt, n, []), t) (add_constrs ctxt (n+1) t)
   389             in (Term (t', bs), (n', ctxt')) end
   390           | prep (For ss) (n, ctxt) =
   391             let val (ns, ctxt') = read_fixes ss ctxt
   392             in (For ns, (n, ctxt')) end
   393           | prep At (n,ctxt) = (At, (n, ctxt))
   394           | prep In (n,ctxt) = (In, (n, ctxt))
   395           | prep Concl (n,ctxt) = (Concl, (n, ctxt))
   396           | prep Asm (n,ctxt) = (Asm, (n, ctxt))
   397 
   398         val (xs, (_, ctxt')) = fold_map prep ps (0, ctxt)
   399 
   400       in (xs, ctxt') end
   401 
   402     fun prep_args ctxt (((raw_pats, raw_to), raw_ths)) =
   403       let
   404 
   405         fun interpret_term_patterns ctxt =
   406           let
   407 
   408             fun descend_hole fixes (Abs (_, _, t)) =
   409                 (case descend_hole fixes t of
   410                   NONE => NONE
   411                 | SOME (fix :: fixes', pos) => SOME (fixes', pos o ft_abs ctxt (apfst SOME fix))
   412                 | SOME ([], _) => raise Match (* XXX -- check phases modified binding *))
   413               | descend_hole fixes (t as l $ r) =
   414                 let val (f, _) = strip_comb t
   415                 in
   416                   if is_hole f
   417                   then SOME (fixes, I)
   418                   else
   419                     (case descend_hole fixes l of
   420                       SOME (fixes', pos) => SOME (fixes', pos o ft_fun ctxt)
   421                     | NONE =>
   422                       (case descend_hole fixes r of
   423                         SOME (fixes', pos) => SOME (fixes', pos o ft_arg ctxt)
   424                       | NONE => NONE))
   425                 end
   426               | descend_hole fixes t =
   427                 if is_hole t then SOME (fixes, I) else NONE
   428 
   429             fun f (t, fixes) = Term (t, (descend_hole (rev fixes) #> the_default ([], I) #> snd) t)
   430 
   431           in map (map_term_pattern f) end
   432 
   433         fun check_terms ctxt ps to =
   434           let
   435             fun safe_chop (0: int) xs = ([], xs)
   436               | safe_chop n (x :: xs) = chop (n - 1) xs |>> cons x
   437               | safe_chop _ _ = raise Match
   438 
   439             fun reinsert_pat _ (Term (_, cs)) (t :: ts) =
   440                 let val (cs', ts') = safe_chop (length cs) ts
   441                 in (Term (t, map dest_Free cs'), ts') end
   442               | reinsert_pat _ (Term _) [] = raise Match
   443               | reinsert_pat ctxt (For ss) ts =
   444                 let val fixes = map (fn s => (s, Variable.default_type ctxt s)) ss
   445                 in (For fixes, ts) end
   446               | reinsert_pat _ At ts = (At, ts)
   447               | reinsert_pat _ In ts = (In, ts)
   448               | reinsert_pat _ Concl ts = (Concl, ts)
   449               | reinsert_pat _ Asm ts = (Asm, ts)
   450 
   451             fun free_constr (s,T) = Type.constraint T (Free (s, dummyT))
   452             fun mk_free_constrs (Term (t, cs)) = t :: map free_constr cs
   453               | mk_free_constrs _ = []
   454 
   455             val ts = maps mk_free_constrs ps @ the_list to
   456               |> Syntax.check_terms (hole_syntax ctxt)
   457             val ctxt' = fold Variable.declare_term ts ctxt
   458             val (ps', (to', ts')) = fold_map (reinsert_pat ctxt') ps ts
   459               ||> (fn xs => case to of NONE => (NONE, xs) | SOME _ => (SOME (hd xs), tl xs))
   460             val _ = case ts' of (_ :: _) => raise Match | [] => ()
   461           in ((ps', to'), ctxt') end
   462 
   463         val (pats, ctxt') = prep_pats ctxt raw_pats
   464 
   465         val ths = Attrib.eval_thms ctxt' raw_ths
   466         val to = Option.map (Syntax.parse_term ctxt') raw_to
   467 
   468         val ((pats', to'), ctxt'') = check_terms ctxt' pats to
   469         val pats'' = interpret_term_patterns ctxt'' pats'
   470 
   471       in ((pats'', ths, (to', ctxt)), ctxt'') end
   472 
   473     val to_parser = Scan.option ((Args.$$$ "to") |-- Parse.term)
   474 
   475     val subst_parser =
   476       let val scan = raw_pattern -- to_parser -- Parse.xthms1
   477       in context_lift scan prep_args end
   478   in
   479     Method.setup @{binding rewrite} (subst_parser >>
   480       (fn (pattern, inthms, inst) => fn ctxt =>
   481         SIMPLE_METHOD' (rewrite_tac ctxt (pattern, inst) inthms)))
   482       "single-step rewriting, allowing subterm selection via patterns."
   483   end
   484 end