# HG changeset patch # User huffman # Date 1118861438 -7200 # Node ID 57c35ede00b9db0f48d721ca15b4db33948bd74b # Parent f2ab5797bbd0b4460241dbb6f3550244047e8726 fixrec package now handles mutually-recursive definitions diff -r f2ab5797bbd0 -r 57c35ede00b9 src/HOLCF/Fixrec.thy --- a/src/HOLCF/Fixrec.thy Wed Jun 15 14:59:25 2005 +0200 +++ b/src/HOLCF/Fixrec.thy Wed Jun 15 20:50:38 2005 +0200 @@ -134,6 +134,27 @@ "match_up\\ = \" by (simp_all add: match_up_def) +subsection {* Mutual recursion *} + +text {* + The following rules are used to prove unfolding theorems from + fixed-point definitions of mutually recursive functions. +*} + +lemma cpair_equalI: "\x \ cfst\p; y \ csnd\p\ \ \ p" +by (simp add: surjective_pairing_Cprod2) + +lemma cpair_eqD1: " = \ x = x'" +by simp + +lemma cpair_eqD2: " = \ y = y'" +by simp + +ML {* +val cpair_equalI = thm "cpair_equalI"; +val cpair_eqD1 = thm "cpair_eqD1"; +val cpair_eqD2 = thm "cpair_eqD2"; +*} subsection {* Intitializing the fixrec package *} diff -r f2ab5797bbd0 -r 57c35ede00b9 src/HOLCF/fixrec_package.ML --- a/src/HOLCF/fixrec_package.ML Wed Jun 15 14:59:25 2005 +0200 +++ b/src/HOLCF/fixrec_package.ML Wed Jun 15 20:50:38 2005 +0200 @@ -7,7 +7,7 @@ signature FIXREC_PACKAGE = sig - val add_fixrec: string list -> theory -> theory + val add_fixrec: string list list -> theory -> theory val add_fixpat: string * string -> theory -> theory end; @@ -27,6 +27,8 @@ ("o"::"p"::" "::rest) => implode rest | _ => con; +val mk_trp = HOLogic.mk_Trueprop; + (* splits a cterm into the right and lefthand sides of equality *) fun dest_eqs (Const ("==", _)$lhs$rhs) = (lhs, rhs) | dest_eqs (Const ("Trueprop", _)$(Const ("op =", _)$lhs$rhs)) = (lhs,rhs) @@ -37,57 +39,92 @@ fun %%: s = Const(s,dummyT); infix 0 ==; fun S == T = %%:"==" $ S $ T; infix 1 ===; fun S === T = %%:"op =" $ S $ T; -infix 9 ` ; fun f ` x = %%:"Rep_CFun" $ f $ x; +infix 9 ` ; fun f ` x = %%:"Cfun.Rep_CFun" $ f $ x; (* infers the type of a term *) -fun infer t thy = #1 (Sign.infer_types (Sign.pp (sign_of thy)) (sign_of thy) (K NONE) (K NONE) [] true ([t],dummyT)); +(* similar to Theory.inferT_axm, but allows any type *) +fun infer sg t = + fst (Sign.infer_types (Sign.pp sg) sg (K NONE) (K NONE) [] true ([t],dummyT)); + +(* The next few functions build continuous lambda abstractions *) + +(* Similar to Term.lambda, but allows abstraction over constants *) +fun lambda' (v as Free (x, T)) t = Abs (x, T, abstract_over (v, t)) + | lambda' (v as Var ((x, _), T)) t = Abs (x, T, abstract_over (v, t)) + | lambda' (v as Const (x, T)) t = Abs (Sign.base_name x, T, abstract_over (v, t)) + | lambda' v t = raise TERM ("lambda'", [v, t]); + +(* builds the expression (LAM v. rhs) *) +fun big_lambda v rhs = %%:"Cfun.Abs_CFun"$(lambda' v rhs); + +(* builds the expression (LAM v1 v2 .. vn. rhs) *) +fun big_lambdas [] rhs = rhs + | big_lambdas (v::vs) rhs = big_lambda v (big_lambdas vs rhs); + +(* builds the expression (LAM . rhs) *) +fun lambda_ctuple [] rhs = big_lambda (%:"unit") rhs + | lambda_ctuple (v::[]) rhs = big_lambda v rhs + | lambda_ctuple (v::vs) rhs = + %%:"Cprod.csplit"`(big_lambda v (lambda_ctuple vs rhs)); + +(* builds the expression *) +fun mk_ctuple [] = %%:"UU" +| mk_ctuple (t::[]) = t +| mk_ctuple (t::ts) = %%:"Cprod.cpair"`t`(mk_ctuple ts); (*************************************************************************) -(************ fixed-point definitions and unfolding theorems *************) +(************************ fixed-point definitions ************************) (*************************************************************************) -fun func1 (lhs as Const(name,T), rhs) = +fun add_fixdefs eqs thy = let - val basename = Sign.base_name name; - val funcT = T ->> T; - val functional = Const ("Cfun.Abs_CFun", (T --> T) --> funcT) $ - Abs (basename, T, abstract_over (lhs,rhs)); - val fix_type = funcT ->> T; - val fix_const = Const ("Fix.fix", fix_type); - val func_type = fix_type --> funcT --> T; - val rhs' = Const ("Cfun.Rep_CFun",func_type)$fix_const$functional; + val (lhss,rhss) = ListPair.unzip (map dest_eqs eqs); + val fixpoint = %%:"Fix.fix"`lambda_ctuple lhss (mk_ctuple rhss); + + fun one_def (l as Const(n,T)) r = + let val b = Sign.base_name n in (b, (b^"_fixdef", l == r)) end + | one_def _ _ = sys_error "fixdefs: lhs not of correct form"; + fun defs [] _ = [] + | defs (l::[]) r = [one_def l r] + | defs (l::ls) r = one_def l (%%:"Cprod.cfst"`r) :: defs ls (%%:"Cprod.csnd"`r); + val (names, pre_fixdefs) = ListPair.unzip (defs lhss fixpoint); + + val fixdefs = map (inferT_axm (sign_of thy)) pre_fixdefs; + val (thy', fixdef_thms) = + PureThy.add_defs_i false (map Thm.no_attributes fixdefs) thy; + val ctuple_fixdef_thm = foldr1 (fn (x,y) => cpair_equalI OF [x,y]) fixdef_thms; + + fun mk_cterm t = let val sg' = sign_of thy' in cterm_of sg' (infer sg' t) end; + val ctuple_unfold_ct = mk_cterm (mk_trp (mk_ctuple lhss === mk_ctuple rhss)); + val ctuple_unfold_thm = prove_goalw_cterm [] ctuple_unfold_ct + (fn _ => [rtac (ctuple_fixdef_thm RS fix_eq2 RS trans) 1, + simp_tac (simpset_of thy') 1]); + val ctuple_induct_thm = + (space_implode "_" names ^ "_induct", [ctuple_fixdef_thm RS def_fix_ind]); + + fun unfolds [] thm = [] + | unfolds (n::[]) thm = [(n^"_unfold", [thm])] + | unfolds (n::ns) thm = let + val thmL = thm RS cpair_eqD1; + val thmR = thm RS cpair_eqD2; + in (n^"_unfold", [thmL]) :: unfolds ns thmR end; + val unfold_thmss = unfolds names ctuple_unfold_thm; + val thmss = ctuple_induct_thm :: unfold_thmss; + val (thy'', _) = PureThy.add_thmss (map Thm.no_attributes thmss) thy'; in - (name, (basename^"_fixdef", equals T $ lhs $ rhs')) - end - | func1 t = sys_error "func1: not of correct form"; + (thy'', names, fixdef_thms, List.concat (map snd unfold_thmss)) + end; (*************************************************************************) (*********** monadic notation and pattern matching compilation ***********) (*************************************************************************) -(* these 3 functions strip off parameters and destruct constructors *) -(* -fun strip_cpair (Const("Cfun.Rep_CFun",_) $ - (Const("Cfun.Rep_CFun",_) $ Const("Cprod.cpair",_) $ b) $ r) = - b :: strip_cpair r - | strip_cpair c = [c]; -*) -fun big_lambda v rhs = %%:"Cfun.Abs_CFun"$(lambda v rhs); - -fun big_lambdas [] rhs = rhs - | big_lambdas (v::vs) rhs = big_lambda v (big_lambdas vs rhs); +fun add_names (Const(a,_), bs) = Sign.base_name a ins_string bs + | add_names (Free(a,_) , bs) = a ins_string bs + | add_names (f $ u , bs) = add_names (f, add_names(u, bs)) + | add_names (Abs(a,_,t), bs) = add_names (t, a ins_string bs) + | add_names (_ , bs) = bs; -(* builds a big lamdba expression with a tuple *) -fun lambda_tuple [] rhs = big_lambda (%:"unit") rhs - | lambda_tuple [v] rhs = big_lambda v rhs - | lambda_tuple (v::vs) rhs = - %%:"Cprod.csplit"`(big_lambda v (lambda_tuple vs rhs)); - -fun add_names (Const(a,_), bs) = NameSpace.base a ins_string bs - | add_names (Free(a,_), bs) = a ins_string bs - | add_names (f$u, bs) = add_names (f, add_names(u, bs)) - | add_names (Abs(a,_,t), bs) = add_names(t,a ins_string bs) - | add_names (_, bs) = bs; fun add_terms ts xs = foldr add_names xs ts; (* builds a monadic term for matching a constructor pattern *) @@ -104,8 +141,8 @@ fun result_type (Type("Cfun.->",[_,T])) (x::xs) = result_type T xs | result_type T _ = T; val v = Free(n, result_type T vs); - val m = "match_"^(extern_name(NameSpace.base c)); - val k = lambda_tuple vs rhs; + val m = "match_"^(extern_name(Sign.base_name c)); + val k = lambda_ctuple vs rhs; in (%%:"Fixrec.bind"`(%%:m`v)`k, v, n::taken) end; @@ -159,14 +196,14 @@ else sys_error "FIXREC: all equations must have the same arity"; val rhs = fatbar arity mats; in - HOLogic.mk_Trueprop (%%:cname === rhs) + mk_trp (%%:cname === rhs) end; (*************************************************************************) (********************** Proving associated theorems **********************) (*************************************************************************) -fun prove_thm thy unfold_thm ct = +fun prove_rew thy unfold_thm ct = let val ss = simpset_of thy; val thm = prove_goalw_cterm [] ct @@ -175,18 +212,24 @@ in thm end; (* this proves that each equation is a theorem *) -fun prove_list thy unfold_thm [] = [] - | prove_list thy unfold_thm (x::xs) = - prove_thm thy unfold_thm x :: prove_list thy unfold_thm xs; +fun prove_rews thy (unfold_thm,cts) = map (prove_rew thy unfold_thm) cts; + +(* proves the pattern matching equations as theorems, using unfold *) +fun make_simps names unfold_thms ctss thy = + let + val thm_names = map (fn name => name^"_rews") names; + val rew_thmss = ListPair.map (prove_rews thy) (unfold_thms, ctss); + val thmss = ListPair.zip (thm_names, rew_thmss); + in + (#1 o PureThy.add_thmss (map Thm.no_attributes thmss)) thy + end; (* this proves the def without fix is a theorem, this uses the fixpoint def *) +(* fun make_simp name eqs ct fixdef_thm thy' = let - val basename = NameSpace.base name; val ss = simpset_of thy'; val eq_thm = fixdef_thm RS fix_eq2; - val unfold_thm = prove_goalw_cterm [] ct - (fn _ => [(rtac (eq_thm RS trans) 1) THEN (simp_tac ss 1)]); val ind_thm = fixdef_thm RS def_fix_ind; val rew_thms = prove_list thy' unfold_thm eqs; val thmss = @@ -196,25 +239,21 @@ in (#1 o PureThy.add_thmss (map Thm.no_attributes thmss)) thy' end; - +*) (*************************************************************************) (************************* Main fixrec function **************************) (*************************************************************************) (* this calls the main processing function and then returns the new state *) -fun add_fixrec strs thy = +fun add_fixrec strss thy = let val sg = sign_of thy; - val cts = map (Thm.read_cterm sg o rpair propT) strs; - val eqs = map term_of cts; - val funcc = infer (compile_pats eqs) thy; - val _ = print_cterm (cterm_of sg funcc); - val (name', fixdef_name_term) = func1 (dest_eqs funcc); - val (thy', [fixdef_thm]) = - PureThy.add_defs_i false [Thm.no_attributes fixdef_name_term] thy; - val ct = cterm_of (sign_of thy') funcc; + val ctss = map (map (Thm.read_cterm sg o rpair propT)) strss; + val tss = map (map term_of) ctss; + val ts' = map (fn ts => infer sg (compile_pats ts)) tss; + val (thy', names, fixdef_thms, unfold_thms) = add_fixdefs ts' thy; in - make_simp name' cts ct fixdef_thm thy' + make_simps names unfold_thms ctss thy' end; (*************************************************************************) @@ -226,7 +265,7 @@ val sign = sign_of thy; val t = term_of (Thm.read_cterm sign (pat, dummyT)); val T = fastype_of t; - val eq = HOLogic.mk_Trueprop (HOLogic.eq_const T $ t $ Var (("x",0),T)); + val eq = mk_trp (HOLogic.eq_const T $ t $ Var (("x",0),T)); fun head_const (Const ("Cfun.Rep_CFun",_) $ f $ t) = head_const f | head_const (Const (c,_)) = c | head_const _ = sys_error "FIXPAT: function is not declared as constant in theory"; @@ -247,7 +286,7 @@ local structure P = OuterParse and K = OuterSyntax.Keyword in -val fixrec_decl = (*P.and_list1*) (Scan.repeat1 P.prop); +val fixrec_decl = P.and_list1 (Scan.repeat1 P.prop); (* this builds a parser for a new keyword, fixrec, whose functionality is defined by add_fixrec *)