src/HOL/Tools/reflection.ML
changeset 52275 9b4c04da53b1
parent 52274 35a2668ac3b0
child 52276 329c41438154
--- a/src/HOL/Tools/reflection.ML	Fri May 31 09:30:32 2013 +0200
+++ b/src/HOL/Tools/reflection.ML	Fri May 31 09:30:32 2013 +0200
@@ -6,12 +6,10 @@
 
 signature REFLECTION =
 sig
-  val reify: Proof.context -> thm list -> term -> thm
+  val reify: Proof.context -> thm list -> conv
   val reify_tac: Proof.context -> thm list -> term option -> int -> tactic
-  val reflect: Proof.context -> (cterm -> thm)
-    -> thm list -> thm list -> term -> thm
-  val reflection_tac: Proof.context -> (cterm -> thm)
-    -> thm list -> thm list -> term option -> int -> tactic
+  val reflect: Proof.context -> thm list -> thm list -> conv -> conv
+  val reflection_tac: Proof.context -> thm list -> thm list -> conv -> term option -> int -> tactic
   val get_default: Proof.context -> { reification_eqs: thm list, correctness_thms: thm list }
   val add_reification_eq: attribute
   val del_reification_eq: attribute
@@ -24,10 +22,24 @@
 structure Reflection : REFLECTION =
 struct
 
+fun dest_listT (Type (@{type_name "list"}, [T])) = T;
+
 val FWD = curry (op OF);
 
-fun dest_listT (Type (@{type_name "list"}, [T])) = T;
+fun rewrite_with ctxt eqs = Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps eqs);
+
+val pure_subst = @{lemma "x == y ==> PROP P y ==> PROP P x" by simp}
 
+fun lift_conv ctxt conv some_t = Subgoal.FOCUS (fn { context, concl, ... } =>
+  let
+    val ct = case some_t
+     of NONE => Thm.dest_arg concl
+      | SOME t => Thm.cterm_of (Proof_Context.theory_of ctxt) t
+    val thm = conv ct;
+  in
+    if Thm.is_reflexive thm then no_tac
+    else ALLGOALS (rtac (pure_subst OF [thm]))
+  end) ctxt;
 
 (* Make a congruence rule out of a defining equation for the interpretation
 
@@ -95,7 +107,10 @@
     val (yes, no) = List.partition P congs;
   in no @ yes end;
 
-fun reify ctxt eqs t =
+fun dereify ctxt eqs =
+  rewrite_with ctxt (eqs @ @{thms nth_Cons_0 nth_Cons_Suc});
+
+fun reify ctxt eqs ct =
   let
     fun index_of t bds =
       let
@@ -125,40 +140,43 @@
     (* da is the decomposition for atoms, ie. it returns ([],g) where g
        returns the right instance f (AtC n) = t , where AtC is the Atoms
        constructor and n is the number of the atom corresponding to t *)
-    fun decomp_reify da cgns (t, ctxt) bds =
+    fun decomp_reify da cgns (ct, ctxt) bds =
       let
         val thy = Proof_Context.theory_of ctxt;
         val cert = cterm_of thy;
         val certT = ctyp_of thy;
-        fun tryabsdecomp (s, ctxt) bds =
-          (case s of
+        fun tryabsdecomp (ct, ctxt) bds =
+          (case Thm.term_of ct of
             Abs (_, xT, ta) =>
               let
                 val ([raw_xn], ctxt') = Variable.variant_fixes ["x"] ctxt;
                 val (xn, ta) = Syntax_Trans.variant_abs (raw_xn, xT, ta);  (* FIXME !? *)
                 val x = Free (xn, xT);
+                val cx = cert x;
+                val cta = cert ta;
                 val bds = (case AList.lookup Type.could_unify bds (HOLogic.listT xT) of
                     NONE => error "tryabsdecomp: Type not found in the Environement"
-                  | SOME (bsT, atsT) => AList.update Type.could_unify (HOLogic.listT xT, (x :: bsT, atsT)) bds);
-               in (([(ta, ctxt')],
+                  | SOME (bsT, atsT) => AList.update Type.could_unify (HOLogic.listT xT,
+                      (x :: bsT, atsT)) bds);
+               in (([(cta, ctxt')],
                     fn ([th], bds) =>
-                      (hd (Variable.export ctxt' ctxt [(Thm.forall_intr (cert x) th) COMP allI]),
+                      (hd (Variable.export ctxt' ctxt [(Thm.forall_intr cx th) COMP allI]),
                        let
                          val (bsT, asT) = the (AList.lookup Type.could_unify bds (HOLogic.listT xT));
                        in
-                         AList.update Type.could_unify (HOLogic.listT xT,(tl bsT, asT)) bds
+                         AList.update Type.could_unify (HOLogic.listT xT, (tl bsT, asT)) bds
                        end)),
                    bds)
                end
-           | _ => da (s, ctxt) bds)
+           | _ => da (ct, ctxt) bds)
       in
         (case cgns of
-          [] => tryabsdecomp (t, ctxt) bds
+          [] => tryabsdecomp (ct, ctxt) bds
         | ((vns, cong) :: congs) =>
             (let
               val (tyenv, tmenv) =
                 Pattern.match thy
-                  ((fst o HOLogic.dest_eq o HOLogic.dest_Trueprop) (concl_of cong), t)
+                  ((fst o HOLogic.dest_eq o HOLogic.dest_Trueprop) (concl_of cong), Thm.term_of ct)
                   (Vartab.empty, Vartab.empty);
               val (fnvs, invs) = List.partition (fn ((vn, _),_) => member (op =) vns vn) (Vartab.dest tmenv);
               val (fts, its) =
@@ -166,15 +184,15 @@
                  map (fn ((vn, vi), (tT, t)) => (cert (Var ((vn, vi), tT)), cert t)) invs);
               val ctyenv = map (fn ((vn, vi), (s, ty)) => (certT (TVar((vn, vi), s)), certT ty)) (Vartab.dest tyenv);
             in
-              ((fts ~~ replicate (length fts) ctxt,
+              ((map cert fts ~~ replicate (length fts) ctxt,
                  apfst (FWD (Drule.instantiate_normalize (ctyenv, its) cong))), bds)
-            end handle Pattern.MATCH => decomp_reify da congs (t,ctxt) bds))
+            end handle Pattern.MATCH => decomp_reify da congs (ct, ctxt) bds))
       end;
 
  (* looks for the atoms equation and instantiates it with the right number *)
-    fun mk_decompatom eqs (t, ctxt) bds = (([], fn (_, bds) =>
+    fun mk_decompatom eqs (ct, ctxt) bds = (([], fn (_, bds) =>
       let
-        val tT = fastype_of t;
+        val tT = fastype_of (Thm.term_of ct);
         fun isat eq =
           let
             val rhs = eq |> prop_of |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |> snd;
@@ -184,12 +202,11 @@
               andalso Type.could_unify (fastype_of rhs, tT)
           end;
 
-        fun get_nths t acc =
-          case t of
-            Const(@{const_name "List.nth"}, _) $ vs $ n => insert (fn ((a, _), (b, _)) => a aconv b) (t, (vs, n)) acc
-          | t1 $ t2 => get_nths t1 (get_nths t2 acc)
-          | Abs (_ ,_ ,t') => get_nths t' acc
-          | _ => acc;
+        fun get_nths (t as (Const (@{const_name "List.nth"}, _) $ vs $ n)) =
+              AList.update (op aconv) (t, (vs, n))
+          | get_nths (t1 $ t2) = get_nths t1 #> get_nths t2
+          | get_nths (Abs (_, _, t')) = get_nths t'
+          | get_nths _ = I;
 
         fun tryeqs [] bds = error "Can not find the atoms equation"
           | tryeqs (eq :: eqs) bds = ((
@@ -207,7 +224,7 @@
                 val xns_map = fst (split_list nths) ~~ xns;
                 val subst = map (fn (nt, xn) => (nt, Var ((xn, 0), fastype_of nt))) xns_map;
                 val rhs_P = subst_free subst rhs;
-                val (tyenv, tmenv) = Pattern.match thy (rhs_P, t) (Vartab.empty, Vartab.empty);
+                val (tyenv, tmenv) = Pattern.match thy (rhs_P, Thm.term_of ct) (Vartab.empty, Vartab.empty);
                 val sbst = Envir.subst_term (tyenv, tmenv);
                 val sbsT = Envir.subst_type tyenv;
                 val subst_ty = map (fn (n, (s, t)) =>
@@ -236,7 +253,7 @@
                 val substt =
                   let
                     val ih = Drule.cterm_rule (Thm.instantiate (subst_ty, []));
-                  in map (fn (v, t) => (ih v, ih t)) (subst_ns @ subst_vs @ cts) end;
+                  in map (pairself ih) (subst_ns @ subst_vs @ cts) end;
                 val th = (Drule.instantiate_normalize (subst_ty, substt) eq) RS sym;
               in (hd (Variable.export ctxt'' ctxt [th]), bds) end)
               handle Pattern.MATCH => tryeqs eqs bds)
@@ -268,49 +285,38 @@
 
     val (congs, bds) = mk_congs ctxt eqs;
     val congs = rearrange congs;
-    val (th, bds) = divide_and_conquer' (decomp_reify (mk_decompatom eqs) congs) (t,ctxt) bds;
+    val (th, bds') = apfst mk_eq (divide_and_conquer' (decomp_reify (mk_decompatom eqs) congs) (ct, ctxt) bds);
     fun is_list_var (Var (_, t)) = can dest_listT t
       | is_list_var _ = false;
-    val vars = th |> prop_of |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |> snd
+    val vars = th |> prop_of |> Logic.dest_equals |> snd
       |> strip_comb |> snd |> filter is_list_var;
     val cert = cterm_of (Proof_Context.theory_of ctxt);
-    val cvs = map (fn (v as Var(_, t)) => (cert v,
-      the (AList.lookup Type.could_unify bds t) |> snd |> HOLogic.mk_list (dest_listT t) |> cert)) vars;
-    val th' = Drule.instantiate_normalize ([], cvs) th;
-    val t' = (fst o HOLogic.dest_eq o HOLogic.dest_Trueprop o prop_of) th';
-    val th'' = Goal.prove ctxt [] [] (HOLogic.mk_Trueprop (HOLogic.mk_eq (t, t')))
-      (fn _ => simp_tac ctxt 1)
-  in FWD trans [th'', th'] end;
+    val vs = map (fn v as Var (_, T) =>
+      (v, the (AList.lookup Type.could_unify bds' T) |> snd |> HOLogic.mk_list (dest_listT T))) vars;
+    val th' = Drule.instantiate_normalize ([], (map o pairself) cert vs) th;
+    val th'' = Thm.symmetric (dereify ctxt [] (Thm.lhs_of th'));
+  in Thm.transitive th'' th' end;
 
-fun reflect ctxt conv corr_thms eqs t =
+fun subst_correctness corr_thms ct =
+  Conv.rewrs_conv (map (Thm.symmetric o mk_eq) corr_thms) ct
+    handle CTERM _ => error "No suitable correctness theorem found";
+
+fun first_arg_conv conv =
   let
-    val reify_thm = reify ctxt eqs t;
-    fun try_corr corr_thm =
-      SOME (FWD trans [reify_thm, corr_thm RS sym]) handle THM _ => NONE;
-    val refl_thm = case get_first try_corr corr_thms
-     of NONE => error "No suitable correctness theorem found"
-      | SOME thm => thm;
-    val rhs = (Thm.dest_arg o Thm.dest_arg o cprop_of) refl_thm;
-    val number_of_args = (length o snd o strip_comb o term_of) rhs;
-    val reified = Thm.dest_arg (fold_range (K Thm.dest_fun) (number_of_args - 1) rhs);
-    val evaluated = conv reified;
-  in
-    refl_thm
-    |> simplify (put_simpset HOL_basic_ss ctxt addsimps [evaluated])
-    |> simplify (put_simpset HOL_basic_ss ctxt addsimps eqs addsimps @{thms nth_Cons_0 nth_Cons_Suc})
-  end;
+    fun conv' ct =
+      if can Thm.dest_comb (fst (Thm.dest_comb ct))
+      then Conv.combination_conv conv' Conv.all_conv ct
+      else Conv.combination_conv Conv.all_conv conv ct;
+  in conv' end;
 
-fun tac_of_thm mk_thm to = SUBGOAL (fn (goal, i) =>
-  let
-    val t = (case to of NONE => HOLogic.dest_Trueprop goal | SOME t => t)
-    val thm = mk_thm t RS ssubst;
-  in rtac thm i end);
- 
-fun reify_tac ctxt eqs = tac_of_thm (reify ctxt eqs);
+fun reflect ctxt corr_thms eqs conv =
+  (reify ctxt eqs) then_conv (subst_correctness corr_thms)
+  then_conv (first_arg_conv conv) then_conv (dereify ctxt eqs);
 
-(*Reflection calls reification and uses the correctness theorem assumed to be the head of the list*)
-fun reflection_tac ctxt conv corr_thms eqs =
-  tac_of_thm (reflect ctxt conv corr_thms eqs);
+fun reify_tac ctxt eqs = lift_conv ctxt (reify ctxt eqs);
+
+fun reflection_tac ctxt corr_thms eqs conv =
+  lift_conv ctxt (reflect ctxt corr_thms eqs conv);
 
 structure Data = Generic_Data
 (
@@ -352,7 +358,7 @@
     val eqs = fold Thm.add_thm user_eqs default_eqs; 
     val conv = Code_Evaluation.dynamic_conv (Proof_Context.theory_of ctxt);
       (*FIXME why Code_Evaluation.dynamic_conv? very specific*)
-  in reflection_tac ctxt conv corr_thms eqs end;
+  in reflection_tac ctxt corr_thms eqs conv end;
 
 
 end