--- a/src/HOLCF/fixrec_package.ML Thu Jun 23 21:17:26 2005 +0200
+++ b/src/HOLCF/fixrec_package.ML Thu Jun 23 21:27:23 2005 +0200
@@ -7,15 +7,23 @@
signature FIXREC_PACKAGE =
sig
- val add_fixrec: ((string * Attrib.src list) * string) list list -> theory -> theory
- val add_fixrec_i: ((string * theory attribute list) * term) list list -> theory -> theory
- val add_fixpat: ((string * Attrib.src list) * string) list -> theory -> theory
- val add_fixpat_i: ((string * theory attribute list) * term) list -> theory -> theory
+ val add_fixrec: bool -> ((string * Attrib.src list) * string) list list
+ -> theory -> theory
+ val add_fixrec_i: bool -> ((string * theory attribute list) * term) list list
+ -> theory -> theory
+ val add_fixpat: (string * Attrib.src list) * string list
+ -> theory -> theory
+ val add_fixpat_i: (string * theory attribute list) * term list
+ -> theory -> theory
end;
structure FixrecPackage: FIXREC_PACKAGE =
struct
+fun fixrec_err s = error ("fixrec definition error:\n" ^ s);
+fun fixrec_eq_err sign s eq =
+ fixrec_err (s ^ "\nin\n" ^ quote (Sign.string_of_term sign eq));
+
(* ->> is taken from holcf_logic.ML *)
(* TODO: fix dependencies so we can import HOLCFLogic here *)
infixr 6 ->>;
@@ -29,10 +37,12 @@
val mk_trp = HOLogic.mk_Trueprop;
(* splits a cterm into the right and lefthand sides of equality *)
+fun dest_eqs t = HOLogic.dest_eq (HOLogic.dest_Trueprop t);
+(*
fun dest_eqs (Const ("==", _)$lhs$rhs) = (lhs, rhs)
| dest_eqs (Const ("Trueprop", _)$(Const ("op =", _)$lhs$rhs)) = (lhs,rhs)
- | dest_eqs t = sys_error (Sign.string_of_term (sign_of (the_context())) t);
-
+ | dest_eqs t = fixrec_err (Sign.string_of_term (sign_of (the_context())) t);
+*)
(* similar to Thm.head_of, but for continuous application *)
fun chead_of (Const("Cfun.Rep_CFun",_)$f$t) = chead_of f
| chead_of u = u;
@@ -83,8 +93,8 @@
val fixpoint = %%:"Fix.fix"`lambda_ctuple lhss (mk_ctuple rhss);
fun one_def (l as Const(n,T)) r =
- let val b = Sign.base_name n in (b, (b^"_fixdef", l == r)) end
- | one_def _ _ = sys_error "fixdefs: lhs not of correct form";
+ let val b = Sign.base_name n in (b, (b^"_def", l == r)) end
+ | one_def _ _ = fixrec_err "fixdefs: lhs not of correct form";
fun defs [] _ = []
| defs (l::[]) r = [one_def l r]
| defs (l::ls) r = one_def l (%%:"Cprod.cfst"`r) :: defs ls (%%:"Cprod.csnd"`r);
@@ -146,7 +156,9 @@
val k = lambda_ctuple vs rhs;
in
(%%:"Fixrec.bind"`(%%:m`v)`k, v, n::taken)
- end;
+ end
+ | Free(n,_) => fixrec_err ("expected constructor, found free variable " ^ quote n)
+ | _ => fixrec_err "pre_build: invalid pattern";
(* builds a monadic term for matching a function definition pattern *)
(* returns (name, arity, matcher) *)
@@ -157,15 +169,12 @@
| Const("Cfun.Rep_CFun", _)$f$x =>
let val (rhs', v, taken') = pre_build x rhs [] taken;
in building f rhs' (v::vs) taken' end
- | Const(_,_) => (pat, length vs, big_lambdas vs rhs)
- | _ => sys_error "function is not declared as constant in theory";
+ | Const(name,_) => (name, length vs, big_lambdas vs rhs)
+ | _ => fixrec_err "function is not declared as constant in theory";
fun match_eq eq =
- let
- val (lhs,rhs) = dest_eqs eq;
- val (Const(name,_), arity, term) =
- building lhs (%%:"Fixrec.return"`rhs) [] (add_terms [eq] []);
- in (name, arity, term) end;
+ let val (lhs,rhs) = dest_eqs eq;
+ in building lhs (%%:"Fixrec.return"`rhs) [] (add_terms [eq] []) end;
(* returns the sum (using +++) of the terms in ms *)
(* also applies "run" to the result! *)
@@ -173,9 +182,9 @@
let
fun unLAM 0 t = t
| unLAM n (_$Abs(_,_,t)) = unLAM (n-1) t
- | unLAM _ _ = sys_error "FIXREC: internal error, not enough LAMs";
+ | unLAM _ _ = fixrec_err "fatbar: internal error, not enough LAMs";
fun reLAM 0 t = t
- | reLAM n t = reLAM (n-1) (%%:"Abs_CFun" $ Abs("",dummyT,t));
+ | reLAM n t = reLAM (n-1) (%%:"Cfun.Abs_CFun" $ Abs("",dummyT,t));
fun mplus (x,y) = %%:"Fixrec.mplus"`x`y;
val msum = foldr1 mplus (map (unLAM arity) ms);
in
@@ -192,9 +201,9 @@
let
val ((n::names),(a::arities),mats) = unzip3 (map match_eq eqs);
val cname = if forall (fn x => n=x) names then n
- else sys_error "FIXREC: all equations must define the same function";
+ else fixrec_err "all equations in block must define the same function";
val arity = if forall (fn x => a=x) arities then a
- else sys_error "FIXREC: all equations must have the same arity";
+ else fixrec_err "all equations in block must have the same arity";
val rhs = fatbar arity mats;
in
mk_trp (%%:cname === rhs)
@@ -204,45 +213,52 @@
(********************** Proving associated theorems **********************)
(*************************************************************************)
-fun prove_simp thy unfold_thm t =
+(* proves a block of pattern matching equations as theorems, using unfold *)
+fun make_simps thy (unfold_thm, eqns) =
let
- val ss = simpset_of thy;
- val ct = cterm_of thy t;
- val thm = prove_goalw_cterm [] ct
- (fn _ => [rtac (unfold_thm RS ssubst_lhs) 1, simp_tac ss 1]);
- in thm end;
-
-(* this proves that each equation is a theorem *)
-fun prove_simps thy (unfold_thm,ts) = map (prove_simp thy unfold_thm) ts;
-
-(* proves the pattern matching equations as theorems, using unfold *)
-fun make_simps cnames unfold_thms namess attss tss thy =
- let
- val thm_names = map (fn name => name^"_simps") cnames;
- val rew_thmss = map (prove_simps thy) (unfold_thms ~~ tss);
- val thms = (List.concat namess ~~ List.concat rew_thmss) ~~ List.concat attss;
- val (thy',_) = PureThy.add_thms thms thy;
- val thmss = thm_names ~~ rew_thmss;
- val simp_attribute = rpair [Simplifier.simp_add_global];
+ fun tacsf prems =
+ [rtac (unfold_thm RS ssubst_lhs) 1, simp_tac (simpset_of thy addsimps prems) 1];
+ fun prove_term t = prove_goalw_cterm [] (cterm_of thy t) tacsf;
+ fun prove_eqn ((name, eqn_t), atts) = ((name, prove_term eqn_t), atts);
in
- (#1 o PureThy.add_thmss (map simp_attribute thmss)) thy'
+ map prove_eqn eqns
end;
(*************************************************************************)
(************************* Main fixrec function **************************)
(*************************************************************************)
-(* this calls the main processing function and then returns the new state *)
-fun gen_add_fixrec prep_prop prep_attrib blocks thy =
+fun gen_add_fixrec prep_prop prep_attrib strict blocks thy =
let
- fun split_list2 xss = split_list (map split_list xss);
- val ((namess, srcsss), strss) = apfst split_list2 (split_list2 blocks);
- val attss = map (map (map (prep_attrib thy))) srcsss;
- val tss = map (map (prep_prop thy)) strss;
- val ts' = map (infer thy o compile_pats) tss;
- val (thy', cnames, fixdef_thms, unfold_thms) = add_fixdefs ts' thy;
+ val eqns = List.concat blocks;
+ val lengths = map length blocks;
+
+ val sign = sign_of thy;
+ val ((names, srcss), strings) = apfst split_list (split_list eqns);
+ val atts = map (map (prep_attrib thy)) srcss;
+ val eqn_ts = map (prep_prop thy) strings;
+ val rec_ts = map (fn eq => chead_of (fst (dest_eqs (Logic.strip_imp_concl eq)))
+ handle TERM _ => fixrec_eq_err sign "not a proper equation" eq) eqn_ts;
+ val (_, eqn_ts') = InductivePackage.unify_consts sign rec_ts eqn_ts;
+
+ fun unconcat [] _ = []
+ | unconcat (n::ns) xs = List.take (xs,n) :: unconcat ns (List.drop (xs,n));
+ val pattern_blocks = unconcat lengths (map Logic.strip_imp_concl eqn_ts');
+ val compiled_ts = map (infer sign o compile_pats) pattern_blocks;
+ val (thy', cnames, fixdef_thms, unfold_thms) = add_fixdefs compiled_ts thy;
in
- make_simps cnames unfold_thms namess attss tss thy'
+ if strict then let (* only prove simp rules if strict = true *)
+ val eqn_blocks = unconcat lengths ((names ~~ eqn_ts') ~~ atts);
+ val simps = List.concat (map (make_simps thy') (unfold_thms ~~ eqn_blocks));
+ val (thy'', simp_thms) = PureThy.add_thms simps thy';
+
+ val simp_names = map (fn name => name^"_simps") cnames;
+ val simp_attribute = rpair [Simplifier.simp_add_global];
+ val simps' = map simp_attribute (simp_names ~~ unconcat lengths simp_thms);
+ in
+ (#1 o PureThy.add_thmss simps') thy''
+ end
+ else thy'
end;
val add_fixrec = gen_add_fixrec Sign.read_prop Attrib.global_attribute;
@@ -253,25 +269,25 @@
(******************************** Fixpat *********************************)
(*************************************************************************)
-fun fix_pat prep_term thy pat =
+fun fix_pat thy t =
let
- val t = prep_term thy pat;
val T = fastype_of t;
val eq = mk_trp (HOLogic.eq_const T $ t $ Var (("x",0),T));
val cname = case chead_of t of Const(c,_) => c | _ =>
- sys_error "FIXPAT: function is not declared as constant in theory";
+ fixrec_err "function is not declared as constant in theory";
val unfold_thm = PureThy.get_thm thy (Name (cname^"_unfold"));
- val rew = prove_goalw_cterm [] (cterm_of thy eq)
+ val simp = prove_goalw_cterm [] (cterm_of thy eq)
(fn _ => [stac unfold_thm 1, simp_tac (simpset_of thy) 1]);
- in rew end;
+ in simp end;
-fun gen_add_fixpat prep_term prep_attrib pats thy =
+fun gen_add_fixpat prep_term prep_attrib ((name, srcs), strings) thy =
let
- val ((names, srcss), strings) = apfst ListPair.unzip (ListPair.unzip pats);
- val atts = map (map (prep_attrib thy)) srcss;
- val simps = map (fix_pat prep_term thy) strings;
- val (thy', _) = PureThy.add_thms ((names ~~ simps) ~~ atts) thy;
- in thy' end;
+ val atts = map (prep_attrib thy) srcs;
+ val ts = map (prep_term thy) strings;
+ val simps = map (fix_pat thy) ts;
+ in
+ (#1 o PureThy.add_thmss [((name, simps), atts)]) thy
+ end;
val add_fixpat = gen_add_fixpat Sign.read_term Attrib.global_attribute;
val add_fixpat_i = gen_add_fixpat Sign.cert_term (K I);
@@ -283,25 +299,27 @@
local structure P = OuterParse and K = OuterSyntax.Keyword in
-val fixrec_decl = P.and_list1 (Scan.repeat1 (P.opt_thm_name ":" -- P.prop));
+val fixrec_eqn = P.opt_thm_name ":" -- P.prop;
+
+val fixrec_strict =
+ Scan.optional (P.$$$ "(" -- P.!!! (P.$$$ "permissive" -- P.$$$ ")") >> K false) true;
+
+val fixrec_decl = fixrec_strict -- P.and_list1 (Scan.repeat1 fixrec_eqn);
(* this builds a parser for a new keyword, fixrec, whose functionality
is defined by add_fixrec *)
val fixrecP =
OuterSyntax.command "fixrec" "define recursive functions (HOLCF)" K.thy_decl
- (fixrec_decl >> (Toplevel.theory o add_fixrec));
-
-(* this adds the parser for fixrec to the syntax *)
-val _ = OuterSyntax.add_parsers [fixrecP];
+ (fixrec_decl >> (Toplevel.theory o uncurry add_fixrec));
(* fixpat parser *)
-val fixpat_decl = Scan.repeat1 (P.opt_thm_name ":" -- P.prop);
+val fixpat_decl = P.opt_thm_name ":" -- Scan.repeat1 P.prop;
val fixpatP =
OuterSyntax.command "fixpat" "define rewrites for fixrec functions" K.thy_decl
(fixpat_decl >> (Toplevel.theory o add_fixpat));
-val _ = OuterSyntax.add_parsers [fixpatP];
+val _ = OuterSyntax.add_parsers [fixrecP, fixpatP];
end; (* local structure *)