src/HOL/Library/reflection.ML
changeset 51723 da12e44b2d65
parent 51717 9e7d1c139569
child 51724 80f9906ede19
--- a/src/HOL/Library/reflection.ML	Sat Apr 20 20:57:49 2013 +0200
+++ b/src/HOL/Library/reflection.ML	Sun Apr 21 10:41:18 2013 +0200
@@ -6,11 +6,17 @@
 
 signature REFLECTION =
 sig
-  val genreify_tac: Proof.context -> thm list -> term option -> int -> tactic
-  val reflection_tac: Proof.context -> thm list -> thm list -> term option -> int -> tactic
+  val gen_reify: Proof.context -> thm list -> term -> thm
+  val gen_reify_tac: Proof.context -> thm list -> term option -> int -> tactic
   val gen_reflection_tac: Proof.context -> (cterm -> thm)
     -> thm list -> thm list -> term option -> int -> tactic
-  val genreif : Proof.context -> thm list -> term -> thm
+  val get_default: Proof.context -> { reification_eqs: thm list, correctness_thms: thm list }
+  val add_reification_eq: attribute
+  val del_reification_eq: attribute
+  val add_correctness_thm: attribute
+  val del_correctness_thm: attribute
+  val default_reify_tac: Proof.context -> thm list -> term option -> int -> tactic
+  val default_reflection_tac: Proof.context -> thm list -> thm list -> term option -> int -> tactic
 end;
 
 structure Reflection : REFLECTION =
@@ -69,11 +75,9 @@
   in  (vs', cong') end;
  (* congs is a list of pairs (P,th) where th is a theorem for *)
         (* [| f p1 = A1; ...; f pn = An|] ==> f (C p1 .. pn) = P *)
+
 val FWD = curry (op OF);
 
-
-exception REIF of string;
-
 fun dest_listT (Type (@{type_name "list"}, [T])) = T;
 
 fun rearrange congs =
@@ -84,7 +88,7 @@
     val (yes,no) = List.partition P congs
   in no @ yes end
 
-fun genreif ctxt raw_eqs t =
+fun gen_reify ctxt eqs t =
   let
     fun index_of t bds =
       let
@@ -154,7 +158,7 @@
                  map (fn ((vn,vi),(tT,t)) => (cert(Var ((vn,vi),tT)), cert t)) invs)
               val ctyenv = map (fn ((vn,vi),(s,ty)) => (certy (TVar((vn,vi),s)), certy ty)) (Vartab.dest tyenv)
             in ((fts ~~ (replicate (length fts) ctxt),
-                 Library.apfst (FWD (Drule.instantiate_normalize (ctyenv, its) cong))), bds)
+                 apfst (FWD (Drule.instantiate_normalize (ctyenv, its) cong))), bds)
             end handle Pattern.MATCH => decomp_genreif da congs (t,ctxt) bds))
       end;
 
@@ -233,40 +237,32 @@
   (* Generic reification procedure: *)
   (* creates all needed cong rules and then just uses the theorem synthesis *)
 
-    fun mk_congs ctxt raw_eqs =
+    fun mk_congs ctxt eqs =
       let
-        val fs = fold_rev (fn eq =>
-                           insert (op =) (eq |> prop_of |> HOLogic.dest_Trueprop
-                           |> HOLogic.dest_eq |> fst |> strip_comb
-                           |> fst)) raw_eqs []
-        val tys = fold_rev (fn f => fold (insert (op =)) (f |> fastype_of |> binder_types |> tl)
-                            ) fs []
-        val (vs, ctxt') = Variable.variant_fixes (replicate (length tys) "vs") ctxt
-        val thy = Proof_Context.theory_of ctxt'
-        val cert = cterm_of thy
-        val vstys = map (fn (t,v) => (t,SOME (cert (Free(v,t)))))
-                    (tys ~~ vs)
-        val is_Var = can dest_Var
-        fun insteq eq vs =
+        val fs = fold_rev (fn eq => insert (op =) (eq |> prop_of |> HOLogic.dest_Trueprop
+          |> HOLogic.dest_eq |> fst |> strip_comb
+          |> fst)) eqs [];
+        val tys = fold_rev (fn f => fold (insert (op =)) (f |> fastype_of |> binder_types |> tl)) fs [];
+        val (vs, ctxt') = Variable.variant_fixes (replicate (length tys) "vs") ctxt;
+        val thy = Proof_Context.theory_of ctxt';
+        val cert = cterm_of thy;
+        val vstys = map (fn (t, v) => (t, SOME (cert (Free (v, t))))) (tys ~~ vs);
+        fun prep_eq eq =
           let
-            val subst = map (fn (v as Var(_, t)) => (cert v, (the o the) (AList.lookup (op =) vstys t)))
-                        (filter is_Var vs)
-          in Thm.instantiate ([],subst) eq
-          end
+            val (_, _ :: vs) = eq |> prop_of |> HOLogic.dest_Trueprop
+              |> HOLogic.dest_eq |> fst |> strip_comb;
+            val subst = map (fn (v as Var (_, t)) =>
+              (cert v, (the o the) (AList.lookup (op =) vstys t))) (filter is_Var vs);
+          in Thm.instantiate ([], subst) eq end;
+        val (ps, congs) = map_split (mk_congeq ctxt' fs o prep_eq) eqs;
+        val bds = AList.make (K ([], [])) tys;
+      in (ps ~~ Variable.export ctxt' ctxt congs, bds) end
 
-        val bds = AList.make (fn _ => ([],[])) tys
-        val eqs = map (fn eq => eq |> prop_of |> HOLogic.dest_Trueprop
-                                   |> HOLogic.dest_eq |> fst |> strip_comb |> snd |> tl
-                                   |> (insteq eq)) raw_eqs
-        val (ps,congs) = split_list (map (mk_congeq ctxt' fs) eqs)
-      in (ps ~~ (Variable.export ctxt' ctxt congs), bds)
-      end
-
-    val (congs, bds) = mk_congs ctxt raw_eqs
+    val (congs, bds) = mk_congs ctxt eqs
     val congs = rearrange congs
-    val (th, bds) = divide_and_conquer' (decomp_genreif (mk_decompatom raw_eqs) congs) (t,ctxt) bds
-    fun is_listVar (Var (_,t)) = can dest_listT t
-         | is_listVar _ = false
+    val (th, bds) = divide_and_conquer' (decomp_genreif (mk_decompatom eqs) congs) (t,ctxt) bds
+    fun is_listVar (Var (_, t)) = can dest_listT t
+      | is_listVar _ = false
     val vars = th |> prop_of |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |> snd
                   |> strip_comb |> snd |> filter is_listVar
     val cert = cterm_of (Proof_Context.theory_of ctxt)
@@ -276,29 +272,28 @@
     val t' = (fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of) th'
     val th'' = Goal.prove ctxt [] [] (HOLogic.mk_Trueprop (HOLogic.mk_eq (t, t')))
                (fn _ => simp_tac ctxt 1)
-  in FWD trans [th'',th']
+  in FWD trans [th'',th'] end
+
+fun gen_reflect ctxt conv corr_thms eqs t =
+  let
+    val reify_thm = gen_reify ctxt eqs t;
+    fun try_corr thm =
+      SOME (FWD trans [reify_thm, thm RS sym]) handle THM _ => NONE;
+    val thm = case get_first try_corr corr_thms
+     of NONE => error "No suitable correctness theorem found"
+      | SOME thm => thm;
+    val ft = (Thm.dest_arg1 o Thm.dest_arg o Thm.dest_arg o cprop_of) thm;
+    val rth = conv ft;
+  in
+    thm
+    |> simplify (put_simpset HOL_basic_ss ctxt addsimps [rth])
+    |> simplify (put_simpset HOL_basic_ss ctxt addsimps eqs addsimps @{thms nth_Cons_0 nth_Cons_Suc})
   end
 
-
-fun genreflect ctxt conv corr_thms raw_eqs t =
-  let
-    val reifth = genreif ctxt raw_eqs t
-    fun trytrans [] = error "No suitable correctness theorem found"
-      | trytrans (th::ths) =
-           (FWD trans [reifth, th RS sym] handle THM _ => trytrans ths)
-    val th = trytrans corr_thms
-    val ft = (Thm.dest_arg1 o Thm.dest_arg o Thm.dest_arg o cprop_of) th
-    val rth = conv ft
-  in
-    simplify
-      (put_simpset HOL_basic_ss ctxt addsimps raw_eqs addsimps @{thms nth_Cons_0 nth_Cons_Suc})
-      (simplify (put_simpset HOL_basic_ss ctxt addsimps [rth]) th)
-  end
-
-fun genreify_tac ctxt eqs to = SUBGOAL (fn (goal, i) =>
+fun gen_reify_tac ctxt eqs to = SUBGOAL (fn (goal, i) =>
   let
     val t = (case to of NONE => HOLogic.dest_Trueprop goal | SOME x => x)
-    val th = genreif ctxt eqs t RS ssubst
+    val th = gen_reify ctxt eqs t RS ssubst
   in rtac th i end);
 
     (* Reflection calls reification and uses the correctness *)
@@ -306,11 +301,50 @@
 fun gen_reflection_tac ctxt conv corr_thms raw_eqs to = SUBGOAL (fn (goal, i) =>
   let
     val t = (case to of NONE => HOLogic.dest_Trueprop goal | SOME x => x)
-    val th = genreflect ctxt conv corr_thms raw_eqs t RS ssubst
+    val th = gen_reflect ctxt conv corr_thms raw_eqs t RS ssubst
   in rtac th i THEN TRY (rtac TrueI i) end);  (* FIXME THEN_ALL_NEW !? *)
 
-fun reflection_tac ctxt = gen_reflection_tac ctxt
-  (Code_Evaluation.dynamic_conv (Proof_Context.theory_of ctxt));
-  (*FIXME why Code_Evaluation.dynamic_conv?  very specific...*)
+structure Data = Generic_Data
+(
+  type T = thm list * thm list;
+  val empty = ([], []);
+  val extend = I;
+  fun merge ((ths1, rths1), (ths2, rths2)) =
+    (Thm.merge_thms (ths1, ths2), Thm.merge_thms (rths1, rths2));
+);
+
+fun get_default ctxt =
+  let
+    val (reification_eqs, correctness_thms) = Data.get (Context.Proof ctxt);
+  in { reification_eqs = reification_eqs, correctness_thms = correctness_thms } end;
+
+val add_reification_eq = Thm.declaration_attribute (Data.map o apfst o Thm.add_thm);
+val del_reification_eq = Thm.declaration_attribute (Data.map o apfst o Thm.del_thm);
+val add_correctness_thm = Thm.declaration_attribute (Data.map o apsnd o Thm.add_thm);
+val del_correctness_thm = Thm.declaration_attribute (Data.map o apsnd o Thm.del_thm);
+
+val _ = Context.>> (Context.map_theory
+  (Attrib.setup @{binding reify}
+    (Attrib.add_del add_reification_eq del_reification_eq) "declare reification equations" #>
+  Attrib.setup @{binding reflection}
+    (Attrib.add_del add_correctness_thm del_correctness_thm) "declare reflection correctness theorems"));
+
+fun default_reify_tac ctxt user_eqs =
+  let
+    val { reification_eqs = default_eqs, correctness_thms = _ } =
+      get_default ctxt;
+    val eqs = user_eqs @ default_eqs; (*FIXME fold update?*)
+  in gen_reify_tac ctxt eqs end;
+
+fun default_reflection_tac ctxt user_thms user_eqs =
+  let
+    val { reification_eqs = default_eqs, correctness_thms = default_thms } =
+      get_default ctxt;
+    val corr_thms = user_thms @ default_thms; (*FIXME fold update?*)
+    val eqs = user_eqs @ default_eqs; (*FIXME fold update?*)
+    val conv = Code_Evaluation.dynamic_conv (Proof_Context.theory_of ctxt);
+      (*FIXME why Code_Evaluation.dynamic_conv? very specific*)
+  in gen_reflection_tac ctxt conv corr_thms eqs end;
+
 
 end