merged
authorblanchet
Tue, 17 Apr 2012 15:25:43 +0200
changeset 47507 d52da3e7aa4c
parent 47506 da72e05849ef (current diff)
parent 47504 aa1b8a59017f (diff)
child 47508 85c6268b4071
merged
--- a/src/HOL/Tools/Lifting/lifting_def.ML	Tue Apr 17 13:54:31 2012 +0200
+++ b/src/HOL/Tools/Lifting/lifting_def.ML	Tue Apr 17 15:25:43 2012 +0200
@@ -6,8 +6,6 @@
 
 signature LIFTING_DEF =
 sig
-  exception FORCE_RTY of typ * term
-
   val add_lift_def:
     (binding * mixfix) -> typ -> term -> thm -> local_theory -> local_theory
 
@@ -28,8 +26,6 @@
 
 fun ants MRSL thm = fold (fn rl => fn thm => rl RS thm) ants thm
 
-exception FORCE_RTY of typ * term
-
 fun get_body_types (Type ("fun", [_, U]), Type ("fun", [_, V])) = get_body_types (U, V)
   | get_body_types (U, V)  = (U, V)
 
@@ -42,7 +38,6 @@
     val rhs_schematic = singleton (Variable.polymorphic ctxt) rhs
     val rty_schematic = fastype_of rhs_schematic
     val match = Sign.typ_match thy (rty_schematic, rty) Vartab.empty
-      handle Type.TYPE_MATCH => raise FORCE_RTY (rty, rhs)
   in
     Envir.subst_term_types match rhs_schematic
   end
@@ -92,7 +87,7 @@
     val ty_args = get_binder_types (rty, qty)
     fun disch_arg args_ty thm = 
       let
-        val quot_thm = Lifting_Term.prove_quot_theorem ctxt args_ty
+        val quot_thm = Lifting_Term.prove_quot_thm ctxt args_ty
       in
         [quot_thm, thm] MRSL @{thm apply_rsp''}
       end
@@ -114,7 +109,7 @@
 fun generate_code_cert ctxt def_thm rsp_thm (rty, qty) =
   let
     val thy = Proof_Context.theory_of ctxt
-    val quot_thm = Lifting_Term.prove_quot_theorem ctxt (get_body_types (rty, qty))
+    val quot_thm = Lifting_Term.prove_quot_thm ctxt (get_body_types (rty, qty))
     val fun_rel = prove_rel ctxt rsp_thm (rty, qty)
     val abs_rep_thm = [quot_thm, fun_rel] MRSL @{thm Quotient_rep_abs}
     val abs_rep_eq = 
@@ -134,7 +129,7 @@
 
 fun define_code_cert code_eqn_thm_name def_thm rsp_thm (rty, qty) lthy = 
   let
-    val quot_thm = Lifting_Term.prove_quot_theorem lthy (get_body_types (rty, qty))
+    val quot_thm = Lifting_Term.prove_quot_thm lthy (get_body_types (rty, qty))
   in
     if can_generate_code_cert quot_thm then
       let
@@ -170,7 +165,7 @@
 fun add_lift_def var qty rhs rsp_thm lthy =
   let
     val rty = fastype_of rhs
-    val quotient_thm = Lifting_Term.prove_quot_theorem lthy (rty, qty)
+    val quotient_thm = Lifting_Term.prove_quot_thm lthy (rty, qty)
     val absrep_trm =  Lifting_Term.quot_thm_abs quotient_thm
     val rty_forced = (domain_type o fastype_of) absrep_trm
     val forced_rhs = force_rty_type lthy rty_forced rhs
@@ -246,9 +241,8 @@
 
 fun lift_def_cmd (raw_var, rhs_raw) lthy =
   let
-    val ((binding, SOME qty, mx), ctxt) = yield_singleton Proof_Context.read_vars raw_var lthy 
-    val rhs' = (Syntax.check_term ctxt o Syntax.parse_term ctxt) rhs_raw
-    val rhs = singleton (Variable.polymorphic ctxt) rhs'
+    val ((binding, SOME qty, mx), lthy') = yield_singleton Proof_Context.read_vars raw_var lthy 
+    val rhs = (Syntax.check_term lthy' o Syntax.parse_term lthy') rhs_raw
  
     fun try_to_prove_refl thm = 
       let
@@ -265,12 +259,11 @@
           | _ => NONE
       end
 
-    val quot_thm = Lifting_Term.prove_quot_theorem lthy (fastype_of rhs, qty)
-    val rsp_rel = Lifting_Term.quot_thm_rel quot_thm
+    val rsp_rel = Lifting_Term.equiv_relation lthy' (fastype_of rhs, qty)
     val rty_forced = (domain_type o fastype_of) rsp_rel;
-    val forced_rhs = force_rty_type ctxt rty_forced rhs;
+    val forced_rhs = force_rty_type lthy' rty_forced rhs;
     val internal_rsp_tm = HOLogic.mk_Trueprop (rsp_rel $ forced_rhs $ forced_rhs)
-    val readable_rsp_thm_eq = mk_readable_rsp_thm_eq internal_rsp_tm lthy
+    val readable_rsp_thm_eq = mk_readable_rsp_thm_eq internal_rsp_tm lthy'
     val maybe_proven_rsp_thm = try_to_prove_refl readable_rsp_thm_eq
     val (readable_rsp_tm, _) = Logic.dest_implies (prop_of readable_rsp_thm_eq)
   
@@ -279,7 +272,7 @@
         val internal_rsp_thm =
           case thm_list of
             [] => the maybe_proven_rsp_thm
-          | [[thm]] => Goal.prove ctxt [] [] internal_rsp_tm 
+          | [[thm]] => Goal.prove lthy [] [] internal_rsp_tm 
             (fn _ => rtac readable_rsp_thm_eq 1 THEN Proof_Context.fact_tac [thm] 1)
       in
         add_lift_def (binding, mx) qty rhs internal_rsp_thm lthy
@@ -287,8 +280,8 @@
 
   in
     case maybe_proven_rsp_thm of
-      SOME _ => Proof.theorem NONE after_qed [] ctxt
-      | NONE =>  Proof.theorem NONE after_qed [[(readable_rsp_tm,[])]] ctxt
+      SOME _ => Proof.theorem NONE after_qed [] lthy'
+      | NONE =>  Proof.theorem NONE after_qed [[(readable_rsp_tm,[])]] lthy'
   end
 
 fun quot_thm_err ctxt (rty, qty) pretty_msg =
@@ -306,21 +299,23 @@
     error error_msg
   end
 
-fun force_rty_err ctxt rty rhs =
+fun check_rty_err ctxt (rty_schematic, rty_forced) (raw_var, rhs_raw) =
   let
+    val (_, ctxt') = yield_singleton Proof_Context.read_vars raw_var ctxt 
+    val rhs = (Syntax.check_term ctxt' o Syntax.parse_term ctxt') rhs_raw
     val error_msg = cat_lines
        ["Lifting failed for the following term:",
         Pretty.string_of (Pretty.block
          [Pretty.str "Term:", Pretty.brk 2, Syntax.pretty_term ctxt rhs]),
         Pretty.string_of (Pretty.block
-         [Pretty.str "Type:", Pretty.brk 2, Syntax.pretty_typ ctxt (fastype_of rhs)]),
+         [Pretty.str "Type:", Pretty.brk 2, Syntax.pretty_typ ctxt rty_schematic]),
         "",
         (Pretty.string_of (Pretty.block
          [Pretty.str "Reason:", 
           Pretty.brk 2, 
           Pretty.str "The type of the term cannot be instancied to",
           Pretty.brk 1,
-          Pretty.quote (Syntax.pretty_typ ctxt rty),
+          Pretty.quote (Syntax.pretty_typ ctxt rty_forced),
           Pretty.str "."]))]
     in
       error error_msg
@@ -329,7 +324,8 @@
 fun lift_def_cmd_with_err_handling (raw_var, rhs_raw) lthy =
   (lift_def_cmd (raw_var, rhs_raw) lthy
     handle Lifting_Term.QUOT_THM (rty, qty, msg) => quot_thm_err lthy (rty, qty) msg)
-    handle FORCE_RTY (rty, rhs) => force_rty_err lthy rty rhs
+    handle Lifting_Term.CHECK_RTY (rty_schematic, rty_forced) => 
+      check_rty_err lthy (rty_schematic, rty_forced) (raw_var, rhs_raw)
 
 (* parser and command *)
 val liftdef_parser =
--- a/src/HOL/Tools/Lifting/lifting_term.ML	Tue Apr 17 13:54:31 2012 +0200
+++ b/src/HOL/Tools/Lifting/lifting_term.ML	Tue Apr 17 15:25:43 2012 +0200
@@ -7,8 +7,9 @@
 signature LIFTING_TERM =
 sig
   exception QUOT_THM of typ * typ * Pretty.T
+  exception CHECK_RTY of typ * typ
 
-  val prove_quot_theorem: Proof.context -> typ * typ -> thm
+  val prove_quot_thm: Proof.context -> typ * typ -> thm
 
   val absrep_fun: Proof.context -> typ * typ -> term
 
@@ -30,6 +31,29 @@
 
 exception QUOT_THM_INTERNAL of Pretty.T
 exception QUOT_THM of typ * typ * Pretty.T
+exception CHECK_RTY of typ * typ
+
+(* matches a type pattern with a type *)
+fun match ctxt err ty_pat ty =
+  let
+    val thy = Proof_Context.theory_of ctxt
+  in
+    Sign.typ_match thy (ty_pat, ty) Vartab.empty
+      handle Type.TYPE_MATCH => err ctxt ty_pat ty
+  end
+
+fun equiv_match_err ctxt ty_pat ty =
+  let
+    val ty_pat_str = Syntax.string_of_typ ctxt ty_pat
+    val ty_str = Syntax.string_of_typ ctxt ty
+  in
+    raise QUOT_THM_INTERNAL (Pretty.block
+      [Pretty.str ("The quotient type " ^ quote ty_str),
+       Pretty.brk 1,
+       Pretty.str ("and the quotient type pattern " ^ quote ty_pat_str),
+       Pretty.brk 1,
+       Pretty.str "don't match."])
+  end
 
 fun get_quot_thm ctxt s =
   let
@@ -55,7 +79,11 @@
        Pretty.str "found."]))
   end
 
-exception NOT_IMPL of string
+fun is_id_quot thm = (prop_of thm = prop_of @{thm identity_quotient})
+
+infix 0 MRSL
+
+fun ants MRSL thm = fold (fn rl => fn thm => rl RS thm) ants thm
 
 fun dest_Quotient (Const (@{const_name Quotient}, _) $ rel $ abs $ rep $ cr)
       = (rel, abs, rep, cr)
@@ -92,64 +120,92 @@
   else
     ()
 
-fun quotient_tac ctxt = SUBGOAL (fn (t, i) =>
-  let
-    val (_, abs, _, _) = dest_Quotient (HOLogic.dest_Trueprop t)
-    val (rty, qty) = Term.dest_funT (fastype_of abs)
-  in
-    case (rty, qty) of
-      (Type (s, _), Type (s', _)) =>
+fun prove_schematic_quot_thm ctxt (rty, qty) =
+  (case (rty, qty) of
+    (Type (s, tys), Type (s', tys')) =>
       if s = s'
       then
         let
-          val thm1 = SOME @{thm identity_quotient}
-          val thm2 = try (get_rel_quot_thm ctxt) s
+          val args = map (prove_schematic_quot_thm ctxt) (tys ~~ tys')
         in
-          resolve_tac (map_filter I [thm1, thm2]) i
+          if forall is_id_quot args
+          then
+            @{thm identity_quotient}
+          else
+            args MRSL (get_rel_quot_thm ctxt s)
         end
       else
         let
           val quot_thm = get_quot_thm ctxt s'
-          val (Type (rs, _), _) = quot_thm_rty_qty quot_thm
+          val (Type (rs, rtys), qty_pat) = quot_thm_rty_qty quot_thm
           val _ = check_raw_types (s, rs) s'
+          val qtyenv = match ctxt equiv_match_err qty_pat qty
+          val rtys' = map (Envir.subst_type qtyenv) rtys
+          val args = map (prove_schematic_quot_thm ctxt) (tys ~~ rtys')
         in
-          resolve_tac [quot_thm, quot_thm RSN (2, @{thm Quotient_compose})] i
+          if forall is_id_quot args
+          then
+            quot_thm
+          else
+            let
+              val rel_quot_thm = args MRSL (get_rel_quot_thm ctxt s)
+            in
+              [rel_quot_thm, quot_thm] MRSL @{thm Quotient_compose}
+           end
         end
-    | (_, Type (s, _)) =>
-      let
-        val thm1 = try (get_quot_thm ctxt) s
-        val thm2 = SOME @{thm identity_quotient}
-        val thm3 = try (get_rel_quot_thm ctxt) s
-      in
-        resolve_tac (map_filter I [thm1, thm2, thm3]) i
-      end
-  | _ => rtac @{thm identity_quotient} i
+    | (_, Type (s', tys')) => 
+      (case try (get_quot_thm ctxt) s' of
+        SOME quot_thm => 
+          let
+            val rty_pat = (fst o quot_thm_rty_qty) quot_thm
+          in
+            prove_schematic_quot_thm ctxt (rty_pat, qty)
+          end
+        | NONE =>
+          let
+            val rty_pat = Type (s', map (fn _ => TFree ("a",[])) tys')
+          in
+            prove_schematic_quot_thm ctxt (rty_pat, qty)
+          end)
+    | _ => @{thm identity_quotient})
     handle QUOT_THM_INTERNAL pretty_msg => raise QUOT_THM (rty, qty, pretty_msg)
-  end)
 
-fun prove_quot_theorem ctxt (rty, qty) =
+fun force_qty_type thy qty quot_thm =
   let
-    val relT = [rty, rty] ---> HOLogic.boolT
-    val absT = rty --> qty
-    val repT = qty --> rty
-    val crT = [rty, qty] ---> HOLogic.boolT
-    val quotT = [relT, absT, repT, crT] ---> HOLogic.boolT
-    val rel = Var (("R", 0), relT)
-    val abs = Var (("Abs", 0), absT)
-    val rep = Var (("Rep", 0), repT)
-    val cr = Var (("T", 0), crT)
-    val quot = Const (@{const_name Quotient}, quotT)
-    val goal = HOLogic.Trueprop $ (quot $ rel $ abs $ rep $ cr)
-    val cgoal = Thm.cterm_of (Proof_Context.theory_of ctxt) goal
-    val tac = REPEAT (quotient_tac ctxt 1)
+    val (_, qty_schematic) = quot_thm_rty_qty quot_thm
+    val match_env = Sign.typ_match thy (qty_schematic, qty) Vartab.empty
+    fun prep_ty thy (x, (S, ty)) =
+      (ctyp_of thy (TVar (x, S)), ctyp_of thy ty)
+    val ty_inst = Vartab.fold (cons o (prep_ty thy)) match_env []
   in
-    Goal.prove_internal [] cgoal (K tac)
+    Thm.instantiate (ty_inst, []) quot_thm
+  end
+
+fun check_rty_type ctxt rty quot_thm =
+  let  
+    val thy = Proof_Context.theory_of ctxt
+    val (rty_forced, _) = quot_thm_rty_qty quot_thm
+    val rty_schematic = Logic.type_map (singleton (Variable.polymorphic ctxt)) rty
+    val _ = Sign.typ_match thy (rty_schematic, rty_forced) Vartab.empty
+      handle Type.TYPE_MATCH => raise CHECK_RTY (rty_schematic, rty_forced)
+  in
+    ()
+  end
+
+fun prove_quot_thm ctxt (rty, qty) =
+  let
+    val thy = Proof_Context.theory_of ctxt
+    val schematic_quot_thm = prove_schematic_quot_thm ctxt (rty, qty)
+    val quot_thm = force_qty_type thy qty schematic_quot_thm
+    val _ = check_rty_type ctxt rty quot_thm
+  in
+    quot_thm
   end
 
 fun absrep_fun ctxt (rty, qty) =
-  quot_thm_abs (prove_quot_theorem ctxt (rty, qty))
+  quot_thm_abs (prove_quot_thm ctxt (rty, qty))
 
 fun equiv_relation ctxt (rty, qty) =
-  quot_thm_rel (prove_quot_theorem ctxt (rty, qty))
+  quot_thm_rel (prove_quot_thm ctxt (rty, qty))
 
 end;
--- a/src/HOL/Tools/Quotient/quotient_def.ML	Tue Apr 17 13:54:31 2012 +0200
+++ b/src/HOL/Tools/Quotient/quotient_def.ML	Tue Apr 17 15:25:43 2012 +0200
@@ -82,7 +82,7 @@
     val ty_args = get_binder_types (rty, qty)
     fun disch_arg args_ty thm = 
       let
-        val quot_thm = Quotient_Term.prove_quot_theorem ctxt args_ty
+        val quot_thm = Quotient_Term.prove_quot_thm ctxt args_ty
       in
         [quot_thm, thm] MRSL @{thm apply_rspQ3''}
       end
@@ -97,7 +97,7 @@
 
 fun generate_code_cert ctxt def_thm rsp_thm (rty, qty) =
   let
-    val quot_thm = Quotient_Term.prove_quot_theorem ctxt (get_body_types (rty, qty))
+    val quot_thm = Quotient_Term.prove_quot_thm ctxt (get_body_types (rty, qty))
     val fun_rel = prove_rel ctxt rsp_thm (rty, qty)
     val abs_rep_thm = [quot_thm, fun_rel] MRSL @{thm Quotient3_rep_abs}
     val abs_rep_eq = 
@@ -117,7 +117,7 @@
 
 fun define_code_cert code_eqn_thm_name def_thm rsp_thm (rty, qty) lthy = 
   let
-    val quot_thm = Quotient_Term.prove_quot_theorem lthy (get_body_types (rty, qty))
+    val quot_thm = Quotient_Term.prove_quot_thm lthy (get_body_types (rty, qty))
   in
     if Quotient_Type.can_generate_code_cert quot_thm then
       let
--- a/src/HOL/Tools/Quotient/quotient_term.ML	Tue Apr 17 13:54:31 2012 +0200
+++ b/src/HOL/Tools/Quotient/quotient_term.ML	Tue Apr 17 15:25:43 2012 +0200
@@ -21,7 +21,7 @@
   val equiv_relation_chk: Proof.context -> typ * typ -> term
 
   val get_rel_from_quot_thm: thm -> term
-  val prove_quot_theorem: Proof.context -> typ * typ -> thm
+  val prove_quot_thm: Proof.context -> typ * typ -> thm
 
   val regularize_trm: Proof.context -> term * term -> term
   val regularize_trm_chk: Proof.context -> term * term -> term
@@ -379,7 +379,7 @@
     else raise NOT_IMPL "nested quotients: not implemented yet"
   end
 
-fun prove_quot_theorem ctxt (rty, qty) =
+fun prove_quot_thm ctxt (rty, qty) =
   if rty = qty
   then @{thm identity_quotient3}
   else
@@ -388,7 +388,7 @@
         if s = s'
         then
           let
-            val args = map (prove_quot_theorem ctxt) (tys ~~ tys')
+            val args = map (prove_quot_thm ctxt) (tys ~~ tys')
           in
             args MRSL (get_rel_quot_thm ctxt s)
           end
@@ -397,7 +397,7 @@
             val (Type (_, rtys), qty_pat) = get_rty_qty ctxt s'
             val qtyenv = match ctxt equiv_match_err qty_pat qty
             val rtys' = map (Envir.subst_type qtyenv) rtys
-            val args = map (prove_quot_theorem ctxt) (tys ~~ rtys')
+            val args = map (prove_quot_thm ctxt) (tys ~~ rtys')
             val quot_thm = get_quot_thm ctxt s'
           in
             if forall is_id_quot args