--- a/src/HOLCF/Tools/Domain/domain_isomorphism.ML Wed Oct 27 11:06:53 2010 -0700
+++ b/src/HOLCF/Tools/Domain/domain_isomorphism.ML Wed Oct 27 11:10:36 2010 -0700
@@ -27,8 +27,9 @@
val domain_isomorphism_cmd :
(string list * binding * mixfix * string * (binding * binding) option) list
-> theory -> theory
+
val add_type_constructor :
- (string * term * string) -> theory -> theory
+ (string * string * string * bool list) -> theory -> theory
val setup : theory -> theory
end;
@@ -44,14 +45,19 @@
val beta_tac = simp_tac beta_ss;
+fun is_cpo thy T = Sign.of_sort thy (T, @{sort cpo});
+fun is_bifinite thy T = Sign.of_sort thy (T, @{sort bifinite});
+
(******************************************************************************)
(******************************** theory data *********************************)
(******************************************************************************)
structure DeflData = Theory_Data
(
- (* terms like "foo_defl" *)
- type T = term Symtab.table;
+ (* constant names like "foo_defl" *)
+ (* list indicates which type arguments correspond to deflation parameters *)
+ (* alternatively, which type arguments allow indirect recursion *)
+ type T = (string * bool list) Symtab.table;
val empty = Symtab.empty;
val extend = I;
fun merge data = Symtab.merge (K true) data;
@@ -76,9 +82,9 @@
);
fun add_type_constructor
- (tname, defl_const, map_name) =
- DeflData.map (Symtab.insert (K true) (tname, defl_const))
- #> Domain_Take_Proofs.add_map_function (tname, map_name)
+ (tname, defl_name, map_name, flags) =
+ DeflData.map (Symtab.insert (K true) (tname, (defl_name, flags)))
+ #> Domain_Take_Proofs.add_map_function (tname, map_name, flags)
val setup =
RepData.setup #> MapIdData.setup #> IsodeflData.setup
@@ -91,15 +97,11 @@
open HOLCF_Library;
infixr 6 ->>;
-infix -->>;
+infixr -->>;
val udomT = @{typ udom};
val deflT = @{typ "defl"};
-fun mapT (T as Type (_, Ts)) =
- (map (fn T => T ->> T) Ts) -->> (T ->> T)
- | mapT T = T ->> T;
-
fun mk_DEFL T =
Const (@{const_name defl}, Term.itselfT T --> deflT) $ Logic.mk_type T;
@@ -155,6 +157,8 @@
case lhs of
Const (@{const_name Rep_CFun}, _) $ f $ (x as Free _) =>
mk_eqn (f, big_lambda x rhs)
+ | f $ Const (@{const_name TYPE}, T) =>
+ mk_eqn (f, Abs ("t", T, rhs))
| Const _ => Logic.mk_equals (lhs, rhs)
| _ => raise TERM ("lhs not of correct form", [lhs, rhs]);
val eqns = map mk_eqn projs;
@@ -204,17 +208,30 @@
(****************** deflation combinators and map functions *******************)
(******************************************************************************)
+fun mk_defl_type (flags : bool list) (Ts : typ list) =
+ map (Term.itselfT o snd) (filter_out fst (flags ~~ Ts)) --->
+ map (K deflT) (filter I flags) -->> deflT;
+
fun defl_of_typ
- (tab : term Symtab.table)
+ (thy : theory)
+ (tab : (string * bool list) Symtab.table)
(T : typ) : term =
let
fun is_closed_typ (Type (_, Ts)) = forall is_closed_typ Ts
+ | is_closed_typ (TFree (n, s)) = not (Sign.subsort thy (s, @{sort bifinite}))
| is_closed_typ _ = false;
fun defl_of (TFree (a, _)) = Free (Library.unprefix "'" a, deflT)
| defl_of (TVar _) = error ("defl_of_typ: TVar")
| defl_of (T as Type (c, Ts)) =
case Symtab.lookup tab c of
- SOME t => list_ccomb (t, map defl_of Ts)
+ SOME (s, flags) =>
+ let
+ val defl_const = Const (s, mk_defl_type flags Ts);
+ val type_args = map (Logic.mk_type o snd) (filter_out fst (flags ~~ Ts));
+ val defl_args = map (defl_of o snd) (filter fst (flags ~~ Ts));
+ in
+ list_ccomb (list_comb (defl_const, type_args), defl_args)
+ end
| NONE => if is_closed_typ T
then mk_DEFL T
else error ("defl_of_typ: type variable under unsupported type constructor " ^ c);
@@ -258,6 +275,10 @@
val dom_eqns = map (fn x => (#absT x, #repT x)) iso_infos;
val rep_abs_consts = map (fn x => (#rep_const x, #abs_const x)) iso_infos;
+ fun mapT (T as Type (_, Ts)) =
+ (map (fn T => T ->> T) (filter (is_cpo thy) Ts)) -->> (T ->> T)
+ | mapT T = T ->> T;
+
(* declare map functions *)
fun declare_map_const (tbind, (lhsT, rhsT)) thy =
let
@@ -274,7 +295,7 @@
fun unprime a = Library.unprefix "'" a;
fun mapvar T = Free (unprime (fst (dest_TFree T)), T ->> T);
fun map_lhs (map_const, lhsT) =
- (lhsT, list_ccomb (map_const, map mapvar (snd (dest_Type lhsT))));
+ (lhsT, list_ccomb (map_const, map mapvar (filter (is_cpo thy) (snd (dest_Type lhsT)))));
val tab1 = map map_lhs (map_consts ~~ map fst dom_eqns);
val Ts = (snd o dest_Type o fst o hd) dom_eqns;
val tab = (Ts ~~ map mapvar Ts) @ tab1;
@@ -304,9 +325,9 @@
fun mk_goal (map_const, (lhsT, rhsT)) =
let
val (_, Ts) = dest_Type lhsT;
- val map_term = list_ccomb (map_const, map mk_f Ts);
+ val map_term = list_ccomb (map_const, map mk_f (filter (is_cpo thy) Ts));
in mk_deflation map_term end;
- val assms = (map mk_assm o snd o dest_Type o fst o hd) dom_eqns;
+ val assms = (map mk_assm o filter (is_cpo thy) o snd o dest_Type o fst o hd) dom_eqns;
val goals = map mk_goal (map_consts ~~ dom_eqns);
val goal = mk_trp (foldr1 HOLogic.mk_conj goals);
val start_thms =
@@ -346,11 +367,15 @@
(* register map functions in theory data *)
local
+ fun register_map ((dname, map_name), args) =
+ Domain_Take_Proofs.add_map_function (dname, map_name, args);
val dnames = map (fst o dest_Type o fst) dom_eqns;
val map_names = map (fst o dest_Const) map_consts;
+ fun args (T, _) = case T of Type (_, Ts) => map (is_cpo thy) Ts | _ => [];
+ val argss = map args dom_eqns;
in
val thy =
- fold Domain_Take_Proofs.add_map_function (dnames ~~ map_names) thy;
+ fold register_map (dnames ~~ map_names ~~ argss) thy;
end;
(* register deflation theorems *)
@@ -433,25 +458,35 @@
val dbinds = map (fn (_, dbind, _, _, _) => dbind) doms;
val morphs = map (fn (_, _, _, _, morphs) => morphs) doms;
+ (* determine deflation combinator arguments *)
+ fun defl_flags (vs, tbind, mx, rhs, morphs) =
+ let fun argT v = TFree (v, the (AList.lookup (op =) sorts v));
+ in map (is_bifinite thy o argT) vs end;
+ val defl_flagss = map defl_flags doms;
+
(* declare deflation combinator constants *)
fun declare_defl_const (vs, tbind, mx, rhs, morphs) thy =
let
- val defl_type = map (K deflT) vs -->> deflT;
+ fun argT v = TFree (v, the (AList.lookup (op =) sorts v));
+ val Ts = map argT vs;
+ val flags = map (is_bifinite thy) Ts;
+ val defl_type = mk_defl_type flags Ts;
val defl_bind = Binding.suffix_name "_defl" tbind;
in
Sign.declare_const ((defl_bind, defl_type), NoSyn) thy
end;
val (defl_consts, thy) = fold_map declare_defl_const doms thy;
+ val defl_names = map (fst o dest_Const) defl_consts;
(* defining equations for type combinators *)
+ val dnames = map (fst o dest_Type o fst) dom_eqns;
val defl_tab1 = DeflData.get thy;
- val defl_tab2 =
- Symtab.make (map (fst o dest_Type o fst) dom_eqns ~~ defl_consts);
+ val defl_tab2 = Symtab.make (dnames ~~ (defl_names ~~ defl_flagss));
val defl_tab' = Symtab.merge (K true) (defl_tab1, defl_tab2);
val thy = DeflData.put defl_tab' thy;
fun mk_defl_spec (lhsT, rhsT) =
- mk_eqs (defl_of_typ defl_tab' lhsT,
- defl_of_typ defl_tab' rhsT);
+ mk_eqs (defl_of_typ thy defl_tab' lhsT,
+ defl_of_typ thy defl_tab' rhsT);
val defl_specs = map mk_defl_spec dom_eqns;
(* register recursive definition of deflation combinators *)
@@ -462,9 +497,11 @@
(* define types using deflation combinators *)
fun make_repdef ((vs, tbind, mx, _, _), defl_const) thy =
let
- fun tfree a = TFree (a, the (AList.lookup (op =) sorts a))
- val reps = map (mk_DEFL o tfree) vs;
- val defl = list_ccomb (defl_const, reps);
+ fun tfree a = TFree (a, the (AList.lookup (op =) sorts a));
+ val Ts = map tfree vs;
+ val type_args = map Logic.mk_type (filter_out (is_bifinite thy) Ts);
+ val defl_args = map mk_DEFL (filter (is_bifinite thy) Ts);
+ val defl = list_ccomb (list_comb (defl_const, type_args), defl_args);
val ((_, _, _, {DEFL, ...}), thy) =
Repdef.add_repdef false NONE (tbind, map (rpair dummyS) vs, mx) defl NONE thy;
in
@@ -560,10 +597,12 @@
fun mk_goal ((map_const, defl_const), (T, rhsT)) =
let
val (_, Ts) = dest_Type T;
- val map_term = list_ccomb (map_const, map mk_f Ts);
- val defl_term = list_ccomb (defl_const, map mk_d Ts);
+ val map_term = list_ccomb (map_const, map mk_f (filter (is_cpo thy) Ts));
+ val type_args = map Logic.mk_type (filter_out (is_bifinite thy) Ts);
+ val defl_args = map mk_d (filter (is_bifinite thy) Ts);
+ val defl_term = list_ccomb (list_comb (defl_const, type_args), defl_args);
in isodefl_const T $ map_term $ defl_term end;
- val assms = (map mk_assm o snd o dest_Type o fst o hd) dom_eqns;
+ val assms = (map mk_assm o filter (is_cpo thy) o snd o dest_Type o fst o hd) dom_eqns;
val goals = map mk_goal (map_consts ~~ defl_consts ~~ dom_eqns);
val goal = mk_trp (foldr1 HOLogic.mk_conj goals);
val start_thms =
@@ -610,7 +649,8 @@
(((map_const, (lhsT, _)), DEFL_thm), isodefl_thm) =
let
val Ts = snd (dest_Type lhsT);
- val lhs = list_ccomb (map_const, map mk_ID Ts);
+ fun is_cpo T = Sign.of_sort thy (T, @{sort cpo});
+ val lhs = list_ccomb (map_const, map mk_ID (filter is_cpo Ts));
val goal = mk_eqs (lhs, mk_ID lhsT);
val tac = EVERY
[rtac @{thm isodefl_DEFL_imp_ID} 1,
@@ -640,8 +680,9 @@
val lub_take_lemma =
let
val lhs = mk_tuple (map mk_lub take_consts);
+ fun is_cpo T = Sign.of_sort thy (T, @{sort cpo});
fun mk_map_ID (map_const, (lhsT, rhsT)) =
- list_ccomb (map_const, map mk_ID (snd (dest_Type lhsT)));
+ list_ccomb (map_const, map mk_ID (filter is_cpo (snd (dest_Type lhsT))));
val rhs = mk_tuple (map mk_map_ID (map_consts ~~ dom_eqns));
val goal = mk_trp (mk_eq (lhs, rhs));
val map_ID_thms = MapIdData.get (ProofContext.init_global thy);