--- a/src/HOL/Tools/primrec_package.ML Fri Jul 21 14:45:43 2006 +0200
+++ b/src/HOL/Tools/primrec_package.ML Fri Jul 21 14:46:27 2006 +0200
@@ -8,14 +8,15 @@
signature PRIMREC_PACKAGE =
sig
val quiet_mode: bool ref
+ val mk_combdefs: theory -> term list -> (string * term) list
val add_primrec: string -> ((bstring * string) * Attrib.src list) list
- -> theory -> theory * thm list
+ -> theory -> thm list * theory
val add_primrec_unchecked: string -> ((bstring * string) * Attrib.src list) list
- -> theory -> theory * thm list
+ -> theory -> thm list * theory
val add_primrec_i: string -> ((bstring * term) * attribute list) list
- -> theory -> theory * thm list
+ -> theory -> thm list * theory
val add_primrec_unchecked_i: string -> ((bstring * term) * attribute list) list
- -> theory -> theory * thm list
+ -> theory -> thm list * theory
end;
structure PrimrecPackage : PRIMREC_PACKAGE =
@@ -26,8 +27,8 @@
exception RecError of string;
fun primrec_err s = error ("Primrec definition error:\n" ^ s);
-fun primrec_eq_err sign s eq =
- primrec_err (s ^ "\nin\n" ^ quote (Sign.string_of_term sign eq));
+fun primrec_eq_err thy s eq =
+ primrec_err (s ^ "\nin\n" ^ quote (Sign.string_of_term thy eq));
(* messages *)
@@ -38,13 +39,13 @@
(* preprocessing of equations *)
-fun process_eqn sign (eq, rec_fns) =
+fun process_eqn thy (eq, rec_fns) =
let
val (lhs, rhs) =
- if null (term_vars eq) then
- HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
- handle TERM _ => raise RecError "not a proper equation"
- else raise RecError "illegal schematic variable(s)";
+ if null (term_vars eq) then
+ HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
+ handle TERM _ => raise RecError "not a proper equation"
+ else raise RecError "illegal schematic variable(s)";
val (recfun, args) = strip_comb lhs;
val fnameT = dest_Const recfun handle TERM _ =>
@@ -61,9 +62,8 @@
val (tname, _) = dest_Type (body_type T) handle TYPE _ =>
raise RecError "cannot determine datatype associated with function"
- val (ls, cargs, rs) = (map dest_Free ls',
- map dest_Free cargs',
- map dest_Free rs')
+ val (ls, cargs, rs) =
+ (map dest_Free ls', map dest_Free cargs', map dest_Free rs')
handle TERM _ => raise RecError "illegal argument in pattern";
val lfrees = ls @ rs @ cargs;
@@ -88,9 +88,9 @@
AList.update (op =) (fnameT, (tname, rpos, (cname, (ls, cargs, rs, rhs, eq))::eqns))
rec_fns)
end
- handle RecError s => primrec_eq_err sign s eq;
+ handle RecError s => primrec_eq_err thy s eq;
-fun process_fun sign descr rec_eqns ((i, fnameT as (fname, _)), (fnameTs, fnss)) =
+fun process_fun thy descr rec_eqns ((i, fnameT as (fname, _)), (fnameTs, fnss)) =
let
val (_, (tname, _, constrs)) = List.nth (descr, i);
@@ -122,7 +122,7 @@
| SOME (i', y) =>
let
val (fs', ts') = foldl_map (subst subs) (fs, xs @ ls @ rs);
- val fs'' = process_fun sign descr rec_eqns ((i', fnameT'), fs')
+ val fs'' = process_fun thy descr rec_eqns ((i', fnameT'), fs')
in (fs'', list_comb (y, ts'))
end)
end
@@ -145,15 +145,15 @@
val recs = List.filter (is_rec_type o snd) (cargs' ~~ cargs);
val rargs = map fst recs;
val subs = map (rpair dummyT o fst)
- (rev (rename_wrt_term rhs rargs));
+ (rev (rename_wrt_term rhs rargs));
val ((fnameTs'', fnss''), rhs') =
- (subst (map (fn ((x, y), z) =>
- (Free x, (body_index y, Free z)))
- (recs ~~ subs))
- ((fnameTs', fnss'), rhs))
- handle RecError s => primrec_eq_err sign s eq
+ (subst (map (fn ((x, y), z) =>
+ (Free x, (body_index y, Free z)))
+ (recs ~~ subs))
+ ((fnameTs', fnss'), rhs))
+ handle RecError s => primrec_eq_err thy s eq
in (fnameTs'', fnss'',
- (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
+ (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
end)
in (case AList.lookup (op =) fnameTs i of
@@ -176,7 +176,7 @@
(* prepare functions needed for definitions *)
-fun get_fns fns (((i, (tname, _, constrs)), rec_name), (fs, defs)) =
+fun get_fns fns (((i : int, (tname, _, constrs)), rec_name), (fs, defs)) =
case AList.lookup (op =) fns i of
NONE =>
let
@@ -192,15 +192,15 @@
(* make definition *)
-fun make_def sign fs (fname, ls, rec_name, tname) =
+fun make_def thy fs (fname, ls, rec_name, tname) =
let
val rhs = foldr (fn (T, t) => Abs ("", T, t))
- (list_comb (Const (rec_name, dummyT),
- fs @ map Bound (0 ::(length ls downto 1))))
- ((map snd ls) @ [dummyT]);
+ (list_comb (Const (rec_name, dummyT),
+ fs @ map Bound (0 ::(length ls downto 1))))
+ ((map snd ls) @ [dummyT]);
val defpair = (Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def",
- Logic.mk_equals (Const (fname, dummyT), rhs))
- in Theory.inferT_axm sign defpair end;
+ Logic.mk_equals (Const (fname, dummyT), rhs))
+ in Theory.inferT_axm thy defpair end;
(* find datatypes which contain all datatypes in tnames' *)
@@ -225,12 +225,10 @@
|> RuleCases.save induction
end;
-fun gen_primrec_i unchecked alt_name eqns_atts thy =
+fun mk_defs thy eqns =
let
- val (eqns, atts) = split_list eqns_atts;
- val sg = Theory.sign_of thy;
val dt_info = DatatypePackage.get_datatypes thy;
- val rec_eqns = foldr (process_eqn sg) [] (map snd eqns);
+ val rec_eqns = foldr (process_eqn thy) [] eqns;
val tnames = distinct (op =) (map (#1 o snd) rec_eqns);
val dts = find_dts dt_info tnames tnames;
val main_fns =
@@ -242,10 +240,20 @@
if null dts then
primrec_err ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
else snd (hd dts);
- val (fnameTs, fnss) = foldr (process_fun sg descr rec_eqns)
- ([], []) main_fns;
+ val (fnameTs, fnss) =
+ foldr (process_fun thy descr rec_eqns) ([], []) main_fns;
val (fs, defs) = foldr (get_fns fnss) ([], []) (descr ~~ rec_names);
- val defs' = map (make_def sg fs) defs;
+ val defs' = map (make_def thy fs) defs;
+ in (fnameTs, rec_eqns, rec_rewrites, dts, defs, defs') end;
+
+fun mk_combdefs thy =
+ #6 o mk_defs thy o map (ObjectLogic.ensure_propT thy);
+
+fun gen_primrec_i unchecked alt_name eqns_atts thy =
+ let
+ val (eqns, atts) = split_list eqns_atts;
+ val (fnameTs, rec_eqns, rec_rewrites, dts, defs, defs') =
+ mk_defs thy (map snd eqns);
val nameTs1 = map snd fnameTs;
val nameTs2 = map fst rec_eqns;
val primrec_name =
@@ -261,26 +269,25 @@
commas_quote (map fst nameTs1) ^ " ...");
val simps = map (fn (_, t) => Goal.prove_global thy' [] [] t
(fn _ => EVERY [rewrite_goals_tac rewrites, rtac refl 1])) eqns;
- val (simps', thy'') = PureThy.add_thms ((map fst eqns ~~ simps) ~~ atts) thy';
- val thy''' = thy''
- |> (snd o PureThy.add_thmss [(("simps", simps'),
- [Simplifier.simp_add, RecfunCodegen.add NONE])])
- |> (snd o PureThy.add_thms [(("induct", prepare_induct (#2 (hd dts)) rec_eqns), [])])
- |> Theory.parent_path
+ val (simps', thy'') = thy' |> PureThy.add_thms ((map fst eqns ~~ simps) ~~ atts);
in
- (thy''', simps')
+ thy''
+ |> (snd o PureThy.add_thmss [(("simps", simps'),
+ [Simplifier.simp_add, RecfunCodegen.add NONE])])
+ |> (snd o PureThy.add_thms [(("induct", prepare_induct (#2 (hd dts)) rec_eqns), [])])
+ |> Theory.parent_path
+ |> pair simps'
end;
fun gen_primrec unchecked alt_name eqns thy =
let
- val sign = Theory.sign_of thy;
val ((names, strings), srcss) = apfst split_list (split_list eqns);
val atts = map (map (Attrib.attribute thy)) srcss;
- val eqn_ts = map (fn s => term_of (Thm.read_cterm sign (s, propT))
+ val eqn_ts = map (fn s => term_of (Thm.read_cterm thy (s, propT))
handle ERROR msg => cat_error msg ("The error(s) above occurred for " ^ s)) strings;
val rec_ts = map (fn eq => head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop eq)))
- handle TERM _ => primrec_eq_err sign "not a proper equation" eq) eqn_ts;
- val (_, eqn_ts') = InductivePackage.unify_consts (sign_of thy) rec_ts eqn_ts
+ handle TERM _ => primrec_eq_err thy "not a proper equation" eq) eqn_ts;
+ val (_, eqn_ts') = InductivePackage.unify_consts thy rec_ts eqn_ts
in
gen_primrec_i unchecked alt_name (names ~~ eqn_ts' ~~ atts) thy
end;
@@ -306,7 +313,7 @@
val primrecP =
OuterSyntax.command "primrec" "define primitive recursive functions on datatypes" K.thy_decl
(primrec_decl >> (fn ((unchecked, alt_name), eqns) =>
- Toplevel.theory (#1 o
+ Toplevel.theory (snd o
(if unchecked then add_primrec_unchecked else add_primrec) alt_name
(map P.triple_swap eqns))));