cleaned up variable desymbolification and argument expansion
authorhaftmann
Fri, 31 Jul 2009 09:34:05 +0200
changeset 32345 4da4fa060bb6
parent 32344 55ca0df19af5
child 32346 7d84fd5ef6ee
cleaned up variable desymbolification and argument expansion
src/HOL/Tools/recfun_codegen.ML
src/Pure/Isar/code.ML
src/Tools/Code/code_preproc.ML
--- a/src/HOL/Tools/recfun_codegen.ML	Thu Jul 30 15:21:31 2009 +0200
+++ b/src/HOL/Tools/recfun_codegen.ML	Fri Jul 31 09:34:05 2009 +0200
@@ -47,7 +47,7 @@
     val thms = Code.these_eqns thy c'
       |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE)
       |> expand_eta thy
-      |> Code.norm_varnames thy
+      |> Code.desymbolize_all_vars thy
       |> map (rpair opt_name)
   in if null thms then NONE else SOME thms end;
 
--- a/src/Pure/Isar/code.ML	Thu Jul 30 15:21:31 2009 +0200
+++ b/src/Pure/Isar/code.ML	Fri Jul 31 09:34:05 2009 +0200
@@ -34,8 +34,7 @@
   val const_typ_eqn: theory -> thm -> string * typ
   val typscheme_eqn: theory -> thm -> (string * sort) list * typ
   val expand_eta: theory -> int -> thm -> thm
-  val norm_args: theory -> thm list -> thm list 
-  val norm_varnames: theory -> thm list -> thm list
+  val desymbolize_all_vars: theory -> thm list -> thm list
 
   (*executable code*)
   val add_datatype: (string * typ) list -> theory -> theory
@@ -135,104 +134,41 @@
     val l = if k = ~1
       then (length o fst o strip_abs) rhs
       else Int.max (0, k - length args);
-    val used = Name.make_context (map (fst o fst) (Term.add_vars lhs []));
-    fun get_name _ 0 = pair []
-      | get_name (Abs (v, ty, t)) k =
-          Name.variants [v]
-          ##>> get_name t (k - 1)
-          #>> (fn ([v'], vs') => (v', ty) :: vs')
-      | get_name t k = 
-          let
-            val (tys, _) = (strip_type o fastype_of) t
-          in case tys
-           of [] => raise TERM ("expand_eta", [t])
-            | ty :: _ =>
-                Name.variants [""]
-                #-> (fn [v] => get_name (t $ Var ((v, 0), ty)) (k - 1)
-                #>> (fn vs' => (v, ty) :: vs'))
-          end;
-    val (vs, _) = get_name rhs l used;
+    val (raw_vars, _) = Term.strip_abs_eta l rhs;
+    val vars = burrow_fst (Name.variant_list (map (fst o fst) (Term.add_vars lhs [])))
+      raw_vars;
     fun expand (v, ty) thm = Drule.fun_cong_rule thm
       (Thm.cterm_of thy (Var ((v, 0), ty)));
   in
     thm
-    |> fold expand vs
+    |> fold expand vars
     |> Conv.fconv_rule Drule.beta_eta_conversion
   end;
 
-fun norm_args thy thms =
-  let
-    val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
-    val k = fold (curry Int.max o num_args_of o Thm.prop_of) thms 0;
-  in
-    thms
-    |> map (expand_eta thy k)
-    |> map (Conv.fconv_rule Drule.beta_eta_conversion)
-  end;
-
-fun canonical_tvars thy thm =
-  let
-    val ctyp = Thm.ctyp_of thy;
-    val purify_tvar = unprefix "'" #> Name.desymbolize false #> prefix "'";
-    fun tvars_subst_for thm = (fold_types o fold_atyps)
-      (fn TVar (v_i as (v, _), sort) => let
-            val v' = purify_tvar v
-          in if v = v' then I
-          else insert (op =) (v_i, (v', sort)) end
-        | _ => I) (prop_of thm) [];
-    fun mk_inst (v_i, (v', sort)) (maxidx, acc) =
-      let
-        val ty = TVar (v_i, sort)
-      in
-        (maxidx + 1, (ctyp ty, ctyp (TVar ((v', maxidx), sort))) :: acc)
-      end;
-    val maxidx = Thm.maxidx_of thm + 1;
-    val (_, inst) = fold mk_inst (tvars_subst_for thm) (maxidx + 1, []);
-  in Thm.instantiate (inst, []) thm end;
-
-fun canonical_vars thy thm =
+fun mk_desymbolization pre post cert vs =
   let
-    val cterm = Thm.cterm_of thy;
-    val purify_var = Name.desymbolize false;
-    fun vars_subst_for thm = fold_aterms
-      (fn Var (v_i as (v, _), ty) => let
-            val v' = purify_var v
-          in if v = v' then I
-          else insert (op =) (v_i, (v', ty)) end
-        | _ => I) (prop_of thm) [];
-    fun mk_inst (v_i as (v, i), (v', ty)) (maxidx, acc) =
-      let
-        val t = Var (v_i, ty)
-      in
-        (maxidx + 1, (cterm t, cterm (Var ((v', maxidx), ty))) :: acc)
-      end;
-    val maxidx = Thm.maxidx_of thm + 1;
-    val (_, inst) = fold mk_inst (vars_subst_for thm) (maxidx + 1, []);
-  in Thm.instantiate ([], inst) thm end;
+    val names = map (pre o fst o fst) vs
+      |> map (Name.desymbolize false)
+      |> Name.variant_list []
+      |> map post;
+    val subst_map = map_filter (fn (((v, i), x), v') =>
+      if v = v' andalso i = 0 then NONE
+      else SOME (((v, i), x), ((v', 0), x))) (vs ~~ names);
+  in (map o pairself) cert subst_map end;
 
-fun canonical_absvars thm =
+fun desymbolize_tvars thy thms =
   let
-    val t = Thm.plain_prop_of thm;
-    val purify_var = Name.desymbolize false;
-    val t' = Term.map_abs_vars purify_var t;
-  in Thm.rename_boundvars t t' thm end;
+    val tvs = fold (Term.add_tvars o Thm.prop_of) thms [];
+    val tvar_subst = mk_desymbolization (unprefix "'") (prefix "'") (Thm.ctyp_of thy o TVar) tvs;
+  in map (Thm.instantiate (tvar_subst, [])) thms end;
 
-fun norm_varnames thy thms =
+fun desymbolize_vars thy thm =
   let
-    fun burrow_thms f [] = []
-      | burrow_thms f thms =
-          thms
-          |> Conjunction.intr_balanced
-          |> f
-          |> Conjunction.elim_balanced (length thms)
-  in
-    thms
-    |> map (canonical_vars thy)
-    |> map canonical_absvars
-    |> map Drule.zero_var_indexes
-    |> burrow_thms (canonical_tvars thy)
-    |> Drule.zero_var_indexes_list
-  end;
+    val vs = Term.add_vars (Thm.prop_of thm) [];
+    val var_subst = mk_desymbolization I I (Thm.cterm_of thy o Var) vs;
+  in Thm.instantiate ([], var_subst) thm end;
+
+fun desymbolize_all_vars thy = desymbolize_tvars thy #> map (desymbolize_vars thy);
 
 
 (** data store **)
--- a/src/Tools/Code/code_preproc.ML	Thu Jul 30 15:21:31 2009 +0200
+++ b/src/Tools/Code/code_preproc.ML	Fri Jul 31 09:34:05 2009 +0200
@@ -129,6 +129,12 @@
   #> Logic.dest_equals
   #> snd;
 
+fun same_arity thy thms =
+  let
+    val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
+    val k = fold (curry Int.max o num_args_of o Thm.prop_of) thms 0;
+  in map (Code.expand_eta thy k) thms end;
+
 fun preprocess thy c eqns =
   let
     val pre = (Simplifier.theory_context thy o #pre o the_thmproc) thy;
@@ -140,8 +146,8 @@
     |> (map o apfst) (rewrite_eqn pre)
     |> (map o apfst) (AxClass.unoverload thy)
     |> map (Code.assert_eqn thy)
-    |> burrow_fst (Code.norm_args thy)
-    |> burrow_fst (Code.norm_varnames thy)
+    |> burrow_fst (same_arity thy)
+    |> burrow_fst (Code.desymbolize_all_vars thy)
   end;
 
 fun preprocess_conv thy ct =