src/HOL/Library/rewrite.ML
author wenzelm
Wed Mar 08 10:50:59 2017 +0100 (2017-03-08)
changeset 65151 a7394aa4d21c
parent 63285 e9c777bfd78c
child 69593 3dda49e08b9d
permissions -rw-r--r--
tuned proofs;
     1 (*  Title:      HOL/Library/rewrite.ML
     2     Author:     Christoph Traut, Lars Noschinski, TU Muenchen
     3 
     4 This is a rewrite method that supports subterm-selection based on patterns.
     5 
     6 The patterns accepted by rewrite are of the following form:
     7   <atom>    ::= <term> | "concl" | "asm" | "for" "(" <names> ")"
     8   <pattern> ::= (in <atom> | at <atom>) [<pattern>]
     9   <args>    ::= [<pattern>] ("to" <term>) <thms>
    10 
    11 This syntax was clearly inspired by Gonthier's and Tassi's language of
    12 patterns but has diverged significantly during its development.
    13 
    14 We also allow introduction of identifiers for bound variables,
    15 which can then be used to match arbitrary subterms inside abstractions.
    16 *)
    17 
    18 infix 1 then_pconv;
    19 infix 0 else_pconv;
    20 
    21 signature REWRITE =
    22 sig
    23   type patconv = Proof.context -> Type.tyenv * (string * term) list -> cconv
    24   val then_pconv: patconv * patconv -> patconv
    25   val else_pconv: patconv * patconv -> patconv
    26   val abs_pconv:  patconv -> string option * typ -> patconv (*XXX*)
    27   val fun_pconv: patconv -> patconv
    28   val arg_pconv: patconv -> patconv
    29   val imp_pconv: patconv -> patconv
    30   val params_pconv: patconv -> patconv
    31   val forall_pconv: patconv -> string option * typ option -> patconv
    32   val all_pconv: patconv
    33   val for_pconv: patconv -> (string option * typ option) list -> patconv
    34   val concl_pconv: patconv -> patconv
    35   val asm_pconv: patconv -> patconv
    36   val asms_pconv: patconv -> patconv
    37   val judgment_pconv: patconv -> patconv
    38   val in_pconv: patconv -> patconv
    39   val match_pconv: patconv -> term * (string option * typ) list -> patconv
    40   val rewrs_pconv: term option -> thm list -> patconv
    41 
    42   datatype ('a, 'b) pattern = At | In | Term of 'a | Concl | Asm | For of 'b list
    43 
    44   val mk_hole: int -> typ -> term
    45 
    46   val rewrite_conv: Proof.context
    47     -> (term * (string * typ) list, string * typ option) pattern list * term option
    48     -> thm list
    49     -> conv
    50 end
    51 
    52 structure Rewrite : REWRITE =
    53 struct
    54 
    55 datatype ('a, 'b) pattern = At | In | Term of 'a | Concl | Asm | For of 'b list
    56 
    57 exception NO_TO_MATCH
    58 
    59 val holeN = Name.internal "_hole"
    60 
    61 fun prep_meta_eq ctxt = Simplifier.mksimps ctxt #> map Drule.zero_var_indexes
    62 
    63 
    64 (* holes *)
    65 
    66 fun mk_hole i T = Var ((holeN, i), T)
    67 
    68 fun is_hole (Var ((name, _), _)) = (name = holeN)
    69   | is_hole _ = false
    70 
    71 fun is_hole_const (Const (@{const_name rewrite_HOLE}, _)) = true
    72   | is_hole_const _ = false
    73 
    74 val hole_syntax =
    75   let
    76     (* Modified variant of Term.replace_hole *)
    77     fun replace_hole Ts (Const (@{const_name rewrite_HOLE}, T)) i =
    78           (list_comb (mk_hole i (Ts ---> T), map_range Bound (length Ts)), i + 1)
    79       | replace_hole Ts (Abs (x, T, t)) i =
    80           let val (t', i') = replace_hole (T :: Ts) t i
    81           in (Abs (x, T, t'), i') end
    82       | replace_hole Ts (t $ u) i =
    83           let
    84             val (t', i') = replace_hole Ts t i
    85             val (u', i'') = replace_hole Ts u i'
    86           in (t' $ u', i'') end
    87       | replace_hole _ a i = (a, i)
    88     fun prep_holes ts = #1 (fold_map (replace_hole []) ts 1)
    89   in
    90     Context.proof_map (Syntax_Phases.term_check 101 "hole_expansion" (K prep_holes))
    91     #> Proof_Context.set_mode Proof_Context.mode_pattern
    92   end
    93 
    94 
    95 (* pattern conversions *)
    96 
    97 type patconv = Proof.context -> Type.tyenv * (string * term) list -> cterm -> thm
    98 
    99 fun (cv1 then_pconv cv2) ctxt tytenv ct = (cv1 ctxt tytenv then_conv cv2 ctxt tytenv) ct
   100 
   101 fun (cv1 else_pconv cv2) ctxt tytenv ct = (cv1 ctxt tytenv else_conv cv2 ctxt tytenv) ct
   102 
   103 fun raw_abs_pconv cv ctxt tytenv ct =
   104   case Thm.term_of ct of
   105     Abs _ => CConv.abs_cconv (fn (x, ctxt') => cv x ctxt' tytenv) ctxt ct
   106   | t => raise TERM ("raw_abs_pconv", [t])
   107 
   108 fun raw_fun_pconv cv ctxt tytenv ct =
   109   case Thm.term_of ct of
   110     _ $ _ => CConv.fun_cconv (cv ctxt tytenv) ct
   111   | t => raise TERM ("raw_fun_pconv", [t])
   112 
   113 fun raw_arg_pconv cv ctxt tytenv ct =
   114   case Thm.term_of ct of
   115     _ $ _ => CConv.arg_cconv (cv ctxt tytenv) ct
   116   | t => raise TERM ("raw_arg_pconv", [t])
   117 
   118 fun abs_pconv cv (s,T) ctxt (tyenv, ts) ct =
   119   let val u = Thm.term_of ct
   120   in
   121     case try (fastype_of #> dest_funT) u of
   122       NONE => raise TERM ("abs_pconv: no function type", [u])
   123     | SOME (U, _) =>
   124         let
   125           val tyenv' =
   126             if T = dummyT then tyenv
   127             else Sign.typ_match (Proof_Context.theory_of ctxt) (T, U) tyenv
   128           val eta_expand_cconv =
   129             case u of
   130               Abs _=> Thm.reflexive
   131             | _ => CConv.rewr_cconv @{thm eta_expand}
   132           fun add_ident NONE _ l = l
   133             | add_ident (SOME name) ct l = (name, Thm.term_of ct) :: l
   134           val abs_cv = CConv.abs_cconv (fn (ct, ctxt) => cv ctxt (tyenv', add_ident s ct ts)) ctxt
   135         in (eta_expand_cconv then_conv abs_cv) ct end
   136         handle Pattern.MATCH => raise TYPE ("abs_pconv: types don't match", [T,U], [u])
   137   end
   138 
   139 fun fun_pconv cv ctxt tytenv ct =
   140   case Thm.term_of ct of
   141     _ $ _ => CConv.fun_cconv (cv ctxt tytenv) ct
   142   | Abs (_, T, _ $ Bound 0) => abs_pconv (fun_pconv cv) (NONE, T) ctxt tytenv ct
   143   | t => raise TERM ("fun_pconv", [t])
   144 
   145 local
   146 
   147 fun arg_pconv_gen cv0 cv ctxt tytenv ct =
   148   case Thm.term_of ct of
   149     _ $ _ => cv0 (cv ctxt tytenv) ct
   150   | Abs (_, T, _ $ Bound 0) => abs_pconv (arg_pconv_gen cv0 cv) (NONE, T) ctxt tytenv ct
   151   | t => raise TERM ("arg_pconv_gen", [t])
   152 
   153 in
   154 
   155 fun arg_pconv ctxt = arg_pconv_gen CConv.arg_cconv ctxt
   156 fun imp_pconv ctxt = arg_pconv_gen (CConv.concl_cconv 1) ctxt
   157 
   158 end
   159 
   160 (* Move to B in !!x_1 ... x_n. B. Do not eta-expand *)
   161 fun params_pconv cv ctxt tytenv ct =
   162   let val pconv =
   163     case Thm.term_of ct of
   164       Const (@{const_name "Pure.all"}, _) $ Abs _ => (raw_arg_pconv o raw_abs_pconv) (fn _ => params_pconv cv)
   165     | Const (@{const_name "Pure.all"}, _) => raw_arg_pconv (params_pconv cv)
   166     | _ => cv
   167   in pconv ctxt tytenv ct end
   168 
   169 fun forall_pconv cv ident ctxt tytenv ct =
   170   case Thm.term_of ct of
   171     Const (@{const_name "Pure.all"}, T) $ _ =>
   172       let
   173         val def_U = T |> dest_funT |> fst |> dest_funT |> fst
   174         val ident' = apsnd (the_default (def_U)) ident
   175       in arg_pconv (abs_pconv cv ident') ctxt tytenv ct end
   176   | t => raise TERM ("forall_pconv", [t])
   177 
   178 fun all_pconv _ _ = Thm.reflexive
   179 
   180 fun for_pconv cv idents ctxt tytenv ct =
   181   let
   182     fun f rev_idents (Const (@{const_name "Pure.all"}, _) $ t) =
   183         let val (rev_idents', cv') = f rev_idents (case t of Abs (_,_,u) => u | _ => t)
   184         in
   185           case rev_idents' of
   186             [] => ([], forall_pconv cv' (NONE, NONE))
   187           | (x :: xs) => (xs, forall_pconv cv' x)
   188         end
   189       | f rev_idents _ = (rev_idents, cv)
   190   in
   191     case f (rev idents) (Thm.term_of ct) of
   192       ([], cv') => cv' ctxt tytenv ct
   193     | _ => raise CTERM ("for_pconv", [ct])
   194   end
   195 
   196 fun concl_pconv cv ctxt tytenv ct =
   197   case Thm.term_of ct of
   198     (Const (@{const_name "Pure.imp"}, _) $ _) $ _ => imp_pconv (concl_pconv cv) ctxt tytenv ct
   199   | _ => cv ctxt tytenv ct
   200 
   201 fun asm_pconv cv ctxt tytenv ct =
   202   case Thm.term_of ct of
   203     (Const (@{const_name "Pure.imp"}, _) $ _) $ _ => CConv.with_prems_cconv ~1 (cv ctxt tytenv) ct
   204   | t => raise TERM ("asm_pconv", [t])
   205 
   206 fun asms_pconv cv ctxt tytenv ct =
   207   case Thm.term_of ct of
   208     (Const (@{const_name "Pure.imp"}, _) $ _) $ _ =>
   209       ((CConv.with_prems_cconv ~1 oo cv) else_pconv imp_pconv (asms_pconv cv)) ctxt tytenv ct
   210   | t => raise TERM ("asms_pconv", [t])
   211 
   212 fun judgment_pconv cv ctxt tytenv ct =
   213   if Object_Logic.is_judgment ctxt (Thm.term_of ct)
   214   then arg_pconv cv ctxt tytenv ct
   215   else cv ctxt tytenv ct
   216 
   217 fun in_pconv cv ctxt tytenv ct =
   218   (cv else_pconv 
   219    raw_fun_pconv (in_pconv cv) else_pconv
   220    raw_arg_pconv (in_pconv cv) else_pconv
   221    raw_abs_pconv (fn _  => in_pconv cv))
   222   ctxt tytenv ct
   223 
   224 fun replace_idents idents t =
   225   let
   226     fun subst ((n1, s)::ss) (t as Free (n2, _)) = if n1 = n2 then s else subst ss t
   227       | subst _ t = t
   228   in Term.map_aterms (subst idents) t end
   229 
   230 fun match_pconv cv (t,fixes) ctxt (tyenv, env_ts) ct =
   231   let
   232     val t' = replace_idents env_ts t
   233     val thy = Proof_Context.theory_of ctxt
   234     val u = Thm.term_of ct
   235 
   236     fun descend_hole fixes (Abs (_, _, t)) =
   237         (case descend_hole fixes t of
   238           NONE => NONE
   239         | SOME (fix :: fixes', pos) => SOME (fixes', abs_pconv pos fix)
   240         | SOME ([], _) => raise Match (* less fixes than abstractions on path to hole *))
   241       | descend_hole fixes (t as l $ r) =
   242         let val (f, _) = strip_comb t
   243         in
   244           if is_hole f
   245           then SOME (fixes, cv)
   246           else
   247             (case descend_hole fixes l of
   248               SOME (fixes', pos) => SOME (fixes', fun_pconv pos)
   249             | NONE =>
   250               (case descend_hole fixes r of
   251                 SOME (fixes', pos) => SOME (fixes', arg_pconv pos)
   252               | NONE => NONE))
   253         end
   254       | descend_hole fixes t =
   255         if is_hole t then SOME (fixes, cv) else NONE
   256 
   257     val to_hole = descend_hole (rev fixes) #> the_default ([], cv) #> snd
   258   in
   259     case try (Pattern.match thy (apply2 Logic.mk_term (t',u))) (tyenv, Vartab.empty) of
   260       NONE => raise TERM ("match_pconv: Does not match pattern", [t, t',u])
   261     | SOME (tyenv', _) => to_hole t ctxt (tyenv', env_ts) ct
   262   end
   263 
   264 fun rewrs_pconv to thms ctxt (tyenv, env_ts) =
   265   let
   266     fun instantiate_normalize_env ctxt env thm =
   267       let
   268         val prop = Thm.prop_of thm
   269         val norm_type = Envir.norm_type o Envir.type_env
   270         val insts = Term.add_vars prop []
   271           |> map (fn x as (s, T) =>
   272               ((s, norm_type env T), Thm.cterm_of ctxt (Envir.norm_term env (Var x))))
   273         val tyinsts = Term.add_tvars prop []
   274           |> map (fn x => (x, Thm.ctyp_of ctxt (norm_type env (TVar x))))
   275       in Drule.instantiate_normalize (tyinsts, insts) thm end
   276     
   277     fun unify_with_rhs context to env thm =
   278       let
   279         val (_, rhs) = thm |> Thm.concl_of |> Logic.dest_equals
   280         val env' = Pattern.unify context (Logic.mk_term to, Logic.mk_term rhs) env
   281           handle Pattern.Unif => raise NO_TO_MATCH
   282       in env' end
   283     
   284     fun inst_thm_to _ (NONE, _) thm = thm
   285       | inst_thm_to (ctxt : Proof.context) (SOME to, env) thm =
   286           instantiate_normalize_env ctxt (unify_with_rhs (Context.Proof ctxt) to env thm) thm
   287     
   288     fun inst_thm ctxt idents (to, tyenv) thm =
   289       let
   290         (* Replace any identifiers with their corresponding bound variables. *)
   291         val maxidx = Term.maxidx_typs (map (snd o snd) (Vartab.dest tyenv)) 0
   292         val env = Envir.Envir {maxidx = maxidx, tenv = Vartab.empty, tyenv = tyenv}
   293         val maxidx = Envir.maxidx_of env |> fold Term.maxidx_term (the_list to)
   294         val thm' = Thm.incr_indexes (maxidx + 1) thm
   295       in SOME (inst_thm_to ctxt (Option.map (replace_idents idents) to, env) thm') end
   296       handle NO_TO_MATCH => NONE
   297     
   298   in CConv.rewrs_cconv (map_filter (inst_thm ctxt env_ts (to, tyenv)) thms) end
   299 
   300 fun rewrite_conv ctxt (pattern, to) thms ct =
   301   let
   302     fun apply_pat At = judgment_pconv
   303       | apply_pat In = in_pconv
   304       | apply_pat Asm = params_pconv o asms_pconv
   305       | apply_pat Concl = params_pconv o concl_pconv
   306       | apply_pat (For idents) = (fn cv => for_pconv cv (map (apfst SOME) idents))
   307       | apply_pat (Term x) = (fn cv => match_pconv cv (apsnd (map (apfst SOME)) x))
   308 
   309     val cv = fold_rev apply_pat pattern
   310 
   311     fun distinct_prems th =
   312       case Seq.pull (distinct_subgoals_tac th) of
   313         NONE => th
   314       | SOME (th', _) => th'
   315 
   316     val rewrite = rewrs_pconv to (maps (prep_meta_eq ctxt) thms)
   317   in cv rewrite ctxt (Vartab.empty, []) ct |> distinct_prems end
   318 
   319 fun rewrite_export_tac ctxt (pat, pat_ctxt) thms =
   320   let
   321     val export = case pat_ctxt of
   322         NONE => I
   323       | SOME inner => singleton (Proof_Context.export inner ctxt)
   324   in CCONVERSION (export o rewrite_conv ctxt pat thms) end
   325 
   326 val _ =
   327   Theory.setup
   328   let
   329     fun mk_fix s = (Binding.name s, NONE, NoSyn)
   330 
   331     val raw_pattern : (string, binding * string option * mixfix) pattern list parser =
   332       let
   333         val sep = (Args.$$$ "at" >> K At) || (Args.$$$ "in" >> K In)
   334         val atom =  (Args.$$$ "asm" >> K Asm) ||
   335           (Args.$$$ "concl" >> K Concl) ||
   336           (Args.$$$ "for" |-- Args.parens (Scan.optional Parse.vars []) >> For) ||
   337           (Parse.term >> Term)
   338         val sep_atom = sep -- atom >> (fn (s,a) => [s,a])
   339 
   340         fun append_default [] = [Concl, In]
   341           | append_default (ps as Term _ :: _) = Concl :: In :: ps
   342           | append_default [For x, In] = [For x, Concl, In]
   343           | append_default (For x :: (ps as In :: Term _:: _)) = For x :: Concl :: ps
   344           | append_default ps = ps
   345 
   346       in Scan.repeats sep_atom >> (rev #> append_default) end
   347 
   348     fun context_lift (scan : 'a parser) f = fn (context : Context.generic, toks) =>
   349       let
   350         val (r, toks') = scan toks
   351         val (r', context') = Context.map_proof_result (fn ctxt => f ctxt r) context
   352       in (r', (context', toks' : Token.T list)) end
   353 
   354     fun read_fixes fixes ctxt =
   355       let fun read_typ (b, rawT, mx) = (b, Option.map (Syntax.read_typ ctxt) rawT, mx)
   356       in Proof_Context.add_fixes (map read_typ fixes) ctxt end
   357 
   358     fun prep_pats ctxt (ps : (string, binding * string option * mixfix) pattern list) =
   359       let
   360         fun add_constrs ctxt n (Abs (x, T, t)) =
   361             let
   362               val (x', ctxt') = yield_singleton Proof_Context.add_fixes (mk_fix x) ctxt
   363             in
   364               (case add_constrs ctxt' (n+1) t of
   365                 NONE => NONE
   366               | SOME ((ctxt'', n', xs), t') =>
   367                   let
   368                     val U = Type_Infer.mk_param n []
   369                     val u = Type.constraint (U --> dummyT) (Abs (x, T, t'))
   370                   in SOME ((ctxt'', n', (x', U) :: xs), u) end)
   371             end
   372           | add_constrs ctxt n (l $ r) =
   373             (case add_constrs ctxt n l of
   374               SOME (c, l') => SOME (c, l' $ r)
   375             | NONE =>
   376               (case add_constrs ctxt n r of
   377                 SOME (c, r') => SOME (c, l $ r')
   378               | NONE => NONE))
   379           | add_constrs ctxt n t =
   380             if is_hole_const t then SOME ((ctxt, n, []), t) else NONE
   381 
   382         fun prep (Term s) (n, ctxt) =
   383             let
   384               val t = Syntax.parse_term ctxt s
   385               val ((ctxt', n', bs), t') =
   386                 the_default ((ctxt, n, []), t) (add_constrs ctxt (n+1) t)
   387             in (Term (t', bs), (n', ctxt')) end
   388           | prep (For ss) (n, ctxt) =
   389             let val (ns, ctxt') = read_fixes ss ctxt
   390             in (For ns, (n, ctxt')) end
   391           | prep At (n,ctxt) = (At, (n, ctxt))
   392           | prep In (n,ctxt) = (In, (n, ctxt))
   393           | prep Concl (n,ctxt) = (Concl, (n, ctxt))
   394           | prep Asm (n,ctxt) = (Asm, (n, ctxt))
   395 
   396         val (xs, (_, ctxt')) = fold_map prep ps (0, ctxt)
   397 
   398       in (xs, ctxt') end
   399 
   400     fun prep_args ctxt (((raw_pats, raw_to), raw_ths)) =
   401       let
   402 
   403         fun check_terms ctxt ps to =
   404           let
   405             fun safe_chop (0: int) xs = ([], xs)
   406               | safe_chop n (x :: xs) = chop (n - 1) xs |>> cons x
   407               | safe_chop _ _ = raise Match
   408 
   409             fun reinsert_pat _ (Term (_, cs)) (t :: ts) =
   410                 let val (cs', ts') = safe_chop (length cs) ts
   411                 in (Term (t, map dest_Free cs'), ts') end
   412               | reinsert_pat _ (Term _) [] = raise Match
   413               | reinsert_pat ctxt (For ss) ts =
   414                 let val fixes = map (fn s => (s, Variable.default_type ctxt s)) ss
   415                 in (For fixes, ts) end
   416               | reinsert_pat _ At ts = (At, ts)
   417               | reinsert_pat _ In ts = (In, ts)
   418               | reinsert_pat _ Concl ts = (Concl, ts)
   419               | reinsert_pat _ Asm ts = (Asm, ts)
   420 
   421             fun free_constr (s,T) = Type.constraint T (Free (s, dummyT))
   422             fun mk_free_constrs (Term (t, cs)) = t :: map free_constr cs
   423               | mk_free_constrs _ = []
   424 
   425             val ts = maps mk_free_constrs ps @ the_list to
   426               |> Syntax.check_terms (hole_syntax ctxt)
   427             val ctxt' = fold Variable.declare_term ts ctxt
   428             val (ps', (to', ts')) = fold_map (reinsert_pat ctxt') ps ts
   429               ||> (fn xs => case to of NONE => (NONE, xs) | SOME _ => (SOME (hd xs), tl xs))
   430             val _ = case ts' of (_ :: _) => raise Match | [] => ()
   431           in ((ps', to'), ctxt') end
   432 
   433         val (pats, ctxt') = prep_pats ctxt raw_pats
   434 
   435         val ths = Attrib.eval_thms ctxt' raw_ths
   436         val to = Option.map (Syntax.parse_term ctxt') raw_to
   437 
   438         val ((pats', to'), ctxt'') = check_terms ctxt' pats to
   439 
   440       in ((pats', ths, (to', ctxt)), ctxt'') end
   441 
   442     val to_parser = Scan.option ((Args.$$$ "to") |-- Parse.term)
   443 
   444     val subst_parser =
   445       let val scan = raw_pattern -- to_parser -- Parse.thms1
   446       in context_lift scan prep_args end
   447   in
   448     Method.setup @{binding rewrite} (subst_parser >>
   449       (fn (pattern, inthms, (to, pat_ctxt)) => fn orig_ctxt =>
   450         SIMPLE_METHOD' (rewrite_export_tac orig_ctxt ((pattern, to), SOME pat_ctxt) inthms)))
   451       "single-step rewriting, allowing subterm selection via patterns."
   452   end
   453 end