src/Tools/nbe.ML
changeset 24347 245ff8661b8c
parent 24292 26ac9fe0e80e
child 24381 560e8ecdf633
--- a/src/Tools/nbe.ML	Mon Aug 20 18:07:30 2007 +0200
+++ b/src/Tools/nbe.ML	Mon Aug 20 18:07:31 2007 +0200
@@ -187,19 +187,19 @@
     fun of_iterm t =
       let
         val (t', ts) = CodeThingol.unfold_app t
-      in of_itermapp t' (fold (cons o of_iterm) ts []) end
-    and of_itermapp (IConst (c, (dss, _))) ts =
-          (case num_args c
-           of SOME n => if n <= length ts
-                then let val (args2, args1) = chop (length ts - n) ts
-                in nbe_apps (nbe_fun c `$` ml_list args1) args2
-                end else nbe_const c ts
-            | NONE => if is_fun c then nbe_apps (nbe_fun c) ts
-                else nbe_const c ts)
-      | of_itermapp (IVar v) ts = nbe_apps (nbe_bound v) ts
-      | of_itermapp ((v, _) `|-> t) ts =
+      in of_iapp t' (fold (cons o of_iterm) ts []) end
+    and of_iconst c ts = case num_args c
+     of SOME n => if n <= length ts
+          then let val (args2, args1) = chop (length ts - n) ts
+          in nbe_apps (nbe_fun c `$` ml_list args1) args2
+          end else nbe_const c ts
+      | NONE => if is_fun c then nbe_apps (nbe_fun c) ts
+          else nbe_const c ts
+    and of_iapp (IConst (c, (dss, _))) ts = of_iconst c ts
+      | of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts
+      | of_iapp ((v, _) `|-> t) ts =
           nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts
-      | of_itermapp (ICase (((t, _), cs), t0)) ts =
+      | of_iapp (ICase (((t, _), cs), t0)) ts =
           nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs
             @ [("_", of_iterm t0)])) ts
   in of_iterm end;
@@ -215,8 +215,8 @@
     val default_eqn = ([ml_list default_params], nbe_const c default_params);
   in map assemble_eqn eqns @ [default_eqn] end;
 
-fun assemble_eqnss thy is_fun [] = ([], "")
-  | assemble_eqnss thy is_fun eqnss =
+fun assemble_eqnss thy is_fun ([], deps) = ([], "")
+  | assemble_eqnss thy is_fun (eqnss, deps) =
       let
         val cs = map fst eqnss;
         val num_args = cs ~~ map (fn (_, (args, rhs) :: _) => length args) eqnss;
@@ -238,25 +238,26 @@
     val result = ml_list [nbe_value `$` ml_list (map nbe_free frees)];
   in ([nbe_value], ml_Let (bind_funs @ [bind_value]) result) end;
 
-fun eqns_of_stmt (name, CodeThingol.Fun ([], _)) =
+fun eqns_of_stmt ((_, CodeThingol.Fun ([], _)), _) =
       NONE
-  | eqns_of_stmt (name, CodeThingol.Fun (eqns, _)) =
-      SOME (name, eqns)
-  | eqns_of_stmt (_, CodeThingol.Datatypecons _) =
+  | eqns_of_stmt ((name, CodeThingol.Fun (eqns, _)), deps) =
+      SOME ((name, eqns), deps)
+  | eqns_of_stmt ((_, CodeThingol.Datatypecons _), _) =
       NONE
-  | eqns_of_stmt (_, CodeThingol.Datatype _) =
+  | eqns_of_stmt ((_, CodeThingol.Datatype _), _) =
       NONE
-  | eqns_of_stmt (_, CodeThingol.Class _) =
+  | eqns_of_stmt ((_, CodeThingol.Class _), _) =
       NONE
-  | eqns_of_stmt (_, CodeThingol.Classrel _) =
+  | eqns_of_stmt ((_, CodeThingol.Classrel _), _) =
       NONE
-  | eqns_of_stmt (_, CodeThingol.Classop _) =
+  | eqns_of_stmt ((_, CodeThingol.Classop _), _) =
       NONE
-  | eqns_of_stmt (_, CodeThingol.Classinst _) =
+  | eqns_of_stmt ((_, CodeThingol.Classinst _), _) =
       NONE;
 
 fun compile_stmts thy is_fun =
   map_filter eqns_of_stmt
+  #> split_list
   #> assemble_eqnss thy is_fun
   #> compile_univs (Nbe_Functions.get thy);
 
@@ -278,9 +279,11 @@
         val compiled = compile_stmts thy (Symtab.defined tab) stmts;
       in Nbe_Functions.change thy (fold Symtab.update compiled) end;
     val nbe_tab = Nbe_Functions.get thy;
-    val stmtss =
-      map (AList.make (Graph.get_node code)) (rev (Graph.strong_conn code))
-      |> (map o filter_out) (Symtab.defined nbe_tab o fst)
+    val stmtss = rev (Graph.strong_conn code)
+      |> (map o map_filter) (fn name => if Symtab.defined nbe_tab name
+           then NONE
+           else SOME ((name, Graph.get_node code name), Graph.imm_succs code name))
+      |> filter_out null
   in fold compile' stmtss nbe_tab end;
 
 (* re-conversion *)
@@ -311,6 +314,7 @@
 
 fun eval thy code t t' deps =
   let
+    val ty = type_of t;
     fun subst_Frees [] = I
       | subst_Frees inst =
           Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s)
@@ -318,12 +322,11 @@
     val anno_vars =
       subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []))
       #> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t []))
+    fun constrain t =
+      singleton (ProofContext.infer_types_pats (ProofContext.init thy)) (TypeInfer.constrain t ty);
     fun check_tvars t = if null (Term.term_tvars t) then t else
       error ("Illegal schematic type variables in normalized term: "
         ^ setmp show_types true (Sign.string_of_term thy) t);
-    val ty = type_of t;
-    fun constrain t =
-      singleton (ProofContext.infer_types_pats (ProofContext.init thy)) (TypeInfer.constrain t ty);
   in
     (t', deps)
     |> eval_term thy (Symtab.defined (ensure_funs thy code))