src/HOL/Library/rewrite.ML
author wenzelm
Wed Apr 08 19:39:08 2015 +0200 (2015-04-08)
changeset 59970 e9f73d87d904
parent 59739 4ed50ebf5d36
child 59975 da10875adf8e
permissions -rw-r--r--
proper context for Object_Logic operations;
     1 (* Author: Christoph Traut, Lars Noschinski
     2 
     3   This is a rewrite method supports subterm-selection based on patterns.
     4 
     5   The patterns accepted by rewrite are of the following form:
     6     <atom>    ::= <term> | "concl" | "asm" | "for" "(" <names> ")"
     7     <pattern> ::= (in <atom> | at <atom>) [<pattern>]
     8     <args>    ::= [<pattern>] ("to" <term>) <thms>
     9 
    10   This syntax was clearly inspired by Gonthier's and Tassi's language of
    11   patterns but has diverged significantly during its development.
    12 
    13   We also allow introduction of identifiers for bound variables,
    14   which can then be used to match arbitary subterms inside abstractions.
    15 *)
    16 
    17 signature REWRITE1 = sig
    18   val setup : theory -> theory
    19 end
    20 
    21 structure Rewrite : REWRITE1 =
    22 struct
    23 
    24 datatype ('a, 'b) pattern = At | In | Term of 'a | Concl | Asm | For of 'b list
    25 
    26 fun map_term_pattern f (Term x) = f x
    27   | map_term_pattern _ (For ss) = (For ss)
    28   | map_term_pattern _ At = At
    29   | map_term_pattern _ In = In
    30   | map_term_pattern _ Concl = Concl
    31   | map_term_pattern _ Asm = Asm
    32 
    33 
    34 exception NO_TO_MATCH
    35 
    36 fun SEQ_CONCAT (tacq : tactic Seq.seq) : tactic = fn st => Seq.maps (fn tac => tac st) tacq
    37 
    38 (* We rewrite subterms using rewrite conversions. These are conversions
    39    that also take a context and a list of identifiers for bound variables
    40    as parameters. *)
    41 type rewrite_conv = Proof.context -> (string * term) list -> conv
    42 
    43 (* To apply such a rewrite conversion to a subterm of our goal, we use
    44    subterm positions, which are just functions that map a rewrite conversion,
    45    working on the top level, to a new rewrite conversion, working on
    46    a specific subterm.
    47 
    48    During substitution, we are traversing the goal to find subterms that
    49    we can rewrite. For each of these subterms, a subterm position is
    50    created and later used in creating a conversion that we use to try and
    51    rewrite this subterm. *)
    52 type subterm_position = rewrite_conv -> rewrite_conv
    53 
    54 (* A focusterm represents a subterm. It is a tuple (t, p), consisting
    55   of the subterm t itself and its subterm position p. *)
    56 type focusterm = Type.tyenv * term * subterm_position
    57 
    58 val dummyN = Name.internal "__dummy"
    59 val holeN = Name.internal "_hole"
    60 
    61 fun prep_meta_eq ctxt =
    62   Simplifier.mksimps ctxt #> map Drule.zero_var_indexes
    63 
    64 
    65 (* rewrite conversions *)
    66 
    67 fun abs_rewr_cconv ident : subterm_position =
    68   let
    69     fun add_ident NONE _ l = l
    70       | add_ident (SOME name) ct l = (name, Thm.term_of ct) :: l
    71     fun inner rewr ctxt idents = CConv.abs_cconv (fn (ct, ctxt) => rewr ctxt (add_ident ident ct idents)) ctxt
    72   in inner end
    73 val fun_rewr_cconv : subterm_position = fn rewr => CConv.fun_cconv oo rewr
    74 val arg_rewr_cconv : subterm_position = fn rewr => CConv.arg_cconv oo rewr
    75 
    76 
    77 (* focus terms *)
    78 
    79 fun ft_abs ctxt (s,T) (tyenv, u, pos) =
    80   case try (fastype_of #> dest_funT) u of
    81     NONE => raise TERM ("ft_abs: no function type", [u])
    82   | SOME (U, _) =>
    83   let
    84     val tyenv' = if T = dummyT then tyenv else Sign.typ_match (Proof_Context.theory_of ctxt) (T, U) tyenv
    85     val x = Free (the_default (Name.internal dummyN) s, Envir.norm_type tyenv' T)
    86     val eta_expand_cconv = CConv.rewr_cconv @{thm eta_expand}
    87     fun eta_expand rewr ctxt bounds = eta_expand_cconv then_conv rewr ctxt bounds
    88     val (u', pos') =
    89       case u of
    90         Abs (_,_,t') => (subst_bound (x, t'), pos o abs_rewr_cconv s)
    91       | _ => (u $ x, pos o eta_expand o abs_rewr_cconv s)
    92   in (tyenv', u', pos') end
    93   handle Pattern.MATCH => raise TYPE ("ft_abs: types don't match", [T,U], [u])
    94 
    95 fun ft_fun _ (tyenv, l $ _, pos) = (tyenv, l, pos o fun_rewr_cconv)
    96   | ft_fun ctxt (ft as (_, Abs (_, T, _ $ Bound 0), _)) = (ft_fun ctxt o ft_abs ctxt (NONE, T)) ft
    97   | ft_fun _ (_, t, _) = raise TERM ("ft_fun", [t])
    98 
    99 fun ft_arg _ (tyenv, _ $ r, pos) = (tyenv, r, pos o arg_rewr_cconv)
   100   | ft_arg ctxt (ft as (_, Abs (_, T, _ $ Bound 0), _)) = (ft_arg ctxt o ft_abs ctxt (NONE, T)) ft
   101   | ft_arg _ (_, t, _) = raise TERM ("ft_arg", [t])
   102 
   103 (* Move to B in !!x_1 ... x_n. B. Do not eta-expand *)
   104 fun ft_params ctxt (ft as (_, t, _) : focusterm) =
   105   case t of
   106     Const (@{const_name "Pure.all"}, _) $ Abs (_,T,_) =>
   107       (ft_params ctxt o ft_abs ctxt (NONE, T) o ft_arg ctxt) ft
   108   | Const (@{const_name "Pure.all"}, _) =>
   109       (ft_params ctxt o ft_arg ctxt) ft
   110   | _ => ft
   111 
   112 fun ft_all ctxt ident (ft as (_, Const (@{const_name "Pure.all"}, T) $ _, _) : focusterm) =
   113     let
   114       val def_U = T |> dest_funT |> fst |> dest_funT |> fst
   115       val ident' = apsnd (the_default (def_U)) ident
   116     in (ft_abs ctxt ident' o ft_arg ctxt) ft end
   117   | ft_all _ _ (_, t, _) = raise TERM ("ft_all", [t])
   118 
   119 fun ft_for ctxt idents (ft as (_, t, _) : focusterm) =
   120   let
   121     fun f rev_idents (Const (@{const_name "Pure.all"}, _) $ t) =
   122         let
   123          val (rev_idents', desc) = f rev_idents (case t of Abs (_,_,u) => u | _ => t)
   124         in
   125           case rev_idents' of
   126             [] => ([], desc o ft_all ctxt (NONE, NONE))
   127           | (x :: xs) => (xs , desc o ft_all ctxt x)
   128         end
   129       | f rev_idents _ = (rev_idents, I)
   130   in case f (rev idents) t of
   131       ([], ft') => SOME (ft' ft)
   132     | _ => NONE
   133   end
   134 
   135 fun ft_concl ctxt (ft as (_, t, _) : focusterm) =
   136   case t of
   137     (Const (@{const_name "Pure.imp"}, _) $ _) $ _ => (ft_concl ctxt o ft_arg ctxt) ft
   138   | _ => ft
   139 
   140 fun ft_assm ctxt (ft as (_, t, _) : focusterm) =
   141   case t of
   142     (Const (@{const_name "Pure.imp"}, _) $ _) $ _ => (ft_concl ctxt o ft_arg ctxt o ft_fun ctxt) ft
   143   | _ => raise TERM ("ft_assm", [t])
   144 
   145 fun ft_judgment ctxt (ft as (_, t, _) : focusterm) =
   146   if Object_Logic.is_judgment ctxt t
   147   then ft_arg ctxt ft
   148   else ft
   149 
   150 
   151 (* Return a lazy sequenze of all subterms of the focusterm for which
   152    the condition holds. *)
   153 fun find_subterms ctxt condition (ft as (_, t, _) : focusterm) =
   154   let
   155     val recurse = find_subterms ctxt condition
   156     val recursive_matches = case t of
   157         _ $ _ => Seq.append (ft |> ft_fun ctxt |> recurse) (ft |> ft_arg ctxt |> recurse)
   158       | Abs (_,T,_) => ft |> ft_abs ctxt (NONE, T) |> recurse
   159       | _ => Seq.empty
   160   in
   161     (* If the condition is met, then the current focusterm is part of the
   162        sequence of results. Otherwise, only the results of the recursive
   163        application are. *)
   164     if condition ft
   165     then Seq.cons ft recursive_matches
   166     else recursive_matches
   167   end
   168 
   169 (* Find all subterms that might be a valid point to apply a rule. *)
   170 fun valid_match_points ctxt =
   171   let
   172     fun is_valid (l $ _) = is_valid l
   173       | is_valid (Abs (_, _, a)) = is_valid a
   174       | is_valid (Var _) = false
   175       | is_valid (Bound _) = false
   176       | is_valid _ = true
   177   in
   178     find_subterms ctxt (#2 #> is_valid )
   179   end
   180 
   181 fun is_hole (Var ((name, _), _)) = (name = holeN)
   182   | is_hole _ = false
   183 
   184 fun is_hole_const (Const (@{const_name rewrite_HOLE}, _)) = true
   185   | is_hole_const _ = false
   186 
   187 val hole_syntax =
   188   let
   189     (* Modified variant of Term.replace_hole *)
   190     fun replace_hole Ts (Const (@{const_name rewrite_HOLE}, T)) i =
   191           (list_comb (Var ((holeN, i), Ts ---> T), map_range Bound (length Ts)), i + 1)
   192       | replace_hole Ts (Abs (x, T, t)) i =
   193           let val (t', i') = replace_hole (T :: Ts) t i
   194           in (Abs (x, T, t'), i') end
   195       | replace_hole Ts (t $ u) i =
   196           let
   197             val (t', i') = replace_hole Ts t i
   198             val (u', i'') = replace_hole Ts u i'
   199           in (t' $ u', i'') end
   200       | replace_hole _ a i = (a, i)
   201     fun prep_holes ts = #1 (fold_map (replace_hole []) ts 1)
   202   in
   203     Context.proof_map (Syntax_Phases.term_check 101 "hole_expansion" (K prep_holes))
   204     #> Proof_Context.set_mode Proof_Context.mode_pattern
   205   end
   206 
   207 (* Find a subterm of the focusterm matching the pattern. *)
   208 fun find_matches ctxt pattern_list =
   209   let
   210     fun move_term ctxt (t, off) (ft : focusterm) =
   211       let
   212         val thy = Proof_Context.theory_of ctxt
   213 
   214         val eta_expands =
   215           let val (_, ts) = strip_comb t
   216           in map fastype_of (snd (take_suffix is_Var ts)) end
   217 
   218         fun do_match (tyenv, u, pos) =
   219           case try (Pattern.match thy (t,u)) (tyenv, Vartab.empty) of
   220             NONE => NONE
   221           | SOME (tyenv', _) => SOME (off (tyenv', u, pos))
   222 
   223         fun match_argT T u =
   224           let val (U, _) = dest_funT (fastype_of u)
   225           in try (Sign.typ_match thy (T,U)) end
   226           handle TYPE _ => K NONE
   227 
   228         fun desc [] ft = do_match ft
   229           | desc (T :: Ts) (ft as (tyenv , u, pos)) =
   230             case do_match ft of
   231               NONE =>
   232                 (case match_argT T u tyenv of
   233                   NONE => NONE
   234                 | SOME tyenv' => desc Ts (ft_abs ctxt (NONE, T) (tyenv', u, pos)))
   235             | SOME ft => SOME ft
   236       in desc eta_expands ft end
   237 
   238     fun seq_unfold f ft =
   239       case f ft of
   240         NONE => Seq.empty
   241       | SOME ft' => Seq.cons ft' (seq_unfold f ft')
   242 
   243     fun apply_pat At = Seq.map (ft_judgment ctxt)
   244       | apply_pat In = Seq.maps (valid_match_points ctxt)
   245       | apply_pat Asm = Seq.maps (seq_unfold (try (ft_assm ctxt)) o ft_params ctxt)
   246       | apply_pat Concl = Seq.map (ft_concl ctxt o ft_params ctxt)
   247       | apply_pat (For idents) = Seq.map_filter ((ft_for ctxt (map (apfst SOME) idents)))
   248       | apply_pat (Term x) = Seq.map_filter ( (move_term ctxt x))
   249 
   250     fun apply_pats ft = ft
   251       |> Seq.single
   252       |> fold apply_pat pattern_list
   253 
   254   in
   255     apply_pats
   256   end
   257 
   258 fun instantiate_normalize_env ctxt env thm =
   259   let
   260     fun certs f = map (apply2 (f ctxt))
   261     val prop = Thm.prop_of thm
   262     val norm_type = Envir.norm_type o Envir.type_env
   263     val insts = Term.add_vars prop []
   264       |> map (fn x as (s,T) => (Var (s, norm_type env T), Envir.norm_term env (Var x)))
   265       |> certs Thm.cterm_of
   266     val tyinsts = Term.add_tvars prop []
   267       |> map (fn x => (TVar x, norm_type env (TVar x)))
   268       |> certs Thm.ctyp_of
   269   in Drule.instantiate_normalize (tyinsts, insts) thm end
   270 
   271 fun unify_with_rhs context to env thm =
   272   let
   273     val (_, rhs) = thm |> Thm.concl_of |> Logic.dest_equals
   274     val env' = Pattern.unify context (Logic.mk_term to, Logic.mk_term rhs) env
   275       handle Pattern.Unif => raise NO_TO_MATCH
   276   in env' end
   277 
   278 fun inst_thm_to _ (NONE, _) thm = thm
   279   | inst_thm_to (ctxt : Proof.context) (SOME to, env) thm =
   280       instantiate_normalize_env ctxt (unify_with_rhs (Context.Proof ctxt) to env thm) thm
   281 
   282 fun inst_thm ctxt idents (to, tyenv) thm =
   283   let
   284     (* Replace any identifiers with their corresponding bound variables. *)
   285     val maxidx = Term.maxidx_typs (map (snd o snd) (Vartab.dest tyenv)) 0
   286     val env = Envir.Envir {maxidx = maxidx, tenv = Vartab.empty, tyenv = tyenv}
   287     val replace_idents =
   288       let
   289         fun subst ((n1, s)::ss) (t as Free (n2, _)) = if n1 = n2 then s else subst ss t
   290           | subst _ t = t
   291       in Term.map_aterms (subst idents) end
   292 
   293     val maxidx = Envir.maxidx_of env |> fold Term.maxidx_term (map_filter I [to])
   294     val thm' = Thm.incr_indexes (maxidx + 1) thm
   295   in SOME (inst_thm_to ctxt (Option.map replace_idents to, env) thm') end
   296   handle NO_TO_MATCH => NONE
   297 
   298 (* Rewrite in subgoal i. *)
   299 fun rewrite_goal_with_thm ctxt (pattern, (to, orig_ctxt)) rules = SUBGOAL (fn (t,i) =>
   300   let
   301     val matches = find_matches ctxt pattern (Vartab.empty, t, I)
   302 
   303     fun rewrite_conv insty ctxt bounds =
   304       CConv.rewrs_cconv (map_filter (inst_thm ctxt bounds insty) rules)
   305 
   306     val export = singleton (Proof_Context.export ctxt orig_ctxt)
   307 
   308     fun distinct_prems th =
   309       case Seq.pull (distinct_subgoals_tac th) of
   310         NONE => th
   311       | SOME (th', _) => th'
   312 
   313     fun tac (tyenv, _, position) = CCONVERSION
   314       (distinct_prems o export o position (rewrite_conv (to, tyenv)) ctxt []) i
   315   in
   316     SEQ_CONCAT (Seq.map tac matches)
   317   end)
   318 
   319 fun rewrite_tac ctxt pattern thms =
   320   let
   321     val thms' = maps (prep_meta_eq ctxt) thms
   322     val tac = rewrite_goal_with_thm ctxt pattern thms'
   323   in tac end
   324 
   325 val setup =
   326   let
   327 
   328     fun mk_fix s = (Binding.name s, NONE, NoSyn)
   329 
   330     val raw_pattern : (string, binding * string option * mixfix) pattern list parser =
   331       let
   332         val sep = (Args.$$$ "at" >> K At) || (Args.$$$ "in" >> K In)
   333         val atom =  (Args.$$$ "asm" >> K Asm) ||
   334           (Args.$$$ "concl" >> K Concl) ||
   335           (Args.$$$ "for" |-- Args.parens (Scan.optional Parse.fixes []) >> For) ||
   336           (Parse.term >> Term)
   337         val sep_atom = sep -- atom >> (fn (s,a) => [s,a])
   338 
   339         fun append_default [] = [Concl, In]
   340           | append_default (ps as Term _ :: _) = Concl :: In :: ps
   341           | append_default ps = ps
   342 
   343       in Scan.repeat sep_atom >> (flat #> rev #> append_default) end
   344 
   345     fun ctxt_lift (scan : 'a parser) f = fn (ctxt : Context.generic, toks) =>
   346       let
   347         val (r, toks') = scan toks
   348         val (r', ctxt') = Context.map_proof_result (fn ctxt => f ctxt r) ctxt
   349       in (r', (ctxt', toks' : Token.T list))end
   350 
   351     fun read_fixes fixes ctxt =
   352       let fun read_typ (b, rawT, mx) = (b, Option.map (Syntax.read_typ ctxt) rawT, mx)
   353       in Proof_Context.add_fixes (map read_typ fixes) ctxt end
   354 
   355     fun prep_pats ctxt (ps : (string, binding * string option * mixfix) pattern list) =
   356       let
   357 
   358         fun add_constrs ctxt n (Abs (x, T, t)) =
   359             let
   360               val (x', ctxt') = yield_singleton Proof_Context.add_fixes (mk_fix x) ctxt
   361             in
   362               (case add_constrs ctxt' (n+1) t of
   363                 NONE => NONE
   364               | SOME ((ctxt'', n', xs), t') =>
   365                   let
   366                     val U = Type_Infer.mk_param n []
   367                     val u = Type.constraint (U --> dummyT) (Abs (x, T, t'))
   368                   in SOME ((ctxt'', n', (x', U) :: xs), u) end)
   369             end
   370           | add_constrs ctxt n (l $ r) =
   371             (case add_constrs ctxt n l of
   372               SOME (c, l') => SOME (c, l' $ r)
   373             | NONE =>
   374               (case add_constrs ctxt n r of
   375                 SOME (c, r') => SOME (c, l $ r')
   376               | NONE => NONE))
   377           | add_constrs ctxt n t =
   378             if is_hole_const t then SOME ((ctxt, n, []), t) else NONE
   379 
   380         fun prep (Term s) (n, ctxt) =
   381             let
   382               val t = Syntax.parse_term ctxt s
   383               val ((ctxt', n', bs), t') =
   384                 the_default ((ctxt, n, []), t) (add_constrs ctxt (n+1) t)
   385             in (Term (t', bs), (n', ctxt')) end
   386           | prep (For ss) (n, ctxt) =
   387             let val (ns, ctxt') = read_fixes ss ctxt
   388             in (For ns, (n, ctxt')) end
   389           | prep At (n,ctxt) = (At, (n, ctxt))
   390           | prep In (n,ctxt) = (In, (n, ctxt))
   391           | prep Concl (n,ctxt) = (Concl, (n, ctxt))
   392           | prep Asm (n,ctxt) = (Asm, (n, ctxt))
   393 
   394         val (xs, (_, ctxt')) = fold_map prep ps (0, ctxt)
   395 
   396       in (xs, ctxt') end
   397 
   398     fun prep_args ctxt (((raw_pats, raw_to), raw_ths)) =
   399       let
   400 
   401         fun interpret_term_patterns ctxt =
   402           let
   403 
   404             fun descend_hole fixes (Abs (_, _, t)) =
   405                 (case descend_hole fixes t of
   406                   NONE => NONE
   407                 | SOME (fix :: fixes', pos) => SOME (fixes', pos o ft_abs ctxt (apfst SOME fix))
   408                 | SOME ([], _) => raise Match (* XXX -- check phases modified binding *))
   409               | descend_hole fixes (t as l $ r) =
   410                 let val (f, _) = strip_comb t
   411                 in
   412                   if is_hole f
   413                   then SOME (fixes, I)
   414                   else
   415                     (case descend_hole fixes l of
   416                       SOME (fixes', pos) => SOME (fixes', pos o ft_fun ctxt)
   417                     | NONE =>
   418                       (case descend_hole fixes r of
   419                         SOME (fixes', pos) => SOME (fixes', pos o ft_arg ctxt)
   420                       | NONE => NONE))
   421                 end
   422               | descend_hole fixes t =
   423                 if is_hole t then SOME (fixes, I) else NONE
   424 
   425             fun f (t, fixes) = Term (t, (descend_hole (rev fixes) #> the_default ([], I) #> snd) t)
   426 
   427           in map (map_term_pattern f) end
   428 
   429         fun check_terms ctxt ps to =
   430           let
   431             fun safe_chop (0: int) xs = ([], xs)
   432               | safe_chop n (x :: xs) = chop (n - 1) xs |>> cons x
   433               | safe_chop _ _ = raise Match
   434 
   435             fun reinsert_pat _ (Term (_, cs)) (t :: ts) =
   436                 let val (cs', ts') = safe_chop (length cs) ts
   437                 in (Term (t, map dest_Free cs'), ts') end
   438               | reinsert_pat _ (Term _) [] = raise Match
   439               | reinsert_pat ctxt (For ss) ts =
   440                 let val fixes = map (fn s => (s, Variable.default_type ctxt s)) ss
   441                 in (For fixes, ts) end
   442               | reinsert_pat _ At ts = (At, ts)
   443               | reinsert_pat _ In ts = (In, ts)
   444               | reinsert_pat _ Concl ts = (Concl, ts)
   445               | reinsert_pat _ Asm ts = (Asm, ts)
   446 
   447             fun free_constr (s,T) = Type.constraint T (Free (s, dummyT))
   448             fun mk_free_constrs (Term (t, cs)) = t :: map free_constr cs
   449               | mk_free_constrs _ = []
   450 
   451             val ts = maps mk_free_constrs ps @ map_filter I [to]
   452               |> Syntax.check_terms (hole_syntax ctxt)
   453             val ctxt' = fold Variable.declare_term ts ctxt
   454             val (ps', (to', ts')) = fold_map (reinsert_pat ctxt') ps ts
   455               ||> (fn xs => case to of NONE => (NONE, xs) | SOME _ => (SOME (hd xs), tl xs))
   456             val _ = case ts' of (_ :: _) => raise Match | [] => ()
   457           in ((ps', to'), ctxt') end
   458 
   459         val (pats, ctxt') = prep_pats ctxt raw_pats
   460 
   461         val ths = Attrib.eval_thms ctxt' raw_ths
   462         val to = Option.map (Syntax.parse_term ctxt') raw_to
   463 
   464         val ((pats', to'), ctxt'') = check_terms ctxt' pats to
   465         val pats'' = interpret_term_patterns ctxt'' pats'
   466 
   467       in ((pats'', ths, (to', ctxt)), ctxt'') end
   468 
   469     val to_parser = Scan.option ((Args.$$$ "to") |-- Parse.term)
   470 
   471     val subst_parser =
   472       let val scan = raw_pattern -- to_parser -- Parse.xthms1
   473       in ctxt_lift scan prep_args end
   474   in
   475     Method.setup @{binding rewrite} (subst_parser >>
   476       (fn (pattern, inthms, inst) => fn ctxt => SIMPLE_METHOD' (rewrite_tac ctxt (pattern, inst) inthms)))
   477       "single-step rewriting, allowing subterm selection via patterns."
   478   end
   479 end