merged
authorhaftmann
Wed, 22 Dec 2010 22:21:14 +0100
changeset 41391 b71bcdb568c0
parent 41386 9400026a82f5 (current diff)
parent 41390 207ee8f8a19c (diff)
child 41392 d1ff42a70f77
merged
--- a/src/HOL/Tools/type_lifting.ML	Wed Dec 22 20:08:40 2010 +0100
+++ b/src/HOL/Tools/type_lifting.ML	Wed Dec 22 22:21:14 2010 +0100
@@ -6,12 +6,12 @@
 
 signature TYPE_LIFTING =
 sig
-  val find_atomic: theory -> typ -> (typ * (bool * bool)) list
-  val construct_mapper: theory -> (string * bool -> term)
+  val find_atomic: Proof.context -> typ -> (typ * (bool * bool)) list
+  val construct_mapper: Proof.context -> (string * bool -> term)
     -> bool -> typ -> typ -> term
-  val type_lifting: string option -> term -> theory -> Proof.state
+  val type_lifting: string option -> term -> local_theory -> Proof.state
   type entry
-  val entries: theory -> entry Symtab.table
+  val entries: Proof.context -> entry list Symtab.table
 end;
 
 structure Type_Lifting : TYPE_LIFTING =
@@ -26,24 +26,27 @@
 
 (* bookkeeping *)
 
-type entry = { mapper: string, variances: (sort * (bool * bool)) list,
+type entry = { mapper: term, variances: (sort * (bool * bool)) list,
   comp: thm, id: thm };
 
-structure Data = Theory_Data(
-  type T = entry Symtab.table
+structure Data = Generic_Data(
+  type T = entry list Symtab.table
   val empty = Symtab.empty
   fun merge (xy : T * T) = Symtab.merge (K true) xy
   val extend = I
 );
 
-val entries = Data.get;
+val entries = Data.get o Context.Proof;
 
 
 (* type analysis *)
 
-fun find_atomic thy T =
+fun term_with_typ ctxt T t = Envir.subst_term_types
+  (Type.typ_match (ProofContext.tsig_of ctxt) (fastype_of t, T) Vartab.empty) t;
+
+fun find_atomic ctxt T =
   let
-    val variances_of = Option.map #variances o Symtab.lookup (Data.get thy);
+    val variances_of = Option.map #variances o try hd o Symtab.lookup_list (entries ctxt);
     fun add_variance is_contra T =
       AList.map_default (op =) (T, (false, false))
         ((if is_contra then apsnd else apfst) (K true));
@@ -56,26 +59,29 @@
       | analyze is_contra T = add_variance is_contra T;
   in analyze false T [] end;
 
-fun construct_mapper thy atomic =
+fun construct_mapper ctxt atomic =
   let
-    val lookup = the o Symtab.lookup (Data.get thy);
+    val lookup = hd o Symtab.lookup_list (entries ctxt);
     fun constructs is_contra (_, (co, contra)) T T' =
       (if co then [construct is_contra T T'] else [])
       @ (if contra then [construct (not is_contra) T T'] else [])
     and construct is_contra (T as Type (tyco, Ts)) (T' as Type (_, Ts')) =
           let
-            val { mapper, variances, ... } = lookup tyco;
+            val { mapper = raw_mapper, variances, ... } = lookup tyco;
             val args = maps (fn (arg_pattern, (T, T')) =>
               constructs is_contra arg_pattern T T')
                 (variances ~~ (Ts ~~ Ts'));
             val (U, U') = if is_contra then (T', T) else (T, T');
-          in list_comb (Const (mapper, map fastype_of args ---> U --> U'), args) end
+            val mapper = term_with_typ ctxt (map fastype_of args ---> U --> U') raw_mapper;
+          in list_comb (mapper, args) end
       | construct is_contra (TFree (v, _)) (TFree _) = atomic (v, is_contra);
   in construct end;
 
 
 (* mapper properties *)
 
+val compositionality_ss = Simplifier.add_simp (Simpdata.mk_eq @{thm comp_def}) HOL_basic_ss;
+
 fun make_comp_prop ctxt variances (tyco, mapper) =
   let
     val sorts = map fst variances
@@ -108,11 +114,22 @@
       else
         HOLogic.mk_comp (Free (f32, T32), Free (f21, T21))
       ) contras (args21 ~~ args32)
-    fun mk_mapper T T' args = list_comb (Const (mapper,
-      map fastype_of args ---> T --> T'), args);
-    val lhs = HOLogic.mk_comp (mk_mapper T2 T1 (map Free args21), mk_mapper T3 T2 (map Free args32));
-    val rhs = mk_mapper T3 T1 args31;
-  in fold_rev Logic.all (map Free (args21 @ args32)) ((HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhs, rhs)) end;
+    fun mk_mapper T T' args = list_comb (term_with_typ ctxt (map fastype_of args ---> T --> T') mapper, args);
+    val mapper21 = mk_mapper T2 T1 (map Free args21);
+    val mapper32 = mk_mapper T3 T2 (map Free args32);
+    val mapper31 = mk_mapper T3 T1 args31;
+    val eq1 = (HOLogic.mk_Trueprop o HOLogic.mk_eq) (HOLogic.mk_comp (mapper21, mapper32), mapper31);
+    val x = Free (the_single (Name.invents nctxt (Long_Name.base_name tyco) 1), T3)
+    val eq2 = (HOLogic.mk_Trueprop o HOLogic.mk_eq) (mapper21 $ (mapper32 $ x), mapper31 $ x);
+    val comp_prop = fold_rev Logic.all (map Free (args21 @ args32)) eq1;
+    val compositionality_prop = fold_rev Logic.all (map Free (args21 @ args32) @ [x]) eq2;
+    fun prove_compositionality ctxt comp_thm = Skip_Proof.prove ctxt [] [] compositionality_prop
+      (K (ALLGOALS (Method.insert_tac [@{thm fun_cong} OF [comp_thm]]
+        THEN' Simplifier.asm_lr_simp_tac compositionality_ss
+        THEN_ALL_NEW (Goal.assume_rule_tac ctxt))));
+  in (comp_prop, prove_compositionality) end;
+
+val identity_ss = Simplifier.add_simp (Simpdata.mk_eq @{thm id_def}) HOL_basic_ss;
 
 fun make_id_prop ctxt variances (tyco, mapper) =
   let
@@ -120,37 +137,17 @@
     val Ts = map TFree vs;
     fun bool_num b = if b then 1 else 0;
     fun mk_argT (T, (_, (co, contra))) =
-      replicate (bool_num co + bool_num contra) (T --> T)
-    val Ts' = maps mk_argT (Ts ~~ variances)
+      replicate (bool_num co + bool_num contra) T
+    val arg_Ts = maps mk_argT (Ts ~~ variances)
     val T = Type (tyco, Ts);
-    val lhs = list_comb (Const (mapper, Ts' ---> T --> T),
-      map (HOLogic.id_const o domain_type) Ts');
-  in (HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhs, HOLogic.id_const T) end;
-
-val comp_apply = Simpdata.mk_eq @{thm o_apply};
-val id_def = Simpdata.mk_eq @{thm id_def};
-
-fun make_compositionality ctxt thm =
-  let
-    val ((_, [thm']), ctxt') = Variable.import false [thm] ctxt;
-    val thm'' = @{thm fun_cong} OF [thm'];
-    val thm''' =
-      (Conv.fconv_rule o Conv.arg_conv o Conv.arg1_conv o Conv.rewr_conv) comp_apply thm'';
-  in singleton (Variable.export ctxt' ctxt) thm''' end;
-
-fun args_conv k conv =
-  if k <= 0 then Conv.all_conv
-  else Conv.combination_conv (args_conv (k - 1) conv) conv;
-
-fun make_identity ctxt variances thm =
-  let
-    val ((_, [thm']), ctxt') = Variable.import false [thm] ctxt;
-    fun bool_num b = if b then 1 else 0;
-    val num_args = Integer.sum
-      (map (fn (_, (co, contra)) => bool_num co + bool_num contra) variances);
-    val thm'' =
-      (Conv.fconv_rule o Conv.arg_conv o Conv.arg1_conv o args_conv num_args o Conv.rewr_conv) id_def thm';
-  in singleton (Variable.export ctxt' ctxt) thm'' end;
+    val head = term_with_typ ctxt (map (fn T => T --> T) arg_Ts ---> T --> T) mapper;
+    val lhs1 = list_comb (head, map (HOLogic.id_const) arg_Ts);
+    val lhs2 = list_comb (head, map (fn arg_T => Abs ("x", arg_T, Bound 0)) arg_Ts);
+    val rhs = HOLogic.id_const T;
+    val (id_prop, identity_prop) = pairself (HOLogic.mk_Trueprop o HOLogic.mk_eq o rpair rhs) (lhs1, lhs2);
+    fun prove_identity ctxt id_thm = Skip_Proof.prove ctxt [] [] identity_prop
+      (K (ALLGOALS (Method.insert_tac [id_thm] THEN' Simplifier.asm_lr_simp_tac identity_ss)));
+  in (id_prop, prove_identity) end;
 
 
 (* analyzing and registering mappers *)
@@ -170,13 +167,14 @@
         val (Ts'', T'') = split_last Ts';
       in (Ts'', T'', T') end;
 
-fun analyze_variances thy tyco T =
+fun analyze_variances ctxt tyco T =
   let
-    fun bad_typ () = error ("Bad mapper type: " ^ Syntax.string_of_typ_global thy T);
+    fun bad_typ () = error ("Bad mapper type: " ^ Syntax.string_of_typ ctxt T);
     val (Ts, T1, T2) = split_mapper_typ tyco T
       handle List.Empty => bad_typ ();
     val _ = pairself
       ((fn tyco' => if tyco' = tyco then () else bad_typ ()) o fst o dest_Type) (T1, T2)
+      handle TYPE _ => bad_typ ();
     val (vs1, vs2) = pairself (map dest_TFree o snd o dest_Type) (T1, T2)
       handle TYPE _ => bad_typ ();
     val _ = if has_duplicates (eq_fst (op =)) (vs1 @ vs2)
@@ -185,7 +183,7 @@
       let
         val coT = TFree var1 --> TFree var2;
         val contraT = TFree var2 --> TFree var1;
-        val sort = Sign.inter_sort thy (sort1, sort2);
+        val sort = Sign.inter_sort (ProofContext.theory_of ctxt) (sort1, sort2);
       in
         consume (op =) coT
         ##>> consume (op =) contraT
@@ -195,50 +193,61 @@
     val _ = if null left_variances then () else bad_typ ();
   in variances end;
 
-fun gen_type_lifting prep_term some_prfx raw_t thy =
+fun gen_type_lifting prep_term some_prfx raw_mapper lthy =
   let
-    val (mapper, T) = case prep_term thy raw_t
-     of Const cT => cT
-      | t => error ("No constant: " ^ Syntax.string_of_term_global thy t);
+    val input_mapper = prep_term lthy raw_mapper;
+    val T = fastype_of input_mapper;
     val _ = Type.no_tvars T;
+    val mapper = singleton (Variable.polymorphic lthy) input_mapper;
+    val _ = if null (Term.add_tfreesT (fastype_of mapper) []) then ()
+      else error ("Illegal locally fixed variables in type: " ^ Syntax.string_of_typ lthy T);
     fun add_tycos (Type (tyco, Ts)) = insert (op =) tyco #> fold add_tycos Ts
       | add_tycos _ = I;
     val tycos = add_tycos T [];
     val tyco = if tycos = ["fun"] then "fun"
       else case remove (op =) "fun" tycos
        of [tyco] => tyco
-        | _ => error ("Bad number of type constructors: " ^ Syntax.string_of_typ_global thy T);
+        | _ => error ("Bad number of type constructors: " ^ Syntax.string_of_typ lthy T);
     val prfx = the_default (Long_Name.base_name tyco) some_prfx;
-    val variances = analyze_variances thy tyco T;
-    val ctxt = ProofContext.init_global thy;
-    val comp_prop = make_comp_prop ctxt variances (tyco, mapper);
-    val id_prop = make_id_prop ctxt variances (tyco, mapper);
+    val variances = analyze_variances lthy tyco T;
+    val (comp_prop, prove_compositionality) = make_comp_prop lthy variances (tyco, mapper);
+    val (id_prop, prove_identity) = make_id_prop lthy variances (tyco, mapper);
     val qualify = Binding.qualify true prfx o Binding.name;
-    fun after_qed [single_comp, single_id] lthy =
+    fun mapper_declaration comp_thm id_thm phi context =
+      let
+        val typ_instance = Type.typ_instance (ProofContext.tsig_of (Context.proof_of context));
+        val mapper' = Morphism.term phi mapper;
+        val T_T' = pairself fastype_of (mapper, mapper');
+      in if typ_instance T_T' andalso typ_instance (swap T_T')
+        then (Data.map o Symtab.cons_list) (tyco,
+          { mapper = mapper', variances = variances,
+            comp = Morphism.thm phi comp_thm, id = Morphism.thm phi id_thm }) context
+        else context
+      end;
+    fun after_qed [single_comp_thm, single_id_thm] lthy =
       lthy
-      |> Local_Theory.note ((qualify compN, []), single_comp)
-      ||>> Local_Theory.note ((qualify idN, []), single_id)
-      |-> (fn ((_, [comp]), (_, [id])) => fn lthy =>
+      |> Local_Theory.note ((qualify compN, []), single_comp_thm)
+      ||>> Local_Theory.note ((qualify idN, []), single_id_thm)
+      |-> (fn ((_, [comp_thm]), (_, [id_thm])) => fn lthy =>
         lthy
-        |> Local_Theory.note ((qualify compositionalityN, []), [make_compositionality lthy comp])
+        |> Local_Theory.note ((qualify compositionalityN, []),
+            [prove_compositionality lthy comp_thm])
         |> snd
-        |> Local_Theory.note ((qualify identityN, []), [make_identity lthy variances id])
+        |> Local_Theory.note ((qualify identityN, []),
+            [prove_identity lthy id_thm])
         |> snd
-        |> (Local_Theory.background_theory o Data.map)
-            (Symtab.update (tyco, { mapper = mapper, variances = variances,
-              comp = comp, id = id })));
+        |> Local_Theory.declaration false (mapper_declaration comp_thm id_thm))
   in
-    thy
-    |> Named_Target.theory_init
+    lthy
     |> Proof.theorem NONE after_qed (map (fn t => [(t, [])]) [comp_prop, id_prop])
   end
 
-val type_lifting = gen_type_lifting Sign.cert_term;
-val type_lifting_cmd = gen_type_lifting Syntax.read_term_global;
+val type_lifting = gen_type_lifting Syntax.check_term;
+val type_lifting_cmd = gen_type_lifting Syntax.read_term;
 
-val _ =
-  Outer_Syntax.command "type_lifting" "register operations managing the functorial structure of a type" Keyword.thy_goal
-    (Scan.option (Parse.name --| Parse.$$$ ":") -- Parse.term
-      >> (fn (prfx, t) => Toplevel.print o (Toplevel.theory_to_proof (type_lifting_cmd prfx t))));
+val _ = Outer_Syntax.local_theory_to_proof "type_lifting"
+  "register operations managing the functorial structure of a type"
+  Keyword.thy_goal (Scan.option (Parse.name --| Parse.$$$ ":") -- Parse.term
+    >> (fn (prfx, t) => type_lifting_cmd prfx t));
 
 end;