# HG changeset patch # User haftmann # Date 1196950209 -3600 # Node ID ea6b11021e79d8e31a871415cb1bfda8a6ecd7b6 # Parent 8d3b7c27049bebaee8aa27a7c99c16fc24931f4a added new primrec package diff -r 8d3b7c27049b -r ea6b11021e79 NEWS --- a/NEWS Thu Dec 06 12:58:01 2007 +0100 +++ b/NEWS Thu Dec 06 15:10:09 2007 +0100 @@ -20,6 +20,11 @@ *** HOL *** +* New primrec package. Specification syntax conforms in style to + definition/function/.... The "primrec" command distinguished old-style + and new-style specifications by syntax. The old primrec package is + now named OldPrimrecPackage. + * Library/Multiset: {#a, b, c#} abbreviates {#a#} + {#b#} + {#c#}. * Constants "card", "internal_split", "option_map" now with authentic diff -r 8d3b7c27049b -r ea6b11021e79 src/HOL/Inductive.thy --- a/src/HOL/Inductive.thy Thu Dec 06 12:58:01 2007 +0100 +++ b/src/HOL/Inductive.thy Thu Dec 06 15:10:09 2007 +0100 @@ -17,6 +17,7 @@ ("Tools/datatype_abs_proofs.ML") ("Tools/datatype_case.ML") ("Tools/datatype_package.ML") + ("Tools/old_primrec_package.ML") ("Tools/primrec_package.ML") ("Tools/datatype_codegen.ML") begin @@ -328,6 +329,7 @@ use "Tools/datatype_case.ML" use "Tools/datatype_package.ML" setup DatatypePackage.setup +use "Tools/old_primrec_package.ML" use "Tools/primrec_package.ML" use "Tools/datatype_codegen.ML" diff -r 8d3b7c27049b -r ea6b11021e79 src/HOL/IsaMakefile --- a/src/HOL/IsaMakefile Thu Dec 06 12:58:01 2007 +0100 +++ b/src/HOL/IsaMakefile Thu Dec 06 15:10:09 2007 +0100 @@ -132,6 +132,7 @@ Tools/inductive_package.ML Tools/inductive_realizer.ML \ Tools/inductive_set_package.ML Tools/lin_arith.ML Tools/meson.ML \ Tools/metis_tools.ML Tools/numeral.ML Tools/numeral_syntax.ML \ + Tools/old_primrec_package.ML \ Tools/polyhash.ML Tools/primrec_package.ML Tools/prop_logic.ML \ Tools/recdef_package.ML Tools/recfun_codegen.ML \ Tools/record_package.ML Tools/refute.ML Tools/refute_isar.ML \ diff -r 8d3b7c27049b -r ea6b11021e79 src/HOL/Library/Eval.thy --- a/src/HOL/Library/Eval.thy Thu Dec 06 12:58:01 2007 +0100 +++ b/src/HOL/Library/Eval.thy Thu Dec 06 15:10:09 2007 +0100 @@ -151,7 +151,7 @@ thy |> Instance.instantiate (tycos, sorts, @{sort term_of}) (pair ()) ((K o K) (Class.intro_classes_tac [])) - |> PrimrecPackage.gen_primrec thy_note thy_def "" defs + |> OldPrimrecPackage.gen_primrec thy_note thy_def "" defs |> snd | NONE => thy; in DatatypePackage.interpretation interpretator end diff -r 8d3b7c27049b -r ea6b11021e79 src/HOL/Nominal/nominal_atoms.ML --- a/src/HOL/Nominal/nominal_atoms.ML Thu Dec 06 12:58:01 2007 +0100 +++ b/src/HOL/Nominal/nominal_atoms.ML Thu Dec 06 15:10:09 2007 +0100 @@ -166,7 +166,7 @@ thy |> Sign.add_consts_i [("swap_" ^ ak_name, swapT, NoSyn)] |> PureThy.add_defs_unchecked_i true [((name, def2),[])] |> snd - |> PrimrecPackage.add_primrec_unchecked_i "" [(("", def1),[])] + |> OldPrimrecPackage.add_primrec_unchecked_i "" [(("", def1),[])] end) ak_names_types thy2; (* declares a permutation function for every atom-kind acting *) @@ -194,7 +194,7 @@ Const (swap_name, swapT) $ x $ (Const (qu_prm_name, prmT) $ xs $ a))); in thy |> Sign.add_consts_i [(prm_name, mk_permT T --> T --> T, NoSyn)] - |> PrimrecPackage.add_primrec_unchecked_i "" [(("", def1), []),(("", def2), [])] + |> OldPrimrecPackage.add_primrec_unchecked_i "" [(("", def1), []),(("", def2), [])] end) ak_names_types thy3; (* defines permutation functions for all combinations of atom-kinds; *) diff -r 8d3b7c27049b -r ea6b11021e79 src/HOL/Nominal/nominal_package.ML --- a/src/HOL/Nominal/nominal_package.ML Thu Dec 06 12:58:01 2007 +0100 +++ b/src/HOL/Nominal/nominal_package.ML Thu Dec 06 15:10:09 2007 +0100 @@ -332,7 +332,7 @@ val (perm_simps, thy2) = thy1 |> Sign.add_consts_i (map (fn (s, T) => (Sign.base_name s, T, NoSyn)) (List.drop (perm_names_types, length new_type_names))) |> - PrimrecPackage.add_primrec_unchecked_i "" perm_eqs; + OldPrimrecPackage.add_primrec_unchecked_i "" perm_eqs; (**** prove that permutation functions introduced by unfolding are ****) (**** equivalent to already existing permutation functions ****) diff -r 8d3b7c27049b -r ea6b11021e79 src/HOL/Nominal/nominal_primrec.ML --- a/src/HOL/Nominal/nominal_primrec.ML Thu Dec 06 12:58:01 2007 +0100 +++ b/src/HOL/Nominal/nominal_primrec.ML Thu Dec 06 15:10:09 2007 +0100 @@ -387,7 +387,7 @@ val rec_ts = map (fn eq => head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop (Logic.strip_imp_concl eq)))) handle TERM _ => primrec_eq_err thy "not a proper equation" eq) eqn_ts; - val (_, eqn_ts') = PrimrecPackage.unify_consts thy rec_ts eqn_ts + val (_, eqn_ts') = OldPrimrecPackage.unify_consts thy rec_ts eqn_ts in gen_primrec_i note def alt_name (Option.map (map (Syntax.read_term_global thy)) invs) diff -r 8d3b7c27049b -r ea6b11021e79 src/HOL/Tools/old_primrec_package.ML --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/Tools/old_primrec_package.ML Thu Dec 06 15:10:09 2007 +0100 @@ -0,0 +1,362 @@ +(* Title: HOL/Tools/primrec_package.ML + ID: $Id$ + Author: Stefan Berghofer, TU Muenchen and Norbert Voelker, FernUni Hagen + +Package for defining functions on datatypes by primitive recursion. +*) + +signature OLD_PRIMREC_PACKAGE = +sig + val quiet_mode: bool ref + val unify_consts: theory -> term list -> term list -> term list * term list + val add_primrec: string -> ((bstring * string) * Attrib.src list) list + -> theory -> thm list * theory + val add_primrec_unchecked: string -> ((bstring * string) * Attrib.src list) list + -> theory -> thm list * theory + val add_primrec_i: string -> ((bstring * term) * attribute list) list + -> theory -> thm list * theory + val add_primrec_unchecked_i: string -> ((bstring * term) * attribute list) list + -> theory -> thm list * theory + (* FIXME !? *) + val gen_primrec: ((bstring * attribute list) * thm list -> theory -> (bstring * thm list) * theory) + -> ((bstring * attribute list) * term -> theory -> (bstring * thm) * theory) + -> string -> ((bstring * attribute list) * term) list + -> theory -> thm list * theory; +end; + +structure OldPrimrecPackage : OLD_PRIMREC_PACKAGE = +struct + +open DatatypeAux; + +exception RecError of string; + +fun primrec_err s = error ("Primrec definition error:\n" ^ s); +fun primrec_eq_err thy s eq = + primrec_err (s ^ "\nin\n" ^ quote (Sign.string_of_term thy eq)); + + +(* messages *) + +val quiet_mode = ref false; +fun message s = if ! quiet_mode then () else writeln s; + + +(*the following code ensures that each recursive set always has the + same type in all introduction rules*) +fun unify_consts thy cs intr_ts = + (let + val add_term_consts_2 = fold_aterms (fn Const c => insert (op =) c | _ => I); + fun varify (t, (i, ts)) = + let val t' = map_types (Logic.incr_tvar (i + 1)) (snd (Type.varify [] t)) + in (maxidx_of_term t', t'::ts) end; + val (i, cs') = foldr varify (~1, []) cs; + val (i', intr_ts') = foldr varify (i, []) intr_ts; + val rec_consts = fold add_term_consts_2 cs' []; + val intr_consts = fold add_term_consts_2 intr_ts' []; + fun unify (cname, cT) = + let val consts = map snd (filter (fn (c, _) => c = cname) intr_consts) + in fold (Sign.typ_unify thy) ((replicate (length consts) cT) ~~ consts) end; + val (env, _) = fold unify rec_consts (Vartab.empty, i'); + val subst = Type.freeze o map_types (Envir.norm_type env) + + in (map subst cs', map subst intr_ts') + end) handle Type.TUNIFY => + (warning "Occurrences of recursive constant have non-unifiable types"; (cs, intr_ts)); + + +(* preprocessing of equations *) + +fun process_eqn thy eq rec_fns = + let + val (lhs, rhs) = + if null (term_vars eq) then + HOLogic.dest_eq (HOLogic.dest_Trueprop eq) + handle TERM _ => raise RecError "not a proper equation" + else raise RecError "illegal schematic variable(s)"; + + val (recfun, args) = strip_comb lhs; + val fnameT = dest_Const recfun handle TERM _ => + raise RecError "function is not declared as constant in theory"; + + val (ls', rest) = take_prefix is_Free args; + val (middle, rs') = take_suffix is_Free rest; + val rpos = length ls'; + + val (constr, cargs') = if null middle then raise RecError "constructor missing" + else strip_comb (hd middle); + val (cname, T) = dest_Const constr + handle TERM _ => raise RecError "ill-formed constructor"; + val (tname, _) = dest_Type (body_type T) handle TYPE _ => + raise RecError "cannot determine datatype associated with function" + + val (ls, cargs, rs) = + (map dest_Free ls', map dest_Free cargs', map dest_Free rs') + handle TERM _ => raise RecError "illegal argument in pattern"; + val lfrees = ls @ rs @ cargs; + + fun check_vars _ [] = () + | check_vars s vars = raise RecError (s ^ commas_quote (map fst vars)) + in + if length middle > 1 then + raise RecError "more than one non-variable in pattern" + else + (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees); + check_vars "extra variables on rhs: " + (map dest_Free (term_frees rhs) \\ lfrees); + case AList.lookup (op =) rec_fns fnameT of + NONE => + (fnameT, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eq))]))::rec_fns + | SOME (_, rpos', eqns) => + if AList.defined (op =) eqns cname then + raise RecError "constructor already occurred as pattern" + else if rpos <> rpos' then + raise RecError "position of recursive argument inconsistent" + else + AList.update (op =) (fnameT, (tname, rpos, (cname, (ls, cargs, rs, rhs, eq))::eqns)) + rec_fns) + end + handle RecError s => primrec_eq_err thy s eq; + +fun process_fun thy descr rec_eqns (i, fnameT as (fname, _)) (fnameTs, fnss) = + let + val (_, (tname, _, constrs)) = List.nth (descr, i); + + (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *) + + fun subst [] t fs = (t, fs) + | subst subs (Abs (a, T, t)) fs = + fs + |> subst subs t + |-> (fn t' => pair (Abs (a, T, t'))) + | subst subs (t as (_ $ _)) fs = + let + val (f, ts) = strip_comb t; + in + if is_Const f andalso dest_Const f mem map fst rec_eqns then + let + val fnameT' as (fname', _) = dest_Const f; + val (_, rpos, _) = the (AList.lookup (op =) rec_eqns fnameT'); + val ls = Library.take (rpos, ts); + val rest = Library.drop (rpos, ts); + val (x', rs) = (hd rest, tl rest) + handle Empty => raise RecError ("not enough arguments\ + \ in recursive application\nof function " ^ quote fname' ^ " on rhs"); + val (x, xs) = strip_comb x' + in case AList.lookup (op =) subs x + of NONE => + fs + |> fold_map (subst subs) ts + |-> (fn ts' => pair (list_comb (f, ts'))) + | SOME (i', y) => + fs + |> fold_map (subst subs) (xs @ ls @ rs) + ||> process_fun thy descr rec_eqns (i', fnameT') + |-> (fn ts' => pair (list_comb (y, ts'))) + end + else + fs + |> fold_map (subst subs) (f :: ts) + |-> (fn (f'::ts') => pair (list_comb (f', ts'))) + end + | subst _ t fs = (t, fs); + + (* translate rec equations into function arguments suitable for rec comb *) + + fun trans eqns (cname, cargs) (fnameTs', fnss', fns) = + (case AList.lookup (op =) eqns cname of + NONE => (warning ("No equation for constructor " ^ quote cname ^ + "\nin definition of function " ^ quote fname); + (fnameTs', fnss', (Const ("HOL.undefined", dummyT))::fns)) + | SOME (ls, cargs', rs, rhs, eq) => + let + val recs = filter (is_rec_type o snd) (cargs' ~~ cargs); + val rargs = map fst recs; + val subs = map (rpair dummyT o fst) + (rev (rename_wrt_term rhs rargs)); + val (rhs', (fnameTs'', fnss'')) = + (subst (map (fn ((x, y), z) => + (Free x, (body_index y, Free z))) + (recs ~~ subs)) rhs (fnameTs', fnss')) + handle RecError s => primrec_eq_err thy s eq + in (fnameTs'', fnss'', + (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns) + end) + + in (case AList.lookup (op =) fnameTs i of + NONE => + if exists (equal fnameT o snd) fnameTs then + raise RecError ("inconsistent functions for datatype " ^ quote tname) + else + let + val (_, _, eqns) = the (AList.lookup (op =) rec_eqns fnameT); + val (fnameTs', fnss', fns) = fold_rev (trans eqns) constrs + ((i, fnameT)::fnameTs, fnss, []) + in + (fnameTs', (i, (fname, #1 (snd (hd eqns)), fns))::fnss') + end + | SOME fnameT' => + if fnameT = fnameT' then (fnameTs, fnss) + else raise RecError ("inconsistent functions for datatype " ^ quote tname)) + end; + + +(* prepare functions needed for definitions *) + +fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) = + case AList.lookup (op =) fns i of + NONE => + let + val dummy_fns = map (fn (_, cargs) => Const ("HOL.undefined", + replicate ((length cargs) + (length (List.filter is_rec_type cargs))) + dummyT ---> HOLogic.unitT)) constrs; + val _ = warning ("No function definition for datatype " ^ quote tname) + in + (dummy_fns @ fs, defs) + end + | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs); + + +(* make definition *) + +fun make_def thy fs (fname, ls, rec_name, tname) = + let + val rhs = fold_rev (fn T => fn t => Abs ("", T, t)) + ((map snd ls) @ [dummyT]) + (list_comb (Const (rec_name, dummyT), + fs @ map Bound (0 ::(length ls downto 1)))) + val def_name = Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def"; + val def_prop = + singleton (Syntax.check_terms (ProofContext.init thy)) + (Logic.mk_equals (Const (fname, dummyT), rhs)); + in (def_name, def_prop) end; + + +(* find datatypes which contain all datatypes in tnames' *) + +fun find_dts (dt_info : datatype_info Symtab.table) _ [] = [] + | find_dts dt_info tnames' (tname::tnames) = + (case Symtab.lookup dt_info tname of + NONE => primrec_err (quote tname ^ " is not a datatype") + | SOME dt => + if tnames' subset (map (#1 o snd) (#descr dt)) then + (tname, dt)::(find_dts dt_info tnames' tnames) + else find_dts dt_info tnames' tnames); + +fun prepare_induct ({descr, induction, ...}: datatype_info) rec_eqns = + let + fun constrs_of (_, (_, _, cs)) = + map (fn (cname:string, (_, cargs, _, _, _)) => (cname, map fst cargs)) cs; + val params_of = these o AList.lookup (op =) (List.concat (map constrs_of rec_eqns)); + in + induction + |> RuleCases.rename_params (map params_of (List.concat (map (map #1 o #3 o #2) descr))) + |> RuleCases.save induction + end; + +local + +fun gen_primrec_i note def alt_name eqns_atts thy = + let + val (eqns, atts) = split_list eqns_atts; + val dt_info = DatatypePackage.get_datatypes thy; + val rec_eqns = fold_rev (process_eqn thy o snd) eqns [] ; + val tnames = distinct (op =) (map (#1 o snd) rec_eqns); + val dts = find_dts dt_info tnames tnames; + val main_fns = + map (fn (tname, {index, ...}) => + (index, + (fst o the o find_first (fn f => (#1 o snd) f = tname)) rec_eqns)) + dts; + val {descr, rec_names, rec_rewrites, ...} = + if null dts then + primrec_err ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive") + else snd (hd dts); + val (fnameTs, fnss) = + fold_rev (process_fun thy descr rec_eqns) main_fns ([], []); + val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []); + val defs' = map (make_def thy fs) defs; + val nameTs1 = map snd fnameTs; + val nameTs2 = map fst rec_eqns; + val _ = if gen_eq_set (op =) (nameTs1, nameTs2) then () + else primrec_err ("functions " ^ commas_quote (map fst nameTs2) ^ + "\nare not mutually recursive"); + val primrec_name = + if alt_name = "" then (space_implode "_" (map (Sign.base_name o #1) defs)) else alt_name; + val (defs_thms', thy') = + thy + |> Sign.add_path primrec_name + |> fold_map def (map (fn (name, t) => ((name, []), t)) defs'); + val rewrites = (map mk_meta_eq rec_rewrites) @ map snd defs_thms'; + val _ = message ("Proving equations for primrec function(s) " ^ + commas_quote (map fst nameTs1) ^ " ..."); + val simps = map (fn (_, t) => Goal.prove_global thy' [] [] t + (fn _ => EVERY [rewrite_goals_tac rewrites, rtac refl 1])) eqns; + val (simps', thy'') = + thy' + |> fold_map note ((map fst eqns ~~ atts) ~~ map single simps); + val simps'' = maps snd simps'; + in + thy'' + |> note (("simps", [Simplifier.simp_add, RecfunCodegen.add_default]), simps'') + |> snd + |> note (("induct", []), [prepare_induct (#2 (hd dts)) rec_eqns]) + |> snd + |> Sign.parent_path + |> pair simps'' + end; + +fun gen_primrec note def alt_name eqns thy = + let + val ((names, strings), srcss) = apfst split_list (split_list eqns); + val atts = map (map (Attrib.attribute thy)) srcss; + val eqn_ts = map (fn s => Syntax.read_prop_global thy s + handle ERROR msg => cat_error msg ("The error(s) above occurred for " ^ s)) strings; + val rec_ts = map (fn eq => head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop eq))) + handle TERM _ => primrec_eq_err thy "not a proper equation" eq) eqn_ts; + val (_, eqn_ts') = unify_consts thy rec_ts eqn_ts + in + gen_primrec_i note def alt_name (names ~~ eqn_ts' ~~ atts) thy + end; + +fun thy_note ((name, atts), thms) = + PureThy.add_thmss [((name, thms), atts)] #-> (fn [thms] => pair (name, thms)); +fun thy_def false ((name, atts), t) = + PureThy.add_defs_i false [((name, t), atts)] #-> (fn [thm] => pair (name, thm)) + | thy_def true ((name, atts), t) = + PureThy.add_defs_unchecked_i false [((name, t), atts)] #-> (fn [thm] => pair (name, thm)); + +in + +val add_primrec = gen_primrec thy_note (thy_def false); +val add_primrec_unchecked = gen_primrec thy_note (thy_def true); +val add_primrec_i = gen_primrec_i thy_note (thy_def false); +val add_primrec_unchecked_i = gen_primrec_i thy_note (thy_def true); +fun gen_primrec note def alt_name specs = + gen_primrec_i note def alt_name (map (fn ((name, t), atts) => ((name, atts), t)) specs); + +end; + + +(* see primrecr_package.ML (* outer syntax *) + +local structure P = OuterParse and K = OuterKeyword in + +val opt_unchecked_name = + Scan.optional (P.$$$ "(" |-- P.!!! + (((P.$$$ "unchecked" >> K true) -- Scan.optional P.name "" || + P.name >> pair false) --| P.$$$ ")")) (false, ""); + +val primrec_decl = + opt_unchecked_name -- Scan.repeat1 (SpecParse.opt_thm_name ":" -- P.prop); + +val _ = + OuterSyntax.command "primrec" "define primitive recursive functions on datatypes" K.thy_decl + (primrec_decl >> (fn ((unchecked, alt_name), eqns) => + Toplevel.theory (snd o + (if unchecked then add_primrec_unchecked else add_primrec) alt_name + (map P.triple_swap eqns)))); + +end;*) + +end; diff -r 8d3b7c27049b -r ea6b11021e79 src/HOL/Tools/primrec_package.ML --- a/src/HOL/Tools/primrec_package.ML Thu Dec 06 12:58:01 2007 +0100 +++ b/src/HOL/Tools/primrec_package.ML Thu Dec 06 15:10:09 2007 +0100 @@ -1,27 +1,15 @@ (* Title: HOL/Tools/primrec_package.ML ID: $Id$ - Author: Stefan Berghofer, TU Muenchen and Norbert Voelker, FernUni Hagen + Author: Stefan Berghofer, TU Muenchen; Norbert Voelker, FernUni Hagen; + Florian Haftmann, TU Muenchen Package for defining functions on datatypes by primitive recursion. *) signature PRIMREC_PACKAGE = sig - val quiet_mode: bool ref - val unify_consts: theory -> term list -> term list -> term list * term list - val add_primrec: string -> ((bstring * string) * Attrib.src list) list - -> theory -> thm list * theory - val add_primrec_unchecked: string -> ((bstring * string) * Attrib.src list) list - -> theory -> thm list * theory - val add_primrec_i: string -> ((bstring * term) * attribute list) list - -> theory -> thm list * theory - val add_primrec_unchecked_i: string -> ((bstring * term) * attribute list) list - -> theory -> thm list * theory - (* FIXME !? *) - val gen_primrec: ((bstring * attribute list) * thm list -> theory -> (bstring * thm list) * theory) - -> ((bstring * attribute list) * term -> theory -> (bstring * thm) * theory) - -> string -> ((bstring * attribute list) * term) list - -> theory -> thm list * theory; + val add_primrec: (string * typ option * mixfix) list -> + ((bstring * Attrib.src list) * term) list -> local_theory -> thm list * local_theory end; structure PrimrecPackage : PRIMREC_PACKAGE = @@ -29,98 +17,71 @@ open DatatypeAux; -exception RecError of string; - -fun primrec_err s = error ("Primrec definition error:\n" ^ s); -fun primrec_eq_err thy s eq = - primrec_err (s ^ "\nin\n" ^ quote (Sign.string_of_term thy eq)); - - -(* messages *) - -val quiet_mode = ref false; -fun message s = if ! quiet_mode then () else writeln s; - +exception PrimrecError of string * term option; -(*the following code ensures that each recursive set always has the - same type in all introduction rules*) -fun unify_consts thy cs intr_ts = - (let - val add_term_consts_2 = fold_aterms (fn Const c => insert (op =) c | _ => I); - fun varify (t, (i, ts)) = - let val t' = map_types (Logic.incr_tvar (i + 1)) (snd (Type.varify [] t)) - in (maxidx_of_term t', t'::ts) end; - val (i, cs') = foldr varify (~1, []) cs; - val (i', intr_ts') = foldr varify (i, []) intr_ts; - val rec_consts = fold add_term_consts_2 cs' []; - val intr_consts = fold add_term_consts_2 intr_ts' []; - fun unify (cname, cT) = - let val consts = map snd (filter (fn (c, _) => c = cname) intr_consts) - in fold (Sign.typ_unify thy) ((replicate (length consts) cT) ~~ consts) end; - val (env, _) = fold unify rec_consts (Vartab.empty, i'); - val subst = Type.freeze o map_types (Envir.norm_type env) +fun primrec_error msg = raise PrimrecError (msg, NONE); +fun primrec_error_eqn msg eqn = raise PrimrecError (msg, SOME eqn); - in (map subst cs', map subst intr_ts') - end) handle Type.TUNIFY => - (warning "Occurrences of recursive constant have non-unifiable types"; (cs, intr_ts)); +fun message s = if ! Toplevel.debug then () else writeln s; (* preprocessing of equations *) -fun process_eqn thy eq rec_fns = +fun process_eqn is_fixed is_const spec rec_fns = let - val (lhs, rhs) = - if null (term_vars eq) then - HOLogic.dest_eq (HOLogic.dest_Trueprop eq) - handle TERM _ => raise RecError "not a proper equation" - else raise RecError "illegal schematic variable(s)"; - + val vars = strip_qnt_vars "all" spec; + val body = strip_qnt_body "all" spec; + val eqn = curry subst_bounds (map Free (rev vars)) body; + val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn) + handle TERM _ => primrec_error "not a proper equation"; val (recfun, args) = strip_comb lhs; - val fnameT = dest_Const recfun handle TERM _ => - raise RecError "function is not declared as constant in theory"; + val fname = case recfun of Free (v, _) => if is_fixed v then v + else primrec_error "illegal head of function equation" + | _ => primrec_error "illegal head of function equation"; val (ls', rest) = take_prefix is_Free args; val (middle, rs') = take_suffix is_Free rest; val rpos = length ls'; - val (constr, cargs') = if null middle then raise RecError "constructor missing" + val (constr, cargs') = if null middle then primrec_error "constructor missing" else strip_comb (hd middle); val (cname, T) = dest_Const constr - handle TERM _ => raise RecError "ill-formed constructor"; + handle TERM _ => primrec_error "ill-formed constructor"; val (tname, _) = dest_Type (body_type T) handle TYPE _ => - raise RecError "cannot determine datatype associated with function" + primrec_error "cannot determine datatype associated with function" val (ls, cargs, rs) = (map dest_Free ls', map dest_Free cargs', map dest_Free rs') - handle TERM _ => raise RecError "illegal argument in pattern"; + handle TERM _ => primrec_error "illegal argument in pattern"; val lfrees = ls @ rs @ cargs; fun check_vars _ [] = () - | check_vars s vars = raise RecError (s ^ commas_quote (map fst vars)) + | check_vars s vars = primrec_error (s ^ commas_quote (map fst vars)) eqn; in if length middle > 1 then - raise RecError "more than one non-variable in pattern" + primrec_error "more than one non-variable in pattern" else (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees); check_vars "extra variables on rhs: " - (map dest_Free (term_frees rhs) \\ lfrees); - case AList.lookup (op =) rec_fns fnameT of + (map dest_Free (term_frees rhs) |> subtract (op =) lfrees + |> filter_out (is_const o fst) |> filter_out (is_fixed o fst)); + case AList.lookup (op =) rec_fns fname of NONE => - (fnameT, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eq))]))::rec_fns + (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))]))::rec_fns | SOME (_, rpos', eqns) => if AList.defined (op =) eqns cname then - raise RecError "constructor already occurred as pattern" + primrec_error "constructor already occurred as pattern" else if rpos <> rpos' then - raise RecError "position of recursive argument inconsistent" + primrec_error "position of recursive argument inconsistent" else - AList.update (op =) (fnameT, (tname, rpos, (cname, (ls, cargs, rs, rhs, eq))::eqns)) + AList.update (op =) + (fname, (tname, rpos, (cname, (ls, cargs, rs, rhs, eqn))::eqns)) rec_fns) - end - handle RecError s => primrec_eq_err thy s eq; + end handle PrimrecError (msg, NONE) => primrec_error_eqn msg spec; -fun process_fun thy descr rec_eqns (i, fnameT as (fname, _)) (fnameTs, fnss) = +fun process_fun descr eqns (i, fname) (fnames, fnss) = let - val (_, (tname, _, constrs)) = List.nth (descr, i); + val (_, (tname, _, constrs)) = nth descr i; (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *) @@ -133,14 +94,13 @@ let val (f, ts) = strip_comb t; in - if is_Const f andalso dest_Const f mem map fst rec_eqns then + if is_Free f + andalso member (fn ((v, _), (w, _)) => v = w) eqns (dest_Free f) then let - val fnameT' as (fname', _) = dest_Const f; - val (_, rpos, _) = the (AList.lookup (op =) rec_eqns fnameT'); - val ls = Library.take (rpos, ts); - val rest = Library.drop (rpos, ts); - val (x', rs) = (hd rest, tl rest) - handle Empty => raise RecError ("not enough arguments\ + val (fname', _) = dest_Free f; + val (_, rpos, _) = the (AList.lookup (op =) eqns fname'); + val (ls, x' :: rs) = chop rpos ts + handle Empty => primrec_error ("not enough arguments\ \ in recursive application\nof function " ^ quote fname' ^ " on rhs"); val (x, xs) = strip_comb x' in case AList.lookup (op =) subs x @@ -151,7 +111,7 @@ | SOME (i', y) => fs |> fold_map (subst subs) (xs @ ls @ rs) - ||> process_fun thy descr rec_eqns (i', fnameT') + ||> process_fun descr eqns (i', fname') |-> (fn ts' => pair (list_comb (y, ts'))) end else @@ -163,41 +123,39 @@ (* translate rec equations into function arguments suitable for rec comb *) - fun trans eqns (cname, cargs) (fnameTs', fnss', fns) = + fun trans eqns (cname, cargs) (fnames', fnss', fns) = (case AList.lookup (op =) eqns cname of NONE => (warning ("No equation for constructor " ^ quote cname ^ "\nin definition of function " ^ quote fname); - (fnameTs', fnss', (Const ("HOL.undefined", dummyT))::fns)) + (fnames', fnss', (Const ("HOL.undefined", dummyT))::fns)) | SOME (ls, cargs', rs, rhs, eq) => let val recs = filter (is_rec_type o snd) (cargs' ~~ cargs); val rargs = map fst recs; val subs = map (rpair dummyT o fst) (rev (rename_wrt_term rhs rargs)); - val (rhs', (fnameTs'', fnss'')) = - (subst (map (fn ((x, y), z) => - (Free x, (body_index y, Free z))) - (recs ~~ subs)) rhs (fnameTs', fnss')) - handle RecError s => primrec_eq_err thy s eq - in (fnameTs'', fnss'', + val (rhs', (fnames'', fnss'')) = (subst (map2 (fn (x, y) => fn z => + (Free x, (body_index y, Free z))) recs subs) rhs (fnames', fnss')) + handle PrimrecError (s, NONE) => primrec_error_eqn s eq + in (fnames'', fnss'', (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns) end) - in (case AList.lookup (op =) fnameTs i of + in (case AList.lookup (op =) fnames i of NONE => - if exists (equal fnameT o snd) fnameTs then - raise RecError ("inconsistent functions for datatype " ^ quote tname) + if exists (fn (_, v) => fname = v) fnames then + primrec_error ("inconsistent functions for datatype " ^ quote tname) else let - val (_, _, eqns) = the (AList.lookup (op =) rec_eqns fnameT); - val (fnameTs', fnss', fns) = fold_rev (trans eqns) constrs - ((i, fnameT)::fnameTs, fnss, []) + val (_, _, eqns) = the (AList.lookup (op =) eqns fname); + val (fnames', fnss', fns) = fold_rev (trans eqns) constrs + ((i, fname)::fnames, fnss, []) in - (fnameTs', (i, (fname, #1 (snd (hd eqns)), fns))::fnss') + (fnames', (i, (fname, #1 (snd (hd eqns)), fns))::fnss') end - | SOME fnameT' => - if fnameT = fnameT' then (fnameTs, fnss) - else raise RecError ("inconsistent functions for datatype " ^ quote tname)) + | SOME fname' => + if fname = fname' then (fnames, fnss) + else primrec_error ("inconsistent functions for datatype " ^ quote tname)) end; @@ -219,17 +177,17 @@ (* make definition *) -fun make_def thy fs (fname, ls, rec_name, tname) = +fun make_def ctxt fixes fs (fname, ls, rec_name, tname) = let - val rhs = fold_rev (fn T => fn t => Abs ("", T, t)) + val raw_rhs = fold_rev (fn T => fn t => Abs ("", T, t)) ((map snd ls) @ [dummyT]) (list_comb (Const (rec_name, dummyT), fs @ map Bound (0 ::(length ls downto 1)))) - val def_name = Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def"; - val def_prop = - singleton (Syntax.check_terms (ProofContext.init thy)) - (Logic.mk_equals (Const (fname, dummyT), rhs)); - in (def_name, def_prop) end; + val def_name = Thm.def_name (Sign.base_name fname); + val rhs = singleton (Syntax.check_terms ctxt) raw_rhs; + val SOME mfx = get_first + (fn ((v, _), mfx) => if v = fname then SOME mfx else NONE) fixes; + in ((fname, mfx), ((def_name, []), rhs)) end; (* find datatypes which contain all datatypes in tnames' *) @@ -237,103 +195,87 @@ fun find_dts (dt_info : datatype_info Symtab.table) _ [] = [] | find_dts dt_info tnames' (tname::tnames) = (case Symtab.lookup dt_info tname of - NONE => primrec_err (quote tname ^ " is not a datatype") + NONE => primrec_error (quote tname ^ " is not a datatype") | SOME dt => if tnames' subset (map (#1 o snd) (#descr dt)) then (tname, dt)::(find_dts dt_info tnames' tnames) else find_dts dt_info tnames' tnames); -fun prepare_induct ({descr, induction, ...}: datatype_info) rec_eqns = + +(* adapted induction rule *) + +fun prepare_induct ({descr, induction, ...}: datatype_info) eqns = let fun constrs_of (_, (_, _, cs)) = map (fn (cname:string, (_, cargs, _, _, _)) => (cname, map fst cargs)) cs; - val params_of = these o AList.lookup (op =) (List.concat (map constrs_of rec_eqns)); + val params_of = these o AList.lookup (op =) (List.concat (map constrs_of eqns)); in induction - |> RuleCases.rename_params (map params_of (List.concat (map (map #1 o #3 o #2) descr))) + |> RuleCases.rename_params (map params_of (maps (map #1 o #3 o #2) descr)) |> RuleCases.save induction end; + +(* primrec definition *) + local -fun gen_primrec_i note def alt_name eqns_atts thy = +fun prepare_spec prep_spec ctxt raw_fixes raw_spec = + let + val ((fixes, spec), _) = prep_spec raw_fixes [(map o apsnd) single raw_spec] ctxt + in (fixes, (map o apsnd) the_single spec) end; + +fun prove_spec ctxt rec_rewrites defs = + let + val rewrites = map mk_meta_eq rec_rewrites @ map (snd o snd) defs; + fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1]; + val _ = message "Proving equations for primrec function"; + in map (fn (name_attr, t) => (name_attr, [Goal.prove ctxt [] [] t tac])) end; + +fun gen_primrec prep_spec raw_fixes raw_spec lthy = let - val (eqns, atts) = split_list eqns_atts; - val dt_info = DatatypePackage.get_datatypes thy; - val rec_eqns = fold_rev (process_eqn thy o snd) eqns [] ; - val tnames = distinct (op =) (map (#1 o snd) rec_eqns); - val dts = find_dts dt_info tnames tnames; - val main_fns = - map (fn (tname, {index, ...}) => - (index, - (fst o the o find_first (fn f => (#1 o snd) f = tname)) rec_eqns)) - dts; + val (fixes, spec) = prepare_spec prep_spec lthy raw_fixes raw_spec; + val eqns = fold_rev (process_eqn (member (op =) (map (fst o fst) fixes)) + (Variable.is_const lthy) o snd) spec []; + val tnames = distinct (op =) (map (#1 o snd) eqns); + val dts = find_dts (DatatypePackage.get_datatypes + (ProofContext.theory_of lthy)) tnames tnames; + val main_fns = map (fn (tname, {index, ...}) => + (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts; val {descr, rec_names, rec_rewrites, ...} = - if null dts then - primrec_err ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive") + if null dts then primrec_error + ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive") else snd (hd dts); - val (fnameTs, fnss) = - fold_rev (process_fun thy descr rec_eqns) main_fns ([], []); + val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []); val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []); - val defs' = map (make_def thy fs) defs; - val nameTs1 = map snd fnameTs; - val nameTs2 = map fst rec_eqns; + val nameTs1 = map snd fnames; + val nameTs2 = map fst eqns; val _ = if gen_eq_set (op =) (nameTs1, nameTs2) then () - else primrec_err ("functions " ^ commas_quote (map fst nameTs2) ^ - "\nare not mutually recursive"); - val primrec_name = - if alt_name = "" then (space_implode "_" (map (Sign.base_name o #1) defs)) else alt_name; - val (defs_thms', thy') = - thy - |> Sign.add_path primrec_name - |> fold_map def (map (fn (name, t) => ((name, []), t)) defs'); - val rewrites = (map mk_meta_eq rec_rewrites) @ map snd defs_thms'; - val _ = message ("Proving equations for primrec function(s) " ^ - commas_quote (map fst nameTs1) ^ " ..."); - val simps = map (fn (_, t) => Goal.prove_global thy' [] [] t - (fn _ => EVERY [rewrite_goals_tac rewrites, rtac refl 1])) eqns; - val (simps', thy'') = - thy' - |> fold_map note ((map fst eqns ~~ atts) ~~ map single simps); - val simps'' = maps snd simps'; + else primrec_error ("functions " ^ commas_quote nameTs2 ^ + "\nare not mutually recursive"); + val qualify = NameSpace.qualified + (space_implode "_" (map (Sign.base_name o #1) defs)); + val simp_atts = [Attrib.internal (K Simplifier.simp_add), + Code.add_default_func_attr (*FIXME*)]; in - thy'' - |> note (("simps", [Simplifier.simp_add, RecfunCodegen.add_default]), simps'') - |> snd - |> note (("induct", []), [prepare_induct (#2 (hd dts)) rec_eqns]) - |> snd - |> Sign.parent_path - |> pair simps'' - end; - -fun gen_primrec note def alt_name eqns thy = - let - val ((names, strings), srcss) = apfst split_list (split_list eqns); - val atts = map (map (Attrib.attribute thy)) srcss; - val eqn_ts = map (fn s => Syntax.read_prop_global thy s - handle ERROR msg => cat_error msg ("The error(s) above occurred for " ^ s)) strings; - val rec_ts = map (fn eq => head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop eq))) - handle TERM _ => primrec_eq_err thy "not a proper equation" eq) eqn_ts; - val (_, eqn_ts') = unify_consts thy rec_ts eqn_ts - in - gen_primrec_i note def alt_name (names ~~ eqn_ts' ~~ atts) thy - end; - -fun thy_note ((name, atts), thms) = - PureThy.add_thmss [((name, thms), atts)] #-> (fn [thms] => pair (name, thms)); -fun thy_def false ((name, atts), t) = - PureThy.add_defs_i false [((name, t), atts)] #-> (fn [thm] => pair (name, thm)) - | thy_def true ((name, atts), t) = - PureThy.add_defs_unchecked_i false [((name, t), atts)] #-> (fn [thm] => pair (name, thm)); + lthy + |> fold_map (LocalTheory.define Thm.definitionK o make_def lthy fixes fs) defs + |-> (fn defs => `(fn ctxt => prove_spec ctxt rec_rewrites defs spec)) + |-> (fn simps => fold_map (LocalTheory.note Thm.theoremK) simps) + |-> (fn simps' => LocalTheory.note Thm.theoremK + ((qualify "simps", simp_atts), maps snd simps')) + ||>> LocalTheory.note Thm.theoremK + ((qualify "induct", []), [prepare_induct (#2 (hd dts)) eqns]) + |>> (snd o fst) + end handle PrimrecError (msg, some_eqn) => + error ("Primrec definition error:\n" ^ msg ^ (case some_eqn + of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn) + | NONE => "")); in -val add_primrec = gen_primrec thy_note (thy_def false); -val add_primrec_unchecked = gen_primrec thy_note (thy_def true); -val add_primrec_i = gen_primrec_i thy_note (thy_def false); -val add_primrec_unchecked_i = gen_primrec_i thy_note (thy_def true); -fun gen_primrec note def alt_name specs = - gen_primrec_i note def alt_name (map (fn ((name, t), atts) => ((name, atts), t)) specs); +val add_primrec = gen_primrec Specification.check_specification; +val add_primrec_cmd = gen_primrec Specification.read_specification; end; @@ -347,15 +289,27 @@ (((P.$$$ "unchecked" >> K true) -- Scan.optional P.name "" || P.name >> pair false) --| P.$$$ ")")) (false, ""); -val primrec_decl = +val old_primrec_decl = opt_unchecked_name -- Scan.repeat1 (SpecParse.opt_thm_name ":" -- P.prop); +fun pipe_error t = P.!!! (Scan.fail_with (K + (cat_lines ["Equations must be separated by " ^ quote "|", quote t]))); + +val statement = SpecParse.opt_thm_name ":" -- P.prop --| Scan.ahead + ((P.term :-- pipe_error) || Scan.succeed ("","")); + +val statements = P.enum1 "|" statement; + +val primrec_decl = P.opt_target -- P.fixes --| P.$$$ "where" -- statements; + val _ = OuterSyntax.command "primrec" "define primitive recursive functions on datatypes" K.thy_decl - (primrec_decl >> (fn ((unchecked, alt_name), eqns) => + ((primrec_decl >> (fn ((opt_target, raw_fixes), raw_spec) => + Toplevel.local_theory opt_target (add_primrec_cmd raw_fixes raw_spec #> snd))) + || (old_primrec_decl >> (fn ((unchecked, alt_name), eqns) => Toplevel.theory (snd o - (if unchecked then add_primrec_unchecked else add_primrec) alt_name - (map P.triple_swap eqns)))); + (if unchecked then OldPrimrecPackage.add_primrec_unchecked else OldPrimrecPackage.add_primrec) alt_name + (map P.triple_swap eqns))))); end; diff -r 8d3b7c27049b -r ea6b11021e79 src/HOLCF/Tools/fixrec_package.ML --- a/src/HOLCF/Tools/fixrec_package.ML Thu Dec 06 12:58:01 2007 +0100 +++ b/src/HOLCF/Tools/fixrec_package.ML Thu Dec 06 15:10:09 2007 +0100 @@ -231,7 +231,7 @@ val eqn_ts = map (prep_prop thy) strings; val rec_ts = map (fn eq => chead_of (fst (dest_eqs (Logic.strip_imp_concl eq))) handle TERM _ => fixrec_eq_err thy "not a proper equation" eq) eqn_ts; - val (_, eqn_ts') = PrimrecPackage.unify_consts thy rec_ts eqn_ts; + val (_, eqn_ts') = OldPrimrecPackage.unify_consts thy rec_ts eqn_ts; fun unconcat [] _ = [] | unconcat (n::ns) xs = List.take (xs,n) :: unconcat ns (List.drop (xs,n));