src/HOL/Library/rewrite.ML
author wenzelm
Wed Apr 08 21:24:27 2015 +0200 (2015-04-08)
changeset 59975 da10875adf8e
parent 59970 e9f73d87d904
child 60050 dc6ac152d864
permissions -rw-r--r--
more standard Isabelle/ML tool setup;
proper file headers;
tuned whitespace;
     1 (*  Title:      HOL/Library/rewrite.ML
     2     Author:     Christoph Traut, Lars Noschinski, TU Muenchen
     3 
     4 This is a rewrite method that supports subterm-selection based on patterns.
     5 
     6 The patterns accepted by rewrite are of the following form:
     7   <atom>    ::= <term> | "concl" | "asm" | "for" "(" <names> ")"
     8   <pattern> ::= (in <atom> | at <atom>) [<pattern>]
     9   <args>    ::= [<pattern>] ("to" <term>) <thms>
    10 
    11 This syntax was clearly inspired by Gonthier's and Tassi's language of
    12 patterns but has diverged significantly during its development.
    13 
    14 We also allow introduction of identifiers for bound variables,
    15 which can then be used to match arbitrary subterms inside abstractions.
    16 *)
    17 
    18 signature REWRITE =
    19 sig
    20   (* FIXME proper ML interface!? *)
    21 end
    22 
    23 structure Rewrite : REWRITE =
    24 struct
    25 
    26 datatype ('a, 'b) pattern = At | In | Term of 'a | Concl | Asm | For of 'b list
    27 
    28 fun map_term_pattern f (Term x) = f x
    29   | map_term_pattern _ (For ss) = (For ss)
    30   | map_term_pattern _ At = At
    31   | map_term_pattern _ In = In
    32   | map_term_pattern _ Concl = Concl
    33   | map_term_pattern _ Asm = Asm
    34 
    35 
    36 exception NO_TO_MATCH
    37 
    38 fun SEQ_CONCAT (tacq : tactic Seq.seq) : tactic = fn st => Seq.maps (fn tac => tac st) tacq
    39 
    40 (* We rewrite subterms using rewrite conversions. These are conversions
    41    that also take a context and a list of identifiers for bound variables
    42    as parameters. *)
    43 type rewrite_conv = Proof.context -> (string * term) list -> conv
    44 
    45 (* To apply such a rewrite conversion to a subterm of our goal, we use
    46    subterm positions, which are just functions that map a rewrite conversion,
    47    working on the top level, to a new rewrite conversion, working on
    48    a specific subterm.
    49 
    50    During substitution, we are traversing the goal to find subterms that
    51    we can rewrite. For each of these subterms, a subterm position is
    52    created and later used in creating a conversion that we use to try and
    53    rewrite this subterm. *)
    54 type subterm_position = rewrite_conv -> rewrite_conv
    55 
    56 (* A focusterm represents a subterm. It is a tuple (t, p), consisting
    57   of the subterm t itself and its subterm position p. *)
    58 type focusterm = Type.tyenv * term * subterm_position
    59 
    60 val dummyN = Name.internal "__dummy"
    61 val holeN = Name.internal "_hole"
    62 
    63 fun prep_meta_eq ctxt =
    64   Simplifier.mksimps ctxt #> map Drule.zero_var_indexes
    65 
    66 
    67 (* rewrite conversions *)
    68 
    69 fun abs_rewr_cconv ident : subterm_position =
    70   let
    71     fun add_ident NONE _ l = l
    72       | add_ident (SOME name) ct l = (name, Thm.term_of ct) :: l
    73     fun inner rewr ctxt idents =
    74       CConv.abs_cconv (fn (ct, ctxt) => rewr ctxt (add_ident ident ct idents)) ctxt
    75   in inner end
    76 
    77 val fun_rewr_cconv : subterm_position = fn rewr => CConv.fun_cconv oo rewr
    78 val arg_rewr_cconv : subterm_position = fn rewr => CConv.arg_cconv oo rewr
    79 
    80 
    81 (* focus terms *)
    82 
    83 fun ft_abs ctxt (s,T) (tyenv, u, pos) =
    84   case try (fastype_of #> dest_funT) u of
    85     NONE => raise TERM ("ft_abs: no function type", [u])
    86   | SOME (U, _) =>
    87       let
    88         val tyenv' =
    89           if T = dummyT then tyenv
    90           else Sign.typ_match (Proof_Context.theory_of ctxt) (T, U) tyenv
    91         val x = Free (the_default (Name.internal dummyN) s, Envir.norm_type tyenv' T)
    92         val eta_expand_cconv = CConv.rewr_cconv @{thm eta_expand}
    93         fun eta_expand rewr ctxt bounds = eta_expand_cconv then_conv rewr ctxt bounds
    94         val (u', pos') =
    95           case u of
    96             Abs (_,_,t') => (subst_bound (x, t'), pos o abs_rewr_cconv s)
    97           | _ => (u $ x, pos o eta_expand o abs_rewr_cconv s)
    98       in (tyenv', u', pos') end
    99       handle Pattern.MATCH => raise TYPE ("ft_abs: types don't match", [T,U], [u])
   100 
   101 fun ft_fun _ (tyenv, l $ _, pos) = (tyenv, l, pos o fun_rewr_cconv)
   102   | ft_fun ctxt (ft as (_, Abs (_, T, _ $ Bound 0), _)) = (ft_fun ctxt o ft_abs ctxt (NONE, T)) ft
   103   | ft_fun _ (_, t, _) = raise TERM ("ft_fun", [t])
   104 
   105 fun ft_arg _ (tyenv, _ $ r, pos) = (tyenv, r, pos o arg_rewr_cconv)
   106   | ft_arg ctxt (ft as (_, Abs (_, T, _ $ Bound 0), _)) = (ft_arg ctxt o ft_abs ctxt (NONE, T)) ft
   107   | ft_arg _ (_, t, _) = raise TERM ("ft_arg", [t])
   108 
   109 (* Move to B in !!x_1 ... x_n. B. Do not eta-expand *)
   110 fun ft_params ctxt (ft as (_, t, _) : focusterm) =
   111   case t of
   112     Const (@{const_name "Pure.all"}, _) $ Abs (_,T,_) =>
   113       (ft_params ctxt o ft_abs ctxt (NONE, T) o ft_arg ctxt) ft
   114   | Const (@{const_name "Pure.all"}, _) =>
   115       (ft_params ctxt o ft_arg ctxt) ft
   116   | _ => ft
   117 
   118 fun ft_all ctxt ident (ft as (_, Const (@{const_name "Pure.all"}, T) $ _, _) : focusterm) =
   119     let
   120       val def_U = T |> dest_funT |> fst |> dest_funT |> fst
   121       val ident' = apsnd (the_default (def_U)) ident
   122     in (ft_abs ctxt ident' o ft_arg ctxt) ft end
   123   | ft_all _ _ (_, t, _) = raise TERM ("ft_all", [t])
   124 
   125 fun ft_for ctxt idents (ft as (_, t, _) : focusterm) =
   126   let
   127     fun f rev_idents (Const (@{const_name "Pure.all"}, _) $ t) =
   128         let
   129          val (rev_idents', desc) = f rev_idents (case t of Abs (_,_,u) => u | _ => t)
   130         in
   131           case rev_idents' of
   132             [] => ([], desc o ft_all ctxt (NONE, NONE))
   133           | (x :: xs) => (xs , desc o ft_all ctxt x)
   134         end
   135       | f rev_idents _ = (rev_idents, I)
   136   in
   137     case f (rev idents) t of
   138       ([], ft') => SOME (ft' ft)
   139     | _ => NONE
   140   end
   141 
   142 fun ft_concl ctxt (ft as (_, t, _) : focusterm) =
   143   case t of
   144     (Const (@{const_name "Pure.imp"}, _) $ _) $ _ => (ft_concl ctxt o ft_arg ctxt) ft
   145   | _ => ft
   146 
   147 fun ft_assm ctxt (ft as (_, t, _) : focusterm) =
   148   case t of
   149     (Const (@{const_name "Pure.imp"}, _) $ _) $ _ => (ft_concl ctxt o ft_arg ctxt o ft_fun ctxt) ft
   150   | _ => raise TERM ("ft_assm", [t])
   151 
   152 fun ft_judgment ctxt (ft as (_, t, _) : focusterm) =
   153   if Object_Logic.is_judgment ctxt t
   154   then ft_arg ctxt ft
   155   else ft
   156 
   157 
   158 (* Return a lazy sequenze of all subterms of the focusterm for which
   159    the condition holds. *)
   160 fun find_subterms ctxt condition (ft as (_, t, _) : focusterm) =
   161   let
   162     val recurse = find_subterms ctxt condition
   163     val recursive_matches =
   164       case t of
   165         _ $ _ => Seq.append (ft |> ft_fun ctxt |> recurse) (ft |> ft_arg ctxt |> recurse)
   166       | Abs (_,T,_) => ft |> ft_abs ctxt (NONE, T) |> recurse
   167       | _ => Seq.empty
   168   in
   169     (* If the condition is met, then the current focusterm is part of the
   170        sequence of results. Otherwise, only the results of the recursive
   171        application are. *)
   172     if condition ft
   173     then Seq.cons ft recursive_matches
   174     else recursive_matches
   175   end
   176 
   177 (* Find all subterms that might be a valid point to apply a rule. *)
   178 fun valid_match_points ctxt =
   179   let
   180     fun is_valid (l $ _) = is_valid l
   181       | is_valid (Abs (_, _, a)) = is_valid a
   182       | is_valid (Var _) = false
   183       | is_valid (Bound _) = false
   184       | is_valid _ = true
   185   in
   186     find_subterms ctxt (#2 #> is_valid )
   187   end
   188 
   189 fun is_hole (Var ((name, _), _)) = (name = holeN)
   190   | is_hole _ = false
   191 
   192 fun is_hole_const (Const (@{const_name rewrite_HOLE}, _)) = true
   193   | is_hole_const _ = false
   194 
   195 val hole_syntax =
   196   let
   197     (* Modified variant of Term.replace_hole *)
   198     fun replace_hole Ts (Const (@{const_name rewrite_HOLE}, T)) i =
   199           (list_comb (Var ((holeN, i), Ts ---> T), map_range Bound (length Ts)), i + 1)
   200       | replace_hole Ts (Abs (x, T, t)) i =
   201           let val (t', i') = replace_hole (T :: Ts) t i
   202           in (Abs (x, T, t'), i') end
   203       | replace_hole Ts (t $ u) i =
   204           let
   205             val (t', i') = replace_hole Ts t i
   206             val (u', i'') = replace_hole Ts u i'
   207           in (t' $ u', i'') end
   208       | replace_hole _ a i = (a, i)
   209     fun prep_holes ts = #1 (fold_map (replace_hole []) ts 1)
   210   in
   211     Context.proof_map (Syntax_Phases.term_check 101 "hole_expansion" (K prep_holes))
   212     #> Proof_Context.set_mode Proof_Context.mode_pattern
   213   end
   214 
   215 (* Find a subterm of the focusterm matching the pattern. *)
   216 fun find_matches ctxt pattern_list =
   217   let
   218     fun move_term ctxt (t, off) (ft : focusterm) =
   219       let
   220         val thy = Proof_Context.theory_of ctxt
   221 
   222         val eta_expands =
   223           let val (_, ts) = strip_comb t
   224           in map fastype_of (snd (take_suffix is_Var ts)) end
   225 
   226         fun do_match (tyenv, u, pos) =
   227           case try (Pattern.match thy (t,u)) (tyenv, Vartab.empty) of
   228             NONE => NONE
   229           | SOME (tyenv', _) => SOME (off (tyenv', u, pos))
   230 
   231         fun match_argT T u =
   232           let val (U, _) = dest_funT (fastype_of u)
   233           in try (Sign.typ_match thy (T,U)) end
   234           handle TYPE _ => K NONE
   235 
   236         fun desc [] ft = do_match ft
   237           | desc (T :: Ts) (ft as (tyenv , u, pos)) =
   238             case do_match ft of
   239               NONE =>
   240                 (case match_argT T u tyenv of
   241                   NONE => NONE
   242                 | SOME tyenv' => desc Ts (ft_abs ctxt (NONE, T) (tyenv', u, pos)))
   243             | SOME ft => SOME ft
   244       in desc eta_expands ft end
   245 
   246     fun seq_unfold f ft =
   247       case f ft of
   248         NONE => Seq.empty
   249       | SOME ft' => Seq.cons ft' (seq_unfold f ft')
   250 
   251     fun apply_pat At = Seq.map (ft_judgment ctxt)
   252       | apply_pat In = Seq.maps (valid_match_points ctxt)
   253       | apply_pat Asm = Seq.maps (seq_unfold (try (ft_assm ctxt)) o ft_params ctxt)
   254       | apply_pat Concl = Seq.map (ft_concl ctxt o ft_params ctxt)
   255       | apply_pat (For idents) = Seq.map_filter ((ft_for ctxt (map (apfst SOME) idents)))
   256       | apply_pat (Term x) = Seq.map_filter ( (move_term ctxt x))
   257 
   258     fun apply_pats ft = ft
   259       |> Seq.single
   260       |> fold apply_pat pattern_list
   261   in
   262     apply_pats
   263   end
   264 
   265 fun instantiate_normalize_env ctxt env thm =
   266   let
   267     fun certs f = map (apply2 (f ctxt))
   268     val prop = Thm.prop_of thm
   269     val norm_type = Envir.norm_type o Envir.type_env
   270     val insts = Term.add_vars prop []
   271       |> map (fn x as (s,T) => (Var (s, norm_type env T), Envir.norm_term env (Var x)))
   272       |> certs Thm.cterm_of
   273     val tyinsts = Term.add_tvars prop []
   274       |> map (fn x => (TVar x, norm_type env (TVar x)))
   275       |> certs Thm.ctyp_of
   276   in Drule.instantiate_normalize (tyinsts, insts) thm end
   277 
   278 fun unify_with_rhs context to env thm =
   279   let
   280     val (_, rhs) = thm |> Thm.concl_of |> Logic.dest_equals
   281     val env' = Pattern.unify context (Logic.mk_term to, Logic.mk_term rhs) env
   282       handle Pattern.Unif => raise NO_TO_MATCH
   283   in env' end
   284 
   285 fun inst_thm_to _ (NONE, _) thm = thm
   286   | inst_thm_to (ctxt : Proof.context) (SOME to, env) thm =
   287       instantiate_normalize_env ctxt (unify_with_rhs (Context.Proof ctxt) to env thm) thm
   288 
   289 fun inst_thm ctxt idents (to, tyenv) thm =
   290   let
   291     (* Replace any identifiers with their corresponding bound variables. *)
   292     val maxidx = Term.maxidx_typs (map (snd o snd) (Vartab.dest tyenv)) 0
   293     val env = Envir.Envir {maxidx = maxidx, tenv = Vartab.empty, tyenv = tyenv}
   294     val replace_idents =
   295       let
   296         fun subst ((n1, s)::ss) (t as Free (n2, _)) = if n1 = n2 then s else subst ss t
   297           | subst _ t = t
   298       in Term.map_aterms (subst idents) end
   299 
   300     val maxidx = Envir.maxidx_of env |> fold Term.maxidx_term (map_filter I [to])
   301     val thm' = Thm.incr_indexes (maxidx + 1) thm
   302   in SOME (inst_thm_to ctxt (Option.map replace_idents to, env) thm') end
   303   handle NO_TO_MATCH => NONE
   304 
   305 (* Rewrite in subgoal i. *)
   306 fun rewrite_goal_with_thm ctxt (pattern, (to, orig_ctxt)) rules = SUBGOAL (fn (t,i) =>
   307   let
   308     val matches = find_matches ctxt pattern (Vartab.empty, t, I)
   309 
   310     fun rewrite_conv insty ctxt bounds =
   311       CConv.rewrs_cconv (map_filter (inst_thm ctxt bounds insty) rules)
   312 
   313     val export = singleton (Proof_Context.export ctxt orig_ctxt)
   314 
   315     fun distinct_prems th =
   316       case Seq.pull (distinct_subgoals_tac th) of
   317         NONE => th
   318       | SOME (th', _) => th'
   319 
   320     fun tac (tyenv, _, position) = CCONVERSION
   321       (distinct_prems o export o position (rewrite_conv (to, tyenv)) ctxt []) i
   322   in
   323     SEQ_CONCAT (Seq.map tac matches)
   324   end)
   325 
   326 fun rewrite_tac ctxt pattern thms =
   327   let
   328     val thms' = maps (prep_meta_eq ctxt) thms
   329     val tac = rewrite_goal_with_thm ctxt pattern thms'
   330   in tac end
   331 
   332 val _ =
   333   Theory.setup
   334   let
   335     fun mk_fix s = (Binding.name s, NONE, NoSyn)
   336 
   337     val raw_pattern : (string, binding * string option * mixfix) pattern list parser =
   338       let
   339         val sep = (Args.$$$ "at" >> K At) || (Args.$$$ "in" >> K In)
   340         val atom =  (Args.$$$ "asm" >> K Asm) ||
   341           (Args.$$$ "concl" >> K Concl) ||
   342           (Args.$$$ "for" |-- Args.parens (Scan.optional Parse.fixes []) >> For) ||
   343           (Parse.term >> Term)
   344         val sep_atom = sep -- atom >> (fn (s,a) => [s,a])
   345 
   346         fun append_default [] = [Concl, In]
   347           | append_default (ps as Term _ :: _) = Concl :: In :: ps
   348           | append_default ps = ps
   349 
   350       in Scan.repeat sep_atom >> (flat #> rev #> append_default) end
   351 
   352     fun context_lift (scan : 'a parser) f = fn (context : Context.generic, toks) =>
   353       let
   354         val (r, toks') = scan toks
   355         val (r', context') = Context.map_proof_result (fn ctxt => f ctxt r) context
   356       in (r', (context', toks' : Token.T list)) end
   357 
   358     fun read_fixes fixes ctxt =
   359       let fun read_typ (b, rawT, mx) = (b, Option.map (Syntax.read_typ ctxt) rawT, mx)
   360       in Proof_Context.add_fixes (map read_typ fixes) ctxt end
   361 
   362     fun prep_pats ctxt (ps : (string, binding * string option * mixfix) pattern list) =
   363       let
   364         fun add_constrs ctxt n (Abs (x, T, t)) =
   365             let
   366               val (x', ctxt') = yield_singleton Proof_Context.add_fixes (mk_fix x) ctxt
   367             in
   368               (case add_constrs ctxt' (n+1) t of
   369                 NONE => NONE
   370               | SOME ((ctxt'', n', xs), t') =>
   371                   let
   372                     val U = Type_Infer.mk_param n []
   373                     val u = Type.constraint (U --> dummyT) (Abs (x, T, t'))
   374                   in SOME ((ctxt'', n', (x', U) :: xs), u) end)
   375             end
   376           | add_constrs ctxt n (l $ r) =
   377             (case add_constrs ctxt n l of
   378               SOME (c, l') => SOME (c, l' $ r)
   379             | NONE =>
   380               (case add_constrs ctxt n r of
   381                 SOME (c, r') => SOME (c, l $ r')
   382               | NONE => NONE))
   383           | add_constrs ctxt n t =
   384             if is_hole_const t then SOME ((ctxt, n, []), t) else NONE
   385 
   386         fun prep (Term s) (n, ctxt) =
   387             let
   388               val t = Syntax.parse_term ctxt s
   389               val ((ctxt', n', bs), t') =
   390                 the_default ((ctxt, n, []), t) (add_constrs ctxt (n+1) t)
   391             in (Term (t', bs), (n', ctxt')) end
   392           | prep (For ss) (n, ctxt) =
   393             let val (ns, ctxt') = read_fixes ss ctxt
   394             in (For ns, (n, ctxt')) end
   395           | prep At (n,ctxt) = (At, (n, ctxt))
   396           | prep In (n,ctxt) = (In, (n, ctxt))
   397           | prep Concl (n,ctxt) = (Concl, (n, ctxt))
   398           | prep Asm (n,ctxt) = (Asm, (n, ctxt))
   399 
   400         val (xs, (_, ctxt')) = fold_map prep ps (0, ctxt)
   401 
   402       in (xs, ctxt') end
   403 
   404     fun prep_args ctxt (((raw_pats, raw_to), raw_ths)) =
   405       let
   406 
   407         fun interpret_term_patterns ctxt =
   408           let
   409 
   410             fun descend_hole fixes (Abs (_, _, t)) =
   411                 (case descend_hole fixes t of
   412                   NONE => NONE
   413                 | SOME (fix :: fixes', pos) => SOME (fixes', pos o ft_abs ctxt (apfst SOME fix))
   414                 | SOME ([], _) => raise Match (* XXX -- check phases modified binding *))
   415               | descend_hole fixes (t as l $ r) =
   416                 let val (f, _) = strip_comb t
   417                 in
   418                   if is_hole f
   419                   then SOME (fixes, I)
   420                   else
   421                     (case descend_hole fixes l of
   422                       SOME (fixes', pos) => SOME (fixes', pos o ft_fun ctxt)
   423                     | NONE =>
   424                       (case descend_hole fixes r of
   425                         SOME (fixes', pos) => SOME (fixes', pos o ft_arg ctxt)
   426                       | NONE => NONE))
   427                 end
   428               | descend_hole fixes t =
   429                 if is_hole t then SOME (fixes, I) else NONE
   430 
   431             fun f (t, fixes) = Term (t, (descend_hole (rev fixes) #> the_default ([], I) #> snd) t)
   432 
   433           in map (map_term_pattern f) end
   434 
   435         fun check_terms ctxt ps to =
   436           let
   437             fun safe_chop (0: int) xs = ([], xs)
   438               | safe_chop n (x :: xs) = chop (n - 1) xs |>> cons x
   439               | safe_chop _ _ = raise Match
   440 
   441             fun reinsert_pat _ (Term (_, cs)) (t :: ts) =
   442                 let val (cs', ts') = safe_chop (length cs) ts
   443                 in (Term (t, map dest_Free cs'), ts') end
   444               | reinsert_pat _ (Term _) [] = raise Match
   445               | reinsert_pat ctxt (For ss) ts =
   446                 let val fixes = map (fn s => (s, Variable.default_type ctxt s)) ss
   447                 in (For fixes, ts) end
   448               | reinsert_pat _ At ts = (At, ts)
   449               | reinsert_pat _ In ts = (In, ts)
   450               | reinsert_pat _ Concl ts = (Concl, ts)
   451               | reinsert_pat _ Asm ts = (Asm, ts)
   452 
   453             fun free_constr (s,T) = Type.constraint T (Free (s, dummyT))
   454             fun mk_free_constrs (Term (t, cs)) = t :: map free_constr cs
   455               | mk_free_constrs _ = []
   456 
   457             val ts = maps mk_free_constrs ps @ map_filter I [to]
   458               |> Syntax.check_terms (hole_syntax ctxt)
   459             val ctxt' = fold Variable.declare_term ts ctxt
   460             val (ps', (to', ts')) = fold_map (reinsert_pat ctxt') ps ts
   461               ||> (fn xs => case to of NONE => (NONE, xs) | SOME _ => (SOME (hd xs), tl xs))
   462             val _ = case ts' of (_ :: _) => raise Match | [] => ()
   463           in ((ps', to'), ctxt') end
   464 
   465         val (pats, ctxt') = prep_pats ctxt raw_pats
   466 
   467         val ths = Attrib.eval_thms ctxt' raw_ths
   468         val to = Option.map (Syntax.parse_term ctxt') raw_to
   469 
   470         val ((pats', to'), ctxt'') = check_terms ctxt' pats to
   471         val pats'' = interpret_term_patterns ctxt'' pats'
   472 
   473       in ((pats'', ths, (to', ctxt)), ctxt'') end
   474 
   475     val to_parser = Scan.option ((Args.$$$ "to") |-- Parse.term)
   476 
   477     val subst_parser =
   478       let val scan = raw_pattern -- to_parser -- Parse.xthms1
   479       in context_lift scan prep_args end
   480   in
   481     Method.setup @{binding rewrite} (subst_parser >>
   482       (fn (pattern, inthms, inst) => fn ctxt =>
   483         SIMPLE_METHOD' (rewrite_tac ctxt (pattern, inst) inthms)))
   484       "single-step rewriting, allowing subterm selection via patterns."
   485   end
   486 end