fold cleanup
authorhaftmann
Fri Oct 20 10:44:47 2006 +0200 (2006-10-20)
changeset 210649684dd7c81b5
parent 21063 3c5074f028c8
child 21065 42669b5bf98e
fold cleanup
src/HOL/Tools/primrec_package.ML
     1.1 --- a/src/HOL/Tools/primrec_package.ML	Fri Oct 20 10:44:42 2006 +0200
     1.2 +++ b/src/HOL/Tools/primrec_package.ML	Fri Oct 20 10:44:47 2006 +0200
     1.3 @@ -42,7 +42,7 @@
     1.4  
     1.5  (* preprocessing of equations *)
     1.6  
     1.7 -fun process_eqn thy (eq, rec_fns) = 
     1.8 +fun process_eqn thy eq rec_fns = 
     1.9    let
    1.10      val (lhs, rhs) = 
    1.11        if null (term_vars eq) then
    1.12 @@ -93,18 +93,20 @@
    1.13    end
    1.14    handle RecError s => primrec_eq_err thy s eq;
    1.15  
    1.16 -fun process_fun thy descr rec_eqns ((i, fnameT as (fname, _)), (fnameTs, fnss)) =
    1.17 +fun process_fun thy descr rec_eqns (i, fnameT as (fname, _)) (fnameTs, fnss) =
    1.18    let
    1.19      val (_, (tname, _, constrs)) = List.nth (descr, i);
    1.20  
    1.21      (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *)
    1.22  
    1.23 -    fun subst [] x = x
    1.24 -      | subst subs (fs, Abs (a, T, t)) =
    1.25 -          let val (fs', t') = subst subs (fs, t)
    1.26 -          in (fs', Abs (a, T, t')) end
    1.27 -      | subst subs (fs, t as (_ $ _)) =
    1.28 -          let val (f, ts) = strip_comb t;
    1.29 +    fun subst [] t fs = (t, fs)
    1.30 +      | subst subs (Abs (a, T, t)) fs =
    1.31 +          fs
    1.32 +          |> subst subs t
    1.33 +          |-> (fn t' => pair (Abs (a, T, t')))
    1.34 +      | subst subs (t as (_ $ _)) fs =
    1.35 +          let
    1.36 +            val (f, ts) = strip_comb t;
    1.37            in
    1.38              if is_Const f andalso dest_Const f mem map fst rec_eqns then
    1.39                let
    1.40 @@ -116,44 +118,41 @@
    1.41                    handle Empty => raise RecError ("not enough arguments\
    1.42                     \ in recursive application\nof function " ^ quote fname' ^ " on rhs");
    1.43                  val (x, xs) = strip_comb x'
    1.44 -              in 
    1.45 -                (case AList.lookup (op =) subs x of
    1.46 -                    NONE =>
    1.47 -                      let
    1.48 -                        val (fs', ts') = foldl_map (subst subs) (fs, ts)
    1.49 -                      in (fs', list_comb (f, ts')) end
    1.50 -                  | SOME (i', y) =>
    1.51 -                      let
    1.52 -                        val (fs', ts') = foldl_map (subst subs) (fs, xs @ ls @ rs);
    1.53 -                        val fs'' = process_fun thy descr rec_eqns ((i', fnameT'), fs')
    1.54 -                      in (fs'', list_comb (y, ts'))
    1.55 -                      end)
    1.56 +              in case AList.lookup (op =) subs x
    1.57 +               of NONE =>
    1.58 +                    fs
    1.59 +                    |> fold_map (subst subs) ts
    1.60 +                    |-> (fn ts' => pair (list_comb (f, ts')))
    1.61 +                | SOME (i', y) =>
    1.62 +                    fs
    1.63 +                    |> fold_map (subst subs) (xs @ ls @ rs)
    1.64 +                    ||> process_fun thy descr rec_eqns (i', fnameT')
    1.65 +                    |-> (fn ts' => pair (list_comb (y, ts')))
    1.66                end
    1.67              else
    1.68 -              let
    1.69 -                val (fs', f'::ts') = foldl_map (subst subs) (fs, f::ts)
    1.70 -              in (fs', list_comb (f', ts')) end
    1.71 +              fs
    1.72 +              |> fold_map (subst subs) (f :: ts)
    1.73 +              |-> (fn (f'::ts') => pair (list_comb (f', ts')))
    1.74            end
    1.75 -      | subst _ x = x;
    1.76 +      | subst _ t fs = (t, fs);
    1.77  
    1.78      (* translate rec equations into function arguments suitable for rec comb *)
    1.79  
    1.80 -    fun trans eqns ((cname, cargs), (fnameTs', fnss', fns)) =
    1.81 +    fun trans eqns (cname, cargs) (fnameTs', fnss', fns) =
    1.82        (case AList.lookup (op =) eqns cname of
    1.83            NONE => (warning ("No equation for constructor " ^ quote cname ^
    1.84              "\nin definition of function " ^ quote fname);
    1.85                (fnameTs', fnss', (Const ("arbitrary", dummyT))::fns))
    1.86          | SOME (ls, cargs', rs, rhs, eq) =>
    1.87              let
    1.88 -              val recs = List.filter (is_rec_type o snd) (cargs' ~~ cargs);
    1.89 +              val recs = filter (is_rec_type o snd) (cargs' ~~ cargs);
    1.90                val rargs = map fst recs;
    1.91                val subs = map (rpair dummyT o fst) 
    1.92                  (rev (rename_wrt_term rhs rargs));
    1.93 -              val ((fnameTs'', fnss''), rhs') = 
    1.94 +              val (rhs', (fnameTs'', fnss'')) = 
    1.95                    (subst (map (fn ((x, y), z) =>
    1.96                                 (Free x, (body_index y, Free z)))
    1.97 -                          (recs ~~ subs))
    1.98 -                   ((fnameTs', fnss'), rhs))
    1.99 +                          (recs ~~ subs)) rhs (fnameTs', fnss'))
   1.100                    handle RecError s => primrec_eq_err thy s eq
   1.101              in (fnameTs'', fnss'', 
   1.102                  (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns)
   1.103 @@ -166,8 +165,8 @@
   1.104          else
   1.105            let
   1.106              val (_, _, eqns) = the (AList.lookup (op =) rec_eqns fnameT);
   1.107 -            val (fnameTs', fnss', fns) = foldr (trans eqns)
   1.108 -              ((i, fnameT)::fnameTs, fnss, []) constrs
   1.109 +            val (fnameTs', fnss', fns) = fold_rev (trans eqns) constrs
   1.110 +              ((i, fnameT)::fnameTs, fnss, []) 
   1.111            in
   1.112              (fnameTs', (i, (fname, #1 (snd (hd eqns)), fns))::fnss')
   1.113            end
   1.114 @@ -179,7 +178,7 @@
   1.115  
   1.116  (* prepare functions needed for definitions *)
   1.117  
   1.118 -fun get_fns fns (((i : int, (tname, _, constrs)), rec_name), (fs, defs)) =
   1.119 +fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) =
   1.120    case AList.lookup (op =) fns i of
   1.121       NONE =>
   1.122         let
   1.123 @@ -190,17 +189,17 @@
   1.124         in
   1.125           (dummy_fns @ fs, defs)
   1.126         end
   1.127 -   | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname)::defs);
   1.128 +   | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs);
   1.129  
   1.130  
   1.131  (* make definition *)
   1.132  
   1.133  fun make_def thy fs (fname, ls, rec_name, tname) =
   1.134    let
   1.135 -    val rhs = foldr (fn (T, t) => Abs ("", T, t)) 
   1.136 +    val rhs = fold_rev (fn T => fn t => Abs ("", T, t))
   1.137 +                    ((map snd ls) @ [dummyT])
   1.138                      (list_comb (Const (rec_name, dummyT),
   1.139                                  fs @ map Bound (0 ::(length ls downto 1))))
   1.140 -                    ((map snd ls) @ [dummyT]);
   1.141      val defpair = (Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def",
   1.142                     Logic.mk_equals (Const (fname, dummyT), rhs))
   1.143    in Theory.inferT_axm thy defpair end;
   1.144 @@ -234,7 +233,7 @@
   1.145    let
   1.146      val (eqns, atts) = split_list eqns_atts;
   1.147      val dt_info = DatatypePackage.get_datatypes thy;
   1.148 -    val rec_eqns = foldr (process_eqn thy) [] (map snd eqns);
   1.149 +    val rec_eqns = fold_rev (process_eqn thy o snd) eqns [] ;
   1.150      val tnames = distinct (op =) (map (#1 o snd) rec_eqns);
   1.151      val dts = find_dts dt_info tnames tnames;
   1.152      val main_fns = 
   1.153 @@ -247,8 +246,8 @@
   1.154          primrec_err ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive")
   1.155        else snd (hd dts);
   1.156      val (fnameTs, fnss) =
   1.157 -      foldr (process_fun thy descr rec_eqns) ([], []) main_fns;
   1.158 -    val (fs, defs) = foldr (get_fns fnss) ([], []) (descr ~~ rec_names);
   1.159 +      fold_rev (process_fun thy descr rec_eqns) main_fns ([], []);
   1.160 +    val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []);
   1.161      val defs' = map (make_def thy fs) defs;
   1.162      val nameTs1 = map snd fnameTs;
   1.163      val nameTs2 = map fst rec_eqns;