--- a/NEWS Thu Sep 25 09:28:07 2008 +0200
+++ b/NEWS Thu Sep 25 09:28:08 2008 +0200
@@ -66,7 +66,10 @@
*** HOL ***
-* HOL/Main: command "value" now integrates different evaluation
+* Normalization by evaluation now allows non-leftlinear equations.
+Declare with attribute [code nbe].
+
+* Command "value" now integrates different evaluation
mechanisms. The result of the first successful evaluation mechanism
is printed. In square brackets a particular named evaluation
mechanisms may be specified (currently, [SML], [code] or [nbe]). See
--- a/src/HOL/Tools/datatype_codegen.ML Thu Sep 25 09:28:07 2008 +0200
+++ b/src/HOL/Tools/datatype_codegen.ML Thu Sep 25 09:28:08 2008 +0200
@@ -433,9 +433,9 @@
let
val vs' = (map o apsnd)
(curry (Sorts.inter_sort (Sign.classes_of thy)) [HOLogic.class_eq]) vs;
- fun add_def tyco lthy =
+ fun add_def dtco lthy =
let
- val ty = Type (tyco, map TFree vs');
+ val ty = Type (dtco, map TFree vs');
fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT)
$ Free ("x", ty) $ Free ("y", ty);
val def = HOLogic.mk_Trueprop (HOLogic.mk_eq
@@ -454,9 +454,15 @@
|> map (MetaSimplifier.rewrite_rule [Thm.transfer thy @{thm equals_eq}]);
fun add_eq_thms dtco thy =
let
+ val ty = Type (dtco, map TFree vs');
val thy_ref = Theory.check_thy thy;
val const = AxClass.param_of_inst thy (@{const_name eq_class.eq}, dtco);
- val get_thms = (fn () => get_eq' (Theory.deref thy_ref) dtco |> rev);
+ val eq_refl = @{thm HOL.eq_refl}
+ |> Thm.instantiate
+ ([pairself (Thm.ctyp_of thy) (TVar (("'a", 0), @{sort eq}), Logic.varifyT ty)], [])
+ |> Simpdata.mk_eq;
+ fun get_thms () = (eq_refl, false)
+ :: rev (map (rpair true) (get_eq' (Theory.deref thy_ref) dtco));
in
Code.add_funcl (const, Susp.delay get_thms) thy
end;
--- a/src/HOL/Tools/typecopy_package.ML Thu Sep 25 09:28:07 2008 +0200
+++ b/src/HOL/Tools/typecopy_package.ML Thu Sep 25 09:28:08 2008 +0200
@@ -124,10 +124,10 @@
let
val SOME { constr, proj_def, inject, vs, ... } = get_info thy tyco;
val vs' = (map o apsnd) (curry (Sorts.inter_sort (Sign.classes_of thy)) [HOLogic.class_eq]) vs;
- val ty = Logic.unvarifyT (Sign.the_const_type thy constr);
+ val ty = Type (tyco, map TFree vs');
+ val ty_constr = Logic.unvarifyT (Sign.the_const_type thy constr);
fun add_def tyco lthy =
let
- val ty = Type (tyco, map TFree vs');
fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT)
$ Free ("x", ty) $ Free ("y", ty);
val def = HOLogic.mk_Trueprop (HOLogic.mk_eq
@@ -140,16 +140,23 @@
in (thm', lthy') end;
fun tac thms = Class.intro_classes_tac []
THEN ALLGOALS (ProofContext.fact_tac thms);
- fun add_eq_thm thy =
+ fun add_eq_thms thy =
let
val eq = inject
|> Code_Unit.constrain_thm [HOLogic.class_eq]
|> Simpdata.mk_eq
|> MetaSimplifier.rewrite_rule [Thm.transfer thy @{thm equals_eq}];
- in Code.add_func eq thy end;
+ val eq_refl = @{thm HOL.eq_refl}
+ |> Thm.instantiate
+ ([pairself (Thm.ctyp_of thy) (TVar (("'a", 0), @{sort eq}), Logic.varifyT ty)], []);
+ in
+ thy
+ |> Code.add_func eq
+ |> Code.add_nonlinear_func eq_refl
+ end;
in
thy
- |> Code.add_datatype [(constr, ty)]
+ |> Code.add_datatype [(constr, ty_constr)]
|> Code.add_func proj_def
|> TheoryTarget.instantiation ([tyco], vs', [HOLogic.class_eq])
|> add_def tyco
@@ -157,7 +164,7 @@
#> LocalTheory.exit
#> ProofContext.theory_of
#> Code.del_func thm
- #> add_eq_thm)
+ #> add_eq_thms)
end;
val setup =
--- a/src/HOL/ex/NormalForm.thy Thu Sep 25 09:28:07 2008 +0200
+++ b/src/HOL/ex/NormalForm.thy Thu Sep 25 09:28:08 2008 +0200
@@ -8,17 +8,30 @@
imports Main "~~/src/HOL/Real/Rational"
begin
+lemma [code nbe]:
+ "x = x \<longleftrightarrow> True" by rule+
+
+lemma [code nbe]:
+ "eq_class.eq (x::bool) x \<longleftrightarrow> True" unfolding eq by rule+
+
+lemma [code nbe]:
+ "eq_class.eq (x::nat) x \<longleftrightarrow> True" unfolding eq by rule+
+
lemma "True" by normalization
lemma "p \<longrightarrow> True" by normalization
-declare disj_assoc [code func]
-lemma "((P | Q) | R) = (P | (Q | R))" by normalization rule
+declare disj_assoc [code nbe]
+lemma "((P | Q) | R) = (P | (Q | R))" by normalization
declare disj_assoc [code func del]
-lemma "0 + (n::nat) = n" by normalization rule
-lemma "0 + Suc n = Suc n" by normalization rule
-lemma "Suc n + Suc m = n + Suc (Suc m)" by normalization rule
+lemma "0 + (n::nat) = n" by normalization
+lemma "0 + Suc n = Suc n" by normalization
+lemma "Suc n + Suc m = n + Suc (Suc m)" by normalization
lemma "~((0::nat) < (0::nat))" by normalization
datatype n = Z | S n
+
+lemma [code nbe]:
+ "eq_class.eq (x::n) x \<longleftrightarrow> True" unfolding eq by rule+
+
consts
add :: "n \<Rightarrow> n \<Rightarrow> n"
add2 :: "n \<Rightarrow> n \<Rightarrow> n"
@@ -40,9 +53,9 @@
lemma [code]: "add2 n Z = n"
by(induct n) auto
-lemma "add2 (add2 n m) k = add2 n (add2 m k)" by normalization rule
-lemma "add2 (add2 (S n) (S m)) (S k) = S(S(S(add2 n (add2 m k))))" by normalization rule
-lemma "add2 (add2 (S n) (add2 (S m) Z)) (S k) = S(S(S(add2 n (add2 m k))))" by normalization rule
+lemma "add2 (add2 n m) k = add2 n (add2 m k)" by normalization
+lemma "add2 (add2 (S n) (S m)) (S k) = S(S(S(add2 n (add2 m k))))" by normalization
+lemma "add2 (add2 (S n) (add2 (S m) Z)) (S k) = S(S(S(add2 n (add2 m k))))" by normalization
primrec
"mul Z = (%n. Z)"
@@ -59,18 +72,22 @@
lemma "exp (S(S Z)) (S(S(S(S Z)))) = exp (S(S(S(S Z)))) (S(S Z))" by normalization
lemma "(let ((x,y),(u,v)) = ((Z,Z),(Z,Z)) in add (add x y) (add u v)) = Z" by normalization
-lemma "split (%x y. x) (a, b) = a" by normalization rule
+lemma "split (%x y. x) (a, b) = a" by normalization
lemma "(%((x,y),(u,v)). add (add x y) (add u v)) ((Z,Z),(Z,Z)) = Z" by normalization
lemma "case Z of Z \<Rightarrow> True | S x \<Rightarrow> False" by normalization
lemma "[] @ [] = []" by normalization
-lemma "map f [x,y,z::'x] = [f x, f y, f z]" by normalization rule+
-lemma "[a, b, c] @ xs = a # b # c # xs" by normalization rule+
-lemma "[] @ xs = xs" by normalization rule
-lemma "map (%f. f True) [id, g, Not] = [True, g True, False]" by normalization rule+
+lemma "map f [x,y,z::'x] = [f x, f y, f z]" by normalization
+lemma "[a, b, c] @ xs = a # b # c # xs" by normalization
+lemma "[] @ xs = xs" by normalization
+lemma "map (%f. f True) [id, g, Not] = [True, g True, False]" by normalization
+
+lemma [code nbe]:
+ "eq_class.eq (x :: 'a\<Colon>eq list) x \<longleftrightarrow> True" unfolding eq by rule+
+
lemma "map (%f. f True) ([id, g, Not] @ fs) = [True, g True, False] @ map (%f. f True) fs" by normalization rule+
-lemma "rev [a, b, c] = [c, b, a]" by normalization rule+
+lemma "rev [a, b, c] = [c, b, a]" by normalization
normal_form "rev (a#b#cs) = rev cs @ [b, a]"
normal_form "map (%F. F [a,b,c::'x]) (map map [f,g,h])"
normal_form "map (%F. F ([a,b,c] @ ds)) (map map ([f,g,h]@fs))"
@@ -79,21 +96,24 @@
by normalization
normal_form "case xs of [] \<Rightarrow> True | x#xs \<Rightarrow> False"
normal_form "map (%x. case x of None \<Rightarrow> False | Some y \<Rightarrow> True) xs = P"
-lemma "let x = y in [x, x] = [y, y]" by normalization rule+
-lemma "Let y (%x. [x,x]) = [y, y]" by normalization rule+
+lemma "let x = y in [x, x] = [y, y]" by normalization
+lemma "Let y (%x. [x,x]) = [y, y]" by normalization
normal_form "case n of Z \<Rightarrow> True | S x \<Rightarrow> False"
-lemma "(%(x,y). add x y) (S z,S z) = S (add z (S z))" by normalization rule+
+lemma "(%(x,y). add x y) (S z,S z) = S (add z (S z))" by normalization
normal_form "filter (%x. x) ([True,False,x]@xs)"
normal_form "filter Not ([True,False,x]@xs)"
-lemma "[x,y,z] @ [a,b,c] = [x, y, z, a, b, c]" by normalization rule+
-lemma "(%(xs, ys). xs @ ys) ([a, b, c], [d, e, f]) = [a, b, c, d, e, f]" by normalization rule+
+lemma "[x,y,z] @ [a,b,c] = [x, y, z, a, b, c]" by normalization
+lemma "(%(xs, ys). xs @ ys) ([a, b, c], [d, e, f]) = [a, b, c, d, e, f]" by normalization
lemma "map (%x. case x of None \<Rightarrow> False | Some y \<Rightarrow> True) [None, Some ()] = [False, True]" by normalization
-lemma "last [a, b, c] = c" by normalization rule
-lemma "last ([a, b, c] @ xs) = last (c # xs)" by normalization rule
+lemma "last [a, b, c] = c" by normalization
+lemma "last ([a, b, c] @ xs) = last (c # xs)" by normalization
-lemma "(2::int) + 3 - 1 + (- k) * 2 = 4 + - k * 2" by normalization rule
+lemma [code nbe]:
+ "eq_class.eq (x :: int) x \<longleftrightarrow> True" unfolding eq by rule+
+
+lemma "(2::int) + 3 - 1 + (- k) * 2 = 4 + - k * 2" by normalization
lemma "(-4::int) * 2 = -8" by normalization
lemma "abs ((-4::int) + 2 * 1) = 2" by normalization
lemma "(2::int) + 3 = 5" by normalization
@@ -111,10 +131,10 @@
lemma "(42::rat) / 1704 = 1 / 284 + 3 / 142" by normalization
normal_form "Suc 0 \<in> set ms"
-lemma "f = f" by normalization rule+
-lemma "f x = f x" by normalization rule+
-lemma "(f o g) x = f (g x)" by normalization rule+
-lemma "(f o id) x = f x" by normalization rule+
+lemma "f = f" by normalization
+lemma "f x = f x" by normalization
+lemma "(f o g) x = f (g x)" by normalization
+lemma "(f o id) x = f x" by normalization
normal_form "(\<lambda>x. x)"
(* Church numerals: *)
--- a/src/Pure/Isar/code.ML Thu Sep 25 09:28:07 2008 +0200
+++ b/src/Pure/Isar/code.ML Thu Sep 25 09:28:08 2008 +0200
@@ -9,12 +9,13 @@
signature CODE =
sig
val add_func: thm -> theory -> theory
+ val add_nonlinear_func: thm -> theory -> theory
val add_liberal_func: thm -> theory -> theory
val add_default_func: thm -> theory -> theory
val add_default_func_attr: Attrib.src
val del_func: thm -> theory -> theory
val del_funcs: string -> theory -> theory
- val add_funcl: string * thm list Susp.T -> theory -> theory
+ val add_funcl: string * (thm * bool) list Susp.T -> theory -> theory
val map_pre: (MetaSimplifier.simpset -> MetaSimplifier.simpset) -> theory -> theory
val map_post: (MetaSimplifier.simpset -> MetaSimplifier.simpset) -> theory -> theory
val add_inline: thm -> theory -> theory
@@ -34,7 +35,7 @@
val coregular_algebra: theory -> Sorts.algebra
val operational_algebra: theory -> (sort -> sort) * Sorts.algebra
- val these_funcs: theory -> string -> thm list
+ val these_funcs: theory -> string -> (thm * bool) list
val get_datatype: theory -> string -> ((string * sort) list * (string * typ list) list)
val get_datatype_of_constr: theory -> string -> string option
val get_case_data: theory -> string -> (int * string list) option
@@ -115,41 +116,38 @@
(** logical and syntactical specification of executable code **)
-(* defining equations with default flag and lazy theorems *)
+(* defining equations with linear flag, default flag and lazy theorems *)
fun pretty_lthms ctxt r = case Susp.peek r
- of SOME thms => map (ProofContext.pretty_thm ctxt) thms
+ of SOME thms => map (ProofContext.pretty_thm ctxt o fst) thms
| NONE => [Pretty.str "[...]"];
fun certificate thy f r =
case Susp.peek r
- of SOME thms => (Susp.value o f thy) thms
+ of SOME thms => (Susp.value o burrow_fst (f thy)) thms
| NONE => let
val thy_ref = Theory.check_thy thy;
- in Susp.delay (fn () => (f (Theory.deref thy_ref) o Susp.force) r) end;
+ in Susp.delay (fn () => (burrow_fst (f (Theory.deref thy_ref)) o Susp.force) r) end;
-fun add_drop_redundant verbose thm thms =
+fun add_drop_redundant (thm, linear) thms =
let
- fun warn thm' = (if verbose
- then warning ("Code generator: dropping redundant defining equation\n" ^ Display.string_of_thm thm')
- else (); true);
val thy = Thm.theory_of_thm thm;
val args_of = snd o strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of;
val args = args_of thm;
- fun matches [] _ = true
- | matches (Var _ :: xs) [] = matches xs []
- | matches (_ :: _) [] = false
- | matches (x :: xs) (y :: ys) = Pattern.matches thy (x, y) andalso matches xs ys;
- fun drop thm' = matches args (args_of thm') andalso warn thm';
- in thm :: filter_out drop thms end;
+ fun matches_args args' = length args <= length args' andalso
+ Pattern.matchess thy (args, curry Library.take (length args) args');
+ fun drop (thm', _) = if matches_args (args_of thm') then
+ (warning ("Code generator: dropping redundant defining equation\n" ^ Display.string_of_thm thm'); true)
+ else false;
+ in (thm, linear) :: filter_out drop thms end;
-fun add_thm _ thm (false, thms) = (false, Susp.value (add_drop_redundant true thm (Susp.force thms)))
- | add_thm true thm (true, thms) = (true, Susp.value (Susp.force thms @ [thm]))
+fun add_thm _ thm (false, thms) = (false, Susp.map_force (add_drop_redundant thm) thms)
+ | add_thm true thm (true, thms) = (true, Susp.map_force (fn thms => thms @ [thm]) thms)
| add_thm false thm (true, thms) = (false, Susp.value [thm]);
fun add_lthms lthms _ = (false, lthms);
-fun del_thm thm = apsnd (Susp.value o remove Thm.eq_thm_prop thm o Susp.force);
+fun del_thm thm = (apsnd o Susp.map_force) (remove (eq_fst Thm.eq_thm_prop) (thm, true));
fun merge_defthms ((true, _), defthms2) = defthms2
| merge_defthms (defthms1 as (false, _), (true, _)) = defthms1
@@ -173,7 +171,7 @@
(* specification data *)
datatype spec = Spec of {
- funcs: (bool * thm list Susp.T) Symtab.table,
+ funcs: (bool * (thm * bool) list Susp.T) Symtab.table,
dtyps: ((string * sort) list * (string * typ list) list) Symtab.table,
cases: (int * string list) Symtab.table * unit Symtab.table
};
@@ -479,7 +477,7 @@
val funcs = classparams
|> map_filter (fn c => try (AxClass.param_of_inst thy) (c, tyco))
|> map (Symtab.lookup ((the_funcs o the_exec) thy))
- |> (map o Option.map) (Susp.force o snd)
+ |> (map o Option.map) (map fst o Susp.force o snd)
|> maps these
|> map (Thm.transfer thy);
fun sorts_of [Type (_, tys)] = map (snd o dest_TVar) tys
@@ -600,9 +598,10 @@
| NONE => check_typ_fun (c, thm);
in check_typ (const_of_func thy thm, thm) end;
-val mk_func = Code_Unit.error_thm (assert_func_typ o Code_Unit.mk_func);
-val mk_liberal_func = Code_Unit.warning_thm (assert_func_typ o Code_Unit.mk_func);
-val mk_default_func = Code_Unit.try_thm (assert_func_typ o Code_Unit.mk_func);
+fun mk_func linear = Code_Unit.error_thm (assert_func_typ o Code_Unit.mk_func linear);
+val mk_liberal_func = Code_Unit.warning_thm (assert_func_typ o Code_Unit.mk_func true);
+val mk_syntactic_func = Code_Unit.warning_thm (assert_func_typ o Code_Unit.mk_func false);
+val mk_default_func = Code_Unit.try_thm (assert_func_typ o Code_Unit.mk_func true);
end; (*local*)
@@ -641,8 +640,8 @@
val is_undefined = Symtab.defined o snd o the_cases o the_exec;
-fun gen_add_func strict default thm thy =
- case (if strict then SOME o mk_func else mk_liberal_func) thm
+fun gen_add_func linear strict default thm thy =
+ case (if strict then SOME o mk_func linear else mk_liberal_func) thm
of SOME func =>
let
val c = const_of_func thy func;
@@ -656,15 +655,16 @@
else ();
in
(map_exec_purge (SOME [c]) o map_funcs) (Symtab.map_default
- (c, (true, Susp.value [])) (add_thm default func)) thy
+ (c, (true, Susp.value [])) (add_thm default (func, linear))) thy
end
| NONE => thy;
-val add_func = gen_add_func true false;
-val add_liberal_func = gen_add_func false false;
-val add_default_func = gen_add_func false true;
+val add_func = gen_add_func true true false;
+val add_liberal_func = gen_add_func true false false;
+val add_default_func = gen_add_func true false true;
+val add_nonlinear_func = gen_add_func false true false;
-fun del_func thm thy = case mk_liberal_func thm
+fun del_func thm thy = case mk_syntactic_func thm
of SOME func => let
val c = const_of_func thy func;
in map_exec_purge (SOME [c]) (map_funcs
@@ -762,6 +762,7 @@
in
TypeInterpretation.init
#> add_del_attribute ("func", (add_func, del_func))
+ #> add_simple_attribute ("nbe", add_nonlinear_func)
#> add_del_attribute ("inline", (add_inline, del_inline))
#> add_del_attribute ("post", (add_post, del_post))
end));
@@ -801,7 +802,7 @@
thms
|> apply_functrans thy
|> map (Code_Unit.rewrite_func pre)
- (*FIXME - must check gere: rewrite rule, defining equation, proper constant *)
+ (*FIXME - must check here: rewrite rule, defining equation, proper constant *)
|> map (AxClass.unoverload thy)
|> common_typ_funcs
end;
@@ -847,24 +848,24 @@
Symtab.lookup ((the_funcs o the_exec) thy) const
|> Option.map (Susp.force o snd)
|> these
- |> map (Thm.transfer thy);
+ |> (map o apfst) (Thm.transfer thy);
in
fun these_funcs thy const =
let
fun drop_refl thy = filter_out (is_equal o Term.fast_term_ord o Logic.dest_equals
- o ObjectLogic.drop_judgment thy o Thm.plain_prop_of);
+ o ObjectLogic.drop_judgment thy o Thm.plain_prop_of o fst);
in
get_funcs thy const
- |> preprocess thy
+ |> burrow_fst (preprocess thy)
|> drop_refl thy
end;
fun default_typ thy c = case default_typ_proto thy c
of SOME ty => Code_Unit.typscheme thy (c, ty)
| NONE => (case get_funcs thy c
- of thm :: _ => snd (Code_Unit.head_func (AxClass.unoverload thy thm))
+ of (thm, _) :: _ => snd (Code_Unit.head_func (AxClass.unoverload thy thm))
| [] => Code_Unit.typscheme thy (c, Sign.the_const_type thy c));
end; (*local*)
--- a/src/Tools/code/code_funcgr.ML Thu Sep 25 09:28:07 2008 +0200
+++ b/src/Tools/code/code_funcgr.ML Thu Sep 25 09:28:08 2008 +0200
@@ -9,7 +9,7 @@
signature CODE_FUNCGR =
sig
type T
- val funcs: T -> string -> thm list
+ val funcs: T -> string -> (thm * bool) list
val typ: T -> string -> (string * sort) list * typ
val all: T -> string list
val pretty: theory -> T -> Pretty.T
@@ -24,7 +24,7 @@
(** the graph type **)
-type T = (((string * sort) list * typ) * thm list) Graph.T;
+type T = (((string * sort) list * typ) * (thm * bool) list) Graph.T;
fun funcs funcgr =
these o Option.map snd o try (Graph.get_node funcgr);
@@ -41,7 +41,7 @@
|> map (fn (s, thms) =>
(Pretty.block o Pretty.fbreaks) (
Pretty.str s
- :: map Display.pretty_thm thms
+ :: map (Display.pretty_thm o fst) thms
))
|> Pretty.chunks;
@@ -57,7 +57,7 @@
| consts_of (const, thms as _ :: _) =
let
fun the_const (c, _) = if c = const then I else insert (op =) c
- in fold_consts the_const thms [] end;
+ in fold_consts the_const (map fst thms) [] end;
fun insts_of thy algebra tys sorts =
let
@@ -100,19 +100,19 @@
fun resort_funcss thy algebra funcgr =
let
val typ_funcgr = try (fst o Graph.get_node funcgr);
- val resort_dep = apsnd (resort_thms thy algebra typ_funcgr);
+ val resort_dep = (apsnd o burrow_fst) (resort_thms thy algebra typ_funcgr);
fun resort_rec typ_of (c, []) = (true, (c, []))
- | resort_rec typ_of (c, thms as thm :: _) = if is_some (AxClass.inst_of_param thy c)
+ | resort_rec typ_of (c, thms as (thm, _) :: _) = if is_some (AxClass.inst_of_param thy c)
then (true, (c, thms))
else let
val (_, (vs, ty)) = Code_Unit.head_func thm;
- val thms' as thm' :: _ = resort_thms thy algebra typ_of thms
+ val thms' as (thm', _) :: _ = burrow_fst (resort_thms thy algebra typ_of) thms
val (_, (vs', ty')) = Code_Unit.head_func thm'; (*FIXME simplify check*)
in (Sign.typ_equiv thy (ty, ty'), (c, thms')) end;
fun resort_recs funcss =
let
fun typ_of c = case these (AList.lookup (op =) funcss c)
- of thm :: _ => (SOME o snd o Code_Unit.head_func) thm
+ of (thm, _) :: _ => (SOME o snd o Code_Unit.head_func) thm
| [] => NONE;
val (unchangeds, funcss') = split_list (map (resort_rec typ_of) funcss);
val unchanged = fold (fn x => fn y => x andalso y) unchangeds true;
@@ -158,8 +158,8 @@
|> pair (SOME const)
else let
val thms = Code.these_funcs thy const
- |> Code_Unit.norm_args
- |> Code_Unit.norm_varnames Code_Name.purify_tvar Code_Name.purify_var;
+ |> burrow_fst Code_Unit.norm_args
+ |> burrow_fst (Code_Unit.norm_varnames Code_Name.purify_tvar Code_Name.purify_var);
val rhs = consts_of (const, thms);
in
auxgr
@@ -182,14 +182,14 @@
|> resort_funcss thy algebra funcgr
|> filter_out (can (Graph.get_node funcgr) o fst);
fun typ_func c [] = Code.default_typ thy c
- | typ_func c (thms as thm :: _) = (snd o Code_Unit.head_func) thm;
+ | typ_func c (thms as (thm, _) :: _) = (snd o Code_Unit.head_func) thm;
fun add_funcs (const, thms) =
Graph.new_node (const, (typ_func const thms, thms));
fun add_deps (funcs as (const, thms)) funcgr =
let
val deps = consts_of funcs;
val insts = instances_of_consts thy algebra funcgr
- (fold_consts (insert (op =)) thms []);
+ (fold_consts (insert (op =)) (map fst thms) []);
in
funcgr
|> ensure_consts thy algebra insts
@@ -278,16 +278,15 @@
(** diagnostic commands **)
-fun code_depgr thy [] = make thy []
- | code_depgr thy consts =
- let
- val gr = make thy consts;
- val select = Graph.all_succs gr consts;
- in
- gr
- |> Graph.subgraph (member (op =) select)
- |> Graph.map_nodes ((apsnd o map) (AxClass.overload thy))
- end;
+fun code_depgr thy consts =
+ let
+ val gr = make thy consts;
+ val select = Graph.all_succs gr consts;
+ in
+ gr
+ |> not (null consts) ? Graph.subgraph (member (op =) select)
+ |> Graph.map_nodes ((apsnd o map o apfst) (AxClass.overload thy))
+ end;
fun code_thms thy = Pretty.writeln o pretty thy o code_depgr thy;
--- a/src/Tools/code/code_haskell.ML Thu Sep 25 09:28:07 2008 +0200
+++ b/src/Tools/code/code_haskell.ML Thu Sep 25 09:28:08 2008 +0200
@@ -153,10 +153,11 @@
)
]
end
- | pr_stmt (name, Code_Thingol.Fun ((vs, ty), eqs)) =
+ | pr_stmt (name, Code_Thingol.Fun ((vs, ty), raw_eqs)) =
let
+ val eqs = filter (snd o snd) raw_eqs;
val tyvars = intro_vars (map fst vs) init_syms;
- fun pr_eq ((ts, t), thm) =
+ fun pr_eq ((ts, t), (thm, _)) =
let
val consts = map_filter
(fn c => if (is_some o syntax_const) c
@@ -248,7 +249,7 @@
| pr_stmt (_, Code_Thingol.Classinst ((class, (tyco, vs)), (_, classparam_insts))) =
let
val tyvars = intro_vars (map fst vs) init_syms;
- fun pr_instdef ((classparam, c_inst), thm) =
+ fun pr_instdef ((classparam, c_inst), (thm, _)) =
semicolon [
(str o classparam_name class) classparam,
str "=",
--- a/src/Tools/code/code_ml.ML Thu Sep 25 09:28:07 2008 +0200
+++ b/src/Tools/code/code_ml.ML Thu Sep 25 09:28:08 2008 +0200
@@ -27,12 +27,12 @@
val target_OCaml = "OCaml";
datatype ml_stmt =
- MLFuns of (string * (typscheme * ((iterm list * iterm) * thm) list)) list
+ MLFuns of (string * (typscheme * ((iterm list * iterm) * (thm * bool)) list)) list
| MLDatas of (string * ((vname * sort) list * (string * itype list) list)) list
| MLClass of string * (vname * ((class * string) list * (string * itype) list))
| MLClassinst of string * ((class * (string * (vname * sort) list))
* ((class * (string * (string * dict list list))) list
- * ((string * const) * thm) list));
+ * ((string * const) * (thm * bool)) list));
fun stmt_names_of (MLFuns fs) = map fst fs
| stmt_names_of (MLDatas ds) = map fst ds
@@ -192,7 +192,7 @@
val vs_dict = filter_out (null o snd) vs;
val shift = if null eqs' then I else
map (Pretty.block o single o Pretty.block o single);
- fun pr_eq definer ((ts, t), thm) =
+ fun pr_eq definer ((ts, t), (thm, _)) =
let
val consts = map_filter
(fn c => if (is_some o syntax_const) c
@@ -299,7 +299,7 @@
str "=",
pr_dicts NOBR [DictConst dss]
];
- fun pr_classparam ((classparam, c_inst), thm) =
+ fun pr_classparam ((classparam, c_inst), (thm, _)) =
concat [
(str o pr_label_classparam) classparam,
str "=",
@@ -453,7 +453,7 @@
in map (lookup_var vars') fished3 end;
fun pr_stmt (MLFuns (funns as funn :: funns')) =
let
- fun pr_eq ((ts, t), thm) =
+ fun pr_eq ((ts, t), (thm, _)) =
let
val consts = map_filter
(fn c => if (is_some o syntax_const) c
@@ -481,7 +481,7 @@
@@ str exc_str
)
end
- | pr_eqs _ _ [((ts, t), thm)] =
+ | pr_eqs _ _ [((ts, t), (thm, _))] =
let
val consts = map_filter
(fn c => if (is_some o syntax_const) c
@@ -611,7 +611,7 @@
str "=",
pr_dicts NOBR [DictConst dss]
];
- fun pr_classparam_inst ((classparam, c_inst), thm) =
+ fun pr_classparam_inst ((classparam, c_inst), (thm, _)) =
concat [
(str o deresolve) classparam,
str "=",
@@ -727,7 +727,7 @@
fold_map
(fn (name, Code_Thingol.Fun stmt) =>
map_nsp_fun_yield (mk_name_stmt false name) #>>
- rpair (name, stmt)
+ rpair (name, stmt |> apsnd (filter (snd o snd)))
| (name, _) =>
error ("Function block containing illegal statement: " ^ labelled_name name)
) stmts
@@ -895,7 +895,8 @@
val _ = if Code_Thingol.contains_dictvar t then
error "Term to be evaluated constains free dictionaries" else ();
val program' = program
- |> Graph.new_node (Code_Name.value_name, Code_Thingol.Fun (([], ty), [(([], t), Drule.dummy_thm)]))
+ |> Graph.new_node (Code_Name.value_name,
+ Code_Thingol.Fun (([], ty), [(([], t), (Drule.dummy_thm, true))]))
|> fold (curry Graph.add_edge Code_Name.value_name) deps;
val (value_code, [value_name']) = ml_code_of thy program' [Code_Name.value_name];
val sml_code = "let\n" ^ value_code ^ "\nin " ^ value_name'
--- a/src/Tools/code/code_thingol.ML Thu Sep 25 09:28:07 2008 +0200
+++ b/src/Tools/code/code_thingol.ML Thu Sep 25 09:28:08 2008 +0200
@@ -62,7 +62,7 @@
datatype stmt =
NoStmt
- | Fun of typscheme * ((iterm list * iterm) * thm) list
+ | Fun of typscheme * ((iterm list * iterm) * (thm * bool)) list
| Datatype of (vname * sort) list * (string * itype list) list
| Datatypecons of string
| Class of vname * ((class * string) list * (string * itype) list)
@@ -70,7 +70,7 @@
| Classparam of class
| Classinst of (class * (string * (vname * sort) list))
* ((class * (string * (string * dict list list))) list
- * ((string * const) * thm) list);
+ * ((string * const) * (thm * bool)) list);
type program = stmt Graph.T;
val empty_funs: program -> string list;
val map_terms_bottom_up: (iterm -> iterm) -> iterm -> iterm;
@@ -258,7 +258,7 @@
type typscheme = (vname * sort) list * itype;
datatype stmt =
NoStmt
- | Fun of typscheme * ((iterm list * iterm) * thm) list
+ | Fun of typscheme * ((iterm list * iterm) * (thm * bool)) list
| Datatype of (vname * sort) list * (string * itype list) list
| Datatypecons of string
| Class of vname * ((class * string) list * (string * itype) list)
@@ -266,7 +266,7 @@
| Classparam of class
| Classinst of (class * (string * (vname * sort) list))
* ((class * (string * (string * dict list list))) list
- * ((string * const) * thm) list);
+ * ((string * const) * (thm * bool)) list);
type program = stmt Graph.T;
@@ -423,14 +423,14 @@
fold_map (ensure_classrel thy algbr funcgr) classrels
#>> (fn classrels => DictVar (classrels, (unprefix "'" v, (k, length sort))))
in fold_map mk_dict typargs end
-and exprgen_eq thy algbr funcgr thm =
+and exprgen_eq thy algbr funcgr (thm, linear) =
let
val (args, rhs) = (apfst (snd o strip_comb) o Logic.dest_equals
o Logic.unvarify o prop_of) thm;
in
fold_map (exprgen_term thy algbr funcgr (SOME thm)) args
##>> exprgen_term thy algbr funcgr (SOME thm) rhs
- #>> rpair thm
+ #>> rpair (thm, linear)
end
and ensure_inst thy (algbr as (_, algebra)) funcgr (class, tyco) =
let
@@ -457,7 +457,7 @@
in
ensure_const thy algbr funcgr c
##>> exprgen_const thy algbr funcgr (SOME thm) c_ty
- #>> (fn (c, IConst c_inst) => ((c, c_inst), thm))
+ #>> (fn (c, IConst c_inst) => ((c, c_inst), (thm, true)))
end;
val stmt_inst =
ensure_class thy algbr funcgr class
@@ -485,7 +485,7 @@
val ty = Logic.unvarifyT raw_ty;
val thms = if (null o Term.typ_tfrees) ty orelse (null o fst o strip_type) ty
then raw_thms
- else map (Code_Unit.expand_eta 1) raw_thms;
+ else (map o apfst) (Code_Unit.expand_eta 1) raw_thms;
in
trns
|> fold_map (exprgen_tyvar_sort thy algbr funcgr) vs
@@ -642,7 +642,7 @@
fold_map (exprgen_tyvar_sort thy algbr funcgr) vs
##>> exprgen_typ thy algbr funcgr ty
##>> exprgen_term thy algbr funcgr NONE t
- #>> (fn ((vs, ty), t) => Fun ((vs, ty), [(([], t), Drule.dummy_thm)]));
+ #>> (fn ((vs, ty), t) => Fun ((vs, ty), [(([], t), (Drule.dummy_thm, true))]));
fun term_value (dep, program1) =
let
val Fun ((vs, ty), [(([], t), _)]) =
--- a/src/Tools/nbe.ML Thu Sep 25 09:28:07 2008 +0200
+++ b/src/Tools/nbe.ML Thu Sep 25 09:28:08 2008 +0200
@@ -15,10 +15,11 @@
| Free of string * Univ list (*free (uninterpreted) variables*)
| DFree of string * int (*free (uninterpreted) dictionary parameters*)
| BVar of int * Univ list
- | Abs of (int * (Univ list -> Univ)) * Univ list;
+ | Abs of (int * (Univ list -> Univ)) * Univ list
val apps: Univ -> Univ list -> Univ (*explicit applications*)
val abss: int -> (Univ list -> Univ) -> Univ
(*abstractions as closures*)
+ val same: Univ -> Univ -> bool
val univs_ref: (unit -> Univ list -> Univ list) option ref
val trace: bool ref
@@ -63,6 +64,13 @@
| Abs of (int * (Univ list -> Univ)) * Univ list
(*abstractions as closures*);
+fun same (Const (k, xs)) (Const (l, ys)) = k = l andalso sames xs ys
+ | same (Free (s, xs)) (Free (t, ys)) = s = t andalso sames xs ys
+ | same (DFree (s, k)) (DFree (t, l)) = s = t andalso k = l
+ | same (BVar (k, xs)) (BVar (l, ys)) = k = l andalso sames xs ys
+ | same _ _ = false
+and sames xs ys = length xs = length ys andalso forall (uncurry same) (xs ~~ ys);
+
(* constructor functions *)
fun abss n f = Abs ((n, f), []);
@@ -92,6 +100,11 @@
fun ml_Let d e = "let\n" ^ d ^ " in " ^ e ^ " end";
fun ml_as v t = "(" ^ v ^ " as " ^ t ^ ")";
+fun ml_and [] = "true"
+ | ml_and [x] = x
+ | ml_and xs = "(" ^ space_implode " andalso " xs ^ ")";
+fun ml_if b x y = "(if " ^ b ^ " then " ^ x ^ " else " ^ y ^ ")";
+
fun ml_list es = "[" ^ commas es ^ "]";
fun ml_fundefs ([(name, [([], e)])]) =
@@ -113,11 +126,12 @@
val univs_ref = ref (NONE : (unit -> Univ list -> Univ list) option);
local
- val prefix = "Nbe.";
- val name_ref = prefix ^ "univs_ref";
- val name_const = prefix ^ "Const";
- val name_abss = prefix ^ "abss";
- val name_apps = prefix ^ "apps";
+ val prefix = "Nbe.";
+ val name_ref = prefix ^ "univs_ref";
+ val name_const = prefix ^ "Const";
+ val name_abss = prefix ^ "abss";
+ val name_apps = prefix ^ "apps";
+ val name_same = prefix ^ "same";
in
val univs_cookie = (name_ref, univs_ref);
@@ -141,6 +155,8 @@
fun nbe_abss 0 f = f `$` ml_list []
| nbe_abss n f = name_abss `$$` [string_of_int n, f];
+fun nbe_same v1 v2 = "(" ^ name_same ^ " " ^ nbe_bound v1 ^ " " ^ nbe_bound v2 ^ ")";
+
end;
open Basic_Code_Thingol;
@@ -173,34 +189,62 @@
| assemble_idict (DictVar (supers, (v, (n, _)))) =
fold_rev (fn super => assemble_constapp super [] o single) supers (nbe_dict v n);
- fun assemble_iterm match_cont constapp =
+ fun assemble_iterm constapp =
let
- fun of_iterm t =
+ fun of_iterm match_cont t =
let
val (t', ts) = Code_Thingol.unfold_app t
- in of_iapp t' (fold_rev (cons o of_iterm) ts []) end
- and of_iapp (IConst (c, (dss, _))) ts = constapp c dss ts
- | of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts
- | of_iapp ((v, _) `|-> t) ts =
- nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts
- | of_iapp (ICase (((t, _), cs), t0)) ts =
- nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs
- @ [("_", case match_cont of SOME s => s | NONE => of_iterm t0)])) ts
+ in of_iapp match_cont t' (fold_rev (cons o of_iterm NONE) ts []) end
+ and of_iapp match_cont (IConst (c, (dss, _))) ts = constapp c dss ts
+ | of_iapp match_cont (IVar v) ts = nbe_apps (nbe_bound v) ts
+ | of_iapp match_cont ((v, _) `|-> t) ts =
+ nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm NONE t))) ts
+ | of_iapp match_cont (ICase (((t, _), cs), t0)) ts =
+ nbe_apps (ml_cases (of_iterm NONE t)
+ (map (fn (p, t) => (of_iterm NONE p, of_iterm match_cont t)) cs
+ @ [("_", case match_cont of SOME s => s | NONE => of_iterm NONE t0)])) ts
in of_iterm end;
+ fun subst_nonlin_vars args =
+ let
+ val vs = (fold o Code_Thingol.fold_varnames)
+ (fn v => AList.map_default (op =) (v, 0) (curry (op +) 1)) args [];
+ val names = Name.make_context (map fst vs);
+ fun declare v k ctxt = let val vs = Name.invents ctxt v k
+ in (vs, fold Name.declare vs ctxt) end;
+ val (vs_renames, _) = fold_map (fn (v, k) => if k > 1
+ then declare v (k - 1) #>> (fn vs => (v, vs))
+ else pair (v, [])) vs names;
+ val samepairs = maps (fn (v, vs) => map (pair v) vs) vs_renames;
+ fun subst_vars (t as IConst _) samepairs = (t, samepairs)
+ | subst_vars (t as IVar v) samepairs = (case AList.lookup (op =) samepairs v
+ of SOME v' => (IVar v', AList.delete (op =) v samepairs)
+ | NONE => (t, samepairs))
+ | subst_vars (t1 `$ t2) samepairs = samepairs
+ |> subst_vars t1
+ ||>> subst_vars t2
+ |>> (op `$)
+ | subst_vars (ICase (_, t)) samepairs = subst_vars t samepairs;
+ val (args', _) = fold_map subst_vars args samepairs;
+ in (samepairs, args') end;
+
fun assemble_eqn c dicts default_args (i, (args, rhs)) =
let
val is_eval = (c = "");
val default_rhs = nbe_apps_local (i+1) c (dicts @ default_args);
val match_cont = if is_eval then NONE else SOME default_rhs;
- val assemble_arg = assemble_iterm NONE
- (fn c => fn _ => fn ts => nbe_apps_constr idx_of c ts);
- val assemble_rhs = assemble_iterm match_cont assemble_constapp;
+ val assemble_arg = assemble_iterm
+ (fn c => fn _ => fn ts => nbe_apps_constr idx_of c ts) NONE;
+ val assemble_rhs = assemble_iterm assemble_constapp match_cont ;
+ val (samepairs, args') = subst_nonlin_vars args;
+ val s_args = map assemble_arg args';
+ val s_rhs = if null samepairs then assemble_rhs rhs
+ else ml_if (ml_and (map (uncurry nbe_same) samepairs))
+ (assemble_rhs rhs) default_rhs;
val eqns = if is_eval then
- [([ml_list (rev (dicts @ map assemble_arg args))], assemble_rhs rhs)]
+ [([ml_list (rev (dicts @ s_args))], s_rhs)]
else
- [([ml_list (rev (dicts @ map2 ml_as default_args
- (map assemble_arg args)))], assemble_rhs rhs),
+ [([ml_list (rev (dicts @ map2 ml_as default_args s_args))], s_rhs),
([ml_list (rev (dicts @ default_args))], default_rhs)]
in (nbe_fun i c, eqns) end;