non left-linear equations for nbe
authorhaftmann
Thu, 25 Sep 2008 09:28:08 +0200
changeset 28350 715163ec93c0
parent 28349 46a0dc9b51bb
child 28351 abfc66969d1f
non left-linear equations for nbe
NEWS
src/HOL/Tools/datatype_codegen.ML
src/HOL/Tools/typecopy_package.ML
src/HOL/ex/NormalForm.thy
src/Pure/Isar/code.ML
src/Tools/code/code_funcgr.ML
src/Tools/code/code_haskell.ML
src/Tools/code/code_ml.ML
src/Tools/code/code_thingol.ML
src/Tools/nbe.ML
--- a/NEWS	Thu Sep 25 09:28:07 2008 +0200
+++ b/NEWS	Thu Sep 25 09:28:08 2008 +0200
@@ -66,7 +66,10 @@
 
 *** HOL ***
 
-* HOL/Main: command "value" now integrates different evaluation
+* Normalization by evaluation now allows non-leftlinear equations.
+Declare with attribute [code nbe].
+
+* Command "value" now integrates different evaluation
 mechanisms.  The result of the first successful evaluation mechanism
 is printed.  In square brackets a particular named evaluation
 mechanisms may be specified (currently, [SML], [code] or [nbe]).  See
--- a/src/HOL/Tools/datatype_codegen.ML	Thu Sep 25 09:28:07 2008 +0200
+++ b/src/HOL/Tools/datatype_codegen.ML	Thu Sep 25 09:28:08 2008 +0200
@@ -433,9 +433,9 @@
   let
     val vs' = (map o apsnd)
       (curry (Sorts.inter_sort (Sign.classes_of thy)) [HOLogic.class_eq]) vs;
-    fun add_def tyco lthy =
+    fun add_def dtco lthy =
       let
-        val ty = Type (tyco, map TFree vs');
+        val ty = Type (dtco, map TFree vs');
         fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT)
           $ Free ("x", ty) $ Free ("y", ty);
         val def = HOLogic.mk_Trueprop (HOLogic.mk_eq
@@ -454,9 +454,15 @@
       |> map (MetaSimplifier.rewrite_rule [Thm.transfer thy @{thm equals_eq}]);
     fun add_eq_thms dtco thy =
       let
+        val ty = Type (dtco, map TFree vs');
         val thy_ref = Theory.check_thy thy;
         val const = AxClass.param_of_inst thy (@{const_name eq_class.eq}, dtco);
-        val get_thms = (fn () => get_eq' (Theory.deref thy_ref) dtco |> rev);
+        val eq_refl = @{thm HOL.eq_refl}
+          |> Thm.instantiate
+              ([pairself (Thm.ctyp_of thy) (TVar (("'a", 0), @{sort eq}), Logic.varifyT ty)], [])
+          |> Simpdata.mk_eq;
+        fun get_thms () = (eq_refl, false)
+          :: rev (map (rpair true) (get_eq' (Theory.deref thy_ref) dtco));
       in
         Code.add_funcl (const, Susp.delay get_thms) thy
       end;
--- a/src/HOL/Tools/typecopy_package.ML	Thu Sep 25 09:28:07 2008 +0200
+++ b/src/HOL/Tools/typecopy_package.ML	Thu Sep 25 09:28:08 2008 +0200
@@ -124,10 +124,10 @@
   let
     val SOME { constr, proj_def, inject, vs, ... } = get_info thy tyco;
     val vs' = (map o apsnd) (curry (Sorts.inter_sort (Sign.classes_of thy)) [HOLogic.class_eq]) vs;
-    val ty = Logic.unvarifyT (Sign.the_const_type thy constr);
+    val ty = Type (tyco, map TFree vs');
+    val ty_constr = Logic.unvarifyT (Sign.the_const_type thy constr);
     fun add_def tyco lthy =
       let
-        val ty = Type (tyco, map TFree vs');
         fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT)
           $ Free ("x", ty) $ Free ("y", ty);
         val def = HOLogic.mk_Trueprop (HOLogic.mk_eq
@@ -140,16 +140,23 @@
       in (thm', lthy') end;
     fun tac thms = Class.intro_classes_tac []
       THEN ALLGOALS (ProofContext.fact_tac thms);
-    fun add_eq_thm thy = 
+    fun add_eq_thms thy = 
       let
         val eq = inject
           |> Code_Unit.constrain_thm [HOLogic.class_eq]
           |> Simpdata.mk_eq
           |> MetaSimplifier.rewrite_rule [Thm.transfer thy @{thm equals_eq}];
-      in Code.add_func eq thy end;
+        val eq_refl = @{thm HOL.eq_refl}
+          |> Thm.instantiate
+              ([pairself (Thm.ctyp_of thy) (TVar (("'a", 0), @{sort eq}), Logic.varifyT ty)], []);
+      in
+        thy
+        |> Code.add_func eq
+        |> Code.add_nonlinear_func eq_refl
+      end;
   in
     thy
-    |> Code.add_datatype [(constr, ty)]
+    |> Code.add_datatype [(constr, ty_constr)]
     |> Code.add_func proj_def
     |> TheoryTarget.instantiation ([tyco], vs', [HOLogic.class_eq])
     |> add_def tyco
@@ -157,7 +164,7 @@
     #> LocalTheory.exit
     #> ProofContext.theory_of
     #> Code.del_func thm
-    #> add_eq_thm)
+    #> add_eq_thms)
   end;
 
 val setup =
--- a/src/HOL/ex/NormalForm.thy	Thu Sep 25 09:28:07 2008 +0200
+++ b/src/HOL/ex/NormalForm.thy	Thu Sep 25 09:28:08 2008 +0200
@@ -8,17 +8,30 @@
 imports Main "~~/src/HOL/Real/Rational"
 begin
 
+lemma [code nbe]:
+  "x = x \<longleftrightarrow> True" by rule+
+
+lemma [code nbe]:
+  "eq_class.eq (x::bool) x \<longleftrightarrow> True" unfolding eq by rule+
+
+lemma [code nbe]:
+  "eq_class.eq (x::nat) x \<longleftrightarrow> True" unfolding eq by rule+
+
 lemma "True" by normalization
 lemma "p \<longrightarrow> True" by normalization
-declare disj_assoc [code func]
-lemma "((P | Q) | R) = (P | (Q | R))" by normalization rule
+declare disj_assoc [code nbe]
+lemma "((P | Q) | R) = (P | (Q | R))" by normalization
 declare disj_assoc [code func del]
-lemma "0 + (n::nat) = n" by normalization rule
-lemma "0 + Suc n = Suc n" by normalization rule
-lemma "Suc n + Suc m = n + Suc (Suc m)" by normalization rule
+lemma "0 + (n::nat) = n" by normalization
+lemma "0 + Suc n = Suc n" by normalization
+lemma "Suc n + Suc m = n + Suc (Suc m)" by normalization
 lemma "~((0::nat) < (0::nat))" by normalization
 
 datatype n = Z | S n
+
+lemma [code nbe]:
+  "eq_class.eq (x::n) x \<longleftrightarrow> True" unfolding eq by rule+
+
 consts
   add :: "n \<Rightarrow> n \<Rightarrow> n"
   add2 :: "n \<Rightarrow> n \<Rightarrow> n"
@@ -40,9 +53,9 @@
 lemma [code]: "add2 n Z = n"
   by(induct n) auto
 
-lemma "add2 (add2 n m) k = add2 n (add2 m k)" by normalization rule
-lemma "add2 (add2 (S n) (S m)) (S k) = S(S(S(add2 n (add2 m k))))" by normalization rule
-lemma "add2 (add2 (S n) (add2 (S m) Z)) (S k) = S(S(S(add2 n (add2 m k))))" by normalization rule
+lemma "add2 (add2 n m) k = add2 n (add2 m k)" by normalization
+lemma "add2 (add2 (S n) (S m)) (S k) = S(S(S(add2 n (add2 m k))))" by normalization
+lemma "add2 (add2 (S n) (add2 (S m) Z)) (S k) = S(S(S(add2 n (add2 m k))))" by normalization
 
 primrec
   "mul Z = (%n. Z)"
@@ -59,18 +72,22 @@
 lemma "exp (S(S Z)) (S(S(S(S Z)))) = exp (S(S(S(S Z)))) (S(S Z))" by normalization
 
 lemma "(let ((x,y),(u,v)) = ((Z,Z),(Z,Z)) in add (add x y) (add u v)) = Z" by normalization
-lemma "split (%x y. x) (a, b) = a" by normalization rule
+lemma "split (%x y. x) (a, b) = a" by normalization
 lemma "(%((x,y),(u,v)). add (add x y) (add u v)) ((Z,Z),(Z,Z)) = Z" by normalization
 
 lemma "case Z of Z \<Rightarrow> True | S x \<Rightarrow> False" by normalization
 
 lemma "[] @ [] = []" by normalization
-lemma "map f [x,y,z::'x] = [f x, f y, f z]" by normalization rule+
-lemma "[a, b, c] @ xs = a # b # c # xs" by normalization rule+
-lemma "[] @ xs = xs" by normalization rule
-lemma "map (%f. f True) [id, g, Not] = [True, g True, False]" by normalization rule+
+lemma "map f [x,y,z::'x] = [f x, f y, f z]" by normalization
+lemma "[a, b, c] @ xs = a # b # c # xs" by normalization
+lemma "[] @ xs = xs" by normalization
+lemma "map (%f. f True) [id, g, Not] = [True, g True, False]" by normalization
+
+lemma [code nbe]:
+  "eq_class.eq (x :: 'a\<Colon>eq list) x \<longleftrightarrow> True" unfolding eq by rule+
+
 lemma "map (%f. f True) ([id, g, Not] @ fs) = [True, g True, False] @ map (%f. f True) fs" by normalization rule+
-lemma "rev [a, b, c] = [c, b, a]" by normalization rule+
+lemma "rev [a, b, c] = [c, b, a]" by normalization
 normal_form "rev (a#b#cs) = rev cs @ [b, a]"
 normal_form "map (%F. F [a,b,c::'x]) (map map [f,g,h])"
 normal_form "map (%F. F ([a,b,c] @ ds)) (map map ([f,g,h]@fs))"
@@ -79,21 +96,24 @@
   by normalization
 normal_form "case xs of [] \<Rightarrow> True | x#xs \<Rightarrow> False"
 normal_form "map (%x. case x of None \<Rightarrow> False | Some y \<Rightarrow> True) xs = P"
-lemma "let x = y in [x, x] = [y, y]" by normalization rule+
-lemma "Let y (%x. [x,x]) = [y, y]" by normalization rule+
+lemma "let x = y in [x, x] = [y, y]" by normalization
+lemma "Let y (%x. [x,x]) = [y, y]" by normalization
 normal_form "case n of Z \<Rightarrow> True | S x \<Rightarrow> False"
-lemma "(%(x,y). add x y) (S z,S z) = S (add z (S z))" by normalization rule+
+lemma "(%(x,y). add x y) (S z,S z) = S (add z (S z))" by normalization
 normal_form "filter (%x. x) ([True,False,x]@xs)"
 normal_form "filter Not ([True,False,x]@xs)"
 
-lemma "[x,y,z] @ [a,b,c] = [x, y, z, a, b, c]" by normalization rule+
-lemma "(%(xs, ys). xs @ ys) ([a, b, c], [d, e, f]) = [a, b, c, d, e, f]" by normalization rule+
+lemma "[x,y,z] @ [a,b,c] = [x, y, z, a, b, c]" by normalization
+lemma "(%(xs, ys). xs @ ys) ([a, b, c], [d, e, f]) = [a, b, c, d, e, f]" by normalization
 lemma "map (%x. case x of None \<Rightarrow> False | Some y \<Rightarrow> True) [None, Some ()] = [False, True]" by normalization
 
-lemma "last [a, b, c] = c" by normalization rule
-lemma "last ([a, b, c] @ xs) = last (c # xs)" by normalization rule
+lemma "last [a, b, c] = c" by normalization
+lemma "last ([a, b, c] @ xs) = last (c # xs)" by normalization
 
-lemma "(2::int) + 3 - 1 + (- k) * 2 = 4 + - k * 2" by normalization rule
+lemma [code nbe]:
+  "eq_class.eq (x :: int) x \<longleftrightarrow> True" unfolding eq by rule+
+
+lemma "(2::int) + 3 - 1 + (- k) * 2 = 4 + - k * 2" by normalization
 lemma "(-4::int) * 2 = -8" by normalization
 lemma "abs ((-4::int) + 2 * 1) = 2" by normalization
 lemma "(2::int) + 3 = 5" by normalization
@@ -111,10 +131,10 @@
 lemma "(42::rat) / 1704 = 1 / 284 + 3 / 142" by normalization
 normal_form "Suc 0 \<in> set ms"
 
-lemma "f = f" by normalization rule+
-lemma "f x = f x" by normalization rule+
-lemma "(f o g) x = f (g x)" by normalization rule+
-lemma "(f o id) x = f x" by normalization rule+
+lemma "f = f" by normalization
+lemma "f x = f x" by normalization
+lemma "(f o g) x = f (g x)" by normalization
+lemma "(f o id) x = f x" by normalization
 normal_form "(\<lambda>x. x)"
 
 (* Church numerals: *)
--- a/src/Pure/Isar/code.ML	Thu Sep 25 09:28:07 2008 +0200
+++ b/src/Pure/Isar/code.ML	Thu Sep 25 09:28:08 2008 +0200
@@ -9,12 +9,13 @@
 signature CODE =
 sig
   val add_func: thm -> theory -> theory
+  val add_nonlinear_func: thm -> theory -> theory
   val add_liberal_func: thm -> theory -> theory
   val add_default_func: thm -> theory -> theory
   val add_default_func_attr: Attrib.src
   val del_func: thm -> theory -> theory
   val del_funcs: string -> theory -> theory
-  val add_funcl: string * thm list Susp.T -> theory -> theory
+  val add_funcl: string * (thm * bool) list Susp.T -> theory -> theory
   val map_pre: (MetaSimplifier.simpset -> MetaSimplifier.simpset) -> theory -> theory
   val map_post: (MetaSimplifier.simpset -> MetaSimplifier.simpset) -> theory -> theory
   val add_inline: thm -> theory -> theory
@@ -34,7 +35,7 @@
 
   val coregular_algebra: theory -> Sorts.algebra
   val operational_algebra: theory -> (sort -> sort) * Sorts.algebra
-  val these_funcs: theory -> string -> thm list
+  val these_funcs: theory -> string -> (thm * bool) list
   val get_datatype: theory -> string -> ((string * sort) list * (string * typ list) list)
   val get_datatype_of_constr: theory -> string -> string option
   val get_case_data: theory -> string -> (int * string list) option
@@ -115,41 +116,38 @@
 
 (** logical and syntactical specification of executable code **)
 
-(* defining equations with default flag and lazy theorems *)
+(* defining equations with linear flag, default flag and lazy theorems *)
 
 fun pretty_lthms ctxt r = case Susp.peek r
- of SOME thms => map (ProofContext.pretty_thm ctxt) thms
+ of SOME thms => map (ProofContext.pretty_thm ctxt o fst) thms
   | NONE => [Pretty.str "[...]"];
 
 fun certificate thy f r =
   case Susp.peek r
-   of SOME thms => (Susp.value o f thy) thms
+   of SOME thms => (Susp.value o burrow_fst (f thy)) thms
     | NONE => let
         val thy_ref = Theory.check_thy thy;
-      in Susp.delay (fn () => (f (Theory.deref thy_ref) o Susp.force) r) end;
+      in Susp.delay (fn () => (burrow_fst (f (Theory.deref thy_ref)) o Susp.force) r) end;
 
-fun add_drop_redundant verbose thm thms =
+fun add_drop_redundant (thm, linear) thms =
   let
-    fun warn thm' = (if verbose
-      then warning ("Code generator: dropping redundant defining equation\n" ^ Display.string_of_thm thm')
-      else (); true);
     val thy = Thm.theory_of_thm thm;
     val args_of = snd o strip_comb o fst o Logic.dest_equals o Thm.plain_prop_of;
     val args = args_of thm;
-    fun matches [] _ = true
-      | matches (Var _ :: xs) [] = matches xs []
-      | matches (_ :: _) [] = false
-      | matches (x :: xs) (y :: ys) = Pattern.matches thy (x, y) andalso matches xs ys;
-    fun drop thm' = matches args (args_of thm') andalso warn thm';
-  in thm :: filter_out drop thms end;
+    fun matches_args args' = length args <= length args' andalso
+      Pattern.matchess thy (args, curry Library.take (length args) args');
+    fun drop (thm', _) = if matches_args (args_of thm') then 
+      (warning ("Code generator: dropping redundant defining equation\n" ^ Display.string_of_thm thm'); true)
+      else false;
+  in (thm, linear) :: filter_out drop thms end;
 
-fun add_thm _ thm (false, thms) = (false, Susp.value (add_drop_redundant true thm (Susp.force thms)))
-  | add_thm true thm (true, thms) = (true, Susp.value (Susp.force thms @ [thm]))
+fun add_thm _ thm (false, thms) = (false, Susp.map_force (add_drop_redundant thm) thms)
+  | add_thm true thm (true, thms) = (true, Susp.map_force (fn thms => thms @ [thm]) thms)
   | add_thm false thm (true, thms) = (false, Susp.value [thm]);
 
 fun add_lthms lthms _ = (false, lthms);
 
-fun del_thm thm = apsnd (Susp.value o remove Thm.eq_thm_prop thm o Susp.force);
+fun del_thm thm = (apsnd o Susp.map_force) (remove (eq_fst Thm.eq_thm_prop) (thm, true));
 
 fun merge_defthms ((true, _), defthms2) = defthms2
   | merge_defthms (defthms1 as (false, _), (true, _)) = defthms1
@@ -173,7 +171,7 @@
 (* specification data *)
 
 datatype spec = Spec of {
-  funcs: (bool * thm list Susp.T) Symtab.table,
+  funcs: (bool * (thm * bool) list Susp.T) Symtab.table,
   dtyps: ((string * sort) list * (string * typ list) list) Symtab.table,
   cases: (int * string list) Symtab.table * unit Symtab.table
 };
@@ -479,7 +477,7 @@
     val funcs = classparams
       |> map_filter (fn c => try (AxClass.param_of_inst thy) (c, tyco))
       |> map (Symtab.lookup ((the_funcs o the_exec) thy))
-      |> (map o Option.map) (Susp.force o snd)
+      |> (map o Option.map) (map fst o Susp.force o snd)
       |> maps these
       |> map (Thm.transfer thy);
     fun sorts_of [Type (_, tys)] = map (snd o dest_TVar) tys
@@ -600,9 +598,10 @@
         | NONE => check_typ_fun (c, thm);
   in check_typ (const_of_func thy thm, thm) end;
 
-val mk_func = Code_Unit.error_thm (assert_func_typ o Code_Unit.mk_func);
-val mk_liberal_func = Code_Unit.warning_thm (assert_func_typ o Code_Unit.mk_func);
-val mk_default_func = Code_Unit.try_thm (assert_func_typ o Code_Unit.mk_func);
+fun mk_func linear = Code_Unit.error_thm (assert_func_typ o Code_Unit.mk_func linear);
+val mk_liberal_func = Code_Unit.warning_thm (assert_func_typ o Code_Unit.mk_func true);
+val mk_syntactic_func = Code_Unit.warning_thm (assert_func_typ o Code_Unit.mk_func false);
+val mk_default_func = Code_Unit.try_thm (assert_func_typ o Code_Unit.mk_func true);
 
 end; (*local*)
 
@@ -641,8 +640,8 @@
 
 val is_undefined = Symtab.defined o snd o the_cases o the_exec;
 
-fun gen_add_func strict default thm thy =
-  case (if strict then SOME o mk_func else mk_liberal_func) thm
+fun gen_add_func linear strict default thm thy =
+  case (if strict then SOME o mk_func linear else mk_liberal_func) thm
    of SOME func =>
         let
           val c = const_of_func thy func;
@@ -656,15 +655,16 @@
             else ();
         in
           (map_exec_purge (SOME [c]) o map_funcs) (Symtab.map_default
-            (c, (true, Susp.value [])) (add_thm default func)) thy
+            (c, (true, Susp.value [])) (add_thm default (func, linear))) thy
         end
     | NONE => thy;
 
-val add_func = gen_add_func true false;
-val add_liberal_func = gen_add_func false false;
-val add_default_func = gen_add_func false true;
+val add_func = gen_add_func true true false;
+val add_liberal_func = gen_add_func true false false;
+val add_default_func = gen_add_func true false true;
+val add_nonlinear_func = gen_add_func false true false;
 
-fun del_func thm thy = case mk_liberal_func thm
+fun del_func thm thy = case mk_syntactic_func thm
  of SOME func => let
         val c = const_of_func thy func;
       in map_exec_purge (SOME [c]) (map_funcs
@@ -762,6 +762,7 @@
   in
     TypeInterpretation.init
     #> add_del_attribute ("func", (add_func, del_func))
+    #> add_simple_attribute ("nbe", add_nonlinear_func)
     #> add_del_attribute ("inline", (add_inline, del_inline))
     #> add_del_attribute ("post", (add_post, del_post))
   end));
@@ -801,7 +802,7 @@
     thms
     |> apply_functrans thy
     |> map (Code_Unit.rewrite_func pre)
-    (*FIXME - must check gere: rewrite rule, defining equation, proper constant *)
+    (*FIXME - must check here: rewrite rule, defining equation, proper constant *)
     |> map (AxClass.unoverload thy)
     |> common_typ_funcs
   end;
@@ -847,24 +848,24 @@
   Symtab.lookup ((the_funcs o the_exec) thy) const
   |> Option.map (Susp.force o snd)
   |> these
-  |> map (Thm.transfer thy);
+  |> (map o apfst) (Thm.transfer thy);
 
 in
 
 fun these_funcs thy const =
   let
     fun drop_refl thy = filter_out (is_equal o Term.fast_term_ord o Logic.dest_equals
-      o ObjectLogic.drop_judgment thy o Thm.plain_prop_of);
+      o ObjectLogic.drop_judgment thy o Thm.plain_prop_of o fst);
   in
     get_funcs thy const
-    |> preprocess thy
+    |> burrow_fst (preprocess thy)
     |> drop_refl thy
   end;
 
 fun default_typ thy c = case default_typ_proto thy c
  of SOME ty => Code_Unit.typscheme thy (c, ty)
   | NONE => (case get_funcs thy c
-     of thm :: _ => snd (Code_Unit.head_func (AxClass.unoverload thy thm))
+     of (thm, _) :: _ => snd (Code_Unit.head_func (AxClass.unoverload thy thm))
       | [] => Code_Unit.typscheme thy (c, Sign.the_const_type thy c));
 
 end; (*local*)
--- a/src/Tools/code/code_funcgr.ML	Thu Sep 25 09:28:07 2008 +0200
+++ b/src/Tools/code/code_funcgr.ML	Thu Sep 25 09:28:08 2008 +0200
@@ -9,7 +9,7 @@
 signature CODE_FUNCGR =
 sig
   type T
-  val funcs: T -> string -> thm list
+  val funcs: T -> string -> (thm * bool) list
   val typ: T -> string -> (string * sort) list * typ
   val all: T -> string list
   val pretty: theory -> T -> Pretty.T
@@ -24,7 +24,7 @@
 
 (** the graph type **)
 
-type T = (((string * sort) list * typ) * thm list) Graph.T;
+type T = (((string * sort) list * typ) * (thm * bool) list) Graph.T;
 
 fun funcs funcgr =
   these o Option.map snd o try (Graph.get_node funcgr);
@@ -41,7 +41,7 @@
   |> map (fn (s, thms) =>
        (Pretty.block o Pretty.fbreaks) (
          Pretty.str s
-         :: map Display.pretty_thm thms
+         :: map (Display.pretty_thm o fst) thms
        ))
   |> Pretty.chunks;
 
@@ -57,7 +57,7 @@
   | consts_of (const, thms as _ :: _) = 
       let
         fun the_const (c, _) = if c = const then I else insert (op =) c
-      in fold_consts the_const thms [] end;
+      in fold_consts the_const (map fst thms) [] end;
 
 fun insts_of thy algebra tys sorts =
   let
@@ -100,19 +100,19 @@
 fun resort_funcss thy algebra funcgr =
   let
     val typ_funcgr = try (fst o Graph.get_node funcgr);
-    val resort_dep = apsnd (resort_thms thy algebra typ_funcgr);
+    val resort_dep = (apsnd o burrow_fst) (resort_thms thy algebra typ_funcgr);
     fun resort_rec typ_of (c, []) = (true, (c, []))
-      | resort_rec typ_of (c, thms as thm :: _) = if is_some (AxClass.inst_of_param thy c)
+      | resort_rec typ_of (c, thms as (thm, _) :: _) = if is_some (AxClass.inst_of_param thy c)
           then (true, (c, thms))
           else let
             val (_, (vs, ty)) = Code_Unit.head_func thm;
-            val thms' as thm' :: _ = resort_thms thy algebra typ_of thms
+            val thms' as (thm', _) :: _ = burrow_fst (resort_thms thy algebra typ_of) thms
             val (_, (vs', ty')) = Code_Unit.head_func thm'; (*FIXME simplify check*)
           in (Sign.typ_equiv thy (ty, ty'), (c, thms')) end;
     fun resort_recs funcss =
       let
         fun typ_of c = case these (AList.lookup (op =) funcss c)
-         of thm :: _ => (SOME o snd o Code_Unit.head_func) thm
+         of (thm, _) :: _ => (SOME o snd o Code_Unit.head_func) thm
           | [] => NONE;
         val (unchangeds, funcss') = split_list (map (resort_rec typ_of) funcss);
         val unchanged = fold (fn x => fn y => x andalso y) unchangeds true;
@@ -158,8 +158,8 @@
     |> pair (SOME const)
   else let
     val thms = Code.these_funcs thy const
-      |> Code_Unit.norm_args
-      |> Code_Unit.norm_varnames Code_Name.purify_tvar Code_Name.purify_var;
+      |> burrow_fst Code_Unit.norm_args
+      |> burrow_fst (Code_Unit.norm_varnames Code_Name.purify_tvar Code_Name.purify_var);
     val rhs = consts_of (const, thms);
   in
     auxgr
@@ -182,14 +182,14 @@
       |> resort_funcss thy algebra funcgr
       |> filter_out (can (Graph.get_node funcgr) o fst);
     fun typ_func c [] = Code.default_typ thy c
-      | typ_func c (thms as thm :: _) = (snd o Code_Unit.head_func) thm;
+      | typ_func c (thms as (thm, _) :: _) = (snd o Code_Unit.head_func) thm;
     fun add_funcs (const, thms) =
       Graph.new_node (const, (typ_func const thms, thms));
     fun add_deps (funcs as (const, thms)) funcgr =
       let
         val deps = consts_of funcs;
         val insts = instances_of_consts thy algebra funcgr
-          (fold_consts (insert (op =)) thms []);
+          (fold_consts (insert (op =)) (map fst thms) []);
       in
         funcgr
         |> ensure_consts thy algebra insts
@@ -278,16 +278,15 @@
 
 (** diagnostic commands **)
 
-fun code_depgr thy [] = make thy []
-  | code_depgr thy consts =
-      let
-        val gr = make thy consts;
-        val select = Graph.all_succs gr consts;
-      in
-        gr
-        |> Graph.subgraph (member (op =) select) 
-        |> Graph.map_nodes ((apsnd o map) (AxClass.overload thy))
-      end;
+fun code_depgr thy consts =
+  let
+    val gr = make thy consts;
+    val select = Graph.all_succs gr consts;
+  in
+    gr
+    |> not (null consts) ? Graph.subgraph (member (op =) select) 
+    |> Graph.map_nodes ((apsnd o map o apfst) (AxClass.overload thy))
+  end;
 
 fun code_thms thy = Pretty.writeln o pretty thy o code_depgr thy;
 
--- a/src/Tools/code/code_haskell.ML	Thu Sep 25 09:28:07 2008 +0200
+++ b/src/Tools/code/code_haskell.ML	Thu Sep 25 09:28:08 2008 +0200
@@ -153,10 +153,11 @@
               )
             ]
           end
-      | pr_stmt (name, Code_Thingol.Fun ((vs, ty), eqs)) =
+      | pr_stmt (name, Code_Thingol.Fun ((vs, ty), raw_eqs)) =
           let
+            val eqs = filter (snd o snd) raw_eqs;
             val tyvars = intro_vars (map fst vs) init_syms;
-            fun pr_eq ((ts, t), thm) =
+            fun pr_eq ((ts, t), (thm, _)) =
               let
                 val consts = map_filter
                   (fn c => if (is_some o syntax_const) c
@@ -248,7 +249,7 @@
       | pr_stmt (_, Code_Thingol.Classinst ((class, (tyco, vs)), (_, classparam_insts))) =
           let
             val tyvars = intro_vars (map fst vs) init_syms;
-            fun pr_instdef ((classparam, c_inst), thm) =
+            fun pr_instdef ((classparam, c_inst), (thm, _)) =
               semicolon [
                 (str o classparam_name class) classparam,
                 str "=",
--- a/src/Tools/code/code_ml.ML	Thu Sep 25 09:28:07 2008 +0200
+++ b/src/Tools/code/code_ml.ML	Thu Sep 25 09:28:08 2008 +0200
@@ -27,12 +27,12 @@
 val target_OCaml = "OCaml";
 
 datatype ml_stmt =
-    MLFuns of (string * (typscheme * ((iterm list * iterm) * thm) list)) list
+    MLFuns of (string * (typscheme * ((iterm list * iterm) * (thm * bool)) list)) list
   | MLDatas of (string * ((vname * sort) list * (string * itype list) list)) list
   | MLClass of string * (vname * ((class * string) list * (string * itype) list))
   | MLClassinst of string * ((class * (string * (vname * sort) list))
         * ((class * (string * (string * dict list list))) list
-      * ((string * const) * thm) list));
+      * ((string * const) * (thm * bool)) list));
 
 fun stmt_names_of (MLFuns fs) = map fst fs
   | stmt_names_of (MLDatas ds) = map fst ds
@@ -192,7 +192,7 @@
                     val vs_dict = filter_out (null o snd) vs;
                     val shift = if null eqs' then I else
                       map (Pretty.block o single o Pretty.block o single);
-                    fun pr_eq definer ((ts, t), thm) =
+                    fun pr_eq definer ((ts, t), (thm, _)) =
                       let
                         val consts = map_filter
                           (fn c => if (is_some o syntax_const) c
@@ -299,7 +299,7 @@
                 str "=",
                 pr_dicts NOBR [DictConst dss]
               ];
-            fun pr_classparam ((classparam, c_inst), thm) =
+            fun pr_classparam ((classparam, c_inst), (thm, _)) =
               concat [
                 (str o pr_label_classparam) classparam,
                 str "=",
@@ -453,7 +453,7 @@
       in map (lookup_var vars') fished3 end;
     fun pr_stmt (MLFuns (funns as funn :: funns')) =
           let
-            fun pr_eq ((ts, t), thm) =
+            fun pr_eq ((ts, t), (thm, _)) =
               let
                 val consts = map_filter
                   (fn c => if (is_some o syntax_const) c
@@ -481,7 +481,7 @@
                       @@ str exc_str
                     )
                   end
-              | pr_eqs _ _ [((ts, t), thm)] =
+              | pr_eqs _ _ [((ts, t), (thm, _))] =
                   let
                     val consts = map_filter
                       (fn c => if (is_some o syntax_const) c
@@ -611,7 +611,7 @@
                 str "=",
                 pr_dicts NOBR [DictConst dss]
               ];
-            fun pr_classparam_inst ((classparam, c_inst), thm) =
+            fun pr_classparam_inst ((classparam, c_inst), (thm, _)) =
               concat [
                 (str o deresolve) classparam,
                 str "=",
@@ -727,7 +727,7 @@
       fold_map
         (fn (name, Code_Thingol.Fun stmt) =>
               map_nsp_fun_yield (mk_name_stmt false name) #>>
-                rpair (name, stmt)
+                rpair (name, stmt |> apsnd (filter (snd o snd)))
           | (name, _) =>
               error ("Function block containing illegal statement: " ^ labelled_name name)
         ) stmts
@@ -895,7 +895,8 @@
         val _ = if Code_Thingol.contains_dictvar t then
           error "Term to be evaluated constains free dictionaries" else ();
         val program' = program
-          |> Graph.new_node (Code_Name.value_name, Code_Thingol.Fun (([], ty), [(([], t), Drule.dummy_thm)]))
+          |> Graph.new_node (Code_Name.value_name,
+              Code_Thingol.Fun (([], ty), [(([], t), (Drule.dummy_thm, true))]))
           |> fold (curry Graph.add_edge Code_Name.value_name) deps;
         val (value_code, [value_name']) = ml_code_of thy program' [Code_Name.value_name];
         val sml_code = "let\n" ^ value_code ^ "\nin " ^ value_name'
--- a/src/Tools/code/code_thingol.ML	Thu Sep 25 09:28:07 2008 +0200
+++ b/src/Tools/code/code_thingol.ML	Thu Sep 25 09:28:08 2008 +0200
@@ -62,7 +62,7 @@
 
   datatype stmt =
       NoStmt
-    | Fun of typscheme * ((iterm list * iterm) * thm) list
+    | Fun of typscheme * ((iterm list * iterm) * (thm * bool)) list
     | Datatype of (vname * sort) list * (string * itype list) list
     | Datatypecons of string
     | Class of vname * ((class * string) list * (string * itype) list)
@@ -70,7 +70,7 @@
     | Classparam of class
     | Classinst of (class * (string * (vname * sort) list))
           * ((class * (string * (string * dict list list))) list
-        * ((string * const) * thm) list);
+        * ((string * const) * (thm * bool)) list);
   type program = stmt Graph.T;
   val empty_funs: program -> string list;
   val map_terms_bottom_up: (iterm -> iterm) -> iterm -> iterm;
@@ -258,7 +258,7 @@
 type typscheme = (vname * sort) list * itype;
 datatype stmt =
     NoStmt
-  | Fun of typscheme * ((iterm list * iterm) * thm) list
+  | Fun of typscheme * ((iterm list * iterm) * (thm * bool)) list
   | Datatype of (vname * sort) list * (string * itype list) list
   | Datatypecons of string
   | Class of vname * ((class * string) list * (string * itype) list)
@@ -266,7 +266,7 @@
   | Classparam of class
   | Classinst of (class * (string * (vname * sort) list))
         * ((class * (string * (string * dict list list))) list
-      * ((string * const) * thm) list);
+      * ((string * const) * (thm * bool)) list);
 
 type program = stmt Graph.T;
 
@@ -423,14 +423,14 @@
           fold_map (ensure_classrel thy algbr funcgr) classrels
           #>> (fn classrels => DictVar (classrels, (unprefix "'" v, (k, length sort))))
   in fold_map mk_dict typargs end
-and exprgen_eq thy algbr funcgr thm =
+and exprgen_eq thy algbr funcgr (thm, linear) =
   let
     val (args, rhs) = (apfst (snd o strip_comb) o Logic.dest_equals
       o Logic.unvarify o prop_of) thm;
   in
     fold_map (exprgen_term thy algbr funcgr (SOME thm)) args
     ##>> exprgen_term thy algbr funcgr (SOME thm) rhs
-    #>> rpair thm
+    #>> rpair (thm, linear)
   end
 and ensure_inst thy (algbr as (_, algebra)) funcgr (class, tyco) =
   let
@@ -457,7 +457,7 @@
       in
         ensure_const thy algbr funcgr c
         ##>> exprgen_const thy algbr funcgr (SOME thm) c_ty
-        #>> (fn (c, IConst c_inst) => ((c, c_inst), thm))
+        #>> (fn (c, IConst c_inst) => ((c, c_inst), (thm, true)))
       end;
     val stmt_inst =
       ensure_class thy algbr funcgr class
@@ -485,7 +485,7 @@
         val ty = Logic.unvarifyT raw_ty;
         val thms = if (null o Term.typ_tfrees) ty orelse (null o fst o strip_type) ty
           then raw_thms
-          else map (Code_Unit.expand_eta 1) raw_thms;
+          else (map o apfst) (Code_Unit.expand_eta 1) raw_thms;
       in
         trns
         |> fold_map (exprgen_tyvar_sort thy algbr funcgr) vs
@@ -642,7 +642,7 @@
       fold_map (exprgen_tyvar_sort thy algbr funcgr) vs
       ##>> exprgen_typ thy algbr funcgr ty
       ##>> exprgen_term thy algbr funcgr NONE t
-      #>> (fn ((vs, ty), t) => Fun ((vs, ty), [(([], t), Drule.dummy_thm)]));
+      #>> (fn ((vs, ty), t) => Fun ((vs, ty), [(([], t), (Drule.dummy_thm, true))]));
     fun term_value (dep, program1) =
       let
         val Fun ((vs, ty), [(([], t), _)]) =
--- a/src/Tools/nbe.ML	Thu Sep 25 09:28:07 2008 +0200
+++ b/src/Tools/nbe.ML	Thu Sep 25 09:28:08 2008 +0200
@@ -15,10 +15,11 @@
     | Free of string * Univ list             (*free (uninterpreted) variables*)
     | DFree of string * int                  (*free (uninterpreted) dictionary parameters*)
     | BVar of int * Univ list
-    | Abs of (int * (Univ list -> Univ)) * Univ list;
+    | Abs of (int * (Univ list -> Univ)) * Univ list
   val apps: Univ -> Univ list -> Univ        (*explicit applications*)
   val abss: int -> (Univ list -> Univ) -> Univ
                                              (*abstractions as closures*)
+  val same: Univ -> Univ -> bool
 
   val univs_ref: (unit -> Univ list -> Univ list) option ref
   val trace: bool ref
@@ -63,6 +64,13 @@
   | Abs of (int * (Univ list -> Univ)) * Univ list
                                        (*abstractions as closures*);
 
+fun same (Const (k, xs)) (Const (l, ys)) = k = l andalso sames xs ys
+  | same (Free (s, xs)) (Free (t, ys)) = s = t andalso sames xs ys
+  | same (DFree (s, k)) (DFree (t, l)) = s = t andalso k = l
+  | same (BVar (k, xs)) (BVar (l, ys)) = k = l andalso sames xs ys
+  | same _ _ = false
+and sames xs ys = length xs = length ys andalso forall (uncurry same) (xs ~~ ys);
+
 (* constructor functions *)
 
 fun abss n f = Abs ((n, f), []);
@@ -92,6 +100,11 @@
 fun ml_Let d e = "let\n" ^ d ^ " in " ^ e ^ " end";
 fun ml_as v t = "(" ^ v ^ " as " ^ t ^ ")";
 
+fun ml_and [] = "true"
+  | ml_and [x] = x
+  | ml_and xs = "(" ^ space_implode " andalso " xs ^ ")";
+fun ml_if b x y = "(if " ^ b ^ " then " ^ x ^ " else " ^ y ^ ")";
+
 fun ml_list es = "[" ^ commas es ^ "]";
 
 fun ml_fundefs ([(name, [([], e)])]) =
@@ -113,11 +126,12 @@
 val univs_ref = ref (NONE : (unit -> Univ list -> Univ list) option);
 
 local
-  val prefix =          "Nbe.";
-  val name_ref =        prefix ^ "univs_ref";
-  val name_const =      prefix ^ "Const";
-  val name_abss =       prefix ^ "abss";
-  val name_apps =       prefix ^ "apps";
+  val prefix =      "Nbe.";
+  val name_ref =    prefix ^ "univs_ref";
+  val name_const =  prefix ^ "Const";
+  val name_abss =   prefix ^ "abss";
+  val name_apps =   prefix ^ "apps";
+  val name_same =   prefix ^ "same";
 in
 
 val univs_cookie = (name_ref, univs_ref);
@@ -141,6 +155,8 @@
 fun nbe_abss 0 f = f `$` ml_list []
   | nbe_abss n f = name_abss `$$` [string_of_int n, f];
 
+fun nbe_same v1 v2 = "(" ^ name_same ^ " " ^ nbe_bound v1 ^ " " ^ nbe_bound v2 ^ ")";
+
 end;
 
 open Basic_Code_Thingol;
@@ -173,34 +189,62 @@
       | assemble_idict (DictVar (supers, (v, (n, _)))) =
           fold_rev (fn super => assemble_constapp super [] o single) supers (nbe_dict v n);
 
-    fun assemble_iterm match_cont constapp =
+    fun assemble_iterm constapp =
       let
-        fun of_iterm t =
+        fun of_iterm match_cont t =
           let
             val (t', ts) = Code_Thingol.unfold_app t
-          in of_iapp t' (fold_rev (cons o of_iterm) ts []) end
-        and of_iapp (IConst (c, (dss, _))) ts = constapp c dss ts
-          | of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts
-          | of_iapp ((v, _) `|-> t) ts =
-              nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts
-          | of_iapp (ICase (((t, _), cs), t0)) ts =
-              nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs
-                @ [("_", case match_cont of SOME s => s | NONE => of_iterm t0)])) ts
+          in of_iapp match_cont t' (fold_rev (cons o of_iterm NONE) ts []) end
+        and of_iapp match_cont (IConst (c, (dss, _))) ts = constapp c dss ts
+          | of_iapp match_cont (IVar v) ts = nbe_apps (nbe_bound v) ts
+          | of_iapp match_cont ((v, _) `|-> t) ts =
+              nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm NONE t))) ts
+          | of_iapp match_cont (ICase (((t, _), cs), t0)) ts =
+              nbe_apps (ml_cases (of_iterm NONE t)
+                (map (fn (p, t) => (of_iterm NONE p, of_iterm match_cont t)) cs
+                  @ [("_", case match_cont of SOME s => s | NONE => of_iterm NONE t0)])) ts
       in of_iterm end;
 
+    fun subst_nonlin_vars args =
+      let
+        val vs = (fold o Code_Thingol.fold_varnames)
+          (fn v => AList.map_default (op =) (v, 0) (curry (op +) 1)) args [];
+        val names = Name.make_context (map fst vs);
+        fun declare v k ctxt = let val vs = Name.invents ctxt v k
+          in (vs, fold Name.declare vs ctxt) end;
+        val (vs_renames, _) = fold_map (fn (v, k) => if k > 1
+          then declare v (k - 1) #>> (fn vs => (v, vs))
+          else pair (v, [])) vs names;
+        val samepairs = maps (fn (v, vs) => map (pair v) vs) vs_renames;
+        fun subst_vars (t as IConst _) samepairs = (t, samepairs)
+          | subst_vars (t as IVar v) samepairs = (case AList.lookup (op =) samepairs v
+             of SOME v' => (IVar v', AList.delete (op =) v samepairs)
+              | NONE => (t, samepairs))
+          | subst_vars (t1 `$ t2) samepairs = samepairs
+              |> subst_vars t1
+              ||>> subst_vars t2
+              |>> (op `$)
+          | subst_vars (ICase (_, t)) samepairs = subst_vars t samepairs;
+        val (args', _) = fold_map subst_vars args samepairs;
+      in (samepairs, args') end;
+
     fun assemble_eqn c dicts default_args (i, (args, rhs)) =
       let
         val is_eval = (c = "");
         val default_rhs = nbe_apps_local (i+1) c (dicts @ default_args);
         val match_cont = if is_eval then NONE else SOME default_rhs;
-        val assemble_arg = assemble_iterm NONE
-          (fn c => fn _ => fn ts => nbe_apps_constr idx_of c ts);
-        val assemble_rhs = assemble_iterm match_cont assemble_constapp;
+        val assemble_arg = assemble_iterm
+          (fn c => fn _ => fn ts => nbe_apps_constr idx_of c ts) NONE;
+        val assemble_rhs = assemble_iterm assemble_constapp match_cont ;
+        val (samepairs, args') = subst_nonlin_vars args;
+        val s_args = map assemble_arg args';
+        val s_rhs = if null samepairs then assemble_rhs rhs
+          else ml_if (ml_and (map (uncurry nbe_same) samepairs))
+            (assemble_rhs rhs) default_rhs;
         val eqns = if is_eval then
-            [([ml_list (rev (dicts @ map assemble_arg args))], assemble_rhs rhs)]
+            [([ml_list (rev (dicts @ s_args))], s_rhs)]
           else
-            [([ml_list (rev (dicts @ map2 ml_as default_args
-                (map assemble_arg args)))], assemble_rhs rhs),
+            [([ml_list (rev (dicts @ map2 ml_as default_args s_args))], s_rhs),
               ([ml_list (rev (dicts @ default_args))], default_rhs)]
       in (nbe_fun i c, eqns) end;