proper eta expansion in recfun_codegen.ML; no eta expansion at all in code_thingol.ML
--- a/src/HOL/Tools/recfun_codegen.ML Tue Aug 11 10:05:53 2009 +0200
+++ b/src/HOL/Tools/recfun_codegen.ML Tue Aug 11 10:43:43 2009 +0200
@@ -14,6 +14,8 @@
open Codegen;
+val const_of = dest_Const o head_of o fst o Logic.dest_equals;
+
structure ModuleData = TheoryDataFun
(
type T = string Symtab.table;
@@ -31,36 +33,32 @@
|> ModuleData.map (Symtab.update (fst (Code.const_typ_eqn thy thm'), module_name))
end;
-fun retrieve_equations thy (c, T) = if c = @{const_name "op ="} then NONE else
- let
- val c' = AxClass.unoverload_const thy (c, T);
- val opt_name = Symtab.lookup (ModuleData.get thy) c';
- val raw_thms = Code.these_eqns thy c'
- |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE);
- in if null raw_thms then NONE else
- raw_thms
- |> Code_Thingol.clean_thms thy (snd (Code.const_typ_eqn thy (hd raw_thms)))
- |> map (rpair opt_name)
- |> SOME
- end;
+fun avoid_value thy [thm] =
+ let val (_, T) = Code.const_typ_eqn thy thm
+ in if null (Term.add_tvarsT T []) orelse (null o fst o strip_type) T
+ then [thm]
+ else [Code_Thingol.expand_eta thy 1 thm]
+ end
+ | avoid_value thy thms = thms;
-val dest_eqn = Logic.dest_equals;
-val const_of = dest_Const o head_of o fst o dest_eqn;
-
-fun get_equations thy defs (s, T) =
- (case retrieve_equations thy (s, T) of
- NONE => ([], "")
- | SOME thms =>
- let val thms' = filter (fn (thm, _) => is_instance T
- (snd (const_of (prop_of thm)))) thms
- in if null thms' then ([], "")
- else (preprocess thy (map fst thms'),
- case snd (snd (split_last thms')) of
- NONE => (case get_defn thy defs s T of
- NONE => Codegen.thyname_of_const thy s
- | SOME ((_, (thyname, _)), _) => thyname)
- | SOME thyname => thyname)
- end);
+fun get_equations thy defs (raw_c, T) = if raw_c = @{const_name "op ="} then ([], "") else
+ let
+ val c = AxClass.unoverload_const thy (raw_c, T);
+ val raw_thms = Code.these_eqns thy c
+ |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE)
+ |> filter (is_instance T o snd o const_of o prop_of);
+ val module_name = case Symtab.lookup (ModuleData.get thy) c
+ of SOME module_name => module_name
+ | NONE => case get_defn thy defs c T
+ of SOME ((_, (thyname, _)), _) => thyname
+ | NONE => Codegen.thyname_of_const thy c;
+ in if null raw_thms then ([], "") else
+ raw_thms
+ |> preprocess thy
+ |> avoid_value thy
+ |> Code_Thingol.clean_thms thy
+ |> rpair module_name
+ end;
fun mk_suffix thy defs (s, T) = (case get_defn thy defs s T of
SOME (_, SOME i) => " def" ^ string_of_int i | _ => "");
@@ -74,7 +72,7 @@
fun add_rec_funs thy defs dep module eqs gr =
let
fun dest_eq t = (fst (const_of t) ^ mk_suffix thy defs (const_of t),
- dest_eqn (rename_term t));
+ Logic.dest_equals (rename_term t));
val eqs' = map dest_eq eqs;
val (dname, _) :: _ = eqs';
val (s, T) = const_of (hd eqs);
--- a/src/Tools/Code/code_thingol.ML Tue Aug 11 10:05:53 2009 +0200
+++ b/src/Tools/Code/code_thingol.ML Tue Aug 11 10:43:43 2009 +0200
@@ -79,7 +79,8 @@
val is_cons: program -> string -> bool
val contr_classparam_typs: program -> string -> itype option list
- val clean_thms: theory -> typ -> thm list -> thm list
+ val expand_eta: theory -> int -> thm -> thm
+ val clean_thms: theory -> thm list -> thm list
val read_const_exprs: theory -> string list -> string list * string list
val consts_program: theory -> string list -> string list * (naming * program)
val cached_program: theory -> naming * program
@@ -403,12 +404,6 @@
val k = fold (curry Int.max o num_args_of o Thm.prop_of) thms 0;
in map (expand_eta thy k) thms end;
-fun avoid_value thy ty [thm] =
- if null (Term.add_tfreesT ty []) orelse (null o fst o strip_type) ty
- then [thm]
- else [expand_eta thy 1 thm]
- | avoid_value thy _ thms = thms;
-
fun mk_desymbolization pre post mk vs =
let
val names = map (pre o fst o fst) vs
@@ -434,8 +429,7 @@
fun desymbolize_all_vars thy = desymbolize_tvars thy #> map (desymbolize_vars thy);
-fun clean_thms thy ty =
- same_arity thy #> avoid_value thy ty #> desymbolize_all_vars thy;
+fun clean_thms thy = same_arity thy #> desymbolize_all_vars thy;
(** statements, abstract programs **)
@@ -563,7 +557,7 @@
fun stmt_fun ((vs, ty), eqns) =
fold_map (translate_tyvar_sort thy algbr funcgr) vs
##>> translate_typ thy algbr funcgr ty
- ##>> fold_map (translate_eqn thy algbr funcgr) (burrow_fst (clean_thms thy ty) eqns)
+ ##>> fold_map (translate_eqn thy algbr funcgr) (burrow_fst (clean_thms thy) eqns)
#>> (fn info => Fun (c, info));
val stmt_const = case Code.get_datatype_of_constr thy c
of SOME tyco => stmt_datatypecons tyco