--- a/src/HOLCF/Tools/Domain/domain_constructors.ML Sat Feb 27 10:12:47 2010 -0800
+++ b/src/HOLCF/Tools/Domain/domain_constructors.ML Sat Feb 27 14:04:46 2010 -0800
@@ -12,6 +12,7 @@
-> typ * (binding * (bool * binding option * typ) list * mixfix) list
-> term * term
-> thm * thm
+ -> thm
-> theory
-> { con_consts : term list,
con_betas : thm list,
@@ -23,6 +24,7 @@
injects : thm list,
dist_les : thm list,
dist_eqs : thm list,
+ cases : thm list,
sel_rews : thm list
} * theory;
end;
@@ -267,6 +269,12 @@
val simple_ss : simpset = HOL_basic_ss addsimps simp_thms;
+val beta_ss =
+ HOL_basic_ss
+ addsimps simp_thms
+ addsimps [@{thm beta_cfun}]
+ addsimprocs [@{simproc cont_proc}];
+
fun define_consts
(specs : (binding * term * mixfix) list)
(thy : theory)
@@ -571,6 +579,76 @@
end;
(******************************************************************************)
+(**************** definition and theorems for case combinator *****************)
+(******************************************************************************)
+
+fun add_case_combinator
+ (spec : (term * (bool * typ) list) list)
+ (lhsT : typ)
+ (dname : string)
+ (case_def : thm)
+ (con_betas : thm list)
+ (casedist : thm)
+ (iso_locale : thm)
+ (thy : theory) =
+ let
+
+ (* prove rep/abs rules *)
+ val rep_strict = iso_locale RS @{thm iso.rep_strict};
+ val abs_inverse = iso_locale RS @{thm iso.abs_iso};
+
+ (* calculate function arguments of case combinator *)
+ val resultT = TVar (("'a",0), @{sort pcpo});
+ val fTs = map (fn (_, args) => map snd args -->> resultT) spec;
+ val fns = Datatype_Prop.indexify_names (map (K "f") spec);
+ val fs = map Free (fns ~~ fTs);
+ val caseT = fTs -->> (lhsT ->> resultT);
+
+ (* TODO: move definition of case combinator here *)
+ val case_bind = Binding.name (dname ^ "_when");
+ val case_const = Const (Sign.full_name thy case_bind, caseT);
+ val case_app = list_ccomb (case_const, fs);
+
+ (* prove beta reduction rule for case combinator *)
+ val case_beta = beta_of_def thy case_def;
+
+ (* prove strictness of case combinator *)
+ val case_strict =
+ let
+ val defs = [case_beta, mk_meta_eq rep_strict];
+ val lhs = case_app ` mk_bottom lhsT;
+ val goal = mk_trp (mk_eq (lhs, mk_bottom resultT));
+ val tacs = [resolve_tac @{thms sscase1 ssplit1 strictify1} 1];
+ in prove thy defs goal (K tacs) end;
+
+ (* prove rewrites for case combinator *)
+ local
+ fun one_case (con, args) f =
+ let
+ val Ts = map snd args;
+ val ns = Name.variant_list fns (Datatype_Prop.make_tnames Ts);
+ val vs = map Free (ns ~~ Ts);
+ val nonlazy = map snd (filter_out (fst o fst) (args ~~ vs));
+ val assms = map (mk_trp o mk_defined) nonlazy;
+ val lhs = case_app ` list_ccomb (con, vs);
+ val rhs = list_ccomb (f, vs);
+ val concl = mk_trp (mk_eq (lhs, rhs));
+ val goal = Logic.list_implies (assms, concl);
+ val defs = case_beta :: con_betas;
+ val rules1 = @{thms sscase2 sscase3 ssplit2 fup2 ID1};
+ val rules2 = @{thms con_defined_iff_rules};
+ val rules = abs_inverse :: rules1 @ rules2;
+ val tacs = [asm_simp_tac (beta_ss addsimps rules) 1];
+ in prove thy defs goal (K tacs) end;
+ in
+ val case_apps = map2 one_case spec fs;
+ end
+
+ in
+ (case_strict :: case_apps, thy)
+ end
+
+(******************************************************************************)
(************** definitions and theorems for selector functions ***************)
(******************************************************************************)
@@ -722,6 +800,7 @@
spec : (binding * (bool * binding option * typ) list * mixfix) list)
(rep_const : term, abs_const : term)
(rep_iso_thm : thm, abs_iso_thm : thm)
+ (case_def : thm)
(thy : theory) =
let
@@ -741,7 +820,18 @@
in
add_constructors con_spec abs_const iso_locale thy
end;
- val {con_consts, con_betas, ...} = con_result;
+ val {con_consts, con_betas, casedist, ...} = con_result;
+
+ (* define case combinator *)
+ val (cases : thm list, thy) =
+ let
+ fun prep_arg (lazy, sel, T) = (lazy, T);
+ fun prep_con c (b, args, mx) = (c, map prep_arg args);
+ val case_spec = map2 prep_con con_consts spec;
+ in
+ add_case_combinator case_spec lhsT dname
+ case_def con_betas casedist iso_locale thy
+ end;
(* TODO: enable this earlier *)
val thy = Sign.add_path dname thy;
@@ -762,13 +852,14 @@
{ con_consts = con_consts,
con_betas = con_betas,
exhaust = #exhaust con_result,
- casedist = #casedist con_result,
+ casedist = casedist,
con_compacts = #con_compacts con_result,
con_rews = #con_rews con_result,
inverts = #inverts con_result,
injects = #injects con_result,
dist_les = #dist_les con_result,
dist_eqs = #dist_eqs con_result,
+ cases = cases,
sel_rews = sel_thms };
in
(result, thy)