src/Tools/nbe.ML
changeset 28350 715163ec93c0
parent 28337 93964076e7b8
child 28423 9fc3befd8191
--- a/src/Tools/nbe.ML	Thu Sep 25 09:28:07 2008 +0200
+++ b/src/Tools/nbe.ML	Thu Sep 25 09:28:08 2008 +0200
@@ -15,10 +15,11 @@
     | Free of string * Univ list             (*free (uninterpreted) variables*)
     | DFree of string * int                  (*free (uninterpreted) dictionary parameters*)
     | BVar of int * Univ list
-    | Abs of (int * (Univ list -> Univ)) * Univ list;
+    | Abs of (int * (Univ list -> Univ)) * Univ list
   val apps: Univ -> Univ list -> Univ        (*explicit applications*)
   val abss: int -> (Univ list -> Univ) -> Univ
                                              (*abstractions as closures*)
+  val same: Univ -> Univ -> bool
 
   val univs_ref: (unit -> Univ list -> Univ list) option ref
   val trace: bool ref
@@ -63,6 +64,13 @@
   | Abs of (int * (Univ list -> Univ)) * Univ list
                                        (*abstractions as closures*);
 
+fun same (Const (k, xs)) (Const (l, ys)) = k = l andalso sames xs ys
+  | same (Free (s, xs)) (Free (t, ys)) = s = t andalso sames xs ys
+  | same (DFree (s, k)) (DFree (t, l)) = s = t andalso k = l
+  | same (BVar (k, xs)) (BVar (l, ys)) = k = l andalso sames xs ys
+  | same _ _ = false
+and sames xs ys = length xs = length ys andalso forall (uncurry same) (xs ~~ ys);
+
 (* constructor functions *)
 
 fun abss n f = Abs ((n, f), []);
@@ -92,6 +100,11 @@
 fun ml_Let d e = "let\n" ^ d ^ " in " ^ e ^ " end";
 fun ml_as v t = "(" ^ v ^ " as " ^ t ^ ")";
 
+fun ml_and [] = "true"
+  | ml_and [x] = x
+  | ml_and xs = "(" ^ space_implode " andalso " xs ^ ")";
+fun ml_if b x y = "(if " ^ b ^ " then " ^ x ^ " else " ^ y ^ ")";
+
 fun ml_list es = "[" ^ commas es ^ "]";
 
 fun ml_fundefs ([(name, [([], e)])]) =
@@ -113,11 +126,12 @@
 val univs_ref = ref (NONE : (unit -> Univ list -> Univ list) option);
 
 local
-  val prefix =          "Nbe.";
-  val name_ref =        prefix ^ "univs_ref";
-  val name_const =      prefix ^ "Const";
-  val name_abss =       prefix ^ "abss";
-  val name_apps =       prefix ^ "apps";
+  val prefix =      "Nbe.";
+  val name_ref =    prefix ^ "univs_ref";
+  val name_const =  prefix ^ "Const";
+  val name_abss =   prefix ^ "abss";
+  val name_apps =   prefix ^ "apps";
+  val name_same =   prefix ^ "same";
 in
 
 val univs_cookie = (name_ref, univs_ref);
@@ -141,6 +155,8 @@
 fun nbe_abss 0 f = f `$` ml_list []
   | nbe_abss n f = name_abss `$$` [string_of_int n, f];
 
+fun nbe_same v1 v2 = "(" ^ name_same ^ " " ^ nbe_bound v1 ^ " " ^ nbe_bound v2 ^ ")";
+
 end;
 
 open Basic_Code_Thingol;
@@ -173,34 +189,62 @@
       | assemble_idict (DictVar (supers, (v, (n, _)))) =
           fold_rev (fn super => assemble_constapp super [] o single) supers (nbe_dict v n);
 
-    fun assemble_iterm match_cont constapp =
+    fun assemble_iterm constapp =
       let
-        fun of_iterm t =
+        fun of_iterm match_cont t =
           let
             val (t', ts) = Code_Thingol.unfold_app t
-          in of_iapp t' (fold_rev (cons o of_iterm) ts []) end
-        and of_iapp (IConst (c, (dss, _))) ts = constapp c dss ts
-          | of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts
-          | of_iapp ((v, _) `|-> t) ts =
-              nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts
-          | of_iapp (ICase (((t, _), cs), t0)) ts =
-              nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs
-                @ [("_", case match_cont of SOME s => s | NONE => of_iterm t0)])) ts
+          in of_iapp match_cont t' (fold_rev (cons o of_iterm NONE) ts []) end
+        and of_iapp match_cont (IConst (c, (dss, _))) ts = constapp c dss ts
+          | of_iapp match_cont (IVar v) ts = nbe_apps (nbe_bound v) ts
+          | of_iapp match_cont ((v, _) `|-> t) ts =
+              nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm NONE t))) ts
+          | of_iapp match_cont (ICase (((t, _), cs), t0)) ts =
+              nbe_apps (ml_cases (of_iterm NONE t)
+                (map (fn (p, t) => (of_iterm NONE p, of_iterm match_cont t)) cs
+                  @ [("_", case match_cont of SOME s => s | NONE => of_iterm NONE t0)])) ts
       in of_iterm end;
 
+    fun subst_nonlin_vars args =
+      let
+        val vs = (fold o Code_Thingol.fold_varnames)
+          (fn v => AList.map_default (op =) (v, 0) (curry (op +) 1)) args [];
+        val names = Name.make_context (map fst vs);
+        fun declare v k ctxt = let val vs = Name.invents ctxt v k
+          in (vs, fold Name.declare vs ctxt) end;
+        val (vs_renames, _) = fold_map (fn (v, k) => if k > 1
+          then declare v (k - 1) #>> (fn vs => (v, vs))
+          else pair (v, [])) vs names;
+        val samepairs = maps (fn (v, vs) => map (pair v) vs) vs_renames;
+        fun subst_vars (t as IConst _) samepairs = (t, samepairs)
+          | subst_vars (t as IVar v) samepairs = (case AList.lookup (op =) samepairs v
+             of SOME v' => (IVar v', AList.delete (op =) v samepairs)
+              | NONE => (t, samepairs))
+          | subst_vars (t1 `$ t2) samepairs = samepairs
+              |> subst_vars t1
+              ||>> subst_vars t2
+              |>> (op `$)
+          | subst_vars (ICase (_, t)) samepairs = subst_vars t samepairs;
+        val (args', _) = fold_map subst_vars args samepairs;
+      in (samepairs, args') end;
+
     fun assemble_eqn c dicts default_args (i, (args, rhs)) =
       let
         val is_eval = (c = "");
         val default_rhs = nbe_apps_local (i+1) c (dicts @ default_args);
         val match_cont = if is_eval then NONE else SOME default_rhs;
-        val assemble_arg = assemble_iterm NONE
-          (fn c => fn _ => fn ts => nbe_apps_constr idx_of c ts);
-        val assemble_rhs = assemble_iterm match_cont assemble_constapp;
+        val assemble_arg = assemble_iterm
+          (fn c => fn _ => fn ts => nbe_apps_constr idx_of c ts) NONE;
+        val assemble_rhs = assemble_iterm assemble_constapp match_cont ;
+        val (samepairs, args') = subst_nonlin_vars args;
+        val s_args = map assemble_arg args';
+        val s_rhs = if null samepairs then assemble_rhs rhs
+          else ml_if (ml_and (map (uncurry nbe_same) samepairs))
+            (assemble_rhs rhs) default_rhs;
         val eqns = if is_eval then
-            [([ml_list (rev (dicts @ map assemble_arg args))], assemble_rhs rhs)]
+            [([ml_list (rev (dicts @ s_args))], s_rhs)]
           else
-            [([ml_list (rev (dicts @ map2 ml_as default_args
-                (map assemble_arg args)))], assemble_rhs rhs),
+            [([ml_list (rev (dicts @ map2 ml_as default_args s_args))], s_rhs),
               ([ml_list (rev (dicts @ default_args))], default_rhs)]
       in (nbe_fun i c, eqns) end;