# HG changeset patch # User huffman # Date 1258585260 28800 # Node ID 7a1518c42c5657eb6e543dfd01b5faf9827a1fc6 # Parent e11e05b32548310fee6e97fe83082ec83eb4eacd cleaned up; factored out fixed-point definition code diff -r e11e05b32548 -r 7a1518c42c56 src/HOLCF/Tools/Domain/domain_isomorphism.ML --- a/src/HOLCF/Tools/Domain/domain_isomorphism.ML Wed Nov 18 12:41:43 2009 -0800 +++ b/src/HOLCF/Tools/Domain/domain_isomorphism.ML Wed Nov 18 15:01:00 2009 -0800 @@ -15,6 +15,14 @@ structure Domain_Isomorphism :> DOMAIN_ISOMORPHISM = struct +val beta_ss = + HOL_basic_ss + addsimps simp_thms + addsimps [@{thm beta_cfun}] + addsimprocs [@{simproc cont_proc}]; + +val beta_tac = simp_tac beta_ss; + (******************************************************************************) (******************************* building types *******************************) (******************************************************************************) @@ -79,6 +87,9 @@ val mk_trp = HOLogic.mk_Trueprop; +val mk_fst = HOLogic.mk_fst; +val mk_snd = HOLogic.mk_snd; + fun mk_cont t = let val T = Term.fastype_of t in Const(@{const_name cont}, T --> HOLogic.boolT) $ t end; @@ -90,6 +101,79 @@ fun mk_Rep_of T = Const (@{const_name Rep_of}, Term.itselfT T --> deflT) $ Logic.mk_type T; +(* splits a cterm into the right and lefthand sides of equality *) +fun dest_eqs t = HOLogic.dest_eq (HOLogic.dest_Trueprop t); + +fun mk_eqs (t, u) = HOLogic.mk_Trueprop (HOLogic.mk_eq (t, u)); + +(******************************************************************************) +(*************** fixed-point definitions and unfolding theorems ***************) +(******************************************************************************) + +fun add_fixdefs + (spec : (binding * term) list) + (thy : theory) : thm list * theory = + let + val binds = map fst spec; + val (lhss, rhss) = ListPair.unzip (map (dest_eqs o snd) spec); + val functional = lambda_tuple lhss (mk_tuple rhss); + val fixpoint = mk_fix (mk_cabs functional); + + (* project components of fixpoint *) + fun mk_projs (x::[]) t = [(x, t)] + | mk_projs (x::xs) t = (x, mk_fst t) :: mk_projs xs (mk_snd t); + val projs = mk_projs lhss fixpoint; + + (* convert parameters to lambda abstractions *) + fun mk_eqn (lhs, rhs) = + case lhs of + Const (@{const_name Rep_CFun}, _) $ f $ (x as Free _) => + mk_eqn (f, big_lambda x rhs) + | Const _ => Logic.mk_equals (lhs, rhs) + | _ => raise TERM ("lhs not of correct form", [lhs, rhs]); + val eqns = map mk_eqn projs; + + (* register constant definitions *) + val (fixdef_thms, thy2) = + (PureThy.add_defs false o map Thm.no_attributes) + (map (Binding.suffix_name "_def") binds ~~ eqns) thy; + + (* prove applied version of definitions *) + fun prove_proj (lhs, rhs) = + let + val tac = rewrite_goals_tac fixdef_thms THEN beta_tac 1; + val goal = Logic.mk_equals (lhs, rhs); + in Goal.prove_global thy2 [] [] goal (K tac) end; + val proj_thms = map prove_proj projs; + + (* mk_tuple lhss == fixpoint *) + fun pair_equalI (thm1, thm2) = @{thm Pair_equalI} OF [thm1, thm2]; + val tuple_fixdef_thm = foldr1 pair_equalI proj_thms; + + val cont_thm = + Goal.prove_global thy2 [] [] (mk_trp (mk_cont functional)) + (K (beta_tac 1)); + val tuple_unfold_thm = + (@{thm def_cont_fix_eq} OF [tuple_fixdef_thm, cont_thm]) + |> LocalDefs.unfold (ProofContext.init thy2) @{thms split_conv}; + + fun mk_unfold_thms [] thm = [] + | mk_unfold_thms (n::[]) thm = [(n, thm)] + | mk_unfold_thms (n::ns) thm = let + val thmL = thm RS @{thm Pair_eqD1}; + val thmR = thm RS @{thm Pair_eqD2}; + in (n, thmL) :: mk_unfold_thms ns thmR end; + val unfold_binds = map (Binding.suffix_name "_unfold") binds; + + (* register unfold theorems *) + val (unfold_thms, thy3) = + (PureThy.add_thms o map (Thm.no_attributes o apsnd Drule.standard)) + (mk_unfold_thms unfold_binds tuple_unfold_thm) thy2; + in + (unfold_thms, thy3) + end; + + (******************************************************************************) fun typ_of_dtyp @@ -130,7 +214,7 @@ fun defl_of (TFree (a, _)) = free a | defl_of (TVar _) = error ("defl_of_typ: TVar") | defl_of (T as Type (c, Ts)) = - case Symtab.lookup defl_tab c of + case Symtab.lookup tab c of SOME t => Library.foldl mk_capply (t, map defl_of Ts) | NONE => if is_closed_typ T then mk_Rep_of T @@ -200,114 +284,49 @@ sorts : (string * sort) list) = fold_map (prep_dom tmp_thy) doms_raw []; + (* domain equations *) + fun mk_dom_eqn (vs, tbind, mx, rhs) = + let fun arg v = TFree (v, the (AList.lookup (op =) sorts v)); + in (Type (Sign.full_name tmp_thy tbind, map arg vs), rhs) end; + val dom_eqns = map mk_dom_eqn doms; + + (* check for valid type parameters *) val (tyvars, _, _, _)::_ = doms; - val (new_doms, types_syntax) = ListPair.unzip (map (fn (tvs, tname, mx, _) => + val new_doms = map (fn (tvs, tname, mx, _) => let val full_tname = Sign.full_name tmp_thy tname in (case duplicates (op =) tvs of [] => - if eq_set (op =) (tyvars, tvs) then ((full_tname, tvs), (tname, mx)) + if eq_set (op =) (tyvars, tvs) then (full_tname, tvs) else error ("Mutually recursive domains must have same type parameters") | dups => error ("Duplicate parameter(s) for domain " ^ quote (Binding.str_of tname) ^ " : " ^ commas dups)) - end) doms); + end) doms; val dom_names = map fst new_doms; - val dtyps = - map (fn (vs, t, mx, rhs) => DatatypeAux.dtyp_of_typ new_doms rhs) doms; - - fun unprime a = Library.unprefix "'" a; - fun free_defl a = Free (a, deflT); - - val (ts, rs) = - let - val used = map unprime tyvars; - val i = length doms; - val ns = map (fn i => "r" ^ ML_Syntax.print_int i) (1 upto i); - val ns' = Name.variant_list used ns; - in (map free_defl used, map free_defl ns') end; - - val defls = - map (defl_of_dtyp new_doms sorts (free_defl o unprime) (nth rs)) dtyps; - val functional = lambda_tuple rs (mk_tuple defls); - val fixpoint = mk_fix (mk_cabs functional); - - fun projs t (_::[]) = [t] - | projs t (_::xs) = HOLogic.mk_fst t :: projs (HOLogic.mk_snd t) xs; - fun typ_eqn ((tvs, tbind, mx, _), t) = + (* declare type combinator constants *) + fun declare_typ_const (vs, tbind, mx, rhs) thy = let - val typ_type = Library.foldr cfunT (map (K deflT) tvs, deflT); + val typ_type = Library.foldr cfunT (map (K deflT) vs, deflT); val typ_bind = Binding.suffix_name "_typ" tbind; - val typ_name = Sign.full_name tmp_thy typ_bind; - val typ_const = Const (typ_name, typ_type); - val args = map (free_defl o unprime) tvs; - val typ_rhs = big_lambdas args t; - val typ_eqn = Logic.mk_equals (typ_const, typ_rhs); - val typ_beta = Logic.mk_equals - (Library.foldl mk_capply (typ_const, args), t); - val typ_syn = (typ_bind, typ_type, NoSyn); - val typ_def = (Binding.suffix_name "_def" typ_bind, typ_eqn); in - ((typ_syn, typ_def), (typ_beta, typ_const)) + Sign.declare_const ((typ_bind, typ_type), NoSyn) thy end; - val ((typ_syns, typ_defs), (typ_betas, typ_consts)) = - map typ_eqn (doms ~~ projs fixpoint doms) - |> ListPair.unzip - |> apfst ListPair.unzip - |> apsnd ListPair.unzip; - val (typ_def_thms, thy2) = - thy - |> Sign.add_consts_i typ_syns - |> (PureThy.add_defs false o map Thm.no_attributes) typ_defs; + val (typ_consts, thy2) = fold_map declare_typ_const doms thy; - val beta_ss = HOL_basic_ss - addsimps simp_thms - addsimps [@{thm beta_cfun}] - addsimprocs [@{simproc cont_proc}]; - val beta_tac = rewrite_goals_tac typ_def_thms THEN simp_tac beta_ss 1; - val typ_beta_thms = - map (fn t => Goal.prove_global thy2 [] [] t (K beta_tac)) typ_betas; - - fun pair_equalI (thm1, thm2) = @{thm Pair_equalI} OF [thm1, thm2]; - val tuple_typ_thm = Drule.standard (foldr1 pair_equalI typ_beta_thms); - - val tuple_cont_thm = - Goal.prove_global thy2 [] [] (mk_trp (mk_cont functional)) - (K (simp_tac beta_ss 1)); - val tuple_unfold_thm = - (@{thm def_cont_fix_eq} OF [tuple_typ_thm, tuple_cont_thm]) - |> LocalDefs.unfold (ProofContext.init thy2) @{thms split_conv}; + (* defining equations for type combinators *) + val defl_tab1 = defl_tab; (* FIXME: use theory data *) + val defl_tab2 = + Symtab.make (map (fst o dest_Type o fst) dom_eqns ~~ typ_consts); + val defl_tab' = Symtab.merge (K true) (defl_tab1, defl_tab2); + fun free a = Free (Library.unprefix "'" a, deflT); + fun mk_defl_spec (lhs, rhs) = + mk_eqs (defl_of_typ defl_tab' free lhs, defl_of_typ defl_tab' free rhs); + val defl_specs = map mk_defl_spec dom_eqns; - fun typ_unfold_eqn ((tvs, tbind, mx, _), t) = - let - val typ_type = Library.foldr cfunT (map (K deflT) tvs, deflT); - val typ_bind = Binding.suffix_name "_typ" tbind; - val typ_name = Sign.full_name tmp_thy typ_bind; - val typ_const = Const (typ_name, typ_type); - val args = map (free_defl o unprime) tvs; - val typ_rhs = big_lambdas args t; - val typ_eqn = Logic.mk_equals (typ_const, typ_rhs); - val typ_beta = Logic.mk_equals - (Library.foldl mk_capply (typ_const, args), t); - val typ_syn = (typ_bind, typ_type, NoSyn); - val typ_def = (Binding.suffix_name "_def" typ_bind, typ_eqn); - in - ((typ_syn, typ_def), (typ_beta, typ_const)) - end; - - val typ_unfold_names = - map (Binding.suffix_name "_typ_unfold" o #2) doms; - fun unfolds [] thm = [] - | unfolds (n::[]) thm = [(n, thm)] - | unfolds (n::ns) thm = let - val thmL = thm RS @{thm Pair_eqD1}; - val thmR = thm RS @{thm Pair_eqD2}; - in (n, thmL) :: unfolds ns thmR end; - val typ_unfold_thms = - map (apsnd Drule.standard) (unfolds typ_unfold_names tuple_unfold_thm); - - val (_, thy3) = thy2 - |> (PureThy.add_thms o map Thm.no_attributes) typ_unfold_thms; + (* register recursive definition of type combinators *) + val typ_binds = map (Binding.suffix_name "_typ" o #2) doms; + val (typ_unfold_thms, thy3) = add_fixdefs (typ_binds ~~ defl_specs) thy2; fun make_repdef ((vs, tbind, mx, _), typ_const) thy = let