different handling of eq class for nbe
authorhaftmann
Tue, 22 Apr 2008 22:00:25 +0200
changeset 26739 947b6013e863
parent 26738 615e1a86787b
child 26740 6c8cd101f875
different handling of eq class for nbe
src/HOL/HOL.thy
src/HOL/ex/NormalForm.thy
src/Tools/nbe.ML
--- a/src/HOL/HOL.thy	Tue Apr 22 13:35:26 2008 +0200
+++ b/src/HOL/HOL.thy	Tue Apr 22 22:00:25 2008 +0200
@@ -1659,8 +1659,6 @@
 
 subsection {* Code generator basic setup -- see further @{text Code_Setup.thy} *}
 
-setup "CodeName.setup #> CodeTarget.setup #> Nbe.setup"
-
 code_datatype Trueprop "prop"
 
 code_datatype "TYPE('a\<Colon>{})"
@@ -1699,6 +1697,9 @@
 
 setup {*
   CodeUnit.add_const_alias @{thm equals_eq}
+  #> CodeName.setup
+  #> CodeTarget.setup
+  #> Nbe.setup @{sort eq} [(@{const_name eq_class.eq}, @{const_name "op ="})]
 *}
 
 lemma [code func]:
--- a/src/HOL/ex/NormalForm.thy	Tue Apr 22 13:35:26 2008 +0200
+++ b/src/HOL/ex/NormalForm.thy	Tue Apr 22 22:00:25 2008 +0200
@@ -58,19 +58,19 @@
 lemma "exp (S(S Z)) (S(S(S(S Z)))) = exp (S(S(S(S Z)))) (S(S Z))" by normalization
 
 lemma "(let ((x,y),(u,v)) = ((Z,Z),(Z,Z)) in add (add x y) (add u v)) = Z" by normalization
-lemma "split (%(x\<Colon>'a\<Colon>eq) y. x) (a, b) = a" by normalization rule
+lemma "split (%x y. x) (a, b) = a" by normalization rule
 lemma "(%((x,y),(u,v)). add (add x y) (add u v)) ((Z,Z),(Z,Z)) = Z" by normalization
 
 lemma "case Z of Z \<Rightarrow> True | S x \<Rightarrow> False" by normalization
 
 lemma "[] @ [] = []" by normalization
-lemma "map f [x,y,z::'x] = [f x \<Colon> 'a\<Colon>eq, f y, f z]" by normalization rule+
-lemma "[a \<Colon> 'a\<Colon>eq, b, c] @ xs = a # b # c # xs" by normalization rule+
-lemma "[] @ xs = (xs \<Colon> 'a\<Colon>eq list)" by normalization rule
+lemma "map f [x,y,z::'x] = [f x, f y, f z]" by normalization rule+
+lemma "[a, b, c] @ xs = a # b # c # xs" by normalization rule+
+lemma "[] @ xs = xs" by normalization rule
 lemma "map (%f. f True) [id, g, Not] = [True, g True, False]" by normalization rule+
 lemma "map (%f. f True) ([id, g, Not] @ fs) = [True, g True, False] @ map (%f. f True) fs" by normalization rule+
-lemma "rev [a, b, c] = [c \<Colon> 'a\<Colon>eq, b, a]" by normalization rule+
-normal_form "rev (a#b#cs) = rev cs @ [b, a \<Colon> 'a\<Colon>eq]"
+lemma "rev [a, b, c] = [c, b, a]" by normalization rule+
+normal_form "rev (a#b#cs) = rev cs @ [b, a]"
 normal_form "map (%F. F [a,b,c::'x]) (map map [f,g,h])"
 normal_form "map (%F. F ([a,b,c] @ ds)) (map map ([f,g,h]@fs))"
 normal_form "map (%F. F [Z,S Z,S(S Z)]) (map map [S,add (S Z),mul (S(S Z)),id])"
@@ -78,19 +78,19 @@
   by normalization
 normal_form "case xs of [] \<Rightarrow> True | x#xs \<Rightarrow> False"
 normal_form "map (%x. case x of None \<Rightarrow> False | Some y \<Rightarrow> True) xs = P"
-lemma "let x = y in [x, x] = [y \<Colon> 'a\<Colon>eq, y]" by normalization rule+
-lemma "Let y (%x. [x,x]) = [y \<Colon> 'a\<Colon>eq, y]" by normalization rule+
+lemma "let x = y in [x, x] = [y, y]" by normalization rule+
+lemma "Let y (%x. [x,x]) = [y, y]" by normalization rule+
 normal_form "case n of Z \<Rightarrow> True | S x \<Rightarrow> False"
 lemma "(%(x,y). add x y) (S z,S z) = S (add z (S z))" by normalization rule+
 normal_form "filter (%x. x) ([True,False,x]@xs)"
 normal_form "filter Not ([True,False,x]@xs)"
 
-lemma "[x,y,z] @ [a,b,c] = [x, y, z, a, b ,c \<Colon> 'a\<Colon>eq]" by normalization rule+
-lemma "(%(xs, ys). xs @ ys) ([a, b, c], [d, e, f]) = [a, b, c, d, e, f \<Colon> 'a\<Colon>eq]" by normalization rule+
+lemma "[x,y,z] @ [a,b,c] = [x, y, z, a, b, c]" by normalization rule+
+lemma "(%(xs, ys). xs @ ys) ([a, b, c], [d, e, f]) = [a, b, c, d, e, f]" by normalization rule+
 lemma "map (%x. case x of None \<Rightarrow> False | Some y \<Rightarrow> True) [None, Some ()] = [False, True]" by normalization
 
-lemma "last [a, b, c \<Colon> 'a\<Colon>eq] = c" by normalization rule
-lemma "last ([a, b, c \<Colon> 'a\<Colon>eq] @ xs) = (if null xs then c else last xs)"
+lemma "last [a, b, c] = c" by normalization rule
+lemma "last ([a, b, c] @ xs) = (if null xs then c else last xs)"
   by normalization rule
 
 lemma "(2::int) + 3 - 1 + (- k) * 2 = 4 + - k * 2" by normalization rule
@@ -111,10 +111,10 @@
 lemma "(42::rat) / 1704 = 1 / 284 + 3 / 142" by normalization
 normal_form "Suc 0 \<in> set ms"
 
-lemma "f = (f \<Colon> 'a\<Colon>eq)" by normalization rule+
-lemma "f x = (f x \<Colon> 'a\<Colon>eq)" by normalization rule+
-lemma "(f o g) x = (f (g x) \<Colon> 'a\<Colon>eq)" by normalization rule+
-lemma "(f o id) x = (f x \<Colon> 'a\<Colon>eq)" by normalization rule+
+lemma "f = f" by normalization rule+
+lemma "f x = f x" by normalization rule+
+lemma "(f o g) x = f (g x)" by normalization rule+
+lemma "(f o id) x = f x" by normalization rule+
 normal_form "(\<lambda>x. x)"
 
 (* Church numerals: *)
--- a/src/Tools/nbe.ML	Tue Apr 22 13:35:26 2008 +0200
+++ b/src/Tools/nbe.ML	Tue Apr 22 22:00:25 2008 +0200
@@ -23,7 +23,7 @@
   val univs_ref: (unit -> Univ list -> Univ list) option ref
   val trace: bool ref
 
-  val setup: theory -> theory
+  val setup: class list -> (string * string) list -> theory -> theory
 end;
 
 structure Nbe: NBE =
@@ -327,7 +327,7 @@
             val ts' = take_until is_dict ts;
             val c = (the o CodeName.const_rev thy o the o Inttab.lookup idx_tab) idx;
             val T = Code.default_typ thy c;
-            val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T;
+            val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, [])) T;
             val typidx' = typidx + maxidx_of_typ T' + 1;
           in of_apps bounds (Term.Const (c, T'), ts') typidx' end
       | of_univ bounds (Free (name, ts)) typidx =
@@ -373,20 +373,51 @@
     |> term_of_univ thy idx_tab
   end;
 
+(* trivial type classes *)
+
+structure Nbe_Triv_Classes = TheoryDataFun
+(
+  type T = class list * (string * string) list;
+  val empty = ([], []);
+  val copy = I;
+  val extend = I;
+  fun merge _ ((classes1, consts1), (classes2, consts2)) =
+    (Library.merge (op =) (classes1, classes2), Library.merge (op =) (consts1, consts2));
+)
+
+fun add_triv_classes thy =
+  let
+    val (trivs, _) = Nbe_Triv_Classes.get thy;
+    val inters = curry (Sorts.inter_sort (Sign.classes_of thy)) trivs;
+    fun map_sorts f = (map_types o map_atyps)
+      (fn TVar (v, sort) => TVar (v, f sort)
+        | TFree (v, sort) => TFree (v, f sort));
+  in map_sorts inters end;
+
+fun subst_triv_consts thy =
+  let
+    fun subst_const f = map_aterms (fn t as Term.Const (c, ty) => (case f c
+         of SOME c' => Term.Const (c', ty)
+          | NONE => t)
+      | t => t);
+    val (_, consts) = Nbe_Triv_Classes.get thy;
+    val subst_inst = perhaps (Option.map fst o AxClass.inst_of_param thy);
+  in map_aterms (subst_const (AList.lookup (op =) consts o subst_inst)) end;
+
 (* evaluation with type reconstruction *)
 
-fun eval thy code t vs_ty_t deps =
+fun eval thy t code vs_ty_t deps =
   let
     val ty = type_of t;
-    fun subst_Frees [] = I
-      | subst_Frees inst =
-          Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s)
-                            | t => t);
-    val anno_vars =
-      subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []))
-      #> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t []))
-    fun constrain t =
-      singleton (Syntax.check_terms (ProofContext.init thy)) (TypeInfer.constrain ty t);
+    val type_free = AList.lookup (op =)
+      (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []));
+    val type_frees = Term.map_aterms
+      (fn (t as Term.Free (s, _)) => the_default t (type_free s) | t => t);
+    fun type_infer t = [(t, ty)]
+      |> TypeInfer.infer_types (Sign.pp thy) (Sign.tsig_of thy) I
+           (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE)
+           Name.context 0 NONE
+      |> fst |> the_single |> fst;
     fun check_tvars t = if null (Term.term_tvars t) then t else
       error ("Illegal schematic type variables in normalized term: "
         ^ setmp show_types true (Sign.string_of_term thy) t);
@@ -394,40 +425,39 @@
   in
     compile_eval thy code vs_ty_t deps
     |> tracing (fn t => "Normalized:\n" ^ string_of_term t)
-    |> anno_vars
+    |> subst_triv_consts thy
+    |> type_frees
     |> tracing (fn t => "Vars typed:\n" ^ string_of_term t)
-    |> constrain
+    |> type_infer
     |> tracing (fn t => "Types inferred:\n" ^ string_of_term t)
+    |> check_tvars
     |> tracing (fn t => "---\n")
-    |> check_tvars
   end;
 
 (* evaluation oracle *)
 
-exception Norm of CodeThingol.code * term
+exception Norm of term * CodeThingol.code
   * (CodeThingol.typscheme * CodeThingol.iterm) * string list;
 
-fun norm_oracle (thy, Norm (code, t, vs_ty_t, deps)) =
-  Logic.mk_equals (t, eval thy code t vs_ty_t deps);
+fun norm_oracle (thy, Norm (t, code, vs_ty_t, deps)) =
+  Logic.mk_equals (t, eval thy t code vs_ty_t deps);
 
-fun norm_invoke thy code t vs_ty_t deps =
-  Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (code, t, vs_ty_t, deps));
+fun norm_invoke thy t code vs_ty_t deps =
+  Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (t, code, vs_ty_t, deps));
   (*FIXME get rid of hardwired theory name*)
 
 fun norm_conv ct =
   let
     val thy = Thm.theory_of_cterm ct;
-    fun conv code vs_ty_t deps ct =
-      let
-        val t = Thm.term_of ct;
-      in norm_invoke thy code t vs_ty_t deps end;
-  in CodePackage.evaluate_conv thy conv ct end;
+    fun evaluator' t code vs_ty_t deps = norm_invoke thy t code vs_ty_t deps;
+    fun evaluator t = (add_triv_classes thy t, evaluator' t);
+  in CodePackage.evaluate_conv thy evaluator ct end;
 
-fun norm_term thy =
+fun norm_term thy t =
   let
-    fun invoke code vs_ty_t deps t =
-      eval thy code t vs_ty_t deps;
-  in CodePackage.evaluate_term thy invoke #> Code.postprocess_term thy end;
+    fun evaluator' t code vs_ty_t deps = eval thy t code vs_ty_t deps;
+    fun evaluator t = (add_triv_classes thy t, evaluator' t);
+  in (Code.postprocess_term thy o CodePackage.evaluate_term thy evaluator) t end;
 
 (* evaluation command *)
 
@@ -448,7 +478,9 @@
   let val ctxt = Toplevel.context_of state
   in norm_print_term ctxt modes (Syntax.read_term ctxt s) end;
 
-val setup = Theory.add_oracle ("norm", norm_oracle)
+fun setup nbe_classes nbe_consts =
+  Theory.add_oracle ("norm", norm_oracle)
+  #> Nbe_Triv_Classes.map (K (nbe_classes, nbe_consts));
 
 local structure P = OuterParse and K = OuterKeyword in