--- a/src/HOL/HOL.thy Tue Apr 22 13:35:26 2008 +0200
+++ b/src/HOL/HOL.thy Tue Apr 22 22:00:25 2008 +0200
@@ -1659,8 +1659,6 @@
subsection {* Code generator basic setup -- see further @{text Code_Setup.thy} *}
-setup "CodeName.setup #> CodeTarget.setup #> Nbe.setup"
-
code_datatype Trueprop "prop"
code_datatype "TYPE('a\<Colon>{})"
@@ -1699,6 +1697,9 @@
setup {*
CodeUnit.add_const_alias @{thm equals_eq}
+ #> CodeName.setup
+ #> CodeTarget.setup
+ #> Nbe.setup @{sort eq} [(@{const_name eq_class.eq}, @{const_name "op ="})]
*}
lemma [code func]:
--- a/src/HOL/ex/NormalForm.thy Tue Apr 22 13:35:26 2008 +0200
+++ b/src/HOL/ex/NormalForm.thy Tue Apr 22 22:00:25 2008 +0200
@@ -58,19 +58,19 @@
lemma "exp (S(S Z)) (S(S(S(S Z)))) = exp (S(S(S(S Z)))) (S(S Z))" by normalization
lemma "(let ((x,y),(u,v)) = ((Z,Z),(Z,Z)) in add (add x y) (add u v)) = Z" by normalization
-lemma "split (%(x\<Colon>'a\<Colon>eq) y. x) (a, b) = a" by normalization rule
+lemma "split (%x y. x) (a, b) = a" by normalization rule
lemma "(%((x,y),(u,v)). add (add x y) (add u v)) ((Z,Z),(Z,Z)) = Z" by normalization
lemma "case Z of Z \<Rightarrow> True | S x \<Rightarrow> False" by normalization
lemma "[] @ [] = []" by normalization
-lemma "map f [x,y,z::'x] = [f x \<Colon> 'a\<Colon>eq, f y, f z]" by normalization rule+
-lemma "[a \<Colon> 'a\<Colon>eq, b, c] @ xs = a # b # c # xs" by normalization rule+
-lemma "[] @ xs = (xs \<Colon> 'a\<Colon>eq list)" by normalization rule
+lemma "map f [x,y,z::'x] = [f x, f y, f z]" by normalization rule+
+lemma "[a, b, c] @ xs = a # b # c # xs" by normalization rule+
+lemma "[] @ xs = xs" by normalization rule
lemma "map (%f. f True) [id, g, Not] = [True, g True, False]" by normalization rule+
lemma "map (%f. f True) ([id, g, Not] @ fs) = [True, g True, False] @ map (%f. f True) fs" by normalization rule+
-lemma "rev [a, b, c] = [c \<Colon> 'a\<Colon>eq, b, a]" by normalization rule+
-normal_form "rev (a#b#cs) = rev cs @ [b, a \<Colon> 'a\<Colon>eq]"
+lemma "rev [a, b, c] = [c, b, a]" by normalization rule+
+normal_form "rev (a#b#cs) = rev cs @ [b, a]"
normal_form "map (%F. F [a,b,c::'x]) (map map [f,g,h])"
normal_form "map (%F. F ([a,b,c] @ ds)) (map map ([f,g,h]@fs))"
normal_form "map (%F. F [Z,S Z,S(S Z)]) (map map [S,add (S Z),mul (S(S Z)),id])"
@@ -78,19 +78,19 @@
by normalization
normal_form "case xs of [] \<Rightarrow> True | x#xs \<Rightarrow> False"
normal_form "map (%x. case x of None \<Rightarrow> False | Some y \<Rightarrow> True) xs = P"
-lemma "let x = y in [x, x] = [y \<Colon> 'a\<Colon>eq, y]" by normalization rule+
-lemma "Let y (%x. [x,x]) = [y \<Colon> 'a\<Colon>eq, y]" by normalization rule+
+lemma "let x = y in [x, x] = [y, y]" by normalization rule+
+lemma "Let y (%x. [x,x]) = [y, y]" by normalization rule+
normal_form "case n of Z \<Rightarrow> True | S x \<Rightarrow> False"
lemma "(%(x,y). add x y) (S z,S z) = S (add z (S z))" by normalization rule+
normal_form "filter (%x. x) ([True,False,x]@xs)"
normal_form "filter Not ([True,False,x]@xs)"
-lemma "[x,y,z] @ [a,b,c] = [x, y, z, a, b ,c \<Colon> 'a\<Colon>eq]" by normalization rule+
-lemma "(%(xs, ys). xs @ ys) ([a, b, c], [d, e, f]) = [a, b, c, d, e, f \<Colon> 'a\<Colon>eq]" by normalization rule+
+lemma "[x,y,z] @ [a,b,c] = [x, y, z, a, b, c]" by normalization rule+
+lemma "(%(xs, ys). xs @ ys) ([a, b, c], [d, e, f]) = [a, b, c, d, e, f]" by normalization rule+
lemma "map (%x. case x of None \<Rightarrow> False | Some y \<Rightarrow> True) [None, Some ()] = [False, True]" by normalization
-lemma "last [a, b, c \<Colon> 'a\<Colon>eq] = c" by normalization rule
-lemma "last ([a, b, c \<Colon> 'a\<Colon>eq] @ xs) = (if null xs then c else last xs)"
+lemma "last [a, b, c] = c" by normalization rule
+lemma "last ([a, b, c] @ xs) = (if null xs then c else last xs)"
by normalization rule
lemma "(2::int) + 3 - 1 + (- k) * 2 = 4 + - k * 2" by normalization rule
@@ -111,10 +111,10 @@
lemma "(42::rat) / 1704 = 1 / 284 + 3 / 142" by normalization
normal_form "Suc 0 \<in> set ms"
-lemma "f = (f \<Colon> 'a\<Colon>eq)" by normalization rule+
-lemma "f x = (f x \<Colon> 'a\<Colon>eq)" by normalization rule+
-lemma "(f o g) x = (f (g x) \<Colon> 'a\<Colon>eq)" by normalization rule+
-lemma "(f o id) x = (f x \<Colon> 'a\<Colon>eq)" by normalization rule+
+lemma "f = f" by normalization rule+
+lemma "f x = f x" by normalization rule+
+lemma "(f o g) x = f (g x)" by normalization rule+
+lemma "(f o id) x = f x" by normalization rule+
normal_form "(\<lambda>x. x)"
(* Church numerals: *)
--- a/src/Tools/nbe.ML Tue Apr 22 13:35:26 2008 +0200
+++ b/src/Tools/nbe.ML Tue Apr 22 22:00:25 2008 +0200
@@ -23,7 +23,7 @@
val univs_ref: (unit -> Univ list -> Univ list) option ref
val trace: bool ref
- val setup: theory -> theory
+ val setup: class list -> (string * string) list -> theory -> theory
end;
structure Nbe: NBE =
@@ -327,7 +327,7 @@
val ts' = take_until is_dict ts;
val c = (the o CodeName.const_rev thy o the o Inttab.lookup idx_tab) idx;
val T = Code.default_typ thy c;
- val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T;
+ val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, [])) T;
val typidx' = typidx + maxidx_of_typ T' + 1;
in of_apps bounds (Term.Const (c, T'), ts') typidx' end
| of_univ bounds (Free (name, ts)) typidx =
@@ -373,20 +373,51 @@
|> term_of_univ thy idx_tab
end;
+(* trivial type classes *)
+
+structure Nbe_Triv_Classes = TheoryDataFun
+(
+ type T = class list * (string * string) list;
+ val empty = ([], []);
+ val copy = I;
+ val extend = I;
+ fun merge _ ((classes1, consts1), (classes2, consts2)) =
+ (Library.merge (op =) (classes1, classes2), Library.merge (op =) (consts1, consts2));
+)
+
+fun add_triv_classes thy =
+ let
+ val (trivs, _) = Nbe_Triv_Classes.get thy;
+ val inters = curry (Sorts.inter_sort (Sign.classes_of thy)) trivs;
+ fun map_sorts f = (map_types o map_atyps)
+ (fn TVar (v, sort) => TVar (v, f sort)
+ | TFree (v, sort) => TFree (v, f sort));
+ in map_sorts inters end;
+
+fun subst_triv_consts thy =
+ let
+ fun subst_const f = map_aterms (fn t as Term.Const (c, ty) => (case f c
+ of SOME c' => Term.Const (c', ty)
+ | NONE => t)
+ | t => t);
+ val (_, consts) = Nbe_Triv_Classes.get thy;
+ val subst_inst = perhaps (Option.map fst o AxClass.inst_of_param thy);
+ in map_aterms (subst_const (AList.lookup (op =) consts o subst_inst)) end;
+
(* evaluation with type reconstruction *)
-fun eval thy code t vs_ty_t deps =
+fun eval thy t code vs_ty_t deps =
let
val ty = type_of t;
- fun subst_Frees [] = I
- | subst_Frees inst =
- Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s)
- | t => t);
- val anno_vars =
- subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []))
- #> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t []))
- fun constrain t =
- singleton (Syntax.check_terms (ProofContext.init thy)) (TypeInfer.constrain ty t);
+ val type_free = AList.lookup (op =)
+ (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []));
+ val type_frees = Term.map_aterms
+ (fn (t as Term.Free (s, _)) => the_default t (type_free s) | t => t);
+ fun type_infer t = [(t, ty)]
+ |> TypeInfer.infer_types (Sign.pp thy) (Sign.tsig_of thy) I
+ (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE)
+ Name.context 0 NONE
+ |> fst |> the_single |> fst;
fun check_tvars t = if null (Term.term_tvars t) then t else
error ("Illegal schematic type variables in normalized term: "
^ setmp show_types true (Sign.string_of_term thy) t);
@@ -394,40 +425,39 @@
in
compile_eval thy code vs_ty_t deps
|> tracing (fn t => "Normalized:\n" ^ string_of_term t)
- |> anno_vars
+ |> subst_triv_consts thy
+ |> type_frees
|> tracing (fn t => "Vars typed:\n" ^ string_of_term t)
- |> constrain
+ |> type_infer
|> tracing (fn t => "Types inferred:\n" ^ string_of_term t)
+ |> check_tvars
|> tracing (fn t => "---\n")
- |> check_tvars
end;
(* evaluation oracle *)
-exception Norm of CodeThingol.code * term
+exception Norm of term * CodeThingol.code
* (CodeThingol.typscheme * CodeThingol.iterm) * string list;
-fun norm_oracle (thy, Norm (code, t, vs_ty_t, deps)) =
- Logic.mk_equals (t, eval thy code t vs_ty_t deps);
+fun norm_oracle (thy, Norm (t, code, vs_ty_t, deps)) =
+ Logic.mk_equals (t, eval thy t code vs_ty_t deps);
-fun norm_invoke thy code t vs_ty_t deps =
- Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (code, t, vs_ty_t, deps));
+fun norm_invoke thy t code vs_ty_t deps =
+ Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (t, code, vs_ty_t, deps));
(*FIXME get rid of hardwired theory name*)
fun norm_conv ct =
let
val thy = Thm.theory_of_cterm ct;
- fun conv code vs_ty_t deps ct =
- let
- val t = Thm.term_of ct;
- in norm_invoke thy code t vs_ty_t deps end;
- in CodePackage.evaluate_conv thy conv ct end;
+ fun evaluator' t code vs_ty_t deps = norm_invoke thy t code vs_ty_t deps;
+ fun evaluator t = (add_triv_classes thy t, evaluator' t);
+ in CodePackage.evaluate_conv thy evaluator ct end;
-fun norm_term thy =
+fun norm_term thy t =
let
- fun invoke code vs_ty_t deps t =
- eval thy code t vs_ty_t deps;
- in CodePackage.evaluate_term thy invoke #> Code.postprocess_term thy end;
+ fun evaluator' t code vs_ty_t deps = eval thy t code vs_ty_t deps;
+ fun evaluator t = (add_triv_classes thy t, evaluator' t);
+ in (Code.postprocess_term thy o CodePackage.evaluate_term thy evaluator) t end;
(* evaluation command *)
@@ -448,7 +478,9 @@
let val ctxt = Toplevel.context_of state
in norm_print_term ctxt modes (Syntax.read_term ctxt s) end;
-val setup = Theory.add_oracle ("norm", norm_oracle)
+fun setup nbe_classes nbe_consts =
+ Theory.add_oracle ("norm", norm_oracle)
+ #> Nbe_Triv_Classes.map (K (nbe_classes, nbe_consts));
local structure P = OuterParse and K = OuterKeyword in