# HG changeset patch # User huffman # Date 1119554843 -7200 # Node ID 0774e9bcdb6c5c741eec3072b5a60ce0384197c1 # Parent 7abf8a713613ec448985689af49ed7bc095e9905 New features: permissive option for fixrec to skip proofs of equations; side conditions for fixrec equations (for definedness); fixpat theorem names apply to entire group of theorems; improved error messages diff -r 7abf8a713613 -r 0774e9bcdb6c src/HOLCF/fixrec_package.ML --- a/src/HOLCF/fixrec_package.ML Thu Jun 23 21:17:26 2005 +0200 +++ b/src/HOLCF/fixrec_package.ML Thu Jun 23 21:27:23 2005 +0200 @@ -7,15 +7,23 @@ signature FIXREC_PACKAGE = sig - val add_fixrec: ((string * Attrib.src list) * string) list list -> theory -> theory - val add_fixrec_i: ((string * theory attribute list) * term) list list -> theory -> theory - val add_fixpat: ((string * Attrib.src list) * string) list -> theory -> theory - val add_fixpat_i: ((string * theory attribute list) * term) list -> theory -> theory + val add_fixrec: bool -> ((string * Attrib.src list) * string) list list + -> theory -> theory + val add_fixrec_i: bool -> ((string * theory attribute list) * term) list list + -> theory -> theory + val add_fixpat: (string * Attrib.src list) * string list + -> theory -> theory + val add_fixpat_i: (string * theory attribute list) * term list + -> theory -> theory end; structure FixrecPackage: FIXREC_PACKAGE = struct +fun fixrec_err s = error ("fixrec definition error:\n" ^ s); +fun fixrec_eq_err sign s eq = + fixrec_err (s ^ "\nin\n" ^ quote (Sign.string_of_term sign eq)); + (* ->> is taken from holcf_logic.ML *) (* TODO: fix dependencies so we can import HOLCFLogic here *) infixr 6 ->>; @@ -29,10 +37,12 @@ val mk_trp = HOLogic.mk_Trueprop; (* splits a cterm into the right and lefthand sides of equality *) +fun dest_eqs t = HOLogic.dest_eq (HOLogic.dest_Trueprop t); +(* fun dest_eqs (Const ("==", _)$lhs$rhs) = (lhs, rhs) | dest_eqs (Const ("Trueprop", _)$(Const ("op =", _)$lhs$rhs)) = (lhs,rhs) - | dest_eqs t = sys_error (Sign.string_of_term (sign_of (the_context())) t); - + | dest_eqs t = fixrec_err (Sign.string_of_term (sign_of (the_context())) t); +*) (* similar to Thm.head_of, but for continuous application *) fun chead_of (Const("Cfun.Rep_CFun",_)$f$t) = chead_of f | chead_of u = u; @@ -83,8 +93,8 @@ val fixpoint = %%:"Fix.fix"`lambda_ctuple lhss (mk_ctuple rhss); fun one_def (l as Const(n,T)) r = - let val b = Sign.base_name n in (b, (b^"_fixdef", l == r)) end - | one_def _ _ = sys_error "fixdefs: lhs not of correct form"; + let val b = Sign.base_name n in (b, (b^"_def", l == r)) end + | one_def _ _ = fixrec_err "fixdefs: lhs not of correct form"; fun defs [] _ = [] | defs (l::[]) r = [one_def l r] | defs (l::ls) r = one_def l (%%:"Cprod.cfst"`r) :: defs ls (%%:"Cprod.csnd"`r); @@ -146,7 +156,9 @@ val k = lambda_ctuple vs rhs; in (%%:"Fixrec.bind"`(%%:m`v)`k, v, n::taken) - end; + end + | Free(n,_) => fixrec_err ("expected constructor, found free variable " ^ quote n) + | _ => fixrec_err "pre_build: invalid pattern"; (* builds a monadic term for matching a function definition pattern *) (* returns (name, arity, matcher) *) @@ -157,15 +169,12 @@ | Const("Cfun.Rep_CFun", _)$f$x => let val (rhs', v, taken') = pre_build x rhs [] taken; in building f rhs' (v::vs) taken' end - | Const(_,_) => (pat, length vs, big_lambdas vs rhs) - | _ => sys_error "function is not declared as constant in theory"; + | Const(name,_) => (name, length vs, big_lambdas vs rhs) + | _ => fixrec_err "function is not declared as constant in theory"; fun match_eq eq = - let - val (lhs,rhs) = dest_eqs eq; - val (Const(name,_), arity, term) = - building lhs (%%:"Fixrec.return"`rhs) [] (add_terms [eq] []); - in (name, arity, term) end; + let val (lhs,rhs) = dest_eqs eq; + in building lhs (%%:"Fixrec.return"`rhs) [] (add_terms [eq] []) end; (* returns the sum (using +++) of the terms in ms *) (* also applies "run" to the result! *) @@ -173,9 +182,9 @@ let fun unLAM 0 t = t | unLAM n (_$Abs(_,_,t)) = unLAM (n-1) t - | unLAM _ _ = sys_error "FIXREC: internal error, not enough LAMs"; + | unLAM _ _ = fixrec_err "fatbar: internal error, not enough LAMs"; fun reLAM 0 t = t - | reLAM n t = reLAM (n-1) (%%:"Abs_CFun" $ Abs("",dummyT,t)); + | reLAM n t = reLAM (n-1) (%%:"Cfun.Abs_CFun" $ Abs("",dummyT,t)); fun mplus (x,y) = %%:"Fixrec.mplus"`x`y; val msum = foldr1 mplus (map (unLAM arity) ms); in @@ -192,9 +201,9 @@ let val ((n::names),(a::arities),mats) = unzip3 (map match_eq eqs); val cname = if forall (fn x => n=x) names then n - else sys_error "FIXREC: all equations must define the same function"; + else fixrec_err "all equations in block must define the same function"; val arity = if forall (fn x => a=x) arities then a - else sys_error "FIXREC: all equations must have the same arity"; + else fixrec_err "all equations in block must have the same arity"; val rhs = fatbar arity mats; in mk_trp (%%:cname === rhs) @@ -204,45 +213,52 @@ (********************** Proving associated theorems **********************) (*************************************************************************) -fun prove_simp thy unfold_thm t = +(* proves a block of pattern matching equations as theorems, using unfold *) +fun make_simps thy (unfold_thm, eqns) = let - val ss = simpset_of thy; - val ct = cterm_of thy t; - val thm = prove_goalw_cterm [] ct - (fn _ => [rtac (unfold_thm RS ssubst_lhs) 1, simp_tac ss 1]); - in thm end; - -(* this proves that each equation is a theorem *) -fun prove_simps thy (unfold_thm,ts) = map (prove_simp thy unfold_thm) ts; - -(* proves the pattern matching equations as theorems, using unfold *) -fun make_simps cnames unfold_thms namess attss tss thy = - let - val thm_names = map (fn name => name^"_simps") cnames; - val rew_thmss = map (prove_simps thy) (unfold_thms ~~ tss); - val thms = (List.concat namess ~~ List.concat rew_thmss) ~~ List.concat attss; - val (thy',_) = PureThy.add_thms thms thy; - val thmss = thm_names ~~ rew_thmss; - val simp_attribute = rpair [Simplifier.simp_add_global]; + fun tacsf prems = + [rtac (unfold_thm RS ssubst_lhs) 1, simp_tac (simpset_of thy addsimps prems) 1]; + fun prove_term t = prove_goalw_cterm [] (cterm_of thy t) tacsf; + fun prove_eqn ((name, eqn_t), atts) = ((name, prove_term eqn_t), atts); in - (#1 o PureThy.add_thmss (map simp_attribute thmss)) thy' + map prove_eqn eqns end; (*************************************************************************) (************************* Main fixrec function **************************) (*************************************************************************) -(* this calls the main processing function and then returns the new state *) -fun gen_add_fixrec prep_prop prep_attrib blocks thy = +fun gen_add_fixrec prep_prop prep_attrib strict blocks thy = let - fun split_list2 xss = split_list (map split_list xss); - val ((namess, srcsss), strss) = apfst split_list2 (split_list2 blocks); - val attss = map (map (map (prep_attrib thy))) srcsss; - val tss = map (map (prep_prop thy)) strss; - val ts' = map (infer thy o compile_pats) tss; - val (thy', cnames, fixdef_thms, unfold_thms) = add_fixdefs ts' thy; + val eqns = List.concat blocks; + val lengths = map length blocks; + + val sign = sign_of thy; + val ((names, srcss), strings) = apfst split_list (split_list eqns); + val atts = map (map (prep_attrib thy)) srcss; + val eqn_ts = map (prep_prop thy) strings; + val rec_ts = map (fn eq => chead_of (fst (dest_eqs (Logic.strip_imp_concl eq))) + handle TERM _ => fixrec_eq_err sign "not a proper equation" eq) eqn_ts; + val (_, eqn_ts') = InductivePackage.unify_consts sign rec_ts eqn_ts; + + fun unconcat [] _ = [] + | unconcat (n::ns) xs = List.take (xs,n) :: unconcat ns (List.drop (xs,n)); + val pattern_blocks = unconcat lengths (map Logic.strip_imp_concl eqn_ts'); + val compiled_ts = map (infer sign o compile_pats) pattern_blocks; + val (thy', cnames, fixdef_thms, unfold_thms) = add_fixdefs compiled_ts thy; in - make_simps cnames unfold_thms namess attss tss thy' + if strict then let (* only prove simp rules if strict = true *) + val eqn_blocks = unconcat lengths ((names ~~ eqn_ts') ~~ atts); + val simps = List.concat (map (make_simps thy') (unfold_thms ~~ eqn_blocks)); + val (thy'', simp_thms) = PureThy.add_thms simps thy'; + + val simp_names = map (fn name => name^"_simps") cnames; + val simp_attribute = rpair [Simplifier.simp_add_global]; + val simps' = map simp_attribute (simp_names ~~ unconcat lengths simp_thms); + in + (#1 o PureThy.add_thmss simps') thy'' + end + else thy' end; val add_fixrec = gen_add_fixrec Sign.read_prop Attrib.global_attribute; @@ -253,25 +269,25 @@ (******************************** Fixpat *********************************) (*************************************************************************) -fun fix_pat prep_term thy pat = +fun fix_pat thy t = let - val t = prep_term thy pat; val T = fastype_of t; val eq = mk_trp (HOLogic.eq_const T $ t $ Var (("x",0),T)); val cname = case chead_of t of Const(c,_) => c | _ => - sys_error "FIXPAT: function is not declared as constant in theory"; + fixrec_err "function is not declared as constant in theory"; val unfold_thm = PureThy.get_thm thy (Name (cname^"_unfold")); - val rew = prove_goalw_cterm [] (cterm_of thy eq) + val simp = prove_goalw_cterm [] (cterm_of thy eq) (fn _ => [stac unfold_thm 1, simp_tac (simpset_of thy) 1]); - in rew end; + in simp end; -fun gen_add_fixpat prep_term prep_attrib pats thy = +fun gen_add_fixpat prep_term prep_attrib ((name, srcs), strings) thy = let - val ((names, srcss), strings) = apfst ListPair.unzip (ListPair.unzip pats); - val atts = map (map (prep_attrib thy)) srcss; - val simps = map (fix_pat prep_term thy) strings; - val (thy', _) = PureThy.add_thms ((names ~~ simps) ~~ atts) thy; - in thy' end; + val atts = map (prep_attrib thy) srcs; + val ts = map (prep_term thy) strings; + val simps = map (fix_pat thy) ts; + in + (#1 o PureThy.add_thmss [((name, simps), atts)]) thy + end; val add_fixpat = gen_add_fixpat Sign.read_term Attrib.global_attribute; val add_fixpat_i = gen_add_fixpat Sign.cert_term (K I); @@ -283,25 +299,27 @@ local structure P = OuterParse and K = OuterSyntax.Keyword in -val fixrec_decl = P.and_list1 (Scan.repeat1 (P.opt_thm_name ":" -- P.prop)); +val fixrec_eqn = P.opt_thm_name ":" -- P.prop; + +val fixrec_strict = + Scan.optional (P.$$$ "(" -- P.!!! (P.$$$ "permissive" -- P.$$$ ")") >> K false) true; + +val fixrec_decl = fixrec_strict -- P.and_list1 (Scan.repeat1 fixrec_eqn); (* this builds a parser for a new keyword, fixrec, whose functionality is defined by add_fixrec *) val fixrecP = OuterSyntax.command "fixrec" "define recursive functions (HOLCF)" K.thy_decl - (fixrec_decl >> (Toplevel.theory o add_fixrec)); - -(* this adds the parser for fixrec to the syntax *) -val _ = OuterSyntax.add_parsers [fixrecP]; + (fixrec_decl >> (Toplevel.theory o uncurry add_fixrec)); (* fixpat parser *) -val fixpat_decl = Scan.repeat1 (P.opt_thm_name ":" -- P.prop); +val fixpat_decl = P.opt_thm_name ":" -- Scan.repeat1 P.prop; val fixpatP = OuterSyntax.command "fixpat" "define rewrites for fixrec functions" K.thy_decl (fixpat_decl >> (Toplevel.theory o add_fixpat)); -val _ = OuterSyntax.add_parsers [fixpatP]; +val _ = OuterSyntax.add_parsers [fixrecP, fixpatP]; end; (* local structure *)