--- a/src/HOL/Tools/recfun_codegen.ML Mon Aug 10 12:24:47 2009 +0200
+++ b/src/HOL/Tools/recfun_codegen.ML Mon Aug 10 12:24:49 2009 +0200
@@ -31,25 +31,18 @@
|> ModuleData.map (Symtab.update (fst (Code.const_typ_eqn thy thm'), module_name))
end;
-fun expand_eta thy [] = []
- | expand_eta thy (thms as thm :: _) =
- let
- val (_, ty) = Code.const_typ_eqn thy thm;
- in if null (Term.add_tvarsT ty []) orelse (null o fst o strip_type) ty
- then thms
- else map (Code.expand_eta thy 1) thms
- end;
-
fun retrieve_equations thy (c, T) = if c = @{const_name "op ="} then NONE else
let
val c' = AxClass.unoverload_const thy (c, T);
val opt_name = Symtab.lookup (ModuleData.get thy) c';
- val thms = Code.these_eqns thy c'
- |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE)
- |> expand_eta thy
- |> Code.desymbolize_all_vars thy
- |> map (rpair opt_name)
- in if null thms then NONE else SOME thms end;
+ val raw_thms = Code.these_eqns thy c'
+ |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE);
+ in if null raw_thms then NONE else
+ raw_thms
+ |> Code_Thingol.clean_thms thy (snd (Code.const_typ_eqn thy (hd raw_thms)))
+ |> map (rpair opt_name)
+ |> SOME
+ end;
val dest_eqn = Logic.dest_equals;
val const_of = dest_Const o head_of o fst o dest_eqn;
--- a/src/Tools/Code/code_preproc.ML Mon Aug 10 12:24:47 2009 +0200
+++ b/src/Tools/Code/code_preproc.ML Mon Aug 10 12:24:49 2009 +0200
@@ -132,12 +132,6 @@
#> Logic.dest_equals
#> snd;
-fun same_arity thy thms =
- let
- val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
- val k = fold (curry Int.max o num_args_of o Thm.prop_of) thms 0;
- in map (Code.expand_eta thy k) thms end;
-
fun preprocess thy c eqns =
let
val pre = (Simplifier.theory_context thy o #pre o the_thmproc) thy;
@@ -149,7 +143,6 @@
|> (map o apfst) (rewrite_eqn pre)
|> (map o apfst) (AxClass.unoverload thy)
|> map (Code.assert_eqn thy)
- |> burrow_fst (same_arity thy)
end;
fun preprocess_conv thy ct =
--- a/src/Tools/Code/code_thingol.ML Mon Aug 10 12:24:47 2009 +0200
+++ b/src/Tools/Code/code_thingol.ML Mon Aug 10 12:24:49 2009 +0200
@@ -79,6 +79,7 @@
val is_cons: program -> string -> bool
val contr_classparam_typs: program -> string -> itype option list
+ val clean_thms: theory -> typ -> thm list -> thm list
val read_const_exprs: theory -> string list -> string list * string list
val consts_program: theory -> string list -> string list * (naming * program)
val cached_program: theory -> naming * program
@@ -376,6 +377,67 @@
end; (* local *)
+(** technical transformations of code equations **)
+
+fun expand_eta thy k thm =
+ let
+ val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm;
+ val (head, args) = strip_comb lhs;
+ val l = if k = ~1
+ then (length o fst o strip_abs) rhs
+ else Int.max (0, k - length args);
+ val (raw_vars, _) = Term.strip_abs_eta l rhs;
+ val vars = burrow_fst (Name.variant_list (map (fst o fst) (Term.add_vars lhs [])))
+ raw_vars;
+ fun expand (v, ty) thm = Drule.fun_cong_rule thm
+ (Thm.cterm_of thy (Var ((v, 0), ty)));
+ in
+ thm
+ |> fold expand vars
+ |> Conv.fconv_rule Drule.beta_eta_conversion
+ end;
+
+fun same_arity thy thms =
+ let
+ val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
+ val k = fold (curry Int.max o num_args_of o Thm.prop_of) thms 0;
+ in map (expand_eta thy k) thms end;
+
+fun avoid_value thy ty [thm] =
+ if null (Term.add_tfreesT ty []) orelse (null o fst o strip_type) ty
+ then [thm]
+ else [expand_eta thy 1 thm]
+ | avoid_value thy _ thms = thms;
+
+fun mk_desymbolization pre post mk vs =
+ let
+ val names = map (pre o fst o fst) vs
+ |> map (Name.desymbolize false)
+ |> Name.variant_list []
+ |> map post;
+ in map_filter (fn (((v, i), x), v') =>
+ if v = v' andalso i = 0 then NONE
+ else SOME (((v, i), x), mk ((v', 0), x))) (vs ~~ names)
+ end;
+
+fun desymbolize_tvars thy thms =
+ let
+ val tvs = fold (Term.add_tvars o Thm.prop_of) thms [];
+ val tvar_subst = mk_desymbolization (unprefix "'") (prefix "'") TVar tvs;
+ in map (Thm.certify_instantiate (tvar_subst, [])) thms end;
+
+fun desymbolize_vars thy thm =
+ let
+ val vs = Term.add_vars (Thm.prop_of thm) [];
+ val var_subst = mk_desymbolization I I Var vs;
+ in Thm.certify_instantiate ([], var_subst) thm end;
+
+fun desymbolize_all_vars thy = desymbolize_tvars thy #> map (desymbolize_vars thy);
+
+fun clean_thms thy ty =
+ same_arity thy #> avoid_value thy ty #> desymbolize_all_vars thy;
+
+
(** statements, abstract programs **)
type typscheme = (vname * sort) list * itype;
@@ -498,17 +560,11 @@
fun stmt_classparam class =
ensure_class thy algbr funcgr class
#>> (fn class => Classparam (c, class));
- fun stmt_fun ((vs, ty), raw_eqns) =
- let
- val eqns = if null (Term.add_tfreesT ty []) orelse (null o fst o strip_type) ty
- then raw_eqns
- else (map o apfst) (Code.expand_eta thy 1) raw_eqns;
- in
- fold_map (translate_tyvar_sort thy algbr funcgr) vs
- ##>> translate_typ thy algbr funcgr ty
- ##>> translate_eqns thy algbr funcgr eqns
- #>> (fn info => Fun (c, info))
- end;
+ fun stmt_fun ((vs, ty), eqns) =
+ fold_map (translate_tyvar_sort thy algbr funcgr) vs
+ ##>> translate_typ thy algbr funcgr ty
+ ##>> fold_map (translate_eqn thy algbr funcgr) (burrow_fst (clean_thms thy ty) eqns)
+ #>> (fn info => Fun (c, info));
val stmt_const = case Code.get_datatype_of_constr thy c
of SOME tyco => stmt_datatypecons tyco
| NONE => (case AxClass.class_of_param thy c
@@ -597,9 +653,6 @@
translate_term thy algbr funcgr thm t'
##>> fold_map (translate_term thy algbr funcgr thm) ts
#>> (fn (t, ts) => t `$$ ts)
-and translate_eqns thy algbr funcgr eqns =
- fold_map (translate_eqn thy algbr funcgr)
- (burrow_fst (Code.desymbolize_all_vars thy) eqns)
and translate_eqn thy algbr funcgr (thm, proper) =
let
val (args, rhs) = (apfst (snd o strip_comb) o Logic.dest_equals