moved 'trivial classes' to foundation of code generator
authorhaftmann
Thu, 24 Apr 2008 16:53:04 +0200
changeset 26747 f32fa5f5bdd1
parent 26746 b010007e9d31
child 26748 4d51ddd6aa5c
moved 'trivial classes' to foundation of code generator
src/HOL/HOL.thy
src/Pure/Isar/code_unit.ML
src/Tools/nbe.ML
--- a/src/HOL/HOL.thy	Thu Apr 24 11:38:42 2008 +0200
+++ b/src/HOL/HOL.thy	Thu Apr 24 16:53:04 2008 +0200
@@ -1699,7 +1699,7 @@
   CodeUnit.add_const_alias @{thm equals_eq}
   #> CodeName.setup
   #> CodeTarget.setup
-  #> Nbe.setup @{sort eq} [(@{const_name eq_class.eq}, @{const_name "op ="})]
+  #> Nbe.setup
 *}
 
 lemma [code func]:
--- a/src/Pure/Isar/code_unit.ML	Thu Apr 24 11:38:42 2008 +0200
+++ b/src/Pure/Isar/code_unit.ML	Thu Apr 24 16:53:04 2008 +0200
@@ -17,9 +17,13 @@
   val inst_thm: sort Vartab.table -> thm -> thm
   val constrain_thm: sort -> thm -> thm
 
-  (*constants*)
+  (*constant aliasses*)
   val add_const_alias: thm -> theory -> theory
   val subst_alias: theory -> string -> string
+  val resubst_alias: theory -> string -> string
+  val triv_classes: theory -> class list
+
+  (*constants*)
   val string_of_typ: theory -> typ -> string
   val string_of_const: theory -> string -> string
   val no_args: theory -> string -> int
@@ -213,11 +217,13 @@
 
 structure ConstAlias = TheoryDataFun
 (
-  type T = ((string * string) * thm) list;
-  val empty = [];
+  type T = ((string * string) * thm) list * class list;
+  val empty = ([], []);
   val copy = I;
   val extend = I;
-  fun merge _ = Library.merge (eq_snd Thm.eq_thm_prop);
+  fun merge _ ((alias1, classes1), (alias2, classes2)) =
+    (Library.merge (eq_snd Thm.eq_thm_prop) (alias1, alias2),
+      Library.merge (op =) (classes1, classes2));
 );
 
 fun add_const_alias thm =
@@ -230,17 +236,35 @@
     val c_c' = case try (pairself (AxClass.unoverload_const thy o dest_Const)) lhs_rhs
      of SOME c_c' => c_c'
       | _ => error ("Not an equation with two constants: " ^ Display.string_of_thm thm);
-  in ConstAlias.map (cons (c_c', thm)) end;
+    val some_class = the_list (AxClass.class_of_param thy (snd c_c'));
+  in
+    ConstAlias.map (fn (alias, classes) =>
+      ((c_c', thm) :: alias, fold (insert (op =)) some_class classes))
+  end;
 
 fun rew_alias thm =
   let
     val thy = Thm.theory_of_thm thm;
-  in rewrite_head (map snd (ConstAlias.get thy)) thm end;
+  in rewrite_head ((map snd o fst o ConstAlias.get) thy) thm end;
 
 fun subst_alias thy c = ConstAlias.get thy
+  |> fst
   |> get_first (fn ((c', c''), _) => if c = c' then SOME c'' else NONE)
   |> the_default c;
 
+fun resubst_alias thy =
+  let
+    val alias = fst (ConstAlias.get thy);
+    val subst_inst_param = Option.map fst o AxClass.inst_of_param thy;
+    fun subst_alias c =
+      get_first (fn ((c', c''), _) => if c = c'' then SOME c' else NONE) alias;
+  in
+    perhaps subst_inst_param
+    #> perhaps subst_alias
+  end;
+
+val triv_classes = snd o ConstAlias.get;
+
 (* reading constants as terms *)
 
 fun check_bare_const thy t = case try dest_Const t
--- a/src/Tools/nbe.ML	Thu Apr 24 11:38:42 2008 +0200
+++ b/src/Tools/nbe.ML	Thu Apr 24 16:53:04 2008 +0200
@@ -23,7 +23,7 @@
   val univs_ref: (unit -> Univ list -> Univ list) option ref
   val trace: bool ref
 
-  val setup: class list -> (string * string) list -> theory -> theory
+  val setup: theory -> theory
 end;
 
 structure Nbe: NBE =
@@ -373,41 +373,13 @@
     |> term_of_univ thy idx_tab
   end;
 
-(* trivial type classes *)
-
-structure Nbe_Triv_Classes = TheoryDataFun
-(
-  type T = class list * (string * string) list;
-  val empty = ([], []);
-  val copy = I;
-  val extend = I;
-  fun merge _ ((classes1, consts1), (classes2, consts2)) =
-    (Library.merge (op =) (classes1, classes2), Library.merge (op =) (consts1, consts2));
-)
-
-fun add_triv_classes thy =
-  let
-    val (trivs, _) = Nbe_Triv_Classes.get thy;
-    val inters = curry (Sorts.inter_sort (Sign.classes_of thy)) trivs;
-    fun map_sorts f = (map_types o map_atyps)
-      (fn TVar (v, sort) => TVar (v, f sort)
-        | TFree (v, sort) => TFree (v, f sort));
-  in map_sorts inters end;
-
-fun subst_triv_consts thy =
-  let
-    fun subst_const f = map_aterms (fn t as Term.Const (c, ty) => (case f c
-         of SOME c' => Term.Const (c', ty)
-          | NONE => t)
-      | t => t);
-    val (_, consts) = Nbe_Triv_Classes.get thy;
-    val subst_inst = perhaps (Option.map fst o AxClass.inst_of_param thy);
-  in map_aterms (subst_const (AList.lookup (op =) consts o subst_inst)) end;
-
 (* evaluation with type reconstruction *)
 
 fun eval thy t code vs_ty_t deps =
   let
+    fun subst_const f = map_aterms (fn t as Term.Const (c, ty) => Term.Const (f c, ty)
+      | t => t);
+    val subst_triv_consts = subst_const (CodeUnit.resubst_alias thy);
     val ty = type_of t;
     val type_free = AList.lookup (op =)
       (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []));
@@ -425,7 +397,7 @@
   in
     compile_eval thy code vs_ty_t deps
     |> tracing (fn t => "Normalized:\n" ^ string_of_term t)
-    |> subst_triv_consts thy
+    |> subst_triv_consts
     |> type_frees
     |> tracing (fn t => "Vars typed:\n" ^ string_of_term t)
     |> type_infer
@@ -446,6 +418,15 @@
   Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (t, code, vs_ty_t, deps));
   (*FIXME get rid of hardwired theory name*)
 
+fun add_triv_classes thy =
+  let
+    val inters = curry (Sorts.inter_sort (Sign.classes_of thy))
+      (CodeUnit.triv_classes thy);
+    fun map_sorts f = (map_types o map_atyps)
+      (fn TVar (v, sort) => TVar (v, f sort)
+        | TFree (v, sort) => TFree (v, f sort));
+  in map_sorts inters end;
+
 fun norm_conv ct =
   let
     val thy = Thm.theory_of_cterm ct;
@@ -478,9 +459,7 @@
   let val ctxt = Toplevel.context_of state
   in norm_print_term ctxt modes (Syntax.read_term ctxt s) end;
 
-fun setup nbe_classes nbe_consts =
-  Theory.add_oracle ("norm", norm_oracle)
-  #> Nbe_Triv_Classes.map (K (nbe_classes, nbe_consts));
+val setup = Theory.add_oracle ("norm", norm_oracle);
 
 local structure P = OuterParse and K = OuterKeyword in