Lifting: support a type variable as a raw type
authorkuncar
Wed, 12 Feb 2014 18:32:55 +0100
changeset 55454 6ea67a791108
parent 55453 0b070d098d1a
child 55455 2cf404a469be
Lifting: support a type variable as a raw type
src/HOL/Tools/Lifting/lifting_term.ML
src/HOL/Tools/Lifting/lifting_util.ML
--- a/src/HOL/Tools/Lifting/lifting_term.ML	Thu Feb 13 13:16:17 2014 +0100
+++ b/src/HOL/Tools/Lifting/lifting_term.ML	Wed Feb 12 18:32:55 2014 +0100
@@ -184,11 +184,13 @@
           rel_quot_thm_prems
       end
 
-fun instantiate_rtys ctxt rty (qty as Type (qty_name, _)) =
+fun rty_is_TVar ctxt qty = (is_TVar o fst o quot_thm_rty_qty o get_quot_thm ctxt o Tname) qty
+
+fun instantiate_rtys ctxt (rty, (qty as Type (qty_name, _))) =
   let
     val quot_thm = get_quot_thm ctxt qty_name
-    val ((rty_pat as Type (_, rty_pat_tys)), qty_pat) = quot_thm_rty_qty quot_thm
-    
+    val (rty_pat, qty_pat) = quot_thm_rty_qty quot_thm
+
     fun inst_rty (Type (s, tys), Type (s', tys')) = 
         if s = s' then Type (s', map inst_rty (tys ~~ tys'))
         else raise QUOT_THM_INTERNAL (Pretty.block 
@@ -206,59 +208,66 @@
       | inst_rty ((TFree _), rty) = rty
       | inst_rty (_, _) = error "check_raw_types: we should not be here"
 
-    val (Type (_, rtys')) = inst_rty (rty_pat, rty)
     val qtyenv = match ctxt equiv_match_err qty_pat qty
   in
-    (rtys', map (Envir.subst_type qtyenv) rty_pat_tys)
+    (inst_rty (rty_pat, rty), Envir.subst_type qtyenv rty_pat)
   end
-  | instantiate_rtys _ _ _ = error "instantiate_rtys: not Type"
+  | instantiate_rtys _ _ = error "instantiate_rtys: not Type"
 
 fun prove_schematic_quot_thm ctxt (rty, qty) =
-  (case (rty, qty) of
-    (Type (s, tys), Type (s', tys')) =>
-      if s = s'
-      then
-        let
-          val args = map (prove_schematic_quot_thm ctxt) (zip_Tvars ctxt s tys tys')
-        in
-          if forall is_id_quot args
-          then
-            @{thm identity_quotient}
-          else
-            args MRSL (get_rel_quot_thm ctxt s)
-        end
-      else
-        let
-          val (rtys, rtys') = instantiate_rtys ctxt rty qty
-          val args = map (prove_schematic_quot_thm ctxt) (rtys ~~ rtys')
-        in
-          if forall is_id_quot args
-          then
-            get_quot_thm ctxt s'
-          else
+  let
+    fun lifting_step (rty, qty) =
+      let
+        val (rty', rtyq) = instantiate_rtys ctxt (rty, qty)
+        val (rty's, rtyqs) = if rty_is_TVar ctxt qty then ([rty'],[rtyq]) 
+          else (Targs rty', Targs rtyq) 
+        val args = map (prove_schematic_quot_thm ctxt) (rty's ~~ rtyqs)
+      in
+        if forall is_id_quot args
+        then
+          get_quot_thm ctxt (Tname qty)
+        else
+          let
+            val quot_thm = get_quot_thm ctxt (Tname qty)
+            val rel_quot_thm = if rty_is_TVar ctxt qty then the_single args else
+              args MRSL (get_rel_quot_thm ctxt (Tname rty))
+          in
+            [rel_quot_thm, quot_thm] MRSL @{thm Quotient_compose}
+         end
+      end
+  in
+    (case (rty, qty) of
+      (Type (s, tys), Type (s', tys')) =>
+        if s = s'
+        then
+          let
+            val args = map (prove_schematic_quot_thm ctxt) (zip_Tvars ctxt s tys tys')
+          in
+            if forall is_id_quot args
+            then
+              @{thm identity_quotient}
+            else
+              args MRSL (get_rel_quot_thm ctxt s)
+          end
+        else
+          lifting_step (rty, qty)
+      | (_, Type (s', tys')) => 
+        (case try (get_quot_thm ctxt) s' of
+          SOME quot_thm => 
             let
-              val quot_thm = get_quot_thm ctxt s'
-              val rel_quot_thm = args MRSL (get_rel_quot_thm ctxt s)
+              val rty_pat = (fst o quot_thm_rty_qty) quot_thm
             in
-              [rel_quot_thm, quot_thm] MRSL @{thm Quotient_compose}
-           end
-        end
-    | (_, Type (s', tys')) => 
-      (case try (get_quot_thm ctxt) s' of
-        SOME quot_thm => 
-          let
-            val rty_pat = (fst o quot_thm_rty_qty) quot_thm
-          in
-            prove_schematic_quot_thm ctxt (rty_pat, qty)
-          end
-        | NONE =>
-          let
-            val rty_pat = Type (s', map (fn _ => TFree ("a",[])) tys')
-          in
-            prove_schematic_quot_thm ctxt (rty_pat, qty)
-          end)
-    | _ => @{thm identity_quotient})
-    handle QUOT_THM_INTERNAL pretty_msg => raise QUOT_THM (rty, qty, pretty_msg)
+              lifting_step (rty_pat, qty)              
+            end
+          | NONE =>
+            let                                               
+              val rty_pat = Type (s', map (fn _ => TFree ("a",[])) tys')
+            in
+              prove_schematic_quot_thm ctxt (rty_pat, qty)
+            end)
+      | _ => @{thm identity_quotient})
+  end
+  handle QUOT_THM_INTERNAL pretty_msg => raise QUOT_THM (rty, qty, pretty_msg)
 
 fun force_qty_type thy qty quot_thm =
   let
@@ -496,12 +505,10 @@
           | [_, trans_rel] =>
             let
               val (rty', qty) = (relation_types o fastype_of) trans_rel
-              val r = (fst o dest_Type) rty' 
-              val q = (fst o dest_Type) qty
             in
-              if r = q then
+              if same_type_constrs (rty', qty) then
                 let
-                  val distr_rules = get_rel_distr_rules ctxt r (head_of tm)
+                  val distr_rules = get_rel_distr_rules ctxt ((fst o dest_Type) rty') (head_of tm)
                   val distr_rule = get_first (prove_extra_assms ctxt ctm) distr_rules
                 in
                   case distr_rule of
@@ -511,7 +518,7 @@
                 end
               else
                 let 
-                  val pcrel_def = get_pcrel_def ctxt q
+                  val pcrel_def = get_pcrel_def ctxt ((fst o dest_Type) qty)
                   val pcrel_const = (head_of o fst o Logic.dest_equals o prop_of) pcrel_def
                 in
                   if same_constants pcrel_const (head_of trans_rel) then
@@ -546,33 +553,41 @@
       let
         val (rty, qty) = (relation_types o fastype_of) (term_of ctm)
       in
-        case (rty, qty) of
-          (Type (r, rargs), Type (q, qargs)) =>
-            if r = q then
-              if forall op= (rargs ~~ qargs) then
-                Conv.all_conv ctm
-              else
-                all_args_conv parametrize_relation_conv ctm
-            else
-              if forall op= (op~~ (instantiate_rtys ctxt rty qty)) then
-                let
-                  val pcr_cr_eq = (Thm.symmetric o mk_meta_eq) (get_pcr_cr_eq ctxt q)
-                in
-                  Conv.rewr_conv pcr_cr_eq ctm
-                end
-                handle QUOT_THM_INTERNAL _ => Conv.all_conv ctm
-              else
-                (let 
-                  val pcrel_def = Thm.symmetric (get_pcrel_def ctxt q)
-                in
-                  (Conv.rewr_conv pcrel_def then_conv all_args_conv parametrize_relation_conv) ctm
-                end
-                handle QUOT_THM_INTERNAL _ => 
-                  (Conv.arg1_conv (all_args_conv parametrize_relation_conv)) ctm)
-          | _ => Conv.all_conv ctm
+        if same_type_constrs (rty, qty) then
+          if forall op= (Targs rty ~~ Targs qty) then
+            Conv.all_conv ctm
+          else
+            all_args_conv parametrize_relation_conv ctm
+        else
+          if is_Type qty then
+            let
+              val q = (fst o dest_Type) qty
+            in
+              let
+                val (rty', rtyq) = instantiate_rtys ctxt (rty, qty)
+                val (rty's, rtyqs) = if rty_is_TVar ctxt qty then ([rty'],[rtyq]) 
+                  else (Targs rty', Targs rtyq)
+              in
+                if forall op= (rty's ~~ rtyqs) then
+                  let
+                    val pcr_cr_eq = (Thm.symmetric o mk_meta_eq) (get_pcr_cr_eq ctxt q)
+                  in      
+                    Conv.rewr_conv pcr_cr_eq ctm
+                  end
+                  handle QUOT_THM_INTERNAL _ => Conv.all_conv ctm
+                else
+                  (let 
+                    val pcrel_def = Thm.symmetric (get_pcrel_def ctxt q)
+                  in
+                    (Conv.rewr_conv pcrel_def then_conv all_args_conv parametrize_relation_conv) ctm
+                  end
+                  handle QUOT_THM_INTERNAL _ => 
+                    (Conv.arg1_conv (all_args_conv parametrize_relation_conv)) ctm)
+              end  
+            end
+          else Conv.all_conv ctm
       end
     in
       Conv.fconv_rule (HOLogic.Trueprop_conv (Conv.fun2_conv parametrize_relation_conv)) thm
     end
-
 end
--- a/src/HOL/Tools/Lifting/lifting_util.ML	Thu Feb 13 13:16:17 2014 +0100
+++ b/src/HOL/Tools/Lifting/lifting_util.ML	Wed Feb 12 18:32:55 2014 +0100
@@ -24,6 +24,9 @@
   val strip_args: int -> term -> term
   val all_args_conv: conv -> conv
   val is_Type: typ -> bool
+  val same_type_constrs: typ * typ -> bool
+  val Targs: typ -> typ list
+  val Tname: typ -> string
   val is_fun_rel: term -> bool
   val relation_types: typ -> typ * typ
   val mk_HOL_eq: thm -> thm
@@ -98,6 +101,15 @@
 fun is_Type (Type _) = true
   | is_Type _ = false
 
+fun same_type_constrs (Type (r, _), Type (q, _)) = (r = q)
+  | same_type_constrs _ = false
+
+fun Targs (Type (_, args)) = args
+  | Targs _ = []
+
+fun Tname (Type (name, _)) = name
+  | Tname _ = ""
+
 fun is_fun_rel (Const (@{const_name "fun_rel"}, _) $ _ $ _) = true
   | is_fun_rel _ = false