prove more algebraic version of functorial properties; retain old properties for convenience
authorhaftmann
Tue, 21 Dec 2010 16:14:46 +0100
changeset 41371 35d2241c169c
parent 41366 ea73e74ec827
child 41372 551eb49a6e91
prove more algebraic version of functorial properties; retain old properties for convenience
src/HOL/Tools/type_lifting.ML
--- a/src/HOL/Tools/type_lifting.ML	Tue Dec 21 15:16:27 2010 +0100
+++ b/src/HOL/Tools/type_lifting.ML	Tue Dec 21 16:14:46 2010 +0100
@@ -17,6 +17,8 @@
 structure Type_Lifting : TYPE_LIFTING =
 struct
 
+val compN = "comp";
+val idN = "id";
 val compositionalityN = "compositionality";
 val identityN = "identity";
 
@@ -25,7 +27,7 @@
 (* bookkeeping *)
 
 type entry = { mapper: string, variances: (sort * (bool * bool)) list,
-  compositionality: thm, identity: thm };
+  comp: thm, id: thm };
 
 structure Data = Theory_Data(
   type T = entry Symtab.table
@@ -74,19 +76,14 @@
 
 (* mapper properties *)
 
-fun make_compositionality_prop variances (tyco, mapper) =
+fun make_comp_prop ctxt variances (tyco, mapper) =
   let
-    fun invents n k nctxt =
-      let
-        val names = Name.invents nctxt n k;
-      in (names, fold Name.declare names nctxt) end;
-    val (((vs3, vs2), vs1), _) = Name.context
-      |> invents Name.aT (length variances)
-      ||>> invents Name.aT (length variances)
-      ||>> invents Name.aT (length variances);
-    fun mk_Ts vs = map2 (fn v => fn (sort, _) => TFree (v, sort))
-      vs variances;
-    val (Ts1, Ts2, Ts3) = (mk_Ts vs1, mk_Ts vs2, mk_Ts vs3);
+    val sorts = map fst variances
+    val (((vs3, vs2), vs1), _) = ctxt
+      |> Variable.invent_types sorts
+      ||>> Variable.invent_types sorts
+      ||>> Variable.invent_types sorts
+    val (Ts1, Ts2, Ts3) = (map TFree vs1, map TFree vs2, map TFree vs3);
     fun mk_argT ((T, T'), (_, (co, contra))) =
       (if co then [(T --> T')] else [])
       @ (if contra then [(T' --> T)] else []);
@@ -94,40 +91,66 @@
       (if co then [false] else []) @ (if contra then [true] else [])) variances;
     val Ts21 = maps mk_argT ((Ts2 ~~ Ts1) ~~ variances);
     val Ts32 = maps mk_argT ((Ts3 ~~ Ts2) ~~ variances);
-    val ((names21, names32), nctxt) = Name.context
+    fun invents n k nctxt =
+      let
+        val names = Name.invents nctxt n k;
+      in (names, fold Name.declare names nctxt) end;
+    val ((names21, names32), nctxt) = Variable.names_of ctxt
       |> invents "f" (length Ts21)
       ||>> invents "f" (length Ts32);
     val T1 = Type (tyco, Ts1);
     val T2 = Type (tyco, Ts2);
     val T3 = Type (tyco, Ts3);
-    val x = Free (the_single (Name.invents nctxt (Long_Name.base_name tyco) 1), T3);
     val (args21, args32) = (names21 ~~ Ts21, names32 ~~ Ts32);
     val args31 = map2 (fn is_contra => fn ((f21, T21), (f32, T32)) =>
       if not is_contra then
-        Abs ("x", domain_type T32, Free (f21, T21) $ (Free (f32, T32) $ Bound 0))
+        HOLogic.mk_comp (Free (f21, T21), Free (f32, T32))
       else
-        Abs ("x", domain_type T21, Free (f32, T32) $ (Free (f21, T21) $ Bound 0))
+        HOLogic.mk_comp (Free (f32, T32), Free (f21, T21))
       ) contras (args21 ~~ args32)
     fun mk_mapper T T' args = list_comb (Const (mapper,
       map fastype_of args ---> T --> T'), args);
-    val lhs = mk_mapper T2 T1 (map Free args21) $
-      (mk_mapper T3 T2 (map Free args32) $ x);
-    val rhs = mk_mapper T3 T1 args31 $ x;
-  in (map Free (args21 @ args32) @ [x], (HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhs, rhs)) end;
+    val lhs = HOLogic.mk_comp (mk_mapper T2 T1 (map Free args21), mk_mapper T3 T2 (map Free args32));
+    val rhs = mk_mapper T3 T1 args31;
+  in fold_rev Logic.all (map Free (args21 @ args32)) ((HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhs, rhs)) end;
 
-fun make_identity_prop variances (tyco, mapper) =
+fun make_id_prop ctxt variances (tyco, mapper) =
   let
-    val vs = Name.invents Name.context Name.aT (length variances);
-    val Ts = map2 (fn v => fn (sort, _) => TFree (v, sort)) vs variances;
+    val (vs, ctxt') = Variable.invent_types (map fst variances) ctxt;
+    val Ts = map TFree vs;
     fun bool_num b = if b then 1 else 0;
     fun mk_argT (T, (_, (co, contra))) =
       replicate (bool_num co + bool_num contra) (T --> T)
     val Ts' = maps mk_argT (Ts ~~ variances)
     val T = Type (tyco, Ts);
-    val x = Free (Long_Name.base_name tyco, T);
     val lhs = list_comb (Const (mapper, Ts' ---> T --> T),
-      map (fn T => Abs ("x", domain_type T, Bound 0)) Ts') $ x;
-  in (x, (HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhs, x)) end;
+      map (HOLogic.mk_id o domain_type) Ts');
+  in (HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhs, HOLogic.mk_id T) end;
+
+val comp_apply = Simpdata.mk_eq @{thm o_apply};
+val id_def = Simpdata.mk_eq @{thm id_def};
+
+fun make_compositionality ctxt thm =
+  let
+    val ((_, [thm']), ctxt') = Variable.import false [thm] ctxt;
+    val thm'' = @{thm fun_cong} OF [thm'];
+    val thm''' =
+      (Conv.fconv_rule o Conv.arg_conv o Conv.arg1_conv o Conv.rewr_conv) comp_apply thm'';
+  in singleton (Variable.export ctxt' ctxt) thm''' end;
+
+fun args_conv k conv =
+  if k <= 0 then Conv.all_conv
+  else Conv.combination_conv (args_conv (k - 1) conv) conv;
+
+fun make_identity ctxt variances thm =
+  let
+    val ((_, [thm']), ctxt') = Variable.import false [thm] ctxt;
+    fun bool_num b = if b then 1 else 0;
+    val num_args = Integer.sum
+      (map (fn (_, (co, contra)) => bool_num co + bool_num contra) variances);
+    val thm'' =
+      (Conv.fconv_rule o Conv.arg_conv o Conv.arg1_conv o args_conv num_args o Conv.rewr_conv) id_def thm';
+  in singleton (Variable.export ctxt' ctxt) thm'' end;
 
 
 (* analyzing and registering mappers *)
@@ -177,7 +200,6 @@
     val (mapper, T) = case prep_term thy raw_t
      of Const cT => cT
       | t => error ("No constant: " ^ Syntax.string_of_term_global thy t);
-    val prfx = the_default (Long_Name.base_name mapper) some_prfx;
     val _ = Type.no_tvars T;
     fun add_tycos (Type (tyco, Ts)) = insert (op =) tyco #> fold add_tycos Ts
       | add_tycos _ = I;
@@ -186,24 +208,29 @@
       else case remove (op =) "fun" tycos
        of [tyco] => tyco
         | _ => error ("Bad number of type constructors: " ^ Syntax.string_of_typ_global thy T);
+    val prfx = the_default (Long_Name.base_name tyco) some_prfx;
     val variances = analyze_variances thy tyco T;
-    val compositionality_prop = uncurry (fold_rev Logic.all)
-      (make_compositionality_prop variances (tyco, mapper));
-    val identity_prop = uncurry Logic.all
-      (make_identity_prop variances (tyco, mapper));
+    val ctxt = ProofContext.init_global thy;
+    val comp_prop = make_comp_prop ctxt variances (tyco, mapper);
+    val id_prop = make_id_prop ctxt variances (tyco, mapper);
     val qualify = Binding.qualify true prfx o Binding.name;
-    fun after_qed [single_compositionality, single_identity] lthy =
+    fun after_qed [single_comp, single_id] lthy =
       lthy
-      |> Local_Theory.note ((qualify compositionalityN, []), single_compositionality)
-      ||>> Local_Theory.note ((qualify identityN, []), single_identity)
-      |-> (fn ((_, [compositionality]), (_, [identity])) =>
-          (Local_Theory.background_theory o Data.map)
+      |> Local_Theory.note ((qualify compN, []), single_comp)
+      ||>> Local_Theory.note ((qualify idN, []), single_id)
+      |-> (fn ((_, [comp]), (_, [id])) => fn lthy =>
+        lthy
+        |> Local_Theory.note ((qualify compositionalityN, []), [make_compositionality lthy comp])
+        |> snd
+        |> Local_Theory.note ((qualify identityN, []), [make_identity lthy variances id])
+        |> snd
+        |> (Local_Theory.background_theory o Data.map)
             (Symtab.update (tyco, { mapper = mapper, variances = variances,
-              compositionality = compositionality, identity = identity })));
+              comp = comp, id = id })));
   in
     thy
     |> Named_Target.theory_init
-    |> Proof.theorem NONE after_qed (map (fn t => [(t, [])]) [compositionality_prop, identity_prop])
+    |> Proof.theorem NONE after_qed (map (fn t => [(t, [])]) [comp_prop, id_prop])
   end
 
 val type_lifting = gen_type_lifting Sign.cert_term;