--- a/src/HOL/Tools/primrec_package.ML Fri Oct 20 10:44:42 2006 +0200
+++ b/src/HOL/Tools/primrec_package.ML Fri Oct 20 10:44:47 2006 +0200
@@ -42,7 +42,7 @@
(* preprocessing of equations *)
-fun process_eqn thy (eq, rec_fns) =
+fun process_eqn thy eq rec_fns =
let
val (lhs, rhs) =
if null (term_vars eq) then
@@ -93,18 +93,20 @@
end
handle RecError s => primrec_eq_err thy s eq;
-fun process_fun thy descr rec_eqns ((i, fnameT as (fname, _)), (fnameTs, fnss)) =
+fun process_fun thy descr rec_eqns (i, fnameT as (fname, _)) (fnameTs, fnss) =
let
val (_, (tname, _, constrs)) = List.nth (descr, i);
(* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)
- fun subst [] x = x
- | subst subs (fs, Abs (a, T, t)) =
- let val (fs', t') = subst subs (fs, t)
- in (fs', Abs (a, T, t')) end
- | subst subs (fs, t as (_ $ _)) =
- let val (f, ts) = strip_comb t;
+ fun subst [] t fs = (t, fs)
+ | subst subs (Abs (a, T, t)) fs =
+ fs
+ |> subst subs t
+ |-> (fn t' => pair (Abs (a, T, t')))
+ | subst subs (t as (_ $ _)) fs =
+ let
+ val (f, ts) = strip_comb t;
in
if is_Const f andalso dest_Const f mem map fst rec_eqns then
let
@@ -116,44 +118,41 @@
handle Empty => raise RecError ("not enough arguments\
\ in recursive application\nof function " ^ quote fname' ^ " on rhs");
val (x, xs) = strip_comb x'
- in
- (case AList.lookup (op =) subs x of
- NONE =>
- let
- val (fs', ts') = foldl_map (subst subs) (fs, ts)
- in (fs', list_comb (f, ts')) end
- | SOME (i', y) =>
- let
- val (fs', ts') = foldl_map (subst subs) (fs, xs @ ls @ rs);
- val fs'' = process_fun thy descr rec_eqns ((i', fnameT'), fs')
- in (fs'', list_comb (y, ts'))
- end)
+ in case AList.lookup (op =) subs x
+ of NONE =>
+ fs
+ |> fold_map (subst subs) ts
+ |-> (fn ts' => pair (list_comb (f, ts')))
+ | SOME (i', y) =>
+ fs
+ |> fold_map (subst subs) (xs @ ls @ rs)
+ ||> process_fun thy descr rec_eqns (i', fnameT')
+ |-> (fn ts' => pair (list_comb (y, ts')))
end
else
- let
- val (fs', f'::ts') = foldl_map (subst subs) (fs, f::ts)
- in (fs', list_comb (f', ts')) end
+ fs
+ |> fold_map (subst subs) (f :: ts)
+ |-> (fn (f'::ts') => pair (list_comb (f', ts')))
end
- | subst _ x = x;
+ | subst _ t fs = (t, fs);
(* translate rec equations into function arguments suitable for rec comb *)
- fun trans eqns ((cname, cargs), (fnameTs', fnss', fns)) =
+ fun trans eqns (cname, cargs) (fnameTs', fnss', fns) =
(case AList.lookup (op =) eqns cname of
NONE => (warning ("No equation for constructor " ^ quote cname ^
"\nin definition of function " ^ quote fname);
(fnameTs', fnss', (Const ("arbitrary", dummyT))::fns))
| SOME (ls, cargs', rs, rhs, eq) =>
let
- val recs = List.filter (is_rec_type o snd) (cargs' ~~ cargs);
+ val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
val rargs = map fst recs;
val subs = map (rpair dummyT o fst)
(rev (rename_wrt_term rhs rargs));
- val ((fnameTs'', fnss''), rhs') =
+ val (rhs', (fnameTs'', fnss'')) =
(subst (map (fn ((x, y), z) =>
(Free x, (body_index y, Free z)))
- (recs ~~ subs))
- ((fnameTs', fnss'), rhs))
+ (recs ~~ subs)) rhs (fnameTs', fnss'))
handle RecError s => primrec_eq_err thy s eq
in (fnameTs'', fnss'',
(list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
@@ -166,8 +165,8 @@
else
let
val (_, _, eqns) = the (AList.lookup (op =) rec_eqns fnameT);
- val (fnameTs', fnss', fns) = foldr (trans eqns)
- ((i, fnameT)::fnameTs, fnss, []) constrs
+ val (fnameTs', fnss', fns) = fold_rev (trans eqns) constrs
+ ((i, fnameT)::fnameTs, fnss, [])
in
(fnameTs', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
end
@@ -179,7 +178,7 @@
(* prepare functions needed for definitions *)
-fun get_fns fns (((i : int, (tname, _, constrs)), rec_name), (fs, defs)) =
+fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) =
case AList.lookup (op =) fns i of
NONE =>
let
@@ -190,17 +189,17 @@
in
(dummy_fns @ fs, defs)
end
- | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname)::defs);
+ | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs);
(* make definition *)
fun make_def thy fs (fname, ls, rec_name, tname) =
let
- val rhs = foldr (fn (T, t) => Abs ("", T, t))
+ val rhs = fold_rev (fn T => fn t => Abs ("", T, t))
+ ((map snd ls) @ [dummyT])
(list_comb (Const (rec_name, dummyT),
fs @ map Bound (0 ::(length ls downto 1))))
- ((map snd ls) @ [dummyT]);
val defpair = (Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def",
Logic.mk_equals (Const (fname, dummyT), rhs))
in Theory.inferT_axm thy defpair end;
@@ -234,7 +233,7 @@
let
val (eqns, atts) = split_list eqns_atts;
val dt_info = DatatypePackage.get_datatypes thy;
- val rec_eqns = foldr (process_eqn thy) [] (map snd eqns);
+ val rec_eqns = fold_rev (process_eqn thy o snd) eqns [] ;
val tnames = distinct (op =) (map (#1 o snd) rec_eqns);
val dts = find_dts dt_info tnames tnames;
val main_fns =
@@ -247,8 +246,8 @@
primrec_err ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
else snd (hd dts);
val (fnameTs, fnss) =
- foldr (process_fun thy descr rec_eqns) ([], []) main_fns;
- val (fs, defs) = foldr (get_fns fnss) ([], []) (descr ~~ rec_names);
+ fold_rev (process_fun thy descr rec_eqns) main_fns ([], []);
+ val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
val defs' = map (make_def thy fs) defs;
val nameTs1 = map snd fnameTs;
val nameTs2 = map fst rec_eqns;