src/HOL/Tools/Lifting/lifting_term.ML
changeset 47504 aa1b8a59017f
parent 47386 09c5160ba550
child 47698 18202d3d5832
--- a/src/HOL/Tools/Lifting/lifting_term.ML	Tue Apr 17 11:03:08 2012 +0200
+++ b/src/HOL/Tools/Lifting/lifting_term.ML	Tue Apr 17 14:56:38 2012 +0200
@@ -7,8 +7,9 @@
 signature LIFTING_TERM =
 sig
   exception QUOT_THM of typ * typ * Pretty.T
+  exception CHECK_RTY of typ * typ
 
-  val prove_quot_theorem: Proof.context -> typ * typ -> thm
+  val prove_quot_thm: Proof.context -> typ * typ -> thm
 
   val absrep_fun: Proof.context -> typ * typ -> term
 
@@ -30,6 +31,29 @@
 
 exception QUOT_THM_INTERNAL of Pretty.T
 exception QUOT_THM of typ * typ * Pretty.T
+exception CHECK_RTY of typ * typ
+
+(* matches a type pattern with a type *)
+fun match ctxt err ty_pat ty =
+  let
+    val thy = Proof_Context.theory_of ctxt
+  in
+    Sign.typ_match thy (ty_pat, ty) Vartab.empty
+      handle Type.TYPE_MATCH => err ctxt ty_pat ty
+  end
+
+fun equiv_match_err ctxt ty_pat ty =
+  let
+    val ty_pat_str = Syntax.string_of_typ ctxt ty_pat
+    val ty_str = Syntax.string_of_typ ctxt ty
+  in
+    raise QUOT_THM_INTERNAL (Pretty.block
+      [Pretty.str ("The quotient type " ^ quote ty_str),
+       Pretty.brk 1,
+       Pretty.str ("and the quotient type pattern " ^ quote ty_pat_str),
+       Pretty.brk 1,
+       Pretty.str "don't match."])
+  end
 
 fun get_quot_thm ctxt s =
   let
@@ -55,7 +79,11 @@
        Pretty.str "found."]))
   end
 
-exception NOT_IMPL of string
+fun is_id_quot thm = (prop_of thm = prop_of @{thm identity_quotient})
+
+infix 0 MRSL
+
+fun ants MRSL thm = fold (fn rl => fn thm => rl RS thm) ants thm
 
 fun dest_Quotient (Const (@{const_name Quotient}, _) $ rel $ abs $ rep $ cr)
       = (rel, abs, rep, cr)
@@ -92,64 +120,92 @@
   else
     ()
 
-fun quotient_tac ctxt = SUBGOAL (fn (t, i) =>
-  let
-    val (_, abs, _, _) = dest_Quotient (HOLogic.dest_Trueprop t)
-    val (rty, qty) = Term.dest_funT (fastype_of abs)
-  in
-    case (rty, qty) of
-      (Type (s, _), Type (s', _)) =>
+fun prove_schematic_quot_thm ctxt (rty, qty) =
+  (case (rty, qty) of
+    (Type (s, tys), Type (s', tys')) =>
       if s = s'
       then
         let
-          val thm1 = SOME @{thm identity_quotient}
-          val thm2 = try (get_rel_quot_thm ctxt) s
+          val args = map (prove_schematic_quot_thm ctxt) (tys ~~ tys')
         in
-          resolve_tac (map_filter I [thm1, thm2]) i
+          if forall is_id_quot args
+          then
+            @{thm identity_quotient}
+          else
+            args MRSL (get_rel_quot_thm ctxt s)
         end
       else
         let
           val quot_thm = get_quot_thm ctxt s'
-          val (Type (rs, _), _) = quot_thm_rty_qty quot_thm
+          val (Type (rs, rtys), qty_pat) = quot_thm_rty_qty quot_thm
           val _ = check_raw_types (s, rs) s'
+          val qtyenv = match ctxt equiv_match_err qty_pat qty
+          val rtys' = map (Envir.subst_type qtyenv) rtys
+          val args = map (prove_schematic_quot_thm ctxt) (tys ~~ rtys')
         in
-          resolve_tac [quot_thm, quot_thm RSN (2, @{thm Quotient_compose})] i
+          if forall is_id_quot args
+          then
+            quot_thm
+          else
+            let
+              val rel_quot_thm = args MRSL (get_rel_quot_thm ctxt s)
+            in
+              [rel_quot_thm, quot_thm] MRSL @{thm Quotient_compose}
+           end
         end
-    | (_, Type (s, _)) =>
-      let
-        val thm1 = try (get_quot_thm ctxt) s
-        val thm2 = SOME @{thm identity_quotient}
-        val thm3 = try (get_rel_quot_thm ctxt) s
-      in
-        resolve_tac (map_filter I [thm1, thm2, thm3]) i
-      end
-  | _ => rtac @{thm identity_quotient} i
+    | (_, Type (s', tys')) => 
+      (case try (get_quot_thm ctxt) s' of
+        SOME quot_thm => 
+          let
+            val rty_pat = (fst o quot_thm_rty_qty) quot_thm
+          in
+            prove_schematic_quot_thm ctxt (rty_pat, qty)
+          end
+        | NONE =>
+          let
+            val rty_pat = Type (s', map (fn _ => TFree ("a",[])) tys')
+          in
+            prove_schematic_quot_thm ctxt (rty_pat, qty)
+          end)
+    | _ => @{thm identity_quotient})
     handle QUOT_THM_INTERNAL pretty_msg => raise QUOT_THM (rty, qty, pretty_msg)
-  end)
 
-fun prove_quot_theorem ctxt (rty, qty) =
+fun force_qty_type thy qty quot_thm =
   let
-    val relT = [rty, rty] ---> HOLogic.boolT
-    val absT = rty --> qty
-    val repT = qty --> rty
-    val crT = [rty, qty] ---> HOLogic.boolT
-    val quotT = [relT, absT, repT, crT] ---> HOLogic.boolT
-    val rel = Var (("R", 0), relT)
-    val abs = Var (("Abs", 0), absT)
-    val rep = Var (("Rep", 0), repT)
-    val cr = Var (("T", 0), crT)
-    val quot = Const (@{const_name Quotient}, quotT)
-    val goal = HOLogic.Trueprop $ (quot $ rel $ abs $ rep $ cr)
-    val cgoal = Thm.cterm_of (Proof_Context.theory_of ctxt) goal
-    val tac = REPEAT (quotient_tac ctxt 1)
+    val (_, qty_schematic) = quot_thm_rty_qty quot_thm
+    val match_env = Sign.typ_match thy (qty_schematic, qty) Vartab.empty
+    fun prep_ty thy (x, (S, ty)) =
+      (ctyp_of thy (TVar (x, S)), ctyp_of thy ty)
+    val ty_inst = Vartab.fold (cons o (prep_ty thy)) match_env []
   in
-    Goal.prove_internal [] cgoal (K tac)
+    Thm.instantiate (ty_inst, []) quot_thm
+  end
+
+fun check_rty_type ctxt rty quot_thm =
+  let  
+    val thy = Proof_Context.theory_of ctxt
+    val (rty_forced, _) = quot_thm_rty_qty quot_thm
+    val rty_schematic = Logic.type_map (singleton (Variable.polymorphic ctxt)) rty
+    val _ = Sign.typ_match thy (rty_schematic, rty_forced) Vartab.empty
+      handle Type.TYPE_MATCH => raise CHECK_RTY (rty_schematic, rty_forced)
+  in
+    ()
+  end
+
+fun prove_quot_thm ctxt (rty, qty) =
+  let
+    val thy = Proof_Context.theory_of ctxt
+    val schematic_quot_thm = prove_schematic_quot_thm ctxt (rty, qty)
+    val quot_thm = force_qty_type thy qty schematic_quot_thm
+    val _ = check_rty_type ctxt rty quot_thm
+  in
+    quot_thm
   end
 
 fun absrep_fun ctxt (rty, qty) =
-  quot_thm_abs (prove_quot_theorem ctxt (rty, qty))
+  quot_thm_abs (prove_quot_thm ctxt (rty, qty))
 
 fun equiv_relation ctxt (rty, qty) =
-  quot_thm_rel (prove_quot_theorem ctxt (rty, qty))
+  quot_thm_rel (prove_quot_thm ctxt (rty, qty))
 
 end;