--- a/src/HOLCF/Tools/fixrec_package.ML Fri Feb 27 18:34:20 2009 -0800
+++ b/src/HOLCF/Tools/fixrec_package.ML Fri Feb 27 19:05:46 2009 -0800
@@ -9,8 +9,12 @@
val legacy_infer_term: theory -> term -> term
val legacy_infer_prop: theory -> term -> term
- val add_fixrec: bool -> (Attrib.binding * string) list list -> theory -> theory
- val add_fixrec_i: bool -> ((binding * attribute list) * term) list list -> theory -> theory
+ val add_fixrec: bool -> (binding * string option * mixfix) list
+ -> (Attrib.binding * string) list -> local_theory -> local_theory
+
+ val add_fixrec_i: bool -> (binding * typ option * mixfix) list
+ -> (Attrib.binding * term) list -> local_theory -> local_theory
+
val add_fixpat: Attrib.binding * string list -> theory -> theory
val add_fixpat_i: (binding * attribute list) * term list -> theory -> theory
val add_matchers: (string * string) list -> theory -> theory
@@ -166,30 +170,34 @@
(************* fixed-point definitions and unfolding theorems ************)
(*************************************************************************)
-fun add_fixdefs eqs thy =
+fun add_fixdefs
+ (fixes : ((binding * typ) * mixfix) list)
+ (spec : (Attrib.binding * term) list)
+ (lthy : local_theory) =
let
- val (lhss,rhss) = ListPair.unzip (map dest_eqs eqs);
+ val names = map (Binding.base_name o fst o fst) fixes;
+ val all_names = space_implode "_" names;
+ val (lhss,rhss) = ListPair.unzip (map (dest_eqs o snd) spec);
val fixpoint = mk_fix (lambda_ctuple lhss (mk_ctuple rhss));
- fun one_def (l as Const(n,T)) r =
- let val b = Sign.base_name n in (b, (b^"_def", l == r)) end
+ fun one_def (l as Free(n,_)) r =
+ let val b = Sign.base_name n
+ in ((Binding.name (b^"_def"), []), r) end
| one_def _ _ = fixrec_err "fixdefs: lhs not of correct form";
fun defs [] _ = []
| defs (l::[]) r = [one_def l r]
| defs (l::ls) r = one_def l (mk_cfst r) :: defs ls (mk_csnd r);
- val (names, fixdefs) = ListPair.unzip (defs lhss fixpoint);
-
- val (fixdef_thms, thy') =
- PureThy.add_defs false (map (Thm.no_attributes o apfst Binding.name) fixdefs) thy;
- val ctuple_fixdef_thm = foldr1 (fn (x,y) => @{thm cpair_equalI} OF [x,y]) fixdef_thms;
-
- val ctuple_unfold = mk_trp (mk_ctuple lhss === mk_ctuple rhss);
- val ctuple_unfold_thm = Goal.prove_global thy' [] [] ctuple_unfold
- (fn _ => EVERY [rtac (ctuple_fixdef_thm RS fix_eq2 RS trans) 1,
- simp_tac (simpset_of thy') 1]);
- val ctuple_induct_thm =
- (space_implode "_" names ^ "_induct", ctuple_fixdef_thm RS def_fix_ind);
-
+ val fixdefs = defs lhss fixpoint;
+ val define_all = fold_map (LocalTheory.define Thm.definitionK);
+ val (fixdef_thms : (term * (string * thm)) list, lthy') = lthy
+ |> define_all (map (apfst fst) fixes ~~ fixdefs);
+ fun cpair_equalI (thm1, thm2) = @{thm cpair_equalI} OF [thm1, thm2];
+ val ctuple_fixdef_thm = foldr1 cpair_equalI (map (snd o snd) fixdef_thms);
+ val ctuple_induct_thm = ctuple_fixdef_thm RS def_fix_ind;
+ val ctuple_unfold_thm =
+ Goal.prove lthy' [] [] (mk_trp (mk_ctuple lhss === mk_ctuple rhss))
+ (fn _ => EVERY [rtac (ctuple_fixdef_thm RS fix_eq2 RS trans) 1,
+ simp_tac (local_simpset_of lthy') 1]);
fun unfolds [] thm = []
| unfolds (n::[]) thm = [(n^"_unfold", thm)]
| unfolds (n::ns) thm = let
@@ -197,10 +205,12 @@
val thmR = thm RS @{thm cpair_eqD2};
in (n^"_unfold", thmL) :: unfolds ns thmR end;
val unfold_thms = unfolds names ctuple_unfold_thm;
- val thms = ctuple_induct_thm :: unfold_thms;
- val (_, thy'') = PureThy.add_thms (map (Thm.no_attributes o apfst Binding.name) thms) thy';
+ fun mk_note (n, thm) = ((Binding.name n, []), [thm]);
+ val (thmss, lthy'') = lthy'
+ |> fold_map (LocalTheory.note Thm.theoremK o mk_note)
+ ((all_names ^ "_induct", ctuple_induct_thm) :: unfold_thms);
in
- (thy'', names, fixdef_thms, map snd unfold_thms)
+ (lthy'', names, fixdef_thms, map snd unfold_thms)
end;
(*************************************************************************)
@@ -260,11 +270,17 @@
| Const(@{const_name Rep_CFun}, _)$f$x =>
let val (rhs', v, taken') = pre_build match_name x rhs [] taken;
in building match_name f rhs' (v::vs) taken' end
- | Const(name,_) => (pat, length vs, big_lambdas vs rhs)
- | _ => fixrec_err "function is not declared as constant in theory";
+ | Free(_,_) => ((pat, length vs), big_lambdas vs rhs)
+ | Const(_,_) => ((pat, length vs), big_lambdas vs rhs)
+ | _ => fixrec_err ("function is not declared as constant in theory: "
+ ^ ML_Syntax.print_term pat);
-fun match_eq match_name eq =
- let val (lhs,rhs) = dest_eqs eq;
+fun strip_alls t =
+ if Logic.is_all t then strip_alls (snd (Logic.dest_all t)) else t;
+
+fun match_eq match_name eq =
+ let
+ val (lhs,rhs) = dest_eqs (Logic.strip_imp_concl (strip_alls eq));
in
building match_name lhs (mk_return rhs) [] (taken_names eq)
end;
@@ -288,15 +304,11 @@
reLAM (rev Ts, dest_maybeT U) (mk_run msum)
end;
-fun unzip3 [] = ([],[],[])
- | unzip3 ((x,y,z)::ts) =
- let val (xs,ys,zs) = unzip3 ts
- in (x::xs, y::ys, z::zs) end;
-
(* this is the pattern-matching compiler function *)
-fun compile_pats match_name eqs =
+fun compile_pats match_name eqs =
let
- val ((n::names),(a::arities),mats) = unzip3 (map (match_eq match_name) eqs);
+ val (((n::names),(a::arities)),mats) =
+ apfst ListPair.unzip (ListPair.unzip (map (match_eq match_name) eqs));
val cname = if forall (fn x => n=x) names then n
else fixrec_err "all equations in block must define the same function";
val arity = if forall (fn x => a=x) arities then a
@@ -311,11 +323,13 @@
(*************************************************************************)
(* proves a block of pattern matching equations as theorems, using unfold *)
-fun make_simps thy (unfold_thm, eqns) =
+fun make_simps lthy (unfold_thm, eqns : (Attrib.binding * term) list) =
let
- val tacs = [rtac (unfold_thm RS @{thm ssubst_lhs}) 1, asm_simp_tac (simpset_of thy) 1];
- fun prove_term t = Goal.prove_global thy [] [] t (K (EVERY tacs));
- fun prove_eqn ((name, eqn_t), atts) = ((name, prove_term eqn_t), atts);
+ val tacs =
+ [rtac (unfold_thm RS @{thm ssubst_lhs}) 1,
+ asm_simp_tac (local_simpset_of lthy) 1];
+ fun prove_term t = Goal.prove lthy [] [] t (K (EVERY tacs));
+ fun prove_eqn (bind, eqn_t) = (bind, prove_term eqn_t);
in
map prove_eqn eqns
end;
@@ -324,48 +338,77 @@
(************************* Main fixrec function **************************)
(*************************************************************************)
-fun gen_add_fixrec prep_prop prep_attrib strict blocks thy =
+local
+(* code adapted from HOL/Tools/primrec_package.ML *)
+
+fun prepare_spec prep_spec ctxt raw_fixes raw_spec =
+ let
+ val ((fixes, spec), _) = prep_spec
+ raw_fixes (map (single o apsnd single) raw_spec) ctxt
+ in (fixes, map (apsnd the_single) spec) end;
+
+fun gen_fixrec
+ (set_group : bool)
+ (prep_spec : (binding * 'a option * mixfix) list ->
+ (Attrib.binding * 'b list) list list ->
+ Proof.context ->
+ (((binding * typ) * mixfix) list * (Attrib.binding * term list) list)
+ * Proof.context
+ )
+ (strict : bool)
+ raw_fixes
+ raw_spec
+ (lthy : local_theory) =
let
- val eqns = List.concat blocks;
- val lengths = map length blocks;
-
- val ((bindings, srcss), strings) = apfst split_list (split_list eqns);
- val names = map Binding.base_name bindings;
- val atts = map (map (prep_attrib thy)) srcss;
- val eqn_ts = map (prep_prop thy) strings;
- val rec_ts = map (fn eq => chead_of (fst (dest_eqs (Logic.strip_imp_concl eq)))
- handle TERM _ => fixrec_eq_err thy "not a proper equation" eq) eqn_ts;
- val (_, eqn_ts') = OldPrimrecPackage.unify_consts thy rec_ts eqn_ts;
-
- fun unconcat [] _ = []
- | unconcat (n::ns) xs = List.take (xs,n) :: unconcat ns (List.drop (xs,n));
- val matcher_tab = FixrecMatchData.get thy;
+ val (fixes : ((binding * typ) * mixfix) list,
+ spec : (Attrib.binding * term) list) =
+ prepare_spec prep_spec lthy raw_fixes raw_spec;
+ val chead_of_spec =
+ chead_of o fst o dest_eqs o Logic.strip_imp_concl o strip_alls o snd;
+ fun name_of (Free (n, _)) = n
+ | name_of t = fixrec_err ("unknown term");
+ val all_names = map (name_of o chead_of_spec) spec;
+ val names = distinct (op =) all_names;
+ fun block_of_name n =
+ map_filter
+ (fn (m,eq) => if m = n then SOME eq else NONE)
+ (all_names ~~ spec);
+ val blocks = map block_of_name names;
+
+ val matcher_tab = FixrecMatchData.get (ProofContext.theory_of lthy);
fun match_name c =
- case Symtab.lookup matcher_tab c of SOME m => m
- | NONE => fixrec_err ("unknown pattern constructor: " ^ c);
+ case Symtab.lookup matcher_tab c of SOME m => m
+ | NONE => fixrec_err ("unknown pattern constructor: " ^ c);
- val pattern_blocks = unconcat lengths (map Logic.strip_imp_concl eqn_ts');
- val compiled_ts =
- map (compile_pats match_name) pattern_blocks;
- val (thy', cnames, fixdef_thms, unfold_thms) = add_fixdefs compiled_ts thy;
+ val matches = map (compile_pats match_name) (map (map snd) blocks);
+ val spec' = map (pair Attrib.empty_binding) matches;
+ val (lthy', cnames, fixdef_thms, unfold_thms) =
+ add_fixdefs fixes spec' lthy;
in
if strict then let (* only prove simp rules if strict = true *)
- val eqn_blocks = unconcat lengths ((names ~~ eqn_ts') ~~ atts);
- val simps = maps (make_simps thy') (unfold_thms ~~ eqn_blocks);
- val (simp_thms, thy'') = PureThy.add_thms ((map o apfst o apfst) Binding.name simps) thy';
-
- val simp_names = map (fn name => name^"_simps") cnames;
- val simp_attribute = rpair [Simplifier.simp_add];
- val simps' = map simp_attribute (simp_names ~~ unconcat lengths simp_thms);
+ val simps : (Attrib.binding * thm) list list =
+ map (make_simps lthy') (unfold_thms ~~ blocks);
+ fun mk_bind n : Attrib.binding =
+ (Binding.name (n ^ "_simps"),
+ [Attrib.internal (K Simplifier.simp_add)]);
+ val simps1 : (Attrib.binding * thm list) list =
+ map (fn (n,xs) => (mk_bind n, map snd xs)) (names ~~ simps);
+ val simps2 : (Attrib.binding * thm list) list =
+ map (apsnd (fn thm => [thm])) (List.concat simps);
+ val (_, lthy'') = lthy'
+ |> fold_map (LocalTheory.note Thm.theoremK) (simps1 @ simps2);
in
- (snd o PureThy.add_thmss ((map o apfst o apfst) Binding.name simps')) thy''
+ lthy''
end
- else thy'
+ else lthy'
end;
-val add_fixrec = gen_add_fixrec Syntax.read_prop_global Attrib.attribute;
-val add_fixrec_i = gen_add_fixrec Sign.cert_prop (K I);
+in
+val add_fixrec_i = gen_fixrec false Specification.check_specification;
+val add_fixrec = gen_fixrec true Specification.read_specification;
+
+end; (* local *)
(*************************************************************************)
(******************************** Fixpat *********************************)
@@ -401,17 +444,34 @@
local structure P = OuterParse and K = OuterKeyword in
-val fixrec_eqn = SpecParse.opt_thm_name ":" -- P.prop;
-
+(* bool parser *)
val fixrec_strict = P.opt_keyword "permissive" >> not;
-val fixrec_decl = fixrec_strict -- P.and_list1 (Scan.repeat1 fixrec_eqn);
+fun pipe_error t = P.!!! (Scan.fail_with (K
+ (cat_lines ["Equations must be separated by " ^ quote "|", quote t])));
+
+(* (Attrib.binding * string) parser *)
+val statement = SpecParse.opt_thm_name ":" -- P.prop --| Scan.ahead
+ ((P.term :-- pipe_error) || Scan.succeed ("",""));
+
+(* ((Attrib.binding * string) list) parser *)
+val statements = P.enum1 "|" statement;
+
+(* (((xstring option * bool) * (Binding.binding * string option * Mixfix.mixfix) list)
+ * (Attrib.binding * string) list) parser *)
+val fixrec_decl =
+ P.opt_target -- fixrec_strict -- P.fixes --| P.$$$ "where" -- statements;
(* this builds a parser for a new keyword, fixrec, whose functionality
is defined by add_fixrec *)
val _ =
- OuterSyntax.command "fixrec" "define recursive functions (HOLCF)" K.thy_decl
- (fixrec_decl >> (Toplevel.theory o uncurry add_fixrec));
+ let
+ val desc = "define recursive functions (HOLCF)";
+ fun fixrec (((opt_target, strict), raw_fixes), raw_spec) =
+ Toplevel.local_theory opt_target (add_fixrec strict raw_fixes raw_spec);
+ in
+ OuterSyntax.command "fixrec" desc K.thy_decl (fixrec_decl >> fixrec)
+ end;
(* fixpat parser *)
val fixpat_decl = SpecParse.opt_thm_name ":" -- Scan.repeat1 P.prop;
@@ -419,7 +479,7 @@
val _ =
OuterSyntax.command "fixpat" "define rewrites for fixrec functions" K.thy_decl
(fixpat_decl >> (Toplevel.theory o add_fixpat));
-
+
end; (* local structure *)
val setup = FixrecMatchData.init;
--- a/src/HOLCF/ex/Fixrec_ex.thy Fri Feb 27 18:34:20 2009 -0800
+++ b/src/HOLCF/ex/Fixrec_ex.thy Fri Feb 27 19:05:46 2009 -0800
@@ -1,5 +1,4 @@
(* Title: HOLCF/ex/Fixrec_ex.thy
- ID: $Id$
Author: Brian Huffman
*)
@@ -19,18 +18,18 @@
text {* typical usage is with lazy constructors *}
-consts down :: "'a u \<rightarrow> 'a"
-fixrec "down\<cdot>(up\<cdot>x) = x"
+fixrec down :: "'a u \<rightarrow> 'a"
+where "down\<cdot>(up\<cdot>x) = x"
text {* with strict constructors, rewrite rules may require side conditions *}
-consts from_sinl :: "'a \<oplus> 'b \<rightarrow> 'a"
-fixrec "x \<noteq> \<bottom> \<Longrightarrow> from_sinl\<cdot>(sinl\<cdot>x) = x"
+fixrec from_sinl :: "'a \<oplus> 'b \<rightarrow> 'a"
+where "x \<noteq> \<bottom> \<Longrightarrow> from_sinl\<cdot>(sinl\<cdot>x) = x"
text {* lifting can turn a strict constructor into a lazy one *}
-consts from_sinl_up :: "'a u \<oplus> 'b \<rightarrow> 'a"
-fixrec "from_sinl_up\<cdot>(sinl\<cdot>(up\<cdot>x)) = x"
+fixrec from_sinl_up :: "'a u \<oplus> 'b \<rightarrow> 'a"
+where "from_sinl_up\<cdot>(sinl\<cdot>(up\<cdot>x)) = x"
subsection {* fixpat examples *}
@@ -41,13 +40,13 @@
text {* zip function for lazy lists *}
-consts lzip :: "'a llist \<rightarrow> 'b llist \<rightarrow> ('a \<times> 'b) llist"
-
text {* notice that the patterns are not exhaustive *}
fixrec
+ lzip :: "'a llist \<rightarrow> 'b llist \<rightarrow> ('a \<times> 'b) llist"
+where
"lzip\<cdot>(lCons\<cdot>x\<cdot>xs)\<cdot>(lCons\<cdot>y\<cdot>ys) = lCons\<cdot><x,y>\<cdot>(lzip\<cdot>xs\<cdot>ys)"
- "lzip\<cdot>lNil\<cdot>lNil = lNil"
+| "lzip\<cdot>lNil\<cdot>lNil = lNil"
text {* fixpat is useful for producing strictness theorems *}
text {* note that pattern matching is done in left-to-right order *}
@@ -68,8 +67,6 @@
text {* another zip function for lazy lists *}
-consts lzip2 :: "'a llist \<rightarrow> 'b llist \<rightarrow> ('a \<times> 'b) llist"
-
text {*
Notice that this version has overlapping patterns.
The second equation cannot be proved as a theorem
@@ -77,8 +74,10 @@
*}
fixrec (permissive)
+ lzip2 :: "'a llist \<rightarrow> 'b llist \<rightarrow> ('a \<times> 'b) llist"
+where
"lzip2\<cdot>(lCons\<cdot>x\<cdot>xs)\<cdot>(lCons\<cdot>y\<cdot>ys) = lCons\<cdot><x,y>\<cdot>(lzip\<cdot>xs\<cdot>ys)"
- "lzip2\<cdot>xs\<cdot>ys = lNil"
+| "lzip2\<cdot>xs\<cdot>ys = lNil"
text {*
Usually fixrec tries to prove all equations as theorems.
@@ -105,21 +104,20 @@
domain 'a tree = Leaf (lazy 'a) | Branch (lazy "'a forest")
and 'a forest = Empty | Trees (lazy "'a tree") "'a forest"
-consts
- map_tree :: "('a \<rightarrow> 'b) \<rightarrow> ('a tree \<rightarrow> 'b tree)"
- map_forest :: "('a \<rightarrow> 'b) \<rightarrow> ('a forest \<rightarrow> 'b forest)"
-
text {*
To define mutually recursive functions, separate the equations
for each function using the keyword "and".
*}
fixrec
- "map_tree\<cdot>f\<cdot>(Leaf\<cdot>x) = Leaf\<cdot>(f\<cdot>x)"
- "map_tree\<cdot>f\<cdot>(Branch\<cdot>ts) = Branch\<cdot>(map_forest\<cdot>f\<cdot>ts)"
+ map_tree :: "('a \<rightarrow> 'b) \<rightarrow> ('a tree \<rightarrow> 'b tree)"
and
- "map_forest\<cdot>f\<cdot>Empty = Empty"
- "ts \<noteq> \<bottom> \<Longrightarrow>
+ map_forest :: "('a \<rightarrow> 'b) \<rightarrow> ('a forest \<rightarrow> 'b forest)"
+where
+ "map_tree\<cdot>f\<cdot>(Leaf\<cdot>x) = Leaf\<cdot>(f\<cdot>x)"
+| "map_tree\<cdot>f\<cdot>(Branch\<cdot>ts) = Branch\<cdot>(map_forest\<cdot>f\<cdot>ts)"
+| "map_forest\<cdot>f\<cdot>Empty = Empty"
+| "ts \<noteq> \<bottom> \<Longrightarrow>
map_forest\<cdot>f\<cdot>(Trees\<cdot>t\<cdot>ts) = Trees\<cdot>(map_tree\<cdot>f\<cdot>t)\<cdot>(map_forest\<cdot>f\<cdot>ts)"
fixpat map_tree_strict [simp]: "map_tree\<cdot>f\<cdot>\<bottom>"