# HG changeset patch # User haftmann # Date 1222327688 -7200 # Node ID 715163ec93c0340a619d9995f139203a8fbb354d # Parent 46a0dc9b51bb99e224433c91cd34b8f7ed364aee non left-linear equations for nbe diff -r 46a0dc9b51bb -r 715163ec93c0 NEWS --- a/NEWS Thu Sep 25 09:28:07 2008 +0200 +++ b/NEWS Thu Sep 25 09:28:08 2008 +0200 @@ -66,7 +66,10 @@ *** HOL *** -* HOL/Main: command "value" now integrates different evaluation +* Normalization by evaluation now allows non-leftlinear equations. +Declare with attribute [code nbe]. + +* Command "value" now integrates different evaluation mechanisms. The result of the first successful evaluation mechanism is printed. In square brackets a particular named evaluation mechanisms may be specified (currently, [SML], [code] or [nbe]). See diff -r 46a0dc9b51bb -r 715163ec93c0 src/HOL/Tools/datatype_codegen.ML --- a/src/HOL/Tools/datatype_codegen.ML Thu Sep 25 09:28:07 2008 +0200 +++ b/src/HOL/Tools/datatype_codegen.ML Thu Sep 25 09:28:08 2008 +0200 @@ -433,9 +433,9 @@ let val vs' = (map o apsnd) (curry (Sorts.inter_sort (Sign.classes_of thy)) [HOLogic.class_eq]) vs; - fun add_def tyco lthy = + fun add_def dtco lthy = let - val ty = Type (tyco, map TFree vs'); + val ty = Type (dtco, map TFree vs'); fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT) $ Free ("x", ty) $ Free ("y", ty); val def = HOLogic.mk_Trueprop (HOLogic.mk_eq @@ -454,9 +454,15 @@ |> map (MetaSimplifier.rewrite_rule [Thm.transfer thy @{thm equals_eq}]); fun add_eq_thms dtco thy = let + val ty = Type (dtco, map TFree vs'); val thy_ref = Theory.check_thy thy; val const = AxClass.param_of_inst thy (@{const_name eq_class.eq}, dtco); - val get_thms = (fn () => get_eq' (Theory.deref thy_ref) dtco |> rev); + val eq_refl = @{thm HOL.eq_refl} + |> Thm.instantiate + ([pairself (Thm.ctyp_of thy) (TVar (("'a", 0), @{sort eq}), Logic.varifyT ty)], []) + |> Simpdata.mk_eq; + fun get_thms () = (eq_refl, false) + :: rev (map (rpair true) (get_eq' (Theory.deref thy_ref) dtco)); in Code.add_funcl (const, Susp.delay get_thms) thy end; diff -r 46a0dc9b51bb -r 715163ec93c0 src/HOL/Tools/typecopy_package.ML --- a/src/HOL/Tools/typecopy_package.ML Thu Sep 25 09:28:07 2008 +0200 +++ b/src/HOL/Tools/typecopy_package.ML Thu Sep 25 09:28:08 2008 +0200 @@ -124,10 +124,10 @@ let val SOME { constr, proj_def, inject, vs, ... } = get_info thy tyco; val vs' = (map o apsnd) (curry (Sorts.inter_sort (Sign.classes_of thy)) [HOLogic.class_eq]) vs; - val ty = Logic.unvarifyT (Sign.the_const_type thy constr); + val ty = Type (tyco, map TFree vs'); + val ty_constr = Logic.unvarifyT (Sign.the_const_type thy constr); fun add_def tyco lthy = let - val ty = Type (tyco, map TFree vs'); fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT) $ Free ("x", ty) $ Free ("y", ty); val def = HOLogic.mk_Trueprop (HOLogic.mk_eq @@ -140,16 +140,23 @@ in (thm', lthy') end; fun tac thms = Class.intro_classes_tac [] THEN ALLGOALS (ProofContext.fact_tac thms); - fun add_eq_thm thy = + fun add_eq_thms thy = let val eq = inject |> Code_Unit.constrain_thm [HOLogic.class_eq] |> Simpdata.mk_eq |> MetaSimplifier.rewrite_rule [Thm.transfer thy @{thm equals_eq}]; - in Code.add_func eq thy end; + val eq_refl = @{thm HOL.eq_refl} + |> Thm.instantiate + ([pairself (Thm.ctyp_of thy) (TVar (("'a", 0), @{sort eq}), Logic.varifyT ty)], []); + in + thy + |> Code.add_func eq + |> Code.add_nonlinear_func eq_refl + end; in thy - |> Code.add_datatype [(constr, ty)] + |> Code.add_datatype [(constr, ty_constr)] |> Code.add_func proj_def |> TheoryTarget.instantiation ([tyco], vs', [HOLogic.class_eq]) |> add_def tyco @@ -157,7 +164,7 @@ #> LocalTheory.exit #> ProofContext.theory_of #> Code.del_func thm - #> add_eq_thm) + #> add_eq_thms) end; val setup = diff -r 46a0dc9b51bb -r 715163ec93c0 src/HOL/ex/NormalForm.thy --- a/src/HOL/ex/NormalForm.thy Thu Sep 25 09:28:07 2008 +0200 +++ b/src/HOL/ex/NormalForm.thy Thu Sep 25 09:28:08 2008 +0200 @@ -8,17 +8,30 @@ imports Main "~~/src/HOL/Real/Rational" begin +lemma [code nbe]: + "x = x \ True" by rule+ + +lemma [code nbe]: + "eq_class.eq (x::bool) x \ True" unfolding eq by rule+ + +lemma [code nbe]: + "eq_class.eq (x::nat) x \ True" unfolding eq by rule+ + lemma "True" by normalization lemma "p \ True" by normalization -declare disj_assoc [code func] -lemma "((P | Q) | R) = (P | (Q | R))" by normalization rule +declare disj_assoc [code nbe] +lemma "((P | Q) | R) = (P | (Q | R))" by normalization declare disj_assoc [code func del] -lemma "0 + (n::nat) = n" by normalization rule -lemma "0 + Suc n = Suc n" by normalization rule -lemma "Suc n + Suc m = n + Suc (Suc m)" by normalization rule +lemma "0 + (n::nat) = n" by normalization +lemma "0 + Suc n = Suc n" by normalization +lemma "Suc n + Suc m = n + Suc (Suc m)" by normalization lemma "~((0::nat) < (0::nat))" by normalization datatype n = Z | S n + +lemma [code nbe]: + "eq_class.eq (x::n) x \ True" unfolding eq by rule+ + consts add :: "n \ n \ n" add2 :: "n \ n \ n" @@ -40,9 +53,9 @@ lemma [code]: "add2 n Z = n" by(induct n) auto -lemma "add2 (add2 n m) k = add2 n (add2 m k)" by normalization rule -lemma "add2 (add2 (S n) (S m)) (S k) = S(S(S(add2 n (add2 m k))))" by normalization rule -lemma "add2 (add2 (S n) (add2 (S m) Z)) (S k) = S(S(S(add2 n (add2 m k))))" by normalization rule +lemma "add2 (add2 n m) k = add2 n (add2 m k)" by normalization +lemma "add2 (add2 (S n) (S m)) (S k) = S(S(S(add2 n (add2 m k))))" by normalization +lemma "add2 (add2 (S n) (add2 (S m) Z)) (S k) = S(S(S(add2 n (add2 m k))))" by normalization primrec "mul Z = (%n. Z)" @@ -59,18 +72,22 @@ lemma "exp (S(S Z)) (S(S(S(S Z)))) = exp (S(S(S(S Z)))) (S(S Z))" by normalization lemma "(let ((x,y),(u,v)) = ((Z,Z),(Z,Z)) in add (add x y) (add u v)) = Z" by normalization -lemma "split (%x y. x) (a, b) = a" by normalization rule +lemma "split (%x y. x) (a, b) = a" by normalization lemma "(%((x,y),(u,v)). add (add x y) (add u v)) ((Z,Z),(Z,Z)) = Z" by normalization lemma "case Z of Z \ True | S x \ False" by normalization lemma "[] @ [] = []" by normalization -lemma "map f [x,y,z::'x] = [f x, f y, f z]" by normalization rule+ -lemma "[a, b, c] @ xs = a # b # c # xs" by normalization rule+ -lemma "[] @ xs = xs" by normalization rule -lemma "map (%f. f True) [id, g, Not] = [True, g True, False]" by normalization rule+ +lemma "map f [x,y,z::'x] = [f x, f y, f z]" by normalization +lemma "[a, b, c] @ xs = a # b # c # xs" by normalization +lemma "[] @ xs = xs" by normalization +lemma "map (%f. f True) [id, g, Not] = [True, g True, False]" by normalization + +lemma [code nbe]: + "eq_class.eq (x :: 'a\eq list) x \ True" unfolding eq by rule+ + lemma "map (%f. f True) ([id, g, Not] @ fs) = [True, g True, False] @ map (%f. f True) fs" by normalization rule+ -lemma "rev [a, b, c] = [c, b, a]" by normalization rule+ +lemma "rev [a, b, c] = [c, b, a]" by normalization normal_form "rev (a#b#cs) = rev cs @ [b, a]" normal_form "map (%F. F [a,b,c::'x]) (map map [f,g,h])" normal_form "map (%F. F ([a,b,c] @ ds)) (map map ([f,g,h]@fs))" @@ -79,21 +96,24 @@ by normalization normal_form "case xs of [] \ True | x#xs \ False" normal_form "map (%x. case x of None \ False | Some y \ True) xs = P" -lemma "let x = y in [x, x] = [y, y]" by normalization rule+ -lemma "Let y (%x. [x,x]) = [y, y]" by normalization rule+ +lemma "let x = y in [x, x] = [y, y]" by normalization +lemma "Let y (%x. [x,x]) = [y, y]" by normalization normal_form "case n of Z \ True | S x \ False" -lemma "(%(x,y). add x y) (S z,S z) = S (add z (S z))" by normalization rule+ +lemma "(%(x,y). add x y) (S z,S z) = S (add z (S z))" by normalization normal_form "filter (%x. x) ([True,False,x]@xs)" normal_form "filter Not ([True,False,x]@xs)" -lemma "[x,y,z] @ [a,b,c] = [x, y, z, a, b, c]" by normalization rule+ -lemma "(%(xs, ys). xs @ ys) ([a, b, c], [d, e, f]) = [a, b, c, d, e, f]" by normalization rule+ +lemma "[x,y,z] @ [a,b,c] = [x, y, z, a, b, c]" by normalization +lemma "(%(xs, ys). xs @ ys) ([a, b, c], [d, e, f]) = [a, b, c, d, e, f]" by normalization lemma "map (%x. case x of None \ False | Some y \ True) [None, Some ()] = [False, True]" by normalization -lemma "last [a, b, c] = c" by normalization rule -lemma "last ([a, b, c] @ xs) = last (c # xs)" by normalization rule +lemma "last [a, b, c] = c" by normalization +lemma "last ([a, b, c] @ xs) = last (c # xs)" by normalization -lemma "(2::int) + 3 - 1 + (- k) * 2 = 4 + - k * 2" by normalization rule +lemma [code nbe]: + "eq_class.eq (x :: int) x \ True" unfolding eq by rule+ + +lemma "(2::int) + 3 - 1 + (- k) * 2 = 4 + - k * 2" by normalization lemma "(-4::int) * 2 = -8" by normalization lemma "abs ((-4::int) + 2 * 1) = 2" by normalization lemma "(2::int) + 3 = 5" by normalization @@ -111,10 +131,10 @@ lemma "(42::rat) / 1704 = 1 / 284 + 3 / 142" by normalization normal_form "Suc 0 \ set ms" -lemma "f = f" by normalization rule+ -lemma "f x = f x" by normalization rule+ -lemma "(f o g) x = f (g x)" by normalization rule+ -lemma "(f o id) x = f x" by normalization rule+ +lemma "f = f" by normalization +lemma "f x = f x" by normalization +lemma "(f o g) x = f (g x)" by normalization +lemma "(f o id) x = f x" by normalization normal_form "(\x. x)" (* Church numerals: *) diff -r 46a0dc9b51bb -r 715163ec93c0 src/Pure/Isar/code.ML --- a/src/Pure/Isar/code.ML Thu Sep 25 09:28:07 2008 +0200 +++ b/src/Pure/Isar/code.ML Thu Sep 25 09:28:08 2008 +0200 @@ -9,12 +9,13 @@ signature CODE = sig val add_func: thm -> theory -> theory + val add_nonlinear_func: thm -> theory -> theory val add_liberal_func: thm -> theory -> theory val add_default_func: thm -> theory -> theory val add_default_func_attr: Attrib.src val del_func: thm -> theory -> theory val del_funcs: string -> theory -> theory - val add_funcl: string * thm list Susp.T -> theory -> theory + val add_funcl: string * (thm * bool) list Susp.T -> theory -> theory val map_pre: (MetaSimplifier.simpset -> MetaSimplifier.simpset) -> theory -> theory val map_post: (MetaSimplifier.simpset -> MetaSimplifier.simpset) -> theory -> theory val add_inline: thm -> theory -> theory @@ -34,7 +35,7 @@ val coregular_algebra: theory -> Sorts.algebra val operational_algebra: theory -> (sort -> sort) * Sorts.algebra - val these_funcs: theory -> string -> thm list + val these_funcs: theory -> string -> (thm * bool) list val get_datatype: theory -> string -> ((string * sort) list * (string * typ list) list) val get_datatype_of_constr: theory -> string -> string option val get_case_data: theory -> string -> (int * string list) option @@ -115,41 +116,38 @@ (** logical and syntactical specification of executable code **) -(* defining equations with default flag and lazy theorems *) +(* defining equations with linear flag, default flag and lazy theorems *) fun pretty_lthms ctxt r = case Susp.peek r - of SOME thms => map (ProofContext.pretty_thm ctxt) thms + of SOME thms => map (ProofContext.pretty_thm ctxt o fst) thms | NONE => [Pretty.str "[...]"]; fun certificate thy f r = case Susp.peek r - of SOME thms => (Susp.value o f thy) thms + of SOME thms => (Susp.value o burrow_fst (f thy)) thms | NONE => let val thy_ref = Theory.check_thy thy; - in Susp.delay (fn () => (f (Theory.deref thy_ref) o Susp.force) r) end; + in Susp.delay (fn () => (burrow_fst (f (Theory.deref thy_ref)) o Susp.force) r) end; -fun add_drop_redundant verbose thm thms = +fun add_drop_redundant (thm, linear) thms = let - fun warn thm' = (if verbose - then warning ("Code generator: dropping redundant defining equation\n" ^ Display.string_of_thm thm') - else (); true); val thy = Thm.theory_of_thm thm; val args_of = snd o strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of; val args = args_of thm; - fun matches [] _ = true - | matches (Var _ :: xs) [] = matches xs [] - | matches (_ :: _) [] = false - | matches (x :: xs) (y :: ys) = Pattern.matches thy (x, y) andalso matches xs ys; - fun drop thm' = matches args (args_of thm') andalso warn thm'; - in thm :: filter_out drop thms end; + fun matches_args args' = length args <= length args' andalso + Pattern.matchess thy (args, curry Library.take (length args) args'); + fun drop (thm', _) = if matches_args (args_of thm') then + (warning ("Code generator: dropping redundant defining equation\n" ^ Display.string_of_thm thm'); true) + else false; + in (thm, linear) :: filter_out drop thms end; -fun add_thm _ thm (false, thms) = (false, Susp.value (add_drop_redundant true thm (Susp.force thms))) - | add_thm true thm (true, thms) = (true, Susp.value (Susp.force thms @ [thm])) +fun add_thm _ thm (false, thms) = (false, Susp.map_force (add_drop_redundant thm) thms) + | add_thm true thm (true, thms) = (true, Susp.map_force (fn thms => thms @ [thm]) thms) | add_thm false thm (true, thms) = (false, Susp.value [thm]); fun add_lthms lthms _ = (false, lthms); -fun del_thm thm = apsnd (Susp.value o remove Thm.eq_thm_prop thm o Susp.force); +fun del_thm thm = (apsnd o Susp.map_force) (remove (eq_fst Thm.eq_thm_prop) (thm, true)); fun merge_defthms ((true, _), defthms2) = defthms2 | merge_defthms (defthms1 as (false, _), (true, _)) = defthms1 @@ -173,7 +171,7 @@ (* specification data *) datatype spec = Spec of { - funcs: (bool * thm list Susp.T) Symtab.table, + funcs: (bool * (thm * bool) list Susp.T) Symtab.table, dtyps: ((string * sort) list * (string * typ list) list) Symtab.table, cases: (int * string list) Symtab.table * unit Symtab.table }; @@ -479,7 +477,7 @@ val funcs = classparams |> map_filter (fn c => try (AxClass.param_of_inst thy) (c, tyco)) |> map (Symtab.lookup ((the_funcs o the_exec) thy)) - |> (map o Option.map) (Susp.force o snd) + |> (map o Option.map) (map fst o Susp.force o snd) |> maps these |> map (Thm.transfer thy); fun sorts_of [Type (_, tys)] = map (snd o dest_TVar) tys @@ -600,9 +598,10 @@ | NONE => check_typ_fun (c, thm); in check_typ (const_of_func thy thm, thm) end; -val mk_func = Code_Unit.error_thm (assert_func_typ o Code_Unit.mk_func); -val mk_liberal_func = Code_Unit.warning_thm (assert_func_typ o Code_Unit.mk_func); -val mk_default_func = Code_Unit.try_thm (assert_func_typ o Code_Unit.mk_func); +fun mk_func linear = Code_Unit.error_thm (assert_func_typ o Code_Unit.mk_func linear); +val mk_liberal_func = Code_Unit.warning_thm (assert_func_typ o Code_Unit.mk_func true); +val mk_syntactic_func = Code_Unit.warning_thm (assert_func_typ o Code_Unit.mk_func false); +val mk_default_func = Code_Unit.try_thm (assert_func_typ o Code_Unit.mk_func true); end; (*local*) @@ -641,8 +640,8 @@ val is_undefined = Symtab.defined o snd o the_cases o the_exec; -fun gen_add_func strict default thm thy = - case (if strict then SOME o mk_func else mk_liberal_func) thm +fun gen_add_func linear strict default thm thy = + case (if strict then SOME o mk_func linear else mk_liberal_func) thm of SOME func => let val c = const_of_func thy func; @@ -656,15 +655,16 @@ else (); in (map_exec_purge (SOME [c]) o map_funcs) (Symtab.map_default - (c, (true, Susp.value [])) (add_thm default func)) thy + (c, (true, Susp.value [])) (add_thm default (func, linear))) thy end | NONE => thy; -val add_func = gen_add_func true false; -val add_liberal_func = gen_add_func false false; -val add_default_func = gen_add_func false true; +val add_func = gen_add_func true true false; +val add_liberal_func = gen_add_func true false false; +val add_default_func = gen_add_func true false true; +val add_nonlinear_func = gen_add_func false true false; -fun del_func thm thy = case mk_liberal_func thm +fun del_func thm thy = case mk_syntactic_func thm of SOME func => let val c = const_of_func thy func; in map_exec_purge (SOME [c]) (map_funcs @@ -762,6 +762,7 @@ in TypeInterpretation.init #> add_del_attribute ("func", (add_func, del_func)) + #> add_simple_attribute ("nbe", add_nonlinear_func) #> add_del_attribute ("inline", (add_inline, del_inline)) #> add_del_attribute ("post", (add_post, del_post)) end)); @@ -801,7 +802,7 @@ thms |> apply_functrans thy |> map (Code_Unit.rewrite_func pre) - (*FIXME - must check gere: rewrite rule, defining equation, proper constant *) + (*FIXME - must check here: rewrite rule, defining equation, proper constant *) |> map (AxClass.unoverload thy) |> common_typ_funcs end; @@ -847,24 +848,24 @@ Symtab.lookup ((the_funcs o the_exec) thy) const |> Option.map (Susp.force o snd) |> these - |> map (Thm.transfer thy); + |> (map o apfst) (Thm.transfer thy); in fun these_funcs thy const = let fun drop_refl thy = filter_out (is_equal o Term.fast_term_ord o Logic.dest_equals - o ObjectLogic.drop_judgment thy o Thm.plain_prop_of); + o ObjectLogic.drop_judgment thy o Thm.plain_prop_of o fst); in get_funcs thy const - |> preprocess thy + |> burrow_fst (preprocess thy) |> drop_refl thy end; fun default_typ thy c = case default_typ_proto thy c of SOME ty => Code_Unit.typscheme thy (c, ty) | NONE => (case get_funcs thy c - of thm :: _ => snd (Code_Unit.head_func (AxClass.unoverload thy thm)) + of (thm, _) :: _ => snd (Code_Unit.head_func (AxClass.unoverload thy thm)) | [] => Code_Unit.typscheme thy (c, Sign.the_const_type thy c)); end; (*local*) diff -r 46a0dc9b51bb -r 715163ec93c0 src/Tools/code/code_funcgr.ML --- a/src/Tools/code/code_funcgr.ML Thu Sep 25 09:28:07 2008 +0200 +++ b/src/Tools/code/code_funcgr.ML Thu Sep 25 09:28:08 2008 +0200 @@ -9,7 +9,7 @@ signature CODE_FUNCGR = sig type T - val funcs: T -> string -> thm list + val funcs: T -> string -> (thm * bool) list val typ: T -> string -> (string * sort) list * typ val all: T -> string list val pretty: theory -> T -> Pretty.T @@ -24,7 +24,7 @@ (** the graph type **) -type T = (((string * sort) list * typ) * thm list) Graph.T; +type T = (((string * sort) list * typ) * (thm * bool) list) Graph.T; fun funcs funcgr = these o Option.map snd o try (Graph.get_node funcgr); @@ -41,7 +41,7 @@ |> map (fn (s, thms) => (Pretty.block o Pretty.fbreaks) ( Pretty.str s - :: map Display.pretty_thm thms + :: map (Display.pretty_thm o fst) thms )) |> Pretty.chunks; @@ -57,7 +57,7 @@ | consts_of (const, thms as _ :: _) = let fun the_const (c, _) = if c = const then I else insert (op =) c - in fold_consts the_const thms [] end; + in fold_consts the_const (map fst thms) [] end; fun insts_of thy algebra tys sorts = let @@ -100,19 +100,19 @@ fun resort_funcss thy algebra funcgr = let val typ_funcgr = try (fst o Graph.get_node funcgr); - val resort_dep = apsnd (resort_thms thy algebra typ_funcgr); + val resort_dep = (apsnd o burrow_fst) (resort_thms thy algebra typ_funcgr); fun resort_rec typ_of (c, []) = (true, (c, [])) - | resort_rec typ_of (c, thms as thm :: _) = if is_some (AxClass.inst_of_param thy c) + | resort_rec typ_of (c, thms as (thm, _) :: _) = if is_some (AxClass.inst_of_param thy c) then (true, (c, thms)) else let val (_, (vs, ty)) = Code_Unit.head_func thm; - val thms' as thm' :: _ = resort_thms thy algebra typ_of thms + val thms' as (thm', _) :: _ = burrow_fst (resort_thms thy algebra typ_of) thms val (_, (vs', ty')) = Code_Unit.head_func thm'; (*FIXME simplify check*) in (Sign.typ_equiv thy (ty, ty'), (c, thms')) end; fun resort_recs funcss = let fun typ_of c = case these (AList.lookup (op =) funcss c) - of thm :: _ => (SOME o snd o Code_Unit.head_func) thm + of (thm, _) :: _ => (SOME o snd o Code_Unit.head_func) thm | [] => NONE; val (unchangeds, funcss') = split_list (map (resort_rec typ_of) funcss); val unchanged = fold (fn x => fn y => x andalso y) unchangeds true; @@ -158,8 +158,8 @@ |> pair (SOME const) else let val thms = Code.these_funcs thy const - |> Code_Unit.norm_args - |> Code_Unit.norm_varnames Code_Name.purify_tvar Code_Name.purify_var; + |> burrow_fst Code_Unit.norm_args + |> burrow_fst (Code_Unit.norm_varnames Code_Name.purify_tvar Code_Name.purify_var); val rhs = consts_of (const, thms); in auxgr @@ -182,14 +182,14 @@ |> resort_funcss thy algebra funcgr |> filter_out (can (Graph.get_node funcgr) o fst); fun typ_func c [] = Code.default_typ thy c - | typ_func c (thms as thm :: _) = (snd o Code_Unit.head_func) thm; + | typ_func c (thms as (thm, _) :: _) = (snd o Code_Unit.head_func) thm; fun add_funcs (const, thms) = Graph.new_node (const, (typ_func const thms, thms)); fun add_deps (funcs as (const, thms)) funcgr = let val deps = consts_of funcs; val insts = instances_of_consts thy algebra funcgr - (fold_consts (insert (op =)) thms []); + (fold_consts (insert (op =)) (map fst thms) []); in funcgr |> ensure_consts thy algebra insts @@ -278,16 +278,15 @@ (** diagnostic commands **) -fun code_depgr thy [] = make thy [] - | code_depgr thy consts = - let - val gr = make thy consts; - val select = Graph.all_succs gr consts; - in - gr - |> Graph.subgraph (member (op =) select) - |> Graph.map_nodes ((apsnd o map) (AxClass.overload thy)) - end; +fun code_depgr thy consts = + let + val gr = make thy consts; + val select = Graph.all_succs gr consts; + in + gr + |> not (null consts) ? Graph.subgraph (member (op =) select) + |> Graph.map_nodes ((apsnd o map o apfst) (AxClass.overload thy)) + end; fun code_thms thy = Pretty.writeln o pretty thy o code_depgr thy; diff -r 46a0dc9b51bb -r 715163ec93c0 src/Tools/code/code_haskell.ML --- a/src/Tools/code/code_haskell.ML Thu Sep 25 09:28:07 2008 +0200 +++ b/src/Tools/code/code_haskell.ML Thu Sep 25 09:28:08 2008 +0200 @@ -153,10 +153,11 @@ ) ] end - | pr_stmt (name, Code_Thingol.Fun ((vs, ty), eqs)) = + | pr_stmt (name, Code_Thingol.Fun ((vs, ty), raw_eqs)) = let + val eqs = filter (snd o snd) raw_eqs; val tyvars = intro_vars (map fst vs) init_syms; - fun pr_eq ((ts, t), thm) = + fun pr_eq ((ts, t), (thm, _)) = let val consts = map_filter (fn c => if (is_some o syntax_const) c @@ -248,7 +249,7 @@ | pr_stmt (_, Code_Thingol.Classinst ((class, (tyco, vs)), (_, classparam_insts))) = let val tyvars = intro_vars (map fst vs) init_syms; - fun pr_instdef ((classparam, c_inst), thm) = + fun pr_instdef ((classparam, c_inst), (thm, _)) = semicolon [ (str o classparam_name class) classparam, str "=", diff -r 46a0dc9b51bb -r 715163ec93c0 src/Tools/code/code_ml.ML --- a/src/Tools/code/code_ml.ML Thu Sep 25 09:28:07 2008 +0200 +++ b/src/Tools/code/code_ml.ML Thu Sep 25 09:28:08 2008 +0200 @@ -27,12 +27,12 @@ val target_OCaml = "OCaml"; datatype ml_stmt = - MLFuns of (string * (typscheme * ((iterm list * iterm) * thm) list)) list + MLFuns of (string * (typscheme * ((iterm list * iterm) * (thm * bool)) list)) list | MLDatas of (string * ((vname * sort) list * (string * itype list) list)) list | MLClass of string * (vname * ((class * string) list * (string * itype) list)) | MLClassinst of string * ((class * (string * (vname * sort) list)) * ((class * (string * (string * dict list list))) list - * ((string * const) * thm) list)); + * ((string * const) * (thm * bool)) list)); fun stmt_names_of (MLFuns fs) = map fst fs | stmt_names_of (MLDatas ds) = map fst ds @@ -192,7 +192,7 @@ val vs_dict = filter_out (null o snd) vs; val shift = if null eqs' then I else map (Pretty.block o single o Pretty.block o single); - fun pr_eq definer ((ts, t), thm) = + fun pr_eq definer ((ts, t), (thm, _)) = let val consts = map_filter (fn c => if (is_some o syntax_const) c @@ -299,7 +299,7 @@ str "=", pr_dicts NOBR [DictConst dss] ]; - fun pr_classparam ((classparam, c_inst), thm) = + fun pr_classparam ((classparam, c_inst), (thm, _)) = concat [ (str o pr_label_classparam) classparam, str "=", @@ -453,7 +453,7 @@ in map (lookup_var vars') fished3 end; fun pr_stmt (MLFuns (funns as funn :: funns')) = let - fun pr_eq ((ts, t), thm) = + fun pr_eq ((ts, t), (thm, _)) = let val consts = map_filter (fn c => if (is_some o syntax_const) c @@ -481,7 +481,7 @@ @@ str exc_str ) end - | pr_eqs _ _ [((ts, t), thm)] = + | pr_eqs _ _ [((ts, t), (thm, _))] = let val consts = map_filter (fn c => if (is_some o syntax_const) c @@ -611,7 +611,7 @@ str "=", pr_dicts NOBR [DictConst dss] ]; - fun pr_classparam_inst ((classparam, c_inst), thm) = + fun pr_classparam_inst ((classparam, c_inst), (thm, _)) = concat [ (str o deresolve) classparam, str "=", @@ -727,7 +727,7 @@ fold_map (fn (name, Code_Thingol.Fun stmt) => map_nsp_fun_yield (mk_name_stmt false name) #>> - rpair (name, stmt) + rpair (name, stmt |> apsnd (filter (snd o snd))) | (name, _) => error ("Function block containing illegal statement: " ^ labelled_name name) ) stmts @@ -895,7 +895,8 @@ val _ = if Code_Thingol.contains_dictvar t then error "Term to be evaluated constains free dictionaries" else (); val program' = program - |> Graph.new_node (Code_Name.value_name, Code_Thingol.Fun (([], ty), [(([], t), Drule.dummy_thm)])) + |> Graph.new_node (Code_Name.value_name, + Code_Thingol.Fun (([], ty), [(([], t), (Drule.dummy_thm, true))])) |> fold (curry Graph.add_edge Code_Name.value_name) deps; val (value_code, [value_name']) = ml_code_of thy program' [Code_Name.value_name]; val sml_code = "let\n" ^ value_code ^ "\nin " ^ value_name' diff -r 46a0dc9b51bb -r 715163ec93c0 src/Tools/code/code_thingol.ML --- a/src/Tools/code/code_thingol.ML Thu Sep 25 09:28:07 2008 +0200 +++ b/src/Tools/code/code_thingol.ML Thu Sep 25 09:28:08 2008 +0200 @@ -62,7 +62,7 @@ datatype stmt = NoStmt - | Fun of typscheme * ((iterm list * iterm) * thm) list + | Fun of typscheme * ((iterm list * iterm) * (thm * bool)) list | Datatype of (vname * sort) list * (string * itype list) list | Datatypecons of string | Class of vname * ((class * string) list * (string * itype) list) @@ -70,7 +70,7 @@ | Classparam of class | Classinst of (class * (string * (vname * sort) list)) * ((class * (string * (string * dict list list))) list - * ((string * const) * thm) list); + * ((string * const) * (thm * bool)) list); type program = stmt Graph.T; val empty_funs: program -> string list; val map_terms_bottom_up: (iterm -> iterm) -> iterm -> iterm; @@ -258,7 +258,7 @@ type typscheme = (vname * sort) list * itype; datatype stmt = NoStmt - | Fun of typscheme * ((iterm list * iterm) * thm) list + | Fun of typscheme * ((iterm list * iterm) * (thm * bool)) list | Datatype of (vname * sort) list * (string * itype list) list | Datatypecons of string | Class of vname * ((class * string) list * (string * itype) list) @@ -266,7 +266,7 @@ | Classparam of class | Classinst of (class * (string * (vname * sort) list)) * ((class * (string * (string * dict list list))) list - * ((string * const) * thm) list); + * ((string * const) * (thm * bool)) list); type program = stmt Graph.T; @@ -423,14 +423,14 @@ fold_map (ensure_classrel thy algbr funcgr) classrels #>> (fn classrels => DictVar (classrels, (unprefix "'" v, (k, length sort)))) in fold_map mk_dict typargs end -and exprgen_eq thy algbr funcgr thm = +and exprgen_eq thy algbr funcgr (thm, linear) = let val (args, rhs) = (apfst (snd o strip_comb) o Logic.dest_equals o Logic.unvarify o prop_of) thm; in fold_map (exprgen_term thy algbr funcgr (SOME thm)) args ##>> exprgen_term thy algbr funcgr (SOME thm) rhs - #>> rpair thm + #>> rpair (thm, linear) end and ensure_inst thy (algbr as (_, algebra)) funcgr (class, tyco) = let @@ -457,7 +457,7 @@ in ensure_const thy algbr funcgr c ##>> exprgen_const thy algbr funcgr (SOME thm) c_ty - #>> (fn (c, IConst c_inst) => ((c, c_inst), thm)) + #>> (fn (c, IConst c_inst) => ((c, c_inst), (thm, true))) end; val stmt_inst = ensure_class thy algbr funcgr class @@ -485,7 +485,7 @@ val ty = Logic.unvarifyT raw_ty; val thms = if (null o Term.typ_tfrees) ty orelse (null o fst o strip_type) ty then raw_thms - else map (Code_Unit.expand_eta 1) raw_thms; + else (map o apfst) (Code_Unit.expand_eta 1) raw_thms; in trns |> fold_map (exprgen_tyvar_sort thy algbr funcgr) vs @@ -642,7 +642,7 @@ fold_map (exprgen_tyvar_sort thy algbr funcgr) vs ##>> exprgen_typ thy algbr funcgr ty ##>> exprgen_term thy algbr funcgr NONE t - #>> (fn ((vs, ty), t) => Fun ((vs, ty), [(([], t), Drule.dummy_thm)])); + #>> (fn ((vs, ty), t) => Fun ((vs, ty), [(([], t), (Drule.dummy_thm, true))])); fun term_value (dep, program1) = let val Fun ((vs, ty), [(([], t), _)]) = diff -r 46a0dc9b51bb -r 715163ec93c0 src/Tools/nbe.ML --- a/src/Tools/nbe.ML Thu Sep 25 09:28:07 2008 +0200 +++ b/src/Tools/nbe.ML Thu Sep 25 09:28:08 2008 +0200 @@ -15,10 +15,11 @@ | Free of string * Univ list (*free (uninterpreted) variables*) | DFree of string * int (*free (uninterpreted) dictionary parameters*) | BVar of int * Univ list - | Abs of (int * (Univ list -> Univ)) * Univ list; + | Abs of (int * (Univ list -> Univ)) * Univ list val apps: Univ -> Univ list -> Univ (*explicit applications*) val abss: int -> (Univ list -> Univ) -> Univ (*abstractions as closures*) + val same: Univ -> Univ -> bool val univs_ref: (unit -> Univ list -> Univ list) option ref val trace: bool ref @@ -63,6 +64,13 @@ | Abs of (int * (Univ list -> Univ)) * Univ list (*abstractions as closures*); +fun same (Const (k, xs)) (Const (l, ys)) = k = l andalso sames xs ys + | same (Free (s, xs)) (Free (t, ys)) = s = t andalso sames xs ys + | same (DFree (s, k)) (DFree (t, l)) = s = t andalso k = l + | same (BVar (k, xs)) (BVar (l, ys)) = k = l andalso sames xs ys + | same _ _ = false +and sames xs ys = length xs = length ys andalso forall (uncurry same) (xs ~~ ys); + (* constructor functions *) fun abss n f = Abs ((n, f), []); @@ -92,6 +100,11 @@ fun ml_Let d e = "let\n" ^ d ^ " in " ^ e ^ " end"; fun ml_as v t = "(" ^ v ^ " as " ^ t ^ ")"; +fun ml_and [] = "true" + | ml_and [x] = x + | ml_and xs = "(" ^ space_implode " andalso " xs ^ ")"; +fun ml_if b x y = "(if " ^ b ^ " then " ^ x ^ " else " ^ y ^ ")"; + fun ml_list es = "[" ^ commas es ^ "]"; fun ml_fundefs ([(name, [([], e)])]) = @@ -113,11 +126,12 @@ val univs_ref = ref (NONE : (unit -> Univ list -> Univ list) option); local - val prefix = "Nbe."; - val name_ref = prefix ^ "univs_ref"; - val name_const = prefix ^ "Const"; - val name_abss = prefix ^ "abss"; - val name_apps = prefix ^ "apps"; + val prefix = "Nbe."; + val name_ref = prefix ^ "univs_ref"; + val name_const = prefix ^ "Const"; + val name_abss = prefix ^ "abss"; + val name_apps = prefix ^ "apps"; + val name_same = prefix ^ "same"; in val univs_cookie = (name_ref, univs_ref); @@ -141,6 +155,8 @@ fun nbe_abss 0 f = f `$` ml_list [] | nbe_abss n f = name_abss `$$` [string_of_int n, f]; +fun nbe_same v1 v2 = "(" ^ name_same ^ " " ^ nbe_bound v1 ^ " " ^ nbe_bound v2 ^ ")"; + end; open Basic_Code_Thingol; @@ -173,34 +189,62 @@ | assemble_idict (DictVar (supers, (v, (n, _)))) = fold_rev (fn super => assemble_constapp super [] o single) supers (nbe_dict v n); - fun assemble_iterm match_cont constapp = + fun assemble_iterm constapp = let - fun of_iterm t = + fun of_iterm match_cont t = let val (t', ts) = Code_Thingol.unfold_app t - in of_iapp t' (fold_rev (cons o of_iterm) ts []) end - and of_iapp (IConst (c, (dss, _))) ts = constapp c dss ts - | of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts - | of_iapp ((v, _) `|-> t) ts = - nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts - | of_iapp (ICase (((t, _), cs), t0)) ts = - nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs - @ [("_", case match_cont of SOME s => s | NONE => of_iterm t0)])) ts + in of_iapp match_cont t' (fold_rev (cons o of_iterm NONE) ts []) end + and of_iapp match_cont (IConst (c, (dss, _))) ts = constapp c dss ts + | of_iapp match_cont (IVar v) ts = nbe_apps (nbe_bound v) ts + | of_iapp match_cont ((v, _) `|-> t) ts = + nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm NONE t))) ts + | of_iapp match_cont (ICase (((t, _), cs), t0)) ts = + nbe_apps (ml_cases (of_iterm NONE t) + (map (fn (p, t) => (of_iterm NONE p, of_iterm match_cont t)) cs + @ [("_", case match_cont of SOME s => s | NONE => of_iterm NONE t0)])) ts in of_iterm end; + fun subst_nonlin_vars args = + let + val vs = (fold o Code_Thingol.fold_varnames) + (fn v => AList.map_default (op =) (v, 0) (curry (op +) 1)) args []; + val names = Name.make_context (map fst vs); + fun declare v k ctxt = let val vs = Name.invents ctxt v k + in (vs, fold Name.declare vs ctxt) end; + val (vs_renames, _) = fold_map (fn (v, k) => if k > 1 + then declare v (k - 1) #>> (fn vs => (v, vs)) + else pair (v, [])) vs names; + val samepairs = maps (fn (v, vs) => map (pair v) vs) vs_renames; + fun subst_vars (t as IConst _) samepairs = (t, samepairs) + | subst_vars (t as IVar v) samepairs = (case AList.lookup (op =) samepairs v + of SOME v' => (IVar v', AList.delete (op =) v samepairs) + | NONE => (t, samepairs)) + | subst_vars (t1 `$ t2) samepairs = samepairs + |> subst_vars t1 + ||>> subst_vars t2 + |>> (op `$) + | subst_vars (ICase (_, t)) samepairs = subst_vars t samepairs; + val (args', _) = fold_map subst_vars args samepairs; + in (samepairs, args') end; + fun assemble_eqn c dicts default_args (i, (args, rhs)) = let val is_eval = (c = ""); val default_rhs = nbe_apps_local (i+1) c (dicts @ default_args); val match_cont = if is_eval then NONE else SOME default_rhs; - val assemble_arg = assemble_iterm NONE - (fn c => fn _ => fn ts => nbe_apps_constr idx_of c ts); - val assemble_rhs = assemble_iterm match_cont assemble_constapp; + val assemble_arg = assemble_iterm + (fn c => fn _ => fn ts => nbe_apps_constr idx_of c ts) NONE; + val assemble_rhs = assemble_iterm assemble_constapp match_cont ; + val (samepairs, args') = subst_nonlin_vars args; + val s_args = map assemble_arg args'; + val s_rhs = if null samepairs then assemble_rhs rhs + else ml_if (ml_and (map (uncurry nbe_same) samepairs)) + (assemble_rhs rhs) default_rhs; val eqns = if is_eval then - [([ml_list (rev (dicts @ map assemble_arg args))], assemble_rhs rhs)] + [([ml_list (rev (dicts @ s_args))], s_rhs)] else - [([ml_list (rev (dicts @ map2 ml_as default_args - (map assemble_arg args)))], assemble_rhs rhs), + [([ml_list (rev (dicts @ map2 ml_as default_args s_args))], s_rhs), ([ml_list (rev (dicts @ default_args))], default_rhs)] in (nbe_fun i c, eqns) end;