fold cleanup
authorhaftmann
Fri, 20 Oct 2006 10:44:47 +0200
changeset 21064 9684dd7c81b5
parent 21063 3c5074f028c8
child 21065 42669b5bf98e
fold cleanup
src/HOL/Tools/primrec_package.ML
--- a/src/HOL/Tools/primrec_package.ML	Fri Oct 20 10:44:42 2006 +0200
+++ b/src/HOL/Tools/primrec_package.ML	Fri Oct 20 10:44:47 2006 +0200
@@ -42,7 +42,7 @@
 
 (* preprocessing of equations *)
 
-fun process_eqn thy (eq, rec_fns) = 
+fun process_eqn thy eq rec_fns = 
   let
     val (lhs, rhs) = 
       if null (term_vars eq) then
@@ -93,18 +93,20 @@
   end
   handle RecError s => primrec_eq_err thy s eq;
 
-fun process_fun thy descr rec_eqns ((i, fnameT as (fname, _)), (fnameTs, fnss)) =
+fun process_fun thy descr rec_eqns (i, fnameT as (fname, _)) (fnameTs, fnss) =
   let
     val (_, (tname, _, constrs)) = List.nth (descr, i);
 
     (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)
 
-    fun subst [] x = x
-      | subst subs (fs, Abs (a, T, t)) =
-          let val (fs', t') = subst subs (fs, t)
-          in (fs', Abs (a, T, t')) end
-      | subst subs (fs, t as (_ $ _)) =
-          let val (f, ts) = strip_comb t;
+    fun subst [] t fs = (t, fs)
+      | subst subs (Abs (a, T, t)) fs =
+          fs
+          |> subst subs t
+          |-> (fn t' => pair (Abs (a, T, t')))
+      | subst subs (t as (_ $ _)) fs =
+          let
+            val (f, ts) = strip_comb t;
           in
             if is_Const f andalso dest_Const f mem map fst rec_eqns then
               let
@@ -116,44 +118,41 @@
                   handle Empty => raise RecError ("not enough arguments\
                    \ in recursive application\nof function " ^ quote fname' ^ " on rhs");
                 val (x, xs) = strip_comb x'
-              in 
-                (case AList.lookup (op =) subs x of
-                    NONE =>
-                      let
-                        val (fs', ts') = foldl_map (subst subs) (fs, ts)
-                      in (fs', list_comb (f, ts')) end
-                  | SOME (i', y) =>
-                      let
-                        val (fs', ts') = foldl_map (subst subs) (fs, xs @ ls @ rs);
-                        val fs'' = process_fun thy descr rec_eqns ((i', fnameT'), fs')
-                      in (fs'', list_comb (y, ts'))
-                      end)
+              in case AList.lookup (op =) subs x
+               of NONE =>
+                    fs
+                    |> fold_map (subst subs) ts
+                    |-> (fn ts' => pair (list_comb (f, ts')))
+                | SOME (i', y) =>
+                    fs
+                    |> fold_map (subst subs) (xs @ ls @ rs)
+                    ||> process_fun thy descr rec_eqns (i', fnameT')
+                    |-> (fn ts' => pair (list_comb (y, ts')))
               end
             else
-              let
-                val (fs', f'::ts') = foldl_map (subst subs) (fs, f::ts)
-              in (fs', list_comb (f', ts')) end
+              fs
+              |> fold_map (subst subs) (f :: ts)
+              |-> (fn (f'::ts') => pair (list_comb (f', ts')))
           end
-      | subst _ x = x;
+      | subst _ t fs = (t, fs);
 
     (* translate rec equations into function arguments suitable for rec comb *)
 
-    fun trans eqns ((cname, cargs), (fnameTs', fnss', fns)) =
+    fun trans eqns (cname, cargs) (fnameTs', fnss', fns) =
       (case AList.lookup (op =) eqns cname of
           NONE => (warning ("No equation for constructor " ^ quote cname ^
             "\nin definition of function " ^ quote fname);
               (fnameTs', fnss', (Const ("arbitrary", dummyT))::fns))
         | SOME (ls, cargs', rs, rhs, eq) =>
             let
-              val recs = List.filter (is_rec_type o snd) (cargs' ~~ cargs);
+              val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
               val rargs = map fst recs;
               val subs = map (rpair dummyT o fst) 
                 (rev (rename_wrt_term rhs rargs));
-              val ((fnameTs'', fnss''), rhs') = 
+              val (rhs', (fnameTs'', fnss'')) = 
                   (subst (map (fn ((x, y), z) =>
                                (Free x, (body_index y, Free z)))
-                          (recs ~~ subs))
-                   ((fnameTs', fnss'), rhs))
+                          (recs ~~ subs)) rhs (fnameTs', fnss'))
                   handle RecError s => primrec_eq_err thy s eq
             in (fnameTs'', fnss'', 
                 (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
@@ -166,8 +165,8 @@
         else
           let
             val (_, _, eqns) = the (AList.lookup (op =) rec_eqns fnameT);
-            val (fnameTs', fnss', fns) = foldr (trans eqns)
-              ((i, fnameT)::fnameTs, fnss, []) constrs
+            val (fnameTs', fnss', fns) = fold_rev (trans eqns) constrs
+              ((i, fnameT)::fnameTs, fnss, []) 
           in
             (fnameTs', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
           end
@@ -179,7 +178,7 @@
 
 (* prepare functions needed for definitions *)
 
-fun get_fns fns (((i : int, (tname, _, constrs)), rec_name), (fs, defs)) =
+fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) =
   case AList.lookup (op =) fns i of
      NONE =>
        let
@@ -190,17 +189,17 @@
        in
          (dummy_fns @ fs, defs)
        end
-   | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname)::defs);
+   | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs);
 
 
 (* make definition *)
 
 fun make_def thy fs (fname, ls, rec_name, tname) =
   let
-    val rhs = foldr (fn (T, t) => Abs ("", T, t)) 
+    val rhs = fold_rev (fn T => fn t => Abs ("", T, t))
+                    ((map snd ls) @ [dummyT])
                     (list_comb (Const (rec_name, dummyT),
                                 fs @ map Bound (0 ::(length ls downto 1))))
-                    ((map snd ls) @ [dummyT]);
     val defpair = (Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def",
                    Logic.mk_equals (Const (fname, dummyT), rhs))
   in Theory.inferT_axm thy defpair end;
@@ -234,7 +233,7 @@
   let
     val (eqns, atts) = split_list eqns_atts;
     val dt_info = DatatypePackage.get_datatypes thy;
-    val rec_eqns = foldr (process_eqn thy) [] (map snd eqns);
+    val rec_eqns = fold_rev (process_eqn thy o snd) eqns [] ;
     val tnames = distinct (op =) (map (#1 o snd) rec_eqns);
     val dts = find_dts dt_info tnames tnames;
     val main_fns = 
@@ -247,8 +246,8 @@
         primrec_err ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
       else snd (hd dts);
     val (fnameTs, fnss) =
-      foldr (process_fun thy descr rec_eqns) ([], []) main_fns;
-    val (fs, defs) = foldr (get_fns fnss) ([], []) (descr ~~ rec_names);
+      fold_rev (process_fun thy descr rec_eqns) main_fns ([], []);
+    val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
     val defs' = map (make_def thy fs) defs;
     val nameTs1 = map snd fnameTs;
     val nameTs2 = map fst rec_eqns;