rewrite: add ML interface
authornoschinl
Wed, 15 Apr 2015 15:10:01 +0200
changeset 60079 ef4fe30e9ef1
parent 60056 71c1b9b9e937
child 60080 2cd500d08c30
rewrite: add ML interface
src/HOL/Library/rewrite.ML
src/HOL/ex/Rewrite_Examples.thy
--- a/src/HOL/Library/rewrite.ML	Tue Apr 14 15:54:17 2015 +0200
+++ b/src/HOL/Library/rewrite.ML	Wed Apr 15 15:10:01 2015 +0200
@@ -17,7 +17,21 @@
 
 signature REWRITE =
 sig
-  (* FIXME proper ML interface!? *)
+  datatype ('a, 'b) pattern = At | In | Term of 'a | Concl | Asm | For of 'b list
+
+  val mk_hole: int -> typ -> term
+
+  val rewrite: Proof.context
+    -> (term * (string * typ) list, string * typ option) pattern list * term option
+    -> thm list
+    -> cterm
+    -> thm Seq.seq
+
+  val rewrite_tac: Proof.context
+    -> (term * (string * typ) list, string * typ option) pattern list * term option
+    -> thm list
+    -> int
+    -> tactic
 end
 
 structure Rewrite : REWRITE =
@@ -182,6 +196,8 @@
     |> Seq.filter (#2 #> is_valid)
   end
 
+fun mk_hole i T = Var ((holeN, i), T)
+
 fun is_hole (Var ((name, _), _)) = (name = holeN)
   | is_hole _ = false
 
@@ -192,7 +208,7 @@
   let
     (* Modified variant of Term.replace_hole *)
     fun replace_hole Ts (Const (@{const_name rewrite_HOLE}, T)) i =
-          (list_comb (Var ((holeN, i), Ts ---> T), map_range Bound (length Ts)), i + 1)
+          (list_comb (mk_hole i (Ts ---> T), map_range Bound (length Ts)), i + 1)
       | replace_hole Ts (Abs (x, T, t)) i =
           let val (t', i') = replace_hole (T :: Ts) t i
           in (Abs (x, T, t'), i') end
@@ -300,33 +316,76 @@
   in SOME (inst_thm_to ctxt (Option.map replace_idents to, env) thm') end
   handle NO_TO_MATCH => NONE
 
-(* Rewrite in subgoal i. *)
-fun rewrite_goal_with_thm ctxt (pattern, (to, orig_ctxt)) rules = SUBGOAL (fn (t,i) =>
+local
+
+fun rewrite_raw ctxt (pattern, to) thms ct =
   let
-    val matches = find_matches ctxt pattern (Vartab.empty, t, I)
+    fun interpret_term_patterns ctxt =
+      let
+    
+        fun descend_hole fixes (Abs (_, _, t)) =
+            (case descend_hole fixes t of
+              NONE => NONE
+            | SOME (fix :: fixes', pos) => SOME (fixes', pos o ft_abs ctxt (apfst SOME fix))
+            | SOME ([], _) => raise Match (* XXX -- check phases modified binding *))
+          | descend_hole fixes (t as l $ r) =
+            let val (f, _) = strip_comb t
+            in
+              if is_hole f
+              then SOME (fixes, I)
+              else
+                (case descend_hole fixes l of
+                  SOME (fixes', pos) => SOME (fixes', pos o ft_fun ctxt)
+                | NONE =>
+                  (case descend_hole fixes r of
+                    SOME (fixes', pos) => SOME (fixes', pos o ft_arg ctxt)
+                  | NONE => NONE))
+            end
+          | descend_hole fixes t =
+            if is_hole t then SOME (fixes, I) else NONE
+    
+        fun f (t, fixes) = Term (t, (descend_hole (rev fixes) #> the_default ([], I) #> snd) t)
+    
+      in map (map_term_pattern f) end
+
+    val pattern' = interpret_term_patterns ctxt pattern
+    val matches = find_matches ctxt pattern' (Vartab.empty, Thm.term_of ct, I)
+
+    val thms' = maps (prep_meta_eq ctxt) thms
 
     fun rewrite_conv insty ctxt bounds =
-      CConv.rewrs_cconv (map_filter (inst_thm ctxt bounds insty) rules)
-
-    val export = singleton (Proof_Context.export ctxt orig_ctxt)
+      CConv.rewrs_cconv (map_filter (inst_thm ctxt bounds insty) thms')
 
     fun distinct_prems th =
       case Seq.pull (distinct_subgoals_tac th) of
         NONE => th
       | SOME (th', _) => th'
 
-    fun tac (tyenv, _, position) = CCONVERSION
-      (distinct_prems o export o position (rewrite_conv (to, tyenv)) ctxt []) i
-  in
-    SEQ_CONCAT (Seq.map tac matches)
-  end)
+    fun conv ((tyenv, _, position) : focusterm) =
+      distinct_prems o position (rewrite_conv (to, tyenv)) ctxt []
+
+  in Seq.map (fn ft => conv ft) matches end
+
+in
+
+fun rewrite ctxt pat thms ct =
+  rewrite_raw ctxt pat thms ct |> Seq.map_filter (fn cv => try cv ct)
 
-fun rewrite_tac ctxt pattern thms =
+fun rewrite_export_tac ctxt (pat, pat_ctxt) thms =
   let
-    val thms' = maps (prep_meta_eq ctxt) thms
-    val tac = rewrite_goal_with_thm ctxt pattern thms'
+    val export = case pat_ctxt of
+        NONE => I
+      | SOME inner => singleton (Proof_Context.export inner ctxt)
+    val tac = CSUBGOAL (fn (ct, i) =>
+      rewrite_raw ctxt pat thms ct
+      |> Seq.map (fn cv => CCONVERSION (export o cv) i)
+      |> SEQ_CONCAT)
   in tac end
 
+fun rewrite_tac ctxt pat = rewrite_export_tac ctxt (pat, NONE)
+
+end
+
 val _ =
   Theory.setup
   let
@@ -402,34 +461,6 @@
     fun prep_args ctxt (((raw_pats, raw_to), raw_ths)) =
       let
 
-        fun interpret_term_patterns ctxt =
-          let
-
-            fun descend_hole fixes (Abs (_, _, t)) =
-                (case descend_hole fixes t of
-                  NONE => NONE
-                | SOME (fix :: fixes', pos) => SOME (fixes', pos o ft_abs ctxt (apfst SOME fix))
-                | SOME ([], _) => raise Match (* XXX -- check phases modified binding *))
-              | descend_hole fixes (t as l $ r) =
-                let val (f, _) = strip_comb t
-                in
-                  if is_hole f
-                  then SOME (fixes, I)
-                  else
-                    (case descend_hole fixes l of
-                      SOME (fixes', pos) => SOME (fixes', pos o ft_fun ctxt)
-                    | NONE =>
-                      (case descend_hole fixes r of
-                        SOME (fixes', pos) => SOME (fixes', pos o ft_arg ctxt)
-                      | NONE => NONE))
-                end
-              | descend_hole fixes t =
-                if is_hole t then SOME (fixes, I) else NONE
-
-            fun f (t, fixes) = Term (t, (descend_hole (rev fixes) #> the_default ([], I) #> snd) t)
-
-          in map (map_term_pattern f) end
-
         fun check_terms ctxt ps to =
           let
             fun safe_chop (0: int) xs = ([], xs)
@@ -466,9 +497,8 @@
         val to = Option.map (Syntax.parse_term ctxt') raw_to
 
         val ((pats', to'), ctxt'') = check_terms ctxt' pats to
-        val pats'' = interpret_term_patterns ctxt'' pats'
 
-      in ((pats'', ths, (to', ctxt)), ctxt'') end
+      in ((pats', ths, (to', ctxt)), ctxt'') end
 
     val to_parser = Scan.option ((Args.$$$ "to") |-- Parse.term)
 
@@ -477,8 +507,8 @@
       in context_lift scan prep_args end
   in
     Method.setup @{binding rewrite} (subst_parser >>
-      (fn (pattern, inthms, inst) => fn ctxt =>
-        SIMPLE_METHOD' (rewrite_tac ctxt (pattern, inst) inthms)))
+      (fn (pattern, inthms, (to, pat_ctxt)) => fn orig_ctxt =>
+        SIMPLE_METHOD' (rewrite_export_tac orig_ctxt ((pattern, to), SOME pat_ctxt) inthms)))
       "single-step rewriting, allowing subterm selection via patterns."
   end
 end
--- a/src/HOL/ex/Rewrite_Examples.thy	Tue Apr 14 15:54:17 2015 +0200
+++ b/src/HOL/ex/Rewrite_Examples.thy	Wed Apr 15 15:10:01 2015 +0200
@@ -199,5 +199,60 @@
 by (rewrite at "x + 1" in for (x) at asm add.commute)
    (rule assms)
 
+(* The rewrite method also has an ML interface *)
+lemma
+  assumes "\<And>a b. P ((a + 1) * (1 + b)) "
+  shows "\<And>a b :: nat. P ((a + 1) * (b + 1))"
+  apply (tactic \<open>
+    let
+      val (x, ctxt) = yield_singleton Variable.add_fixes "x" @{context}
+      (* Note that the pattern order is reversed *)
+      val pat = [
+        Rewrite.For [(x, SOME @{typ nat})],
+        Rewrite.In,
+        Rewrite.Term (@{const plus(nat)} $ Free (x, @{typ nat}) $ @{term "1 :: nat"}, [])]
+      val to = NONE
+    in Rewrite.rewrite_tac ctxt (pat, to) @{thms add.commute} 1 end
+  \<close>)
+  apply (fact assms)
+  done
+
+lemma
+  assumes "Q (\<lambda>b :: int. P (\<lambda>a. a + b) (\<lambda>a. a + b))"
+  shows "Q (\<lambda>b :: int. P (\<lambda>a. a + b) (\<lambda>a. b + a))"
+  apply (tactic \<open>
+    let
+      val (x, ctxt) = yield_singleton Variable.add_fixes "x" @{context}
+      val pat = [
+        Rewrite.Concl,
+        Rewrite.In,
+        Rewrite.Term (Free ("Q", (@{typ "int"} --> TVar (("'b",0), [])) --> @{typ bool})
+          $ Abs ("x", @{typ int}, Rewrite.mk_hole 1 (@{typ int} --> TVar (("'b",0), [])) $ Bound 0), [(x, @{typ int})]),
+        Rewrite.In,
+        Rewrite.Term (@{const plus(int)} $ Free (x, @{typ int}) $ Var (("c", 0), @{typ int}), [])
+        ]
+      val to = NONE
+    in Rewrite.rewrite_tac ctxt (pat, to) @{thms add.commute} 1 end
+  \<close>)
+  apply (fact assms)
+  done
+
+(* There is also conversion-like rewrite function: *)
+ML \<open>
+  val ct = @{cprop "Q (\<lambda>b :: int. P (\<lambda>a. a + b) (\<lambda>a. b + a))"}
+  val (x, ctxt) = yield_singleton Variable.add_fixes "x" @{context}
+  val pat = [
+    Rewrite.Concl,
+    Rewrite.In,
+    Rewrite.Term (Free ("Q", (@{typ "int"} --> TVar (("'b",0), [])) --> @{typ bool})
+      $ Abs ("x", @{typ int}, Rewrite.mk_hole 1 (@{typ int} --> TVar (("'b",0), [])) $ Bound 0), [(x, @{typ int})]),
+    Rewrite.In,
+    Rewrite.Term (@{const plus(int)} $ Free (x, @{typ int}) $ Var (("c", 0), @{typ int}), [])
+    ]
+  val to = NONE
+  val ct_ths = Rewrite.rewrite ctxt (pat, to) @{thms add.commute} ct
+    |> Seq.list_of
+\<close>
+
 end