--- a/src/HOLCF/fixrec_package.ML Wed Jun 15 14:59:25 2005 +0200
+++ b/src/HOLCF/fixrec_package.ML Wed Jun 15 20:50:38 2005 +0200
@@ -7,7 +7,7 @@
signature FIXREC_PACKAGE =
sig
- val add_fixrec: string list -> theory -> theory
+ val add_fixrec: string list list -> theory -> theory
val add_fixpat: string * string -> theory -> theory
end;
@@ -27,6 +27,8 @@
("o"::"p"::" "::rest) => implode rest
| _ => con;
+val mk_trp = HOLogic.mk_Trueprop;
+
(* splits a cterm into the right and lefthand sides of equality *)
fun dest_eqs (Const ("==", _)$lhs$rhs) = (lhs, rhs)
| dest_eqs (Const ("Trueprop", _)$(Const ("op =", _)$lhs$rhs)) = (lhs,rhs)
@@ -37,57 +39,92 @@
fun %%: s = Const(s,dummyT);
infix 0 ==; fun S == T = %%:"==" $ S $ T;
infix 1 ===; fun S === T = %%:"op =" $ S $ T;
-infix 9 ` ; fun f ` x = %%:"Rep_CFun" $ f $ x;
+infix 9 ` ; fun f ` x = %%:"Cfun.Rep_CFun" $ f $ x;
(* infers the type of a term *)
-fun infer t thy = #1 (Sign.infer_types (Sign.pp (sign_of thy)) (sign_of thy) (K NONE) (K NONE) [] true ([t],dummyT));
+(* similar to Theory.inferT_axm, but allows any type *)
+fun infer sg t =
+ fst (Sign.infer_types (Sign.pp sg) sg (K NONE) (K NONE) [] true ([t],dummyT));
+
+(* The next few functions build continuous lambda abstractions *)
+
+(* Similar to Term.lambda, but allows abstraction over constants *)
+fun lambda' (v as Free (x, T)) t = Abs (x, T, abstract_over (v, t))
+ | lambda' (v as Var ((x, _), T)) t = Abs (x, T, abstract_over (v, t))
+ | lambda' (v as Const (x, T)) t = Abs (Sign.base_name x, T, abstract_over (v, t))
+ | lambda' v t = raise TERM ("lambda'", [v, t]);
+
+(* builds the expression (LAM v. rhs) *)
+fun big_lambda v rhs = %%:"Cfun.Abs_CFun"$(lambda' v rhs);
+
+(* builds the expression (LAM v1 v2 .. vn. rhs) *)
+fun big_lambdas [] rhs = rhs
+ | big_lambdas (v::vs) rhs = big_lambda v (big_lambdas vs rhs);
+
+(* builds the expression (LAM <v1,v2,..,vn>. rhs) *)
+fun lambda_ctuple [] rhs = big_lambda (%:"unit") rhs
+ | lambda_ctuple (v::[]) rhs = big_lambda v rhs
+ | lambda_ctuple (v::vs) rhs =
+ %%:"Cprod.csplit"`(big_lambda v (lambda_ctuple vs rhs));
+
+(* builds the expression <v1,v2,..,vn> *)
+fun mk_ctuple [] = %%:"UU"
+| mk_ctuple (t::[]) = t
+| mk_ctuple (t::ts) = %%:"Cprod.cpair"`t`(mk_ctuple ts);
(*************************************************************************)
-(************ fixed-point definitions and unfolding theorems *************)
+(************************ fixed-point definitions ************************)
(*************************************************************************)
-fun func1 (lhs as Const(name,T), rhs) =
+fun add_fixdefs eqs thy =
let
- val basename = Sign.base_name name;
- val funcT = T ->> T;
- val functional = Const ("Cfun.Abs_CFun", (T --> T) --> funcT) $
- Abs (basename, T, abstract_over (lhs,rhs));
- val fix_type = funcT ->> T;
- val fix_const = Const ("Fix.fix", fix_type);
- val func_type = fix_type --> funcT --> T;
- val rhs' = Const ("Cfun.Rep_CFun",func_type)$fix_const$functional;
+ val (lhss,rhss) = ListPair.unzip (map dest_eqs eqs);
+ val fixpoint = %%:"Fix.fix"`lambda_ctuple lhss (mk_ctuple rhss);
+
+ fun one_def (l as Const(n,T)) r =
+ let val b = Sign.base_name n in (b, (b^"_fixdef", l == r)) end
+ | one_def _ _ = sys_error "fixdefs: lhs not of correct form";
+ fun defs [] _ = []
+ | defs (l::[]) r = [one_def l r]
+ | defs (l::ls) r = one_def l (%%:"Cprod.cfst"`r) :: defs ls (%%:"Cprod.csnd"`r);
+ val (names, pre_fixdefs) = ListPair.unzip (defs lhss fixpoint);
+
+ val fixdefs = map (inferT_axm (sign_of thy)) pre_fixdefs;
+ val (thy', fixdef_thms) =
+ PureThy.add_defs_i false (map Thm.no_attributes fixdefs) thy;
+ val ctuple_fixdef_thm = foldr1 (fn (x,y) => cpair_equalI OF [x,y]) fixdef_thms;
+
+ fun mk_cterm t = let val sg' = sign_of thy' in cterm_of sg' (infer sg' t) end;
+ val ctuple_unfold_ct = mk_cterm (mk_trp (mk_ctuple lhss === mk_ctuple rhss));
+ val ctuple_unfold_thm = prove_goalw_cterm [] ctuple_unfold_ct
+ (fn _ => [rtac (ctuple_fixdef_thm RS fix_eq2 RS trans) 1,
+ simp_tac (simpset_of thy') 1]);
+ val ctuple_induct_thm =
+ (space_implode "_" names ^ "_induct", [ctuple_fixdef_thm RS def_fix_ind]);
+
+ fun unfolds [] thm = []
+ | unfolds (n::[]) thm = [(n^"_unfold", [thm])]
+ | unfolds (n::ns) thm = let
+ val thmL = thm RS cpair_eqD1;
+ val thmR = thm RS cpair_eqD2;
+ in (n^"_unfold", [thmL]) :: unfolds ns thmR end;
+ val unfold_thmss = unfolds names ctuple_unfold_thm;
+ val thmss = ctuple_induct_thm :: unfold_thmss;
+ val (thy'', _) = PureThy.add_thmss (map Thm.no_attributes thmss) thy';
in
- (name, (basename^"_fixdef", equals T $ lhs $ rhs'))
- end
- | func1 t = sys_error "func1: not of correct form";
+ (thy'', names, fixdef_thms, List.concat (map snd unfold_thmss))
+ end;
(*************************************************************************)
(*********** monadic notation and pattern matching compilation ***********)
(*************************************************************************)
-(* these 3 functions strip off parameters and destruct constructors *)
-(*
-fun strip_cpair (Const("Cfun.Rep_CFun",_) $
- (Const("Cfun.Rep_CFun",_) $ Const("Cprod.cpair",_) $ b) $ r) =
- b :: strip_cpair r
- | strip_cpair c = [c];
-*)
-fun big_lambda v rhs = %%:"Cfun.Abs_CFun"$(lambda v rhs);
-
-fun big_lambdas [] rhs = rhs
- | big_lambdas (v::vs) rhs = big_lambda v (big_lambdas vs rhs);
+fun add_names (Const(a,_), bs) = Sign.base_name a ins_string bs
+ | add_names (Free(a,_) , bs) = a ins_string bs
+ | add_names (f $ u , bs) = add_names (f, add_names(u, bs))
+ | add_names (Abs(a,_,t), bs) = add_names (t, a ins_string bs)
+ | add_names (_ , bs) = bs;
-(* builds a big lamdba expression with a tuple *)
-fun lambda_tuple [] rhs = big_lambda (%:"unit") rhs
- | lambda_tuple [v] rhs = big_lambda v rhs
- | lambda_tuple (v::vs) rhs =
- %%:"Cprod.csplit"`(big_lambda v (lambda_tuple vs rhs));
-
-fun add_names (Const(a,_), bs) = NameSpace.base a ins_string bs
- | add_names (Free(a,_), bs) = a ins_string bs
- | add_names (f$u, bs) = add_names (f, add_names(u, bs))
- | add_names (Abs(a,_,t), bs) = add_names(t,a ins_string bs)
- | add_names (_, bs) = bs;
fun add_terms ts xs = foldr add_names xs ts;
(* builds a monadic term for matching a constructor pattern *)
@@ -104,8 +141,8 @@
fun result_type (Type("Cfun.->",[_,T])) (x::xs) = result_type T xs
| result_type T _ = T;
val v = Free(n, result_type T vs);
- val m = "match_"^(extern_name(NameSpace.base c));
- val k = lambda_tuple vs rhs;
+ val m = "match_"^(extern_name(Sign.base_name c));
+ val k = lambda_ctuple vs rhs;
in
(%%:"Fixrec.bind"`(%%:m`v)`k, v, n::taken)
end;
@@ -159,14 +196,14 @@
else sys_error "FIXREC: all equations must have the same arity";
val rhs = fatbar arity mats;
in
- HOLogic.mk_Trueprop (%%:cname === rhs)
+ mk_trp (%%:cname === rhs)
end;
(*************************************************************************)
(********************** Proving associated theorems **********************)
(*************************************************************************)
-fun prove_thm thy unfold_thm ct =
+fun prove_rew thy unfold_thm ct =
let
val ss = simpset_of thy;
val thm = prove_goalw_cterm [] ct
@@ -175,18 +212,24 @@
in thm end;
(* this proves that each equation is a theorem *)
-fun prove_list thy unfold_thm [] = []
- | prove_list thy unfold_thm (x::xs) =
- prove_thm thy unfold_thm x :: prove_list thy unfold_thm xs;
+fun prove_rews thy (unfold_thm,cts) = map (prove_rew thy unfold_thm) cts;
+
+(* proves the pattern matching equations as theorems, using unfold *)
+fun make_simps names unfold_thms ctss thy =
+ let
+ val thm_names = map (fn name => name^"_rews") names;
+ val rew_thmss = ListPair.map (prove_rews thy) (unfold_thms, ctss);
+ val thmss = ListPair.zip (thm_names, rew_thmss);
+ in
+ (#1 o PureThy.add_thmss (map Thm.no_attributes thmss)) thy
+ end;
(* this proves the def without fix is a theorem, this uses the fixpoint def *)
+(*
fun make_simp name eqs ct fixdef_thm thy' =
let
- val basename = NameSpace.base name;
val ss = simpset_of thy';
val eq_thm = fixdef_thm RS fix_eq2;
- val unfold_thm = prove_goalw_cterm [] ct
- (fn _ => [(rtac (eq_thm RS trans) 1) THEN (simp_tac ss 1)]);
val ind_thm = fixdef_thm RS def_fix_ind;
val rew_thms = prove_list thy' unfold_thm eqs;
val thmss =
@@ -196,25 +239,21 @@
in
(#1 o PureThy.add_thmss (map Thm.no_attributes thmss)) thy'
end;
-
+*)
(*************************************************************************)
(************************* Main fixrec function **************************)
(*************************************************************************)
(* this calls the main processing function and then returns the new state *)
-fun add_fixrec strs thy =
+fun add_fixrec strss thy =
let
val sg = sign_of thy;
- val cts = map (Thm.read_cterm sg o rpair propT) strs;
- val eqs = map term_of cts;
- val funcc = infer (compile_pats eqs) thy;
- val _ = print_cterm (cterm_of sg funcc);
- val (name', fixdef_name_term) = func1 (dest_eqs funcc);
- val (thy', [fixdef_thm]) =
- PureThy.add_defs_i false [Thm.no_attributes fixdef_name_term] thy;
- val ct = cterm_of (sign_of thy') funcc;
+ val ctss = map (map (Thm.read_cterm sg o rpair propT)) strss;
+ val tss = map (map term_of) ctss;
+ val ts' = map (fn ts => infer sg (compile_pats ts)) tss;
+ val (thy', names, fixdef_thms, unfold_thms) = add_fixdefs ts' thy;
in
- make_simp name' cts ct fixdef_thm thy'
+ make_simps names unfold_thms ctss thy'
end;
(*************************************************************************)
@@ -226,7 +265,7 @@
val sign = sign_of thy;
val t = term_of (Thm.read_cterm sign (pat, dummyT));
val T = fastype_of t;
- val eq = HOLogic.mk_Trueprop (HOLogic.eq_const T $ t $ Var (("x",0),T));
+ val eq = mk_trp (HOLogic.eq_const T $ t $ Var (("x",0),T));
fun head_const (Const ("Cfun.Rep_CFun",_) $ f $ t) = head_const f
| head_const (Const (c,_)) = c
| head_const _ = sys_error "FIXPAT: function is not declared as constant in theory";
@@ -247,7 +286,7 @@
local structure P = OuterParse and K = OuterSyntax.Keyword in
-val fixrec_decl = (*P.and_list1*) (Scan.repeat1 P.prop);
+val fixrec_decl = P.and_list1 (Scan.repeat1 P.prop);
(* this builds a parser for a new keyword, fixrec, whose functionality
is defined by add_fixrec *)