src/HOL/Tools/BNF/bnf_comp.ML
changeset 55854 ee270328a781
parent 55853 e776a4b813d1
child 55855 98ad5680173a
--- a/src/HOL/Tools/BNF/bnf_comp.ML	Mon Mar 03 12:48:19 2014 +0100
+++ b/src/HOL/Tools/BNF/bnf_comp.ML	Mon Mar 03 12:48:20 2014 +0100
@@ -6,6 +6,8 @@
 Composition of bounded natural functors.
 *)
 
+val inline_ref = Unsynchronized.ref true;
+
 signature BNF_COMP =
 sig
   val ID_bnf: BNF_Def.bnf
@@ -628,18 +630,45 @@
   let val rho = Vartab.fold (cons o apsnd snd) (Sign.typ_match thy (repT, repU) Vartab.empty) [];
   in Term.typ_subst_TVars rho absT end;
 
-fun mk_repT (t as Type (C, Ts)) repT (u as Type (C', Us)) =
-    if C = C' andalso length Ts = length Us then Term.typ_subst_atomic (Ts ~~ Us) repT
-    else raise Term.TYPE ("mk_repT", [t, repT, u], [])
-  | mk_repT t repT u =  raise Term.TYPE ("mk_repT", [t, repT, u], []);
+fun mk_repT absT repT absU =
+  if absT = repT then absU
+  else
+    (case (absT, absU) of
+      (Type (C, Ts), Type (C', Us)) =>
+        if C = C' then Term.typ_subst_atomic (Ts ~~ Us) repT
+        else raise Term.TYPE ("mk_repT", [absT, repT, absT], [])
+    | _ => raise Term.TYPE ("mk_repT", [absT, repT, absT], []));
 
-fun mk_abs_or_rep getT (Type (_, Us)) abs =
-  let val Ts = snd (dest_Type (getT (fastype_of abs)))
-  in Term.subst_atomic_types (Ts ~~ Us) abs end;
+fun mk_abs_or_rep _ absU (Const (@{const_name id_abs}, _)) =
+    Const (@{const_name id_abs}, absU --> absU)
+  | mk_abs_or_rep _ absU (Const (@{const_name id_rep}, _)) =
+    Const (@{const_name id_rep}, absU --> absU)
+  | mk_abs_or_rep getT (Type (_, Us)) abs =
+    let val Ts = snd (dest_Type (getT (fastype_of abs)))
+    in Term.subst_atomic_types (Ts ~~ Us) abs end;
 
 val mk_abs = mk_abs_or_rep range_type;
 val mk_rep = mk_abs_or_rep domain_type;
 
+val smart_max_inline_type_size = 5; (*FUDGE*)
+
+fun maybe_typedef (b, As, mx) set opt_morphs tac =
+  let
+    val repT = HOLogic.dest_setT (fastype_of set);
+    val inline = Term.size_of_typ repT <= smart_max_inline_type_size;
+  in
+    if inline then
+      pair (repT,
+        (@{const_name id_rep}, @{const_name id_abs}, @{thm type_definition_id_rep_abs_UNIV},
+         @{thm type_definition.Abs_inverse[OF type_definition_id_rep_abs_UNIV]},
+         @{thm type_definition.Abs_inject[OF type_definition_id_rep_abs_UNIV]}))
+    else
+      typedef (b, As, mx) set opt_morphs tac
+      #>> (fn (T_name, ({Rep_name, Abs_name, ...},
+          {type_definition, Abs_inverse, Abs_inject, ...}) : Typedef.info) =>
+        (Type (T_name, map TFree As), (Rep_name, Abs_name, type_definition, Abs_inverse, Abs_inject)))
+  end;
+
 fun seal_bnf qualify (unfold_set : unfold_set) b Ds bnf lthy =
   let
     val live = live_of_bnf bnf;
@@ -670,25 +699,23 @@
     fun unfold_all ctxt = unfold_sets ctxt o unfold_maps ctxt o unfold_rels ctxt;
 
     val repTA = mk_T_of_bnf Ds As bnf;
-    val repTB = mk_T_of_bnf Ds Bs bnf;
     val T_bind = qualify b;
     val TA_params = Term.add_tfreesT repTA [];
-    val TB_params = Term.add_tfreesT repTB [];
-    val ((T_name, (T_glob_info, T_loc_info)), lthy) =
-      typedef (T_bind, TA_params, NoSyn)
+    val ((TA, (Rep_name, Abs_name, type_definition, Abs_inverse, Abs_inject)), lthy) =
+      maybe_typedef (T_bind, TA_params, NoSyn)
         (HOLogic.mk_UNIV repTA) NONE (EVERY' [rtac exI, rtac UNIV_I] 1) lthy;
-    val TA = Type (T_name, map TFree TA_params);
-    val TB = Type (T_name, map TFree TB_params);
-    val RepA = Const (#Rep_name T_glob_info, TA --> repTA);
-    val RepB = Const (#Rep_name T_glob_info, TB --> repTB);
-    val AbsA = Const (#Abs_name T_glob_info, repTA --> TA);
-    val AbsB = Const (#Abs_name T_glob_info, repTB --> TB);
-    val typedef_thm = #type_definition T_loc_info;
-    val Abs_inject' = #Abs_inject T_loc_info OF @{thms UNIV_I UNIV_I};
-    val Abs_inverse' = #Abs_inverse T_loc_info OF @{thms UNIV_I};
+
+    val repTB = mk_T_of_bnf Ds Bs bnf;
+    val TB = Term.typ_subst_atomic (As ~~ Bs) TA;
+    val RepA = Const (Rep_name, TA --> repTA);
+    val RepB = Const (Rep_name, TB --> repTB);
+    val AbsA = Const (Abs_name, repTA --> TA);
+    val AbsB = Const (Abs_name, repTB --> TB);
+    val Abs_inject' = Abs_inject OF @{thms UNIV_I UNIV_I};
+    val Abs_inverse' = Abs_inverse OF @{thms UNIV_I};
 
     val absT_info = {absT = TA, repT = repTA, abs = AbsA, rep = RepA, abs_inject = Abs_inject',
-      abs_inverse = Abs_inverse', type_definition = typedef_thm};
+      abs_inverse = Abs_inverse', type_definition = type_definition};
 
     val bnf_map = fold_rev Term.absfree fs' (HOLogic.mk_comp (HOLogic.mk_comp (AbsB,
       Term.list_comb (expand_maps (mk_map_of_bnf Ds As Bs bnf), fs)), RepA));
@@ -722,15 +749,16 @@
       (@{thm Cinfinite_cong} OF [bd_ordIso, bd_Cinfinite_of_bnf bnf]) RS conjunct1;
 
     fun map_id0_tac ctxt =
-      rtac (@{thm type_copy_map_id0} OF [typedef_thm, unfold_maps ctxt (map_id0_of_bnf bnf)]) 1;
+      rtac (@{thm type_copy_map_id0} OF [type_definition, unfold_maps ctxt (map_id0_of_bnf bnf)]) 1;
     fun map_comp0_tac ctxt =
-      rtac (@{thm type_copy_map_comp0} OF [typedef_thm, unfold_maps ctxt (map_comp0_of_bnf bnf)]) 1;
+      rtac (@{thm type_copy_map_comp0} OF
+        [type_definition, unfold_maps ctxt (map_comp0_of_bnf bnf)]) 1;
     fun map_cong0_tac ctxt =
       EVERY' (rtac @{thm type_copy_map_cong0} :: rtac (unfold_all ctxt (map_cong0_of_bnf bnf)) ::
         map (fn i => EVERY' [select_prem_tac live (dtac meta_spec) i, etac meta_mp,
           etac (o_apply RS equalityD2 RS set_mp)]) (1 upto live)) 1;
     fun set_map0_tac thm ctxt =
-      rtac (@{thm type_copy_set_map0} OF [typedef_thm, unfold_all ctxt thm]) 1;
+      rtac (@{thm type_copy_set_map0} OF [type_definition, unfold_all ctxt thm]) 1;
     val set_bd_tacs = map (fn thm => fn ctxt => rtac (@{thm ordLeq_ordIso_trans} OF
         [unfold_sets ctxt thm, bd_ordIso] RS @{thm type_copy_set_bd}) 1)
       (set_bd_of_bnf bnf);
@@ -738,8 +766,9 @@
       rtac (unfold_rels ctxt (le_rel_OO_of_bnf bnf) RS @{thm vimage2p_relcompp_mono}) 1;
     fun rel_OO_Grp_tac ctxt =
       (rtac (unfold_all ctxt (rel_OO_Grp_of_bnf bnf) RS @{thm vimage2p_cong} RS trans) THEN'
-      SELECT_GOAL (unfold_thms_tac ctxt [o_apply, typedef_thm RS @{thm type_copy_vimage2p_Grp_Rep},
-        typedef_thm RS @{thm vimage2p_relcompp_converse}]) THEN' rtac refl) 1;
+      SELECT_GOAL (unfold_thms_tac ctxt [o_apply,
+        type_definition RS @{thm type_copy_vimage2p_Grp_Rep},
+        type_definition RS @{thm vimage2p_relcompp_converse}]) THEN' rtac refl) 1;
 
     val tacs = zip_axioms map_id0_tac map_comp0_tac map_cong0_tac
       (map set_map0_tac (set_map0_of_bnf bnf)) (K (rtac bd_card_order 1)) (K (rtac bd_cinfinite 1))
@@ -750,7 +779,7 @@
           (AbsA $ Term.list_comb (t, map Bound (0 upto length I - 1))))
       (mk_wits_of_bnf (replicate nwits Ds) (replicate nwits As) bnf);
 
-    fun wit_tac ctxt = ALLGOALS (dtac (typedef_thm RS @{thm type_copy_wit})) THEN
+    fun wit_tac ctxt = ALLGOALS (dtac (type_definition RS @{thm type_copy_wit})) THEN
       mk_simple_wit_tac (map (unfold_all ctxt) (wit_thms_of_bnf bnf));
 
     val (bnf', lthy') =