proper eta expansion in recfun_codegen.ML; no eta expansion at all in code_thingol.ML
authorhaftmann
Tue Aug 11 10:43:43 2009 +0200 (2009-08-11)
changeset 3235898c00ee9e786
parent 32357 84a6d701e36f
child 32359 bc1e123295f5
proper eta expansion in recfun_codegen.ML; no eta expansion at all in code_thingol.ML
src/HOL/Tools/recfun_codegen.ML
src/Tools/Code/code_thingol.ML
     1.1 --- a/src/HOL/Tools/recfun_codegen.ML	Tue Aug 11 10:05:53 2009 +0200
     1.2 +++ b/src/HOL/Tools/recfun_codegen.ML	Tue Aug 11 10:43:43 2009 +0200
     1.3 @@ -14,6 +14,8 @@
     1.4  
     1.5  open Codegen;
     1.6  
     1.7 +val const_of = dest_Const o head_of o fst o Logic.dest_equals;
     1.8 +
     1.9  structure ModuleData = TheoryDataFun
    1.10  (
    1.11    type T = string Symtab.table;
    1.12 @@ -31,36 +33,32 @@
    1.13      |> ModuleData.map (Symtab.update (fst (Code.const_typ_eqn thy thm'), module_name))
    1.14    end;
    1.15  
    1.16 -fun retrieve_equations thy (c, T) = if c = @{const_name "op ="} then NONE else
    1.17 -  let
    1.18 -    val c' = AxClass.unoverload_const thy (c, T);
    1.19 -    val opt_name = Symtab.lookup (ModuleData.get thy) c';
    1.20 -    val raw_thms = Code.these_eqns thy c'
    1.21 -      |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE);
    1.22 -  in if null raw_thms then NONE else
    1.23 -    raw_thms
    1.24 -    |> Code_Thingol.clean_thms thy (snd (Code.const_typ_eqn thy (hd raw_thms)))
    1.25 -    |> map (rpair opt_name)
    1.26 -    |> SOME
    1.27 -  end;
    1.28 +fun avoid_value thy [thm] =
    1.29 +      let val (_, T) = Code.const_typ_eqn thy thm
    1.30 +      in if null (Term.add_tvarsT T []) orelse (null o fst o strip_type) T
    1.31 +        then [thm]
    1.32 +        else [Code_Thingol.expand_eta thy 1 thm]
    1.33 +      end
    1.34 +  | avoid_value thy thms = thms;
    1.35  
    1.36 -val dest_eqn = Logic.dest_equals;
    1.37 -val const_of = dest_Const o head_of o fst o dest_eqn;
    1.38 -
    1.39 -fun get_equations thy defs (s, T) =
    1.40 -  (case retrieve_equations thy (s, T) of
    1.41 -     NONE => ([], "")
    1.42 -   | SOME thms => 
    1.43 -       let val thms' = filter (fn (thm, _) => is_instance T
    1.44 -           (snd (const_of (prop_of thm)))) thms
    1.45 -       in if null thms' then ([], "")
    1.46 -         else (preprocess thy (map fst thms'),
    1.47 -           case snd (snd (split_last thms')) of
    1.48 -               NONE => (case get_defn thy defs s T of
    1.49 -                   NONE => Codegen.thyname_of_const thy s
    1.50 -                 | SOME ((_, (thyname, _)), _) => thyname)
    1.51 -             | SOME thyname => thyname)
    1.52 -       end);
    1.53 +fun get_equations thy defs (raw_c, T) = if raw_c = @{const_name "op ="} then ([], "") else
    1.54 +  let
    1.55 +    val c = AxClass.unoverload_const thy (raw_c, T);
    1.56 +    val raw_thms = Code.these_eqns thy c
    1.57 +      |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE)
    1.58 +      |> filter (is_instance T o snd o const_of o prop_of);
    1.59 +    val module_name = case Symtab.lookup (ModuleData.get thy) c
    1.60 +     of SOME module_name => module_name
    1.61 +      | NONE => case get_defn thy defs c T
    1.62 +         of SOME ((_, (thyname, _)), _) => thyname
    1.63 +          | NONE => Codegen.thyname_of_const thy c;
    1.64 +  in if null raw_thms then ([], "") else
    1.65 +    raw_thms
    1.66 +    |> preprocess thy
    1.67 +    |> avoid_value thy
    1.68 +    |> Code_Thingol.clean_thms thy
    1.69 +    |> rpair module_name
    1.70 +  end;
    1.71  
    1.72  fun mk_suffix thy defs (s, T) = (case get_defn thy defs s T of
    1.73    SOME (_, SOME i) => " def" ^ string_of_int i | _ => "");
    1.74 @@ -74,7 +72,7 @@
    1.75  fun add_rec_funs thy defs dep module eqs gr =
    1.76    let
    1.77      fun dest_eq t = (fst (const_of t) ^ mk_suffix thy defs (const_of t),
    1.78 -      dest_eqn (rename_term t));
    1.79 +      Logic.dest_equals (rename_term t));
    1.80      val eqs' = map dest_eq eqs;
    1.81      val (dname, _) :: _ = eqs';
    1.82      val (s, T) = const_of (hd eqs);
     2.1 --- a/src/Tools/Code/code_thingol.ML	Tue Aug 11 10:05:53 2009 +0200
     2.2 +++ b/src/Tools/Code/code_thingol.ML	Tue Aug 11 10:43:43 2009 +0200
     2.3 @@ -79,7 +79,8 @@
     2.4    val is_cons: program -> string -> bool
     2.5    val contr_classparam_typs: program -> string -> itype option list
     2.6  
     2.7 -  val clean_thms: theory -> typ -> thm list -> thm list
     2.8 +  val expand_eta: theory -> int -> thm -> thm
     2.9 +  val clean_thms: theory -> thm list -> thm list
    2.10    val read_const_exprs: theory -> string list -> string list * string list
    2.11    val consts_program: theory -> string list -> string list * (naming * program)
    2.12    val cached_program: theory -> naming * program
    2.13 @@ -403,12 +404,6 @@
    2.14      val k = fold (curry Int.max o num_args_of o Thm.prop_of) thms 0;
    2.15    in map (expand_eta thy k) thms end;
    2.16  
    2.17 -fun avoid_value thy ty [thm] =
    2.18 -      if null (Term.add_tfreesT ty []) orelse (null o fst o strip_type) ty
    2.19 -      then [thm]
    2.20 -      else [expand_eta thy 1 thm]
    2.21 -  | avoid_value thy _ thms = thms;
    2.22 -
    2.23  fun mk_desymbolization pre post mk vs =
    2.24    let
    2.25      val names = map (pre o fst o fst) vs
    2.26 @@ -434,8 +429,7 @@
    2.27  
    2.28  fun desymbolize_all_vars thy = desymbolize_tvars thy #> map (desymbolize_vars thy);
    2.29  
    2.30 -fun clean_thms thy ty =
    2.31 -  same_arity thy #> avoid_value thy ty #> desymbolize_all_vars thy;
    2.32 +fun clean_thms thy = same_arity thy #> desymbolize_all_vars thy;
    2.33  
    2.34  
    2.35  (** statements, abstract programs **)
    2.36 @@ -563,7 +557,7 @@
    2.37      fun stmt_fun ((vs, ty), eqns) =
    2.38        fold_map (translate_tyvar_sort thy algbr funcgr) vs
    2.39        ##>> translate_typ thy algbr funcgr ty
    2.40 -      ##>> fold_map (translate_eqn thy algbr funcgr) (burrow_fst (clean_thms thy ty) eqns)
    2.41 +      ##>> fold_map (translate_eqn thy algbr funcgr) (burrow_fst (clean_thms thy) eqns)
    2.42        #>> (fn info => Fun (c, info));
    2.43      val stmt_const = case Code.get_datatype_of_constr thy c
    2.44       of SOME tyco => stmt_datatypecons tyco