# HG changeset patch # User haftmann # Date 1161333887 -7200 # Node ID 9684dd7c81b5df4a826fc6d76e3e4d95adbe91b5 # Parent 3c5074f028c86154ebc162a02ce667218e3b2790 fold cleanup diff -r 3c5074f028c8 -r 9684dd7c81b5 src/HOL/Tools/primrec_package.ML --- a/src/HOL/Tools/primrec_package.ML Fri Oct 20 10:44:42 2006 +0200 +++ b/src/HOL/Tools/primrec_package.ML Fri Oct 20 10:44:47 2006 +0200 @@ -42,7 +42,7 @@ (* preprocessing of equations *) -fun process_eqn thy (eq, rec_fns) = +fun process_eqn thy eq rec_fns = let val (lhs, rhs) = if null (term_vars eq) then @@ -93,18 +93,20 @@ end handle RecError s => primrec_eq_err thy s eq; -fun process_fun thy descr rec_eqns ((i, fnameT as (fname, _)), (fnameTs, fnss)) = +fun process_fun thy descr rec_eqns (i, fnameT as (fname, _)) (fnameTs, fnss) = let val (_, (tname, _, constrs)) = List.nth (descr, i); (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *) - fun subst [] x = x - | subst subs (fs, Abs (a, T, t)) = - let val (fs', t') = subst subs (fs, t) - in (fs', Abs (a, T, t')) end - | subst subs (fs, t as (_ $ _)) = - let val (f, ts) = strip_comb t; + fun subst [] t fs = (t, fs) + | subst subs (Abs (a, T, t)) fs = + fs + |> subst subs t + |-> (fn t' => pair (Abs (a, T, t'))) + | subst subs (t as (_ $ _)) fs = + let + val (f, ts) = strip_comb t; in if is_Const f andalso dest_Const f mem map fst rec_eqns then let @@ -116,44 +118,41 @@ handle Empty => raise RecError ("not enough arguments\ \ in recursive application\nof function " ^ quote fname' ^ " on rhs"); val (x, xs) = strip_comb x' - in - (case AList.lookup (op =) subs x of - NONE => - let - val (fs', ts') = foldl_map (subst subs) (fs, ts) - in (fs', list_comb (f, ts')) end - | SOME (i', y) => - let - val (fs', ts') = foldl_map (subst subs) (fs, xs @ ls @ rs); - val fs'' = process_fun thy descr rec_eqns ((i', fnameT'), fs') - in (fs'', list_comb (y, ts')) - end) + in case AList.lookup (op =) subs x + of NONE => + fs + |> fold_map (subst subs) ts + |-> (fn ts' => pair (list_comb (f, ts'))) + | SOME (i', y) => + fs + |> fold_map (subst subs) (xs @ ls @ rs) + ||> process_fun thy descr rec_eqns (i', fnameT') + |-> (fn ts' => pair (list_comb (y, ts'))) end else - let - val (fs', f'::ts') = foldl_map (subst subs) (fs, f::ts) - in (fs', list_comb (f', ts')) end + fs + |> fold_map (subst subs) (f :: ts) + |-> (fn (f'::ts') => pair (list_comb (f', ts'))) end - | subst _ x = x; + | subst _ t fs = (t, fs); (* translate rec equations into function arguments suitable for rec comb *) - fun trans eqns ((cname, cargs), (fnameTs', fnss', fns)) = + fun trans eqns (cname, cargs) (fnameTs', fnss', fns) = (case AList.lookup (op =) eqns cname of NONE => (warning ("No equation for constructor " ^ quote cname ^ "\nin definition of function " ^ quote fname); (fnameTs', fnss', (Const ("arbitrary", dummyT))::fns)) | SOME (ls, cargs', rs, rhs, eq) => let - val recs = List.filter (is_rec_type o snd) (cargs' ~~ cargs); + val recs = filter (is_rec_type o snd) (cargs' ~~ cargs); val rargs = map fst recs; val subs = map (rpair dummyT o fst) (rev (rename_wrt_term rhs rargs)); - val ((fnameTs'', fnss''), rhs') = + val (rhs', (fnameTs'', fnss'')) = (subst (map (fn ((x, y), z) => (Free x, (body_index y, Free z))) - (recs ~~ subs)) - ((fnameTs', fnss'), rhs)) + (recs ~~ subs)) rhs (fnameTs', fnss')) handle RecError s => primrec_eq_err thy s eq in (fnameTs'', fnss'', (list_abs_free (cargs' @ subs @ ls @ rs, rhs'))::fns) @@ -166,8 +165,8 @@ else let val (_, _, eqns) = the (AList.lookup (op =) rec_eqns fnameT); - val (fnameTs', fnss', fns) = foldr (trans eqns) - ((i, fnameT)::fnameTs, fnss, []) constrs + val (fnameTs', fnss', fns) = fold_rev (trans eqns) constrs + ((i, fnameT)::fnameTs, fnss, []) in (fnameTs', (i, (fname, #1 (snd (hd eqns)), fns))::fnss') end @@ -179,7 +178,7 @@ (* prepare functions needed for definitions *) -fun get_fns fns (((i : int, (tname, _, constrs)), rec_name), (fs, defs)) = +fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) = case AList.lookup (op =) fns i of NONE => let @@ -190,17 +189,17 @@ in (dummy_fns @ fs, defs) end - | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname)::defs); + | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs); (* make definition *) fun make_def thy fs (fname, ls, rec_name, tname) = let - val rhs = foldr (fn (T, t) => Abs ("", T, t)) + val rhs = fold_rev (fn T => fn t => Abs ("", T, t)) + ((map snd ls) @ [dummyT]) (list_comb (Const (rec_name, dummyT), fs @ map Bound (0 ::(length ls downto 1)))) - ((map snd ls) @ [dummyT]); val defpair = (Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def", Logic.mk_equals (Const (fname, dummyT), rhs)) in Theory.inferT_axm thy defpair end; @@ -234,7 +233,7 @@ let val (eqns, atts) = split_list eqns_atts; val dt_info = DatatypePackage.get_datatypes thy; - val rec_eqns = foldr (process_eqn thy) [] (map snd eqns); + val rec_eqns = fold_rev (process_eqn thy o snd) eqns [] ; val tnames = distinct (op =) (map (#1 o snd) rec_eqns); val dts = find_dts dt_info tnames tnames; val main_fns = @@ -247,8 +246,8 @@ primrec_err ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive") else snd (hd dts); val (fnameTs, fnss) = - foldr (process_fun thy descr rec_eqns) ([], []) main_fns; - val (fs, defs) = foldr (get_fns fnss) ([], []) (descr ~~ rec_names); + fold_rev (process_fun thy descr rec_eqns) main_fns ([], []); + val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []); val defs' = map (make_def thy fs) defs; val nameTs1 = map snd fnameTs; val nameTs2 = map fst rec_eqns;