using integers for pattern matching
authorhaftmann
Wed, 13 Feb 2008 09:35:33 +0100
changeset 26064 65585de05a66
parent 26063 b2862698dc79
child 26065 d80a49f51b94
using integers for pattern matching
src/Tools/nbe.ML
--- a/src/Tools/nbe.ML	Wed Feb 13 09:35:32 2008 +0100
+++ b/src/Tools/nbe.ML	Wed Feb 13 09:35:33 2008 +0100
@@ -11,7 +11,7 @@
   val norm_term: theory -> term -> term
 
   datatype Univ =
-      Const of string * Univ list            (*named (uninterpreted) constants*)
+      Const of int * Univ list               (*named (uninterpreted) constants*)
     | Free of string * Univ list             (*free (uninterpreted) variables*)
     | DFree of string * int                  (*free (uninterpreted) dictionary parameters*)
     | BVar of int * Univ list
@@ -56,7 +56,7 @@
 *)
 
 datatype Univ =
-    Const of string * Univ list        (*named (uninterpreted) constants*)
+    Const of int * Univ list           (*named (uninterpreted) constants*)
   | Free of string * Univ list         (*free variables*)
   | DFree of string * int              (*free (uninterpreted) dictionary parameters*)
   | BVar of int * Univ list            (*bound named variables*)
@@ -76,10 +76,6 @@
   | apps (Free (name, xs)) ys = Free (name, ys @ xs)
   | apps (BVar (name, xs)) ys = BVar (name, ys @ xs);
 
-(* universe graph *)
-
-type univ_gr = Univ option Graph.T;
-
 
 (** assembling and compiling ML code from terms **)
 
@@ -134,8 +130,8 @@
 fun nbe_apps t [] = t
   | nbe_apps t ts = name_apps `$$` [t, ml_list (rev ts)];
 fun nbe_apps_local c ts = nbe_fun c `$` ml_list (rev ts);
-fun nbe_apps_constr c ts =
-  name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list (rev ts) ^ ")");
+fun nbe_apps_constr idx ts =
+  name_const `$` ("(" ^ string_of_int idx ^ ", " ^ ml_list (rev ts) ^ ")");
 
 fun nbe_abss 0 f = f `$` ml_list []
   | nbe_abss n f = name_abss `$$` [string_of_int n, f];
@@ -146,7 +142,7 @@
 
 (* code generation *)
 
-fun assemble_eqnss deps eqnss =
+fun assemble_eqnss idx_of deps eqnss =
   let
     fun prep_eqns (c, (vs, eqns)) =
       let
@@ -165,7 +161,7 @@
             end else nbe_apps (nbe_abss n (nbe_fun c)) ts'
         | NONE => if member (op =) deps c
             then nbe_apps (nbe_fun c) ts'
-            else nbe_apps_constr c ts'
+            else nbe_apps_constr (idx_of c) ts'
       end
     and assemble_idict (DictConst (inst, dss)) =
           assemble_constapp inst dss []
@@ -189,14 +185,17 @@
 
     fun assemble_eqns (c, (num_args, (dicts, eqns))) =
       let
-        val assemble_arg = assemble_iterm (fn c => fn _ => fn ts => nbe_apps_constr c ts);
+        val assemble_arg = assemble_iterm
+          (fn c => fn _ => fn ts => nbe_apps_constr (idx_of c) ts);
         val assemble_rhs = assemble_iterm assemble_constapp;
         fun assemble_eqn (args, rhs) =
           ([ml_list (rev (dicts @ map assemble_arg args))], assemble_rhs rhs);
         val default_args = map nbe_bound (Name.invent_list [] "a" num_args);
-        val default_eqn = ([ml_list (rev default_args)], nbe_apps_constr c default_args);
+        val default_eqn = if c = "" then NONE
+          else SOME ([ml_list (rev default_args)],
+            nbe_apps_constr (idx_of c) default_args);
       in
-        ((nbe_fun c, map assemble_eqn eqns @ [default_eqn]),
+        ((nbe_fun c, map assemble_eqn eqns @ the_list default_eqn),
           nbe_abss num_args (nbe_fun c))
       end;
 
@@ -209,12 +208,14 @@
 fun compile_eqnss gr raw_deps [] = []
   | compile_eqnss gr raw_deps eqnss = 
       let
+        val (deps, deps_vals) = split_list (map_filter
+          (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node gr dep)))) raw_deps);
+        val idx_of = raw_deps
+          |> map (fn dep => (dep, snd (Graph.get_node gr dep)))
+          |> AList.lookup (op =)
+          |> (fn f => the o f);
+        val s = assemble_eqnss idx_of deps eqnss;
         val cs = map fst eqnss;
-        val (deps, deps_vals) = raw_deps
-          |> map_filter (fn dep => case Graph.get_node gr dep of NONE => NONE
-              | SOME univ => SOME (dep, univ))
-          |> split_list;
-        val s = assemble_eqnss deps eqnss;
       in
         s
         |> tracing (fn s => "\n--- code to be evaluated:\n" ^ s)
@@ -258,25 +259,30 @@
     val names = map (fst o fst) stmts_deps;
     val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps;
     val eqnss = maps (eqns_of_stmt o fst) stmts_deps;
-    val deps = names_deps
+    val refl_deps = names_deps
       |> maps snd
       |> distinct (op =)
-      |> subtract (op =) names;
+      |> fold (insert (op =)) names;
+    fun new_node name (gr, (maxidx, idx_tab)) = if can (Graph.get_node gr) name
+      then (gr, (maxidx, idx_tab))
+      else (Graph.new_node (name, (NONE, maxidx)) gr,
+        (maxidx + 1, Inttab.update_new (maxidx, name) idx_tab));
     fun compile gr = eqnss
-      |> compile_eqnss gr deps
+      |> compile_eqnss gr refl_deps
       |> rpair gr;
   in
-    fold (fn name => Graph.default_node (name, NONE)) names
-    #> fold (fn name => Graph.default_node (name, NONE)) deps
-    #> fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps
-    #> compile
-    #-> fold (fn (name, univ) => Graph.map_node name (K (SOME univ)))
+    fold new_node refl_deps
+    #> apfst (fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps
+      #> compile
+      #-> fold (fn (name, univ) => (Graph.map_node name o apfst) (K (SOME univ))))
   end;
 
 fun ensure_stmts code =
   let
-    fun add_stmts names gr = if exists ((can o Graph.get_node) gr) names then gr else gr
-      |> compile_stmts (map (fn name => ((name, Graph.get_node code name),
+    fun add_stmts names (gr, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) gr) names
+      then (gr, (maxidx, idx_tab))
+      else (gr, (maxidx, idx_tab))
+        |> compile_stmts (map (fn name => ((name, Graph.get_node code name),
           Graph.imm_succs code name)) names);
   in fold_rev add_stmts (Graph.strong_conn code) end;
 
@@ -299,23 +305,27 @@
 
 (* reification *)
 
-fun term_of_univ thy t =
+fun term_of_univ thy idx_tab t =
   let
     fun take_until f [] = []
       | take_until f (x::xs) = if f x then [] else x :: take_until f xs;
-    fun is_dict (Const (c, _)) =
-          (is_some o CodeName.class_rev thy) c
-          orelse (is_some o CodeName.classrel_rev thy) c
-          orelse (is_some o CodeName.instance_rev thy) c
+    fun is_dict (Const (idx, _)) =
+          let
+            val c = the (Inttab.lookup idx_tab idx);
+          in
+            (is_some o CodeName.class_rev thy) c
+            orelse (is_some o CodeName.classrel_rev thy) c
+            orelse (is_some o CodeName.instance_rev thy) c
+          end
       | is_dict (DFree _) = true
       | is_dict _ = false;
     fun of_apps bounds (t, ts) =
       fold_map (of_univ bounds) ts
       #>> (fn ts' => list_comb (t, rev ts'))
-    and of_univ bounds (Const (name, ts)) typidx =
+    and of_univ bounds (Const (idx, ts)) typidx =
           let
             val ts' = take_until is_dict ts;
-            val SOME c = CodeName.const_rev thy name;
+            val c = (the o CodeName.const_rev thy o the o Inttab.lookup idx_tab) idx;
             val T = Code.default_typ thy c;
             val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T;
             val typidx' = typidx + maxidx_of_typ T' + 1;
@@ -334,12 +344,14 @@
 
 structure Nbe_Functions = CodeDataFun
 (
-  type T = univ_gr;
-  val empty = Graph.empty;
-  fun merge _ = Graph.merge (K true);
-  fun purge _ NONE _ = Graph.empty
-    | purge NONE _ _ = Graph.empty
-    | purge (SOME thy) (SOME cs) gr =
+  type T = (Univ option * int) Graph.T * (int * string Inttab.table);
+  val empty = (Graph.empty, (0, Inttab.empty));
+  fun merge _ ((gr1, (maxidx1, idx_tab1)), (gr2, (maxidx2, idx_tab2))) =
+    (Graph.merge (K true) (gr1, gr2), (IntInf.max (maxidx1, maxidx2),
+      Inttab.merge (K true) (idx_tab1, idx_tab2)));
+  fun purge _ NONE _ = empty
+    | purge NONE _ _ = empty
+    | purge (SOME thy) (SOME cs) (gr, (maxidx, idx_tab)) =
         let
           val cs_exisiting =
             map_filter (CodeName.const_rev thy) (Graph.keys gr);
@@ -347,15 +359,19 @@
               o map (CodeName.const thy)
               o filter (member (op =) cs_exisiting)
             ) cs;
-        in Graph.del_nodes dels gr end;
+        in (Graph.del_nodes dels gr, (maxidx, idx_tab)) end;
 );
 
 (* compilation, evaluation and reification *)
 
 fun compile_eval thy code vs_ty_t deps =
-  vs_ty_t
-  |> eval_term (Nbe_Functions.change thy (ensure_stmts code)) deps
-  |> term_of_univ thy;
+  let
+    val (gr, (_, idx_tab)) = Nbe_Functions.change thy (ensure_stmts code);
+  in
+    vs_ty_t
+    |> eval_term gr deps
+    |> term_of_univ thy idx_tab
+  end;
 
 (* evaluation with type reconstruction *)