moved all technical processing of code equations to code_thingol.ML
authorhaftmann
Mon, 10 Aug 2009 12:24:49 +0200
changeset 32353 0ac26087464b
parent 32352 4839a704939a
child 32354 bb40e900e1f3
moved all technical processing of code equations to code_thingol.ML
src/HOL/Tools/recfun_codegen.ML
src/Tools/Code/code_preproc.ML
src/Tools/Code/code_thingol.ML
--- a/src/HOL/Tools/recfun_codegen.ML	Mon Aug 10 12:24:47 2009 +0200
+++ b/src/HOL/Tools/recfun_codegen.ML	Mon Aug 10 12:24:49 2009 +0200
@@ -31,25 +31,18 @@
     |> ModuleData.map (Symtab.update (fst (Code.const_typ_eqn thy thm'), module_name))
   end;
 
-fun expand_eta thy [] = []
-  | expand_eta thy (thms as thm :: _) =
-      let
-        val (_, ty) = Code.const_typ_eqn thy thm;
-      in if null (Term.add_tvarsT ty []) orelse (null o fst o strip_type) ty
-        then thms
-        else map (Code.expand_eta thy 1) thms
-      end;
-
 fun retrieve_equations thy (c, T) = if c = @{const_name "op ="} then NONE else
   let
     val c' = AxClass.unoverload_const thy (c, T);
     val opt_name = Symtab.lookup (ModuleData.get thy) c';
-    val thms = Code.these_eqns thy c'
-      |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE)
-      |> expand_eta thy
-      |> Code.desymbolize_all_vars thy
-      |> map (rpair opt_name)
-  in if null thms then NONE else SOME thms end;
+    val raw_thms = Code.these_eqns thy c'
+      |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE);
+  in if null raw_thms then NONE else
+    raw_thms
+    |> Code_Thingol.clean_thms thy (snd (Code.const_typ_eqn thy (hd raw_thms)))
+    |> map (rpair opt_name)
+    |> SOME
+  end;
 
 val dest_eqn = Logic.dest_equals;
 val const_of = dest_Const o head_of o fst o dest_eqn;
--- a/src/Tools/Code/code_preproc.ML	Mon Aug 10 12:24:47 2009 +0200
+++ b/src/Tools/Code/code_preproc.ML	Mon Aug 10 12:24:49 2009 +0200
@@ -132,12 +132,6 @@
   #> Logic.dest_equals
   #> snd;
 
-fun same_arity thy thms =
-  let
-    val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
-    val k = fold (curry Int.max o num_args_of o Thm.prop_of) thms 0;
-  in map (Code.expand_eta thy k) thms end;
-
 fun preprocess thy c eqns =
   let
     val pre = (Simplifier.theory_context thy o #pre o the_thmproc) thy;
@@ -149,7 +143,6 @@
     |> (map o apfst) (rewrite_eqn pre)
     |> (map o apfst) (AxClass.unoverload thy)
     |> map (Code.assert_eqn thy)
-    |> burrow_fst (same_arity thy)
   end;
 
 fun preprocess_conv thy ct =
--- a/src/Tools/Code/code_thingol.ML	Mon Aug 10 12:24:47 2009 +0200
+++ b/src/Tools/Code/code_thingol.ML	Mon Aug 10 12:24:49 2009 +0200
@@ -79,6 +79,7 @@
   val is_cons: program -> string -> bool
   val contr_classparam_typs: program -> string -> itype option list
 
+  val clean_thms: theory -> typ -> thm list -> thm list
   val read_const_exprs: theory -> string list -> string list * string list
   val consts_program: theory -> string list -> string list * (naming * program)
   val cached_program: theory -> naming * program
@@ -376,6 +377,67 @@
 end; (* local *)
 
 
+(** technical transformations of code equations **)
+
+fun expand_eta thy k thm =
+  let
+    val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm;
+    val (head, args) = strip_comb lhs;
+    val l = if k = ~1
+      then (length o fst o strip_abs) rhs
+      else Int.max (0, k - length args);
+    val (raw_vars, _) = Term.strip_abs_eta l rhs;
+    val vars = burrow_fst (Name.variant_list (map (fst o fst) (Term.add_vars lhs [])))
+      raw_vars;
+    fun expand (v, ty) thm = Drule.fun_cong_rule thm
+      (Thm.cterm_of thy (Var ((v, 0), ty)));
+  in
+    thm
+    |> fold expand vars
+    |> Conv.fconv_rule Drule.beta_eta_conversion
+  end;
+
+fun same_arity thy thms =
+  let
+    val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
+    val k = fold (curry Int.max o num_args_of o Thm.prop_of) thms 0;
+  in map (expand_eta thy k) thms end;
+
+fun avoid_value thy ty [thm] =
+      if null (Term.add_tfreesT ty []) orelse (null o fst o strip_type) ty
+      then [thm]
+      else [expand_eta thy 1 thm]
+  | avoid_value thy _ thms = thms;
+
+fun mk_desymbolization pre post mk vs =
+  let
+    val names = map (pre o fst o fst) vs
+      |> map (Name.desymbolize false)
+      |> Name.variant_list []
+      |> map post;
+  in map_filter (fn (((v, i), x), v') =>
+    if v = v' andalso i = 0 then NONE
+    else SOME (((v, i), x), mk ((v', 0), x))) (vs ~~ names)
+  end;
+
+fun desymbolize_tvars thy thms =
+  let
+    val tvs = fold (Term.add_tvars o Thm.prop_of) thms [];
+    val tvar_subst = mk_desymbolization (unprefix "'") (prefix "'") TVar tvs;
+  in map (Thm.certify_instantiate (tvar_subst, [])) thms end;
+
+fun desymbolize_vars thy thm =
+  let
+    val vs = Term.add_vars (Thm.prop_of thm) [];
+    val var_subst = mk_desymbolization I I Var vs;
+  in Thm.certify_instantiate ([], var_subst) thm end;
+
+fun desymbolize_all_vars thy = desymbolize_tvars thy #> map (desymbolize_vars thy);
+
+fun clean_thms thy ty =
+  same_arity thy #> avoid_value thy ty #> desymbolize_all_vars thy;
+
+
 (** statements, abstract programs **)
 
 type typscheme = (vname * sort) list * itype;
@@ -498,17 +560,11 @@
     fun stmt_classparam class =
       ensure_class thy algbr funcgr class
       #>> (fn class => Classparam (c, class));
-    fun stmt_fun ((vs, ty), raw_eqns) =
-      let
-        val eqns = if null (Term.add_tfreesT ty []) orelse (null o fst o strip_type) ty
-          then raw_eqns
-          else (map o apfst) (Code.expand_eta thy 1) raw_eqns;
-      in
-        fold_map (translate_tyvar_sort thy algbr funcgr) vs
-        ##>> translate_typ thy algbr funcgr ty
-        ##>> translate_eqns thy algbr funcgr eqns
-        #>> (fn info => Fun (c, info))
-      end;
+    fun stmt_fun ((vs, ty), eqns) =
+      fold_map (translate_tyvar_sort thy algbr funcgr) vs
+      ##>> translate_typ thy algbr funcgr ty
+      ##>> fold_map (translate_eqn thy algbr funcgr) (burrow_fst (clean_thms thy ty) eqns)
+      #>> (fn info => Fun (c, info));
     val stmt_const = case Code.get_datatype_of_constr thy c
      of SOME tyco => stmt_datatypecons tyco
       | NONE => (case AxClass.class_of_param thy c
@@ -597,9 +653,6 @@
             translate_term thy algbr funcgr thm t'
             ##>> fold_map (translate_term thy algbr funcgr thm) ts
             #>> (fn (t, ts) => t `$$ ts)
-and translate_eqns thy algbr funcgr eqns =
-  fold_map (translate_eqn thy algbr funcgr)
-    (burrow_fst (Code.desymbolize_all_vars thy) eqns)
 and translate_eqn thy algbr funcgr (thm, proper) =
   let
     val (args, rhs) = (apfst (snd o strip_comb) o Logic.dest_equals