infer variances of user-given mapper operation; proper thm storing
authorhaftmann
Wed, 17 Nov 2010 17:27:25 +0100
changeset 40594 fae1da97bb5e
parent 40587 5206d19038c7
child 40595 448520778e38
infer variances of user-given mapper operation; proper thm storing
src/HOL/Tools/functorial_mappers.ML
--- a/src/HOL/Tools/functorial_mappers.ML	Wed Nov 17 12:24:58 2010 +0100
+++ b/src/HOL/Tools/functorial_mappers.ML	Wed Nov 17 17:27:25 2010 +0100
@@ -19,6 +19,9 @@
 structure Functorial_Mappers : FUNCTORIAL_MAPPERS =
 struct
 
+val concatenateN = "concatenate";
+val identityN = "identity";
+
 (** functorial mappers and their properties **)
 
 (* bookkeeping *)
@@ -133,8 +136,10 @@
 
 fun register tyco mapper variances raw_concatenate raw_identity thy =
   let
-    val (_, concatenate_prop) = make_concatenate_prop variances (tyco, mapper);
-    val (_, identity_prop) = make_identity_prop variances (tyco, mapper);
+    val concatenate_prop = uncurry Logic.all
+      (make_concatenate_prop variances (tyco, mapper));
+    val identity_prop = uncurry Logic.all
+      (make_identity_prop variances (tyco, mapper));
     val concatenate = Goal.prove_global thy
       (Term.add_free_names concatenate_prop []) [] concatenate_prop
       (K (ALLGOALS (ProofContext.fact_tac [raw_concatenate])));
@@ -148,6 +153,9 @@
         concatenate = concatenate, identity = identity }))
   end;
 
+fun consume eq x [] = (false, [])
+  | consume eq x (ys as z :: zs) = if eq (x, z) then (true, zs) else (false, ys);
+
 fun split_mapper_typ "fun" T =
       let
         val (Ts', T') = strip_type T;
@@ -171,7 +179,19 @@
       handle TYPE _ => bad_typ ();
     val _ = if has_duplicates (eq_fst (op =)) (vs1 @ vs2)
       then bad_typ () else ();
-  in [] end;
+    fun check_variance_pair (var1 as (v1, sort1), var2 as (v2, sort2)) =
+      let
+        val coT = TFree var1 --> TFree var2;
+        val contraT = TFree var2 --> TFree var1;
+        val sort = Sign.inter_sort thy (sort1, sort2);
+      in
+        consume (op =) coT
+        ##>> consume (op =) contraT
+        #>> pair sort
+      end;
+    val (variances, left_variances) = fold_map check_variance_pair (vs1 ~~ vs2) Ts;
+    val _ = if null left_variances then () else bad_typ ();
+  in variances end;
 
 fun gen_type_mapper prep_term raw_t thy =
   let
@@ -187,13 +207,19 @@
        of [tyco] => tyco
         | _ => error ("Bad number of type constructors: " ^ Syntax.string_of_typ_global thy T);
     val variances = analyze_variances thy tyco T;
-    val (_, concatenate_prop) = make_concatenate_prop variances (tyco, mapper);
-    val (_, identity_prop) = make_identity_prop variances (tyco, mapper);
-    fun after_qed [[concatenate], [identity]] lthy =
+    val concatenate_prop = uncurry Logic.all
+      (make_concatenate_prop variances (tyco, mapper));
+    val identity_prop = uncurry Logic.all
+      (make_identity_prop variances (tyco, mapper));
+    val qualify = Binding.qualify true (Long_Name.base_name mapper) o Binding.name;
+    fun after_qed [single_concatenate, single_identity] lthy =
       lthy
-      |> (Local_Theory.background_theory o Data.map)
-          (Symtab.update (tyco, { mapper = mapper, variances = variances,
-            concatenate = concatenate, identity = identity }));
+      |> Local_Theory.note ((qualify concatenateN, []), single_concatenate)
+      ||>> Local_Theory.note ((qualify identityN, []), single_identity)
+      |-> (fn ((_, [concatenate]), (_, [identity])) =>
+          (Local_Theory.background_theory o Data.map)
+            (Symtab.update (tyco, { mapper = mapper, variances = variances,
+              concatenate = concatenate, identity = identity })));
   in
     thy
     |> Named_Target.theory_init