proper eta expansion in recfun_codegen.ML; no eta expansion at all in code_thingol.ML
authorhaftmann
Tue, 11 Aug 2009 10:43:43 +0200
changeset 32358 98c00ee9e786
parent 32357 84a6d701e36f
child 32359 bc1e123295f5
proper eta expansion in recfun_codegen.ML; no eta expansion at all in code_thingol.ML
src/HOL/Tools/recfun_codegen.ML
src/Tools/Code/code_thingol.ML
--- a/src/HOL/Tools/recfun_codegen.ML	Tue Aug 11 10:05:53 2009 +0200
+++ b/src/HOL/Tools/recfun_codegen.ML	Tue Aug 11 10:43:43 2009 +0200
@@ -14,6 +14,8 @@
 
 open Codegen;
 
+val const_of = dest_Const o head_of o fst o Logic.dest_equals;
+
 structure ModuleData = TheoryDataFun
 (
   type T = string Symtab.table;
@@ -31,36 +33,32 @@
     |> ModuleData.map (Symtab.update (fst (Code.const_typ_eqn thy thm'), module_name))
   end;
 
-fun retrieve_equations thy (c, T) = if c = @{const_name "op ="} then NONE else
-  let
-    val c' = AxClass.unoverload_const thy (c, T);
-    val opt_name = Symtab.lookup (ModuleData.get thy) c';
-    val raw_thms = Code.these_eqns thy c'
-      |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE);
-  in if null raw_thms then NONE else
-    raw_thms
-    |> Code_Thingol.clean_thms thy (snd (Code.const_typ_eqn thy (hd raw_thms)))
-    |> map (rpair opt_name)
-    |> SOME
-  end;
+fun avoid_value thy [thm] =
+      let val (_, T) = Code.const_typ_eqn thy thm
+      in if null (Term.add_tvarsT T []) orelse (null o fst o strip_type) T
+        then [thm]
+        else [Code_Thingol.expand_eta thy 1 thm]
+      end
+  | avoid_value thy thms = thms;
 
-val dest_eqn = Logic.dest_equals;
-val const_of = dest_Const o head_of o fst o dest_eqn;
-
-fun get_equations thy defs (s, T) =
-  (case retrieve_equations thy (s, T) of
-     NONE => ([], "")
-   | SOME thms => 
-       let val thms' = filter (fn (thm, _) => is_instance T
-           (snd (const_of (prop_of thm)))) thms
-       in if null thms' then ([], "")
-         else (preprocess thy (map fst thms'),
-           case snd (snd (split_last thms')) of
-               NONE => (case get_defn thy defs s T of
-                   NONE => Codegen.thyname_of_const thy s
-                 | SOME ((_, (thyname, _)), _) => thyname)
-             | SOME thyname => thyname)
-       end);
+fun get_equations thy defs (raw_c, T) = if raw_c = @{const_name "op ="} then ([], "") else
+  let
+    val c = AxClass.unoverload_const thy (raw_c, T);
+    val raw_thms = Code.these_eqns thy c
+      |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE)
+      |> filter (is_instance T o snd o const_of o prop_of);
+    val module_name = case Symtab.lookup (ModuleData.get thy) c
+     of SOME module_name => module_name
+      | NONE => case get_defn thy defs c T
+         of SOME ((_, (thyname, _)), _) => thyname
+          | NONE => Codegen.thyname_of_const thy c;
+  in if null raw_thms then ([], "") else
+    raw_thms
+    |> preprocess thy
+    |> avoid_value thy
+    |> Code_Thingol.clean_thms thy
+    |> rpair module_name
+  end;
 
 fun mk_suffix thy defs (s, T) = (case get_defn thy defs s T of
   SOME (_, SOME i) => " def" ^ string_of_int i | _ => "");
@@ -74,7 +72,7 @@
 fun add_rec_funs thy defs dep module eqs gr =
   let
     fun dest_eq t = (fst (const_of t) ^ mk_suffix thy defs (const_of t),
-      dest_eqn (rename_term t));
+      Logic.dest_equals (rename_term t));
     val eqs' = map dest_eq eqs;
     val (dname, _) :: _ = eqs';
     val (s, T) = const_of (hd eqs);
--- a/src/Tools/Code/code_thingol.ML	Tue Aug 11 10:05:53 2009 +0200
+++ b/src/Tools/Code/code_thingol.ML	Tue Aug 11 10:43:43 2009 +0200
@@ -79,7 +79,8 @@
   val is_cons: program -> string -> bool
   val contr_classparam_typs: program -> string -> itype option list
 
-  val clean_thms: theory -> typ -> thm list -> thm list
+  val expand_eta: theory -> int -> thm -> thm
+  val clean_thms: theory -> thm list -> thm list
   val read_const_exprs: theory -> string list -> string list * string list
   val consts_program: theory -> string list -> string list * (naming * program)
   val cached_program: theory -> naming * program
@@ -403,12 +404,6 @@
     val k = fold (curry Int.max o num_args_of o Thm.prop_of) thms 0;
   in map (expand_eta thy k) thms end;
 
-fun avoid_value thy ty [thm] =
-      if null (Term.add_tfreesT ty []) orelse (null o fst o strip_type) ty
-      then [thm]
-      else [expand_eta thy 1 thm]
-  | avoid_value thy _ thms = thms;
-
 fun mk_desymbolization pre post mk vs =
   let
     val names = map (pre o fst o fst) vs
@@ -434,8 +429,7 @@
 
 fun desymbolize_all_vars thy = desymbolize_tvars thy #> map (desymbolize_vars thy);
 
-fun clean_thms thy ty =
-  same_arity thy #> avoid_value thy ty #> desymbolize_all_vars thy;
+fun clean_thms thy = same_arity thy #> desymbolize_all_vars thy;
 
 
 (** statements, abstract programs **)
@@ -563,7 +557,7 @@
     fun stmt_fun ((vs, ty), eqns) =
       fold_map (translate_tyvar_sort thy algbr funcgr) vs
       ##>> translate_typ thy algbr funcgr ty
-      ##>> fold_map (translate_eqn thy algbr funcgr) (burrow_fst (clean_thms thy ty) eqns)
+      ##>> fold_map (translate_eqn thy algbr funcgr) (burrow_fst (clean_thms thy) eqns)
       #>> (fn info => Fun (c, info));
     val stmt_const = case Code.get_datatype_of_constr thy c
      of SOME tyco => stmt_datatypecons tyco