# HG changeset patch # User haftmann # Date 1208894425 -7200 # Node ID 947b6013e8637c85ba7e5f9ea9941edc48ed70c6 # Parent 615e1a86787b488571d245b37c79653a39538b0a different handling of eq class for nbe diff -r 615e1a86787b -r 947b6013e863 src/HOL/HOL.thy --- a/src/HOL/HOL.thy Tue Apr 22 13:35:26 2008 +0200 +++ b/src/HOL/HOL.thy Tue Apr 22 22:00:25 2008 +0200 @@ -1659,8 +1659,6 @@ subsection {* Code generator basic setup -- see further @{text Code_Setup.thy} *} -setup "CodeName.setup #> CodeTarget.setup #> Nbe.setup" - code_datatype Trueprop "prop" code_datatype "TYPE('a\{})" @@ -1699,6 +1697,9 @@ setup {* CodeUnit.add_const_alias @{thm equals_eq} + #> CodeName.setup + #> CodeTarget.setup + #> Nbe.setup @{sort eq} [(@{const_name eq_class.eq}, @{const_name "op ="})] *} lemma [code func]: diff -r 615e1a86787b -r 947b6013e863 src/HOL/ex/NormalForm.thy --- a/src/HOL/ex/NormalForm.thy Tue Apr 22 13:35:26 2008 +0200 +++ b/src/HOL/ex/NormalForm.thy Tue Apr 22 22:00:25 2008 +0200 @@ -58,19 +58,19 @@ lemma "exp (S(S Z)) (S(S(S(S Z)))) = exp (S(S(S(S Z)))) (S(S Z))" by normalization lemma "(let ((x,y),(u,v)) = ((Z,Z),(Z,Z)) in add (add x y) (add u v)) = Z" by normalization -lemma "split (%(x\'a\eq) y. x) (a, b) = a" by normalization rule +lemma "split (%x y. x) (a, b) = a" by normalization rule lemma "(%((x,y),(u,v)). add (add x y) (add u v)) ((Z,Z),(Z,Z)) = Z" by normalization lemma "case Z of Z \ True | S x \ False" by normalization lemma "[] @ [] = []" by normalization -lemma "map f [x,y,z::'x] = [f x \ 'a\eq, f y, f z]" by normalization rule+ -lemma "[a \ 'a\eq, b, c] @ xs = a # b # c # xs" by normalization rule+ -lemma "[] @ xs = (xs \ 'a\eq list)" by normalization rule +lemma "map f [x,y,z::'x] = [f x, f y, f z]" by normalization rule+ +lemma "[a, b, c] @ xs = a # b # c # xs" by normalization rule+ +lemma "[] @ xs = xs" by normalization rule lemma "map (%f. f True) [id, g, Not] = [True, g True, False]" by normalization rule+ lemma "map (%f. f True) ([id, g, Not] @ fs) = [True, g True, False] @ map (%f. f True) fs" by normalization rule+ -lemma "rev [a, b, c] = [c \ 'a\eq, b, a]" by normalization rule+ -normal_form "rev (a#b#cs) = rev cs @ [b, a \ 'a\eq]" +lemma "rev [a, b, c] = [c, b, a]" by normalization rule+ +normal_form "rev (a#b#cs) = rev cs @ [b, a]" normal_form "map (%F. F [a,b,c::'x]) (map map [f,g,h])" normal_form "map (%F. F ([a,b,c] @ ds)) (map map ([f,g,h]@fs))" normal_form "map (%F. F [Z,S Z,S(S Z)]) (map map [S,add (S Z),mul (S(S Z)),id])" @@ -78,19 +78,19 @@ by normalization normal_form "case xs of [] \ True | x#xs \ False" normal_form "map (%x. case x of None \ False | Some y \ True) xs = P" -lemma "let x = y in [x, x] = [y \ 'a\eq, y]" by normalization rule+ -lemma "Let y (%x. [x,x]) = [y \ 'a\eq, y]" by normalization rule+ +lemma "let x = y in [x, x] = [y, y]" by normalization rule+ +lemma "Let y (%x. [x,x]) = [y, y]" by normalization rule+ normal_form "case n of Z \ True | S x \ False" lemma "(%(x,y). add x y) (S z,S z) = S (add z (S z))" by normalization rule+ normal_form "filter (%x. x) ([True,False,x]@xs)" normal_form "filter Not ([True,False,x]@xs)" -lemma "[x,y,z] @ [a,b,c] = [x, y, z, a, b ,c \ 'a\eq]" by normalization rule+ -lemma "(%(xs, ys). xs @ ys) ([a, b, c], [d, e, f]) = [a, b, c, d, e, f \ 'a\eq]" by normalization rule+ +lemma "[x,y,z] @ [a,b,c] = [x, y, z, a, b, c]" by normalization rule+ +lemma "(%(xs, ys). xs @ ys) ([a, b, c], [d, e, f]) = [a, b, c, d, e, f]" by normalization rule+ lemma "map (%x. case x of None \ False | Some y \ True) [None, Some ()] = [False, True]" by normalization -lemma "last [a, b, c \ 'a\eq] = c" by normalization rule -lemma "last ([a, b, c \ 'a\eq] @ xs) = (if null xs then c else last xs)" +lemma "last [a, b, c] = c" by normalization rule +lemma "last ([a, b, c] @ xs) = (if null xs then c else last xs)" by normalization rule lemma "(2::int) + 3 - 1 + (- k) * 2 = 4 + - k * 2" by normalization rule @@ -111,10 +111,10 @@ lemma "(42::rat) / 1704 = 1 / 284 + 3 / 142" by normalization normal_form "Suc 0 \ set ms" -lemma "f = (f \ 'a\eq)" by normalization rule+ -lemma "f x = (f x \ 'a\eq)" by normalization rule+ -lemma "(f o g) x = (f (g x) \ 'a\eq)" by normalization rule+ -lemma "(f o id) x = (f x \ 'a\eq)" by normalization rule+ +lemma "f = f" by normalization rule+ +lemma "f x = f x" by normalization rule+ +lemma "(f o g) x = f (g x)" by normalization rule+ +lemma "(f o id) x = f x" by normalization rule+ normal_form "(\x. x)" (* Church numerals: *) diff -r 615e1a86787b -r 947b6013e863 src/Tools/nbe.ML --- a/src/Tools/nbe.ML Tue Apr 22 13:35:26 2008 +0200 +++ b/src/Tools/nbe.ML Tue Apr 22 22:00:25 2008 +0200 @@ -23,7 +23,7 @@ val univs_ref: (unit -> Univ list -> Univ list) option ref val trace: bool ref - val setup: theory -> theory + val setup: class list -> (string * string) list -> theory -> theory end; structure Nbe: NBE = @@ -327,7 +327,7 @@ val ts' = take_until is_dict ts; val c = (the o CodeName.const_rev thy o the o Inttab.lookup idx_tab) idx; val T = Code.default_typ thy c; - val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T; + val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, [])) T; val typidx' = typidx + maxidx_of_typ T' + 1; in of_apps bounds (Term.Const (c, T'), ts') typidx' end | of_univ bounds (Free (name, ts)) typidx = @@ -373,20 +373,51 @@ |> term_of_univ thy idx_tab end; +(* trivial type classes *) + +structure Nbe_Triv_Classes = TheoryDataFun +( + type T = class list * (string * string) list; + val empty = ([], []); + val copy = I; + val extend = I; + fun merge _ ((classes1, consts1), (classes2, consts2)) = + (Library.merge (op =) (classes1, classes2), Library.merge (op =) (consts1, consts2)); +) + +fun add_triv_classes thy = + let + val (trivs, _) = Nbe_Triv_Classes.get thy; + val inters = curry (Sorts.inter_sort (Sign.classes_of thy)) trivs; + fun map_sorts f = (map_types o map_atyps) + (fn TVar (v, sort) => TVar (v, f sort) + | TFree (v, sort) => TFree (v, f sort)); + in map_sorts inters end; + +fun subst_triv_consts thy = + let + fun subst_const f = map_aterms (fn t as Term.Const (c, ty) => (case f c + of SOME c' => Term.Const (c', ty) + | NONE => t) + | t => t); + val (_, consts) = Nbe_Triv_Classes.get thy; + val subst_inst = perhaps (Option.map fst o AxClass.inst_of_param thy); + in map_aterms (subst_const (AList.lookup (op =) consts o subst_inst)) end; + (* evaluation with type reconstruction *) -fun eval thy code t vs_ty_t deps = +fun eval thy t code vs_ty_t deps = let val ty = type_of t; - fun subst_Frees [] = I - | subst_Frees inst = - Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s) - | t => t); - val anno_vars = - subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t [])) - #> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t [])) - fun constrain t = - singleton (Syntax.check_terms (ProofContext.init thy)) (TypeInfer.constrain ty t); + val type_free = AList.lookup (op =) + (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t [])); + val type_frees = Term.map_aterms + (fn (t as Term.Free (s, _)) => the_default t (type_free s) | t => t); + fun type_infer t = [(t, ty)] + |> TypeInfer.infer_types (Sign.pp thy) (Sign.tsig_of thy) I + (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE) + Name.context 0 NONE + |> fst |> the_single |> fst; fun check_tvars t = if null (Term.term_tvars t) then t else error ("Illegal schematic type variables in normalized term: " ^ setmp show_types true (Sign.string_of_term thy) t); @@ -394,40 +425,39 @@ in compile_eval thy code vs_ty_t deps |> tracing (fn t => "Normalized:\n" ^ string_of_term t) - |> anno_vars + |> subst_triv_consts thy + |> type_frees |> tracing (fn t => "Vars typed:\n" ^ string_of_term t) - |> constrain + |> type_infer |> tracing (fn t => "Types inferred:\n" ^ string_of_term t) + |> check_tvars |> tracing (fn t => "---\n") - |> check_tvars end; (* evaluation oracle *) -exception Norm of CodeThingol.code * term +exception Norm of term * CodeThingol.code * (CodeThingol.typscheme * CodeThingol.iterm) * string list; -fun norm_oracle (thy, Norm (code, t, vs_ty_t, deps)) = - Logic.mk_equals (t, eval thy code t vs_ty_t deps); +fun norm_oracle (thy, Norm (t, code, vs_ty_t, deps)) = + Logic.mk_equals (t, eval thy t code vs_ty_t deps); -fun norm_invoke thy code t vs_ty_t deps = - Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (code, t, vs_ty_t, deps)); +fun norm_invoke thy t code vs_ty_t deps = + Thm.invoke_oracle_i thy "HOL.norm" (thy, Norm (t, code, vs_ty_t, deps)); (*FIXME get rid of hardwired theory name*) fun norm_conv ct = let val thy = Thm.theory_of_cterm ct; - fun conv code vs_ty_t deps ct = - let - val t = Thm.term_of ct; - in norm_invoke thy code t vs_ty_t deps end; - in CodePackage.evaluate_conv thy conv ct end; + fun evaluator' t code vs_ty_t deps = norm_invoke thy t code vs_ty_t deps; + fun evaluator t = (add_triv_classes thy t, evaluator' t); + in CodePackage.evaluate_conv thy evaluator ct end; -fun norm_term thy = +fun norm_term thy t = let - fun invoke code vs_ty_t deps t = - eval thy code t vs_ty_t deps; - in CodePackage.evaluate_term thy invoke #> Code.postprocess_term thy end; + fun evaluator' t code vs_ty_t deps = eval thy t code vs_ty_t deps; + fun evaluator t = (add_triv_classes thy t, evaluator' t); + in (Code.postprocess_term thy o CodePackage.evaluate_term thy evaluator) t end; (* evaluation command *) @@ -448,7 +478,9 @@ let val ctxt = Toplevel.context_of state in norm_print_term ctxt modes (Syntax.read_term ctxt s) end; -val setup = Theory.add_oracle ("norm", norm_oracle) +fun setup nbe_classes nbe_consts = + Theory.add_oracle ("norm", norm_oracle) + #> Nbe_Triv_Classes.map (K (nbe_classes, nbe_consts)); local structure P = OuterParse and K = OuterKeyword in