author haftmann Fri, 20 Oct 2006 10:44:47 +0200 changeset 21064 9684dd7c81b5 parent 21063 3c5074f028c8 child 21065 42669b5bf98e
fold cleanup
```--- a/src/HOL/Tools/primrec_package.ML	Fri Oct 20 10:44:42 2006 +0200
+++ b/src/HOL/Tools/primrec_package.ML	Fri Oct 20 10:44:47 2006 +0200
@@ -42,7 +42,7 @@

(* preprocessing of equations *)

-fun process_eqn thy (eq, rec_fns) =
+fun process_eqn thy eq rec_fns =
let
val (lhs, rhs) =
if null (term_vars eq) then
@@ -93,18 +93,20 @@
end
handle RecError s => primrec_eq_err thy s eq;

-fun process_fun thy descr rec_eqns ((i, fnameT as (fname, _)), (fnameTs, fnss)) =
+fun process_fun thy descr rec_eqns (i, fnameT as (fname, _)) (fnameTs, fnss) =
let
val (_, (tname, _, constrs)) = List.nth (descr, i);

(* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)

-    fun subst [] x = x
-      | subst subs (fs, Abs (a, T, t)) =
-          let val (fs', t') = subst subs (fs, t)
-          in (fs', Abs (a, T, t')) end
-      | subst subs (fs, t as (_ \$ _)) =
-          let val (f, ts) = strip_comb t;
+    fun subst [] t fs = (t, fs)
+      | subst subs (Abs (a, T, t)) fs =
+          fs
+          |> subst subs t
+          |-> (fn t' => pair (Abs (a, T, t')))
+      | subst subs (t as (_ \$ _)) fs =
+          let
+            val (f, ts) = strip_comb t;
in
if is_Const f andalso dest_Const f mem map fst rec_eqns then
let
@@ -116,44 +118,41 @@
handle Empty => raise RecError ("not enough arguments\
\ in recursive application\nof function " ^ quote fname' ^ " on rhs");
val (x, xs) = strip_comb x'
-              in
-                (case AList.lookup (op =) subs x of
-                    NONE =>
-                      let
-                        val (fs', ts') = foldl_map (subst subs) (fs, ts)
-                      in (fs', list_comb (f, ts')) end
-                  | SOME (i', y) =>
-                      let
-                        val (fs', ts') = foldl_map (subst subs) (fs, xs @ ls @ rs);
-                        val fs'' = process_fun thy descr rec_eqns ((i', fnameT'), fs')
-                      in (fs'', list_comb (y, ts'))
-                      end)
+              in case AList.lookup (op =) subs x
+               of NONE =>
+                    fs
+                    |> fold_map (subst subs) ts
+                    |-> (fn ts' => pair (list_comb (f, ts')))
+                | SOME (i', y) =>
+                    fs
+                    |> fold_map (subst subs) (xs @ ls @ rs)
+                    ||> process_fun thy descr rec_eqns (i', fnameT')
+                    |-> (fn ts' => pair (list_comb (y, ts')))
end
else
-              let
-                val (fs', f'::ts') = foldl_map (subst subs) (fs, f::ts)
-              in (fs', list_comb (f', ts')) end
+              fs
+              |> fold_map (subst subs) (f :: ts)
+              |-> (fn (f'::ts') => pair (list_comb (f', ts')))
end
-      | subst _ x = x;
+      | subst _ t fs = (t, fs);

(* translate rec equations into function arguments suitable for rec comb *)

-    fun trans eqns ((cname, cargs), (fnameTs', fnss', fns)) =
+    fun trans eqns (cname, cargs) (fnameTs', fnss', fns) =
(case AList.lookup (op =) eqns cname of
NONE => (warning ("No equation for constructor " ^ quote cname ^
"\nin definition of function " ^ quote fname);
(fnameTs', fnss', (Const ("arbitrary", dummyT))::fns))
| SOME (ls, cargs', rs, rhs, eq) =>
let
-              val recs = List.filter (is_rec_type o snd) (cargs' ~~ cargs);
+              val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
val rargs = map fst recs;
val subs = map (rpair dummyT o fst)
(rev (rename_wrt_term rhs rargs));
-              val ((fnameTs'', fnss''), rhs') =
+              val (rhs', (fnameTs'', fnss'')) =
(subst (map (fn ((x, y), z) =>
(Free x, (body_index y, Free z)))
-                          (recs ~~ subs))
-                   ((fnameTs', fnss'), rhs))
+                          (recs ~~ subs)) rhs (fnameTs', fnss'))
handle RecError s => primrec_eq_err thy s eq
in (fnameTs'', fnss'',
(list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
@@ -166,8 +165,8 @@
else
let
val (_, _, eqns) = the (AList.lookup (op =) rec_eqns fnameT);
-            val (fnameTs', fnss', fns) = foldr (trans eqns)
-              ((i, fnameT)::fnameTs, fnss, []) constrs
+            val (fnameTs', fnss', fns) = fold_rev (trans eqns) constrs
+              ((i, fnameT)::fnameTs, fnss, [])
in
(fnameTs', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
end
@@ -179,7 +178,7 @@

(* prepare functions needed for definitions *)

-fun get_fns fns (((i : int, (tname, _, constrs)), rec_name), (fs, defs)) =
+fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) =
case AList.lookup (op =) fns i of
NONE =>
let
@@ -190,17 +189,17 @@
in
(dummy_fns @ fs, defs)
end
-   | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname)::defs);
+   | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs);

(* make definition *)

fun make_def thy fs (fname, ls, rec_name, tname) =
let
-    val rhs = foldr (fn (T, t) => Abs ("", T, t))
+    val rhs = fold_rev (fn T => fn t => Abs ("", T, t))
+                    ((map snd ls) @ [dummyT])
(list_comb (Const (rec_name, dummyT),
fs @ map Bound (0 ::(length ls downto 1))))
-                    ((map snd ls) @ [dummyT]);
val defpair = (Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def",
Logic.mk_equals (Const (fname, dummyT), rhs))
in Theory.inferT_axm thy defpair end;
@@ -234,7 +233,7 @@
let
val (eqns, atts) = split_list eqns_atts;
val dt_info = DatatypePackage.get_datatypes thy;
-    val rec_eqns = foldr (process_eqn thy) [] (map snd eqns);
+    val rec_eqns = fold_rev (process_eqn thy o snd) eqns [] ;
val tnames = distinct (op =) (map (#1 o snd) rec_eqns);
val dts = find_dts dt_info tnames tnames;
val main_fns =
@@ -247,8 +246,8 @@
primrec_err ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
else snd (hd dts);
val (fnameTs, fnss) =
-      foldr (process_fun thy descr rec_eqns) ([], []) main_fns;
-    val (fs, defs) = foldr (get_fns fnss) ([], []) (descr ~~ rec_names);
+      fold_rev (process_fun thy descr rec_eqns) main_fns ([], []);
+    val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
val defs' = map (make_def thy fs) defs;
val nameTs1 = map snd fnameTs;
val nameTs2 = map fst rec_eqns;```