--- a/src/HOLCF/Tools/Domain/domain_isomorphism.ML Wed Nov 18 12:41:43 2009 -0800
+++ b/src/HOLCF/Tools/Domain/domain_isomorphism.ML Wed Nov 18 15:01:00 2009 -0800
@@ -15,6 +15,14 @@
structure Domain_Isomorphism :> DOMAIN_ISOMORPHISM =
struct
+val beta_ss =
+ HOL_basic_ss
+ addsimps simp_thms
+ addsimps [@{thm beta_cfun}]
+ addsimprocs [@{simproc cont_proc}];
+
+val beta_tac = simp_tac beta_ss;
+
(******************************************************************************)
(******************************* building types *******************************)
(******************************************************************************)
@@ -79,6 +87,9 @@
val mk_trp = HOLogic.mk_Trueprop;
+val mk_fst = HOLogic.mk_fst;
+val mk_snd = HOLogic.mk_snd;
+
fun mk_cont t =
let val T = Term.fastype_of t
in Const(@{const_name cont}, T --> HOLogic.boolT) $ t end;
@@ -90,6 +101,79 @@
fun mk_Rep_of T =
Const (@{const_name Rep_of}, Term.itselfT T --> deflT) $ Logic.mk_type T;
+(* splits a cterm into the right and lefthand sides of equality *)
+fun dest_eqs t = HOLogic.dest_eq (HOLogic.dest_Trueprop t);
+
+fun mk_eqs (t, u) = HOLogic.mk_Trueprop (HOLogic.mk_eq (t, u));
+
+(******************************************************************************)
+(*************** fixed-point definitions and unfolding theorems ***************)
+(******************************************************************************)
+
+fun add_fixdefs
+ (spec : (binding * term) list)
+ (thy : theory) : thm list * theory =
+ let
+ val binds = map fst spec;
+ val (lhss, rhss) = ListPair.unzip (map (dest_eqs o snd) spec);
+ val functional = lambda_tuple lhss (mk_tuple rhss);
+ val fixpoint = mk_fix (mk_cabs functional);
+
+ (* project components of fixpoint *)
+ fun mk_projs (x::[]) t = [(x, t)]
+ | mk_projs (x::xs) t = (x, mk_fst t) :: mk_projs xs (mk_snd t);
+ val projs = mk_projs lhss fixpoint;
+
+ (* convert parameters to lambda abstractions *)
+ fun mk_eqn (lhs, rhs) =
+ case lhs of
+ Const (@{const_name Rep_CFun}, _) $ f $ (x as Free _) =>
+ mk_eqn (f, big_lambda x rhs)
+ | Const _ => Logic.mk_equals (lhs, rhs)
+ | _ => raise TERM ("lhs not of correct form", [lhs, rhs]);
+ val eqns = map mk_eqn projs;
+
+ (* register constant definitions *)
+ val (fixdef_thms, thy2) =
+ (PureThy.add_defs false o map Thm.no_attributes)
+ (map (Binding.suffix_name "_def") binds ~~ eqns) thy;
+
+ (* prove applied version of definitions *)
+ fun prove_proj (lhs, rhs) =
+ let
+ val tac = rewrite_goals_tac fixdef_thms THEN beta_tac 1;
+ val goal = Logic.mk_equals (lhs, rhs);
+ in Goal.prove_global thy2 [] [] goal (K tac) end;
+ val proj_thms = map prove_proj projs;
+
+ (* mk_tuple lhss == fixpoint *)
+ fun pair_equalI (thm1, thm2) = @{thm Pair_equalI} OF [thm1, thm2];
+ val tuple_fixdef_thm = foldr1 pair_equalI proj_thms;
+
+ val cont_thm =
+ Goal.prove_global thy2 [] [] (mk_trp (mk_cont functional))
+ (K (beta_tac 1));
+ val tuple_unfold_thm =
+ (@{thm def_cont_fix_eq} OF [tuple_fixdef_thm, cont_thm])
+ |> LocalDefs.unfold (ProofContext.init thy2) @{thms split_conv};
+
+ fun mk_unfold_thms [] thm = []
+ | mk_unfold_thms (n::[]) thm = [(n, thm)]
+ | mk_unfold_thms (n::ns) thm = let
+ val thmL = thm RS @{thm Pair_eqD1};
+ val thmR = thm RS @{thm Pair_eqD2};
+ in (n, thmL) :: mk_unfold_thms ns thmR end;
+ val unfold_binds = map (Binding.suffix_name "_unfold") binds;
+
+ (* register unfold theorems *)
+ val (unfold_thms, thy3) =
+ (PureThy.add_thms o map (Thm.no_attributes o apsnd Drule.standard))
+ (mk_unfold_thms unfold_binds tuple_unfold_thm) thy2;
+ in
+ (unfold_thms, thy3)
+ end;
+
+
(******************************************************************************)
fun typ_of_dtyp
@@ -130,7 +214,7 @@
fun defl_of (TFree (a, _)) = free a
| defl_of (TVar _) = error ("defl_of_typ: TVar")
| defl_of (T as Type (c, Ts)) =
- case Symtab.lookup defl_tab c of
+ case Symtab.lookup tab c of
SOME t => Library.foldl mk_capply (t, map defl_of Ts)
| NONE => if is_closed_typ T
then mk_Rep_of T
@@ -200,114 +284,49 @@
sorts : (string * sort) list) =
fold_map (prep_dom tmp_thy) doms_raw [];
+ (* domain equations *)
+ fun mk_dom_eqn (vs, tbind, mx, rhs) =
+ let fun arg v = TFree (v, the (AList.lookup (op =) sorts v));
+ in (Type (Sign.full_name tmp_thy tbind, map arg vs), rhs) end;
+ val dom_eqns = map mk_dom_eqn doms;
+
+ (* check for valid type parameters *)
val (tyvars, _, _, _)::_ = doms;
- val (new_doms, types_syntax) = ListPair.unzip (map (fn (tvs, tname, mx, _) =>
+ val new_doms = map (fn (tvs, tname, mx, _) =>
let val full_tname = Sign.full_name tmp_thy tname
in
(case duplicates (op =) tvs of
[] =>
- if eq_set (op =) (tyvars, tvs) then ((full_tname, tvs), (tname, mx))
+ if eq_set (op =) (tyvars, tvs) then (full_tname, tvs)
else error ("Mutually recursive domains must have same type parameters")
| dups => error ("Duplicate parameter(s) for domain " ^ quote (Binding.str_of tname) ^
" : " ^ commas dups))
- end) doms);
+ end) doms;
val dom_names = map fst new_doms;
- val dtyps =
- map (fn (vs, t, mx, rhs) => DatatypeAux.dtyp_of_typ new_doms rhs) doms;
-
- fun unprime a = Library.unprefix "'" a;
- fun free_defl a = Free (a, deflT);
-
- val (ts, rs) =
- let
- val used = map unprime tyvars;
- val i = length doms;
- val ns = map (fn i => "r" ^ ML_Syntax.print_int i) (1 upto i);
- val ns' = Name.variant_list used ns;
- in (map free_defl used, map free_defl ns') end;
-
- val defls =
- map (defl_of_dtyp new_doms sorts (free_defl o unprime) (nth rs)) dtyps;
- val functional = lambda_tuple rs (mk_tuple defls);
- val fixpoint = mk_fix (mk_cabs functional);
-
- fun projs t (_::[]) = [t]
- | projs t (_::xs) = HOLogic.mk_fst t :: projs (HOLogic.mk_snd t) xs;
- fun typ_eqn ((tvs, tbind, mx, _), t) =
+ (* declare type combinator constants *)
+ fun declare_typ_const (vs, tbind, mx, rhs) thy =
let
- val typ_type = Library.foldr cfunT (map (K deflT) tvs, deflT);
+ val typ_type = Library.foldr cfunT (map (K deflT) vs, deflT);
val typ_bind = Binding.suffix_name "_typ" tbind;
- val typ_name = Sign.full_name tmp_thy typ_bind;
- val typ_const = Const (typ_name, typ_type);
- val args = map (free_defl o unprime) tvs;
- val typ_rhs = big_lambdas args t;
- val typ_eqn = Logic.mk_equals (typ_const, typ_rhs);
- val typ_beta = Logic.mk_equals
- (Library.foldl mk_capply (typ_const, args), t);
- val typ_syn = (typ_bind, typ_type, NoSyn);
- val typ_def = (Binding.suffix_name "_def" typ_bind, typ_eqn);
in
- ((typ_syn, typ_def), (typ_beta, typ_const))
+ Sign.declare_const ((typ_bind, typ_type), NoSyn) thy
end;
- val ((typ_syns, typ_defs), (typ_betas, typ_consts)) =
- map typ_eqn (doms ~~ projs fixpoint doms)
- |> ListPair.unzip
- |> apfst ListPair.unzip
- |> apsnd ListPair.unzip;
- val (typ_def_thms, thy2) =
- thy
- |> Sign.add_consts_i typ_syns
- |> (PureThy.add_defs false o map Thm.no_attributes) typ_defs;
+ val (typ_consts, thy2) = fold_map declare_typ_const doms thy;
- val beta_ss = HOL_basic_ss
- addsimps simp_thms
- addsimps [@{thm beta_cfun}]
- addsimprocs [@{simproc cont_proc}];
- val beta_tac = rewrite_goals_tac typ_def_thms THEN simp_tac beta_ss 1;
- val typ_beta_thms =
- map (fn t => Goal.prove_global thy2 [] [] t (K beta_tac)) typ_betas;
-
- fun pair_equalI (thm1, thm2) = @{thm Pair_equalI} OF [thm1, thm2];
- val tuple_typ_thm = Drule.standard (foldr1 pair_equalI typ_beta_thms);
-
- val tuple_cont_thm =
- Goal.prove_global thy2 [] [] (mk_trp (mk_cont functional))
- (K (simp_tac beta_ss 1));
- val tuple_unfold_thm =
- (@{thm def_cont_fix_eq} OF [tuple_typ_thm, tuple_cont_thm])
- |> LocalDefs.unfold (ProofContext.init thy2) @{thms split_conv};
+ (* defining equations for type combinators *)
+ val defl_tab1 = defl_tab; (* FIXME: use theory data *)
+ val defl_tab2 =
+ Symtab.make (map (fst o dest_Type o fst) dom_eqns ~~ typ_consts);
+ val defl_tab' = Symtab.merge (K true) (defl_tab1, defl_tab2);
+ fun free a = Free (Library.unprefix "'" a, deflT);
+ fun mk_defl_spec (lhs, rhs) =
+ mk_eqs (defl_of_typ defl_tab' free lhs, defl_of_typ defl_tab' free rhs);
+ val defl_specs = map mk_defl_spec dom_eqns;
- fun typ_unfold_eqn ((tvs, tbind, mx, _), t) =
- let
- val typ_type = Library.foldr cfunT (map (K deflT) tvs, deflT);
- val typ_bind = Binding.suffix_name "_typ" tbind;
- val typ_name = Sign.full_name tmp_thy typ_bind;
- val typ_const = Const (typ_name, typ_type);
- val args = map (free_defl o unprime) tvs;
- val typ_rhs = big_lambdas args t;
- val typ_eqn = Logic.mk_equals (typ_const, typ_rhs);
- val typ_beta = Logic.mk_equals
- (Library.foldl mk_capply (typ_const, args), t);
- val typ_syn = (typ_bind, typ_type, NoSyn);
- val typ_def = (Binding.suffix_name "_def" typ_bind, typ_eqn);
- in
- ((typ_syn, typ_def), (typ_beta, typ_const))
- end;
-
- val typ_unfold_names =
- map (Binding.suffix_name "_typ_unfold" o #2) doms;
- fun unfolds [] thm = []
- | unfolds (n::[]) thm = [(n, thm)]
- | unfolds (n::ns) thm = let
- val thmL = thm RS @{thm Pair_eqD1};
- val thmR = thm RS @{thm Pair_eqD2};
- in (n, thmL) :: unfolds ns thmR end;
- val typ_unfold_thms =
- map (apsnd Drule.standard) (unfolds typ_unfold_names tuple_unfold_thm);
-
- val (_, thy3) = thy2
- |> (PureThy.add_thms o map Thm.no_attributes) typ_unfold_thms;
+ (* register recursive definition of type combinators *)
+ val typ_binds = map (Binding.suffix_name "_typ" o #2) doms;
+ val (typ_unfold_thms, thy3) = add_fixdefs (typ_binds ~~ defl_specs) thy2;
fun make_repdef ((vs, tbind, mx, _), typ_const) thy =
let