src/Pure/Tools/nbe_codegen.ML
changeset 20846 5fde744176d7
parent 20706 f77bd47a70df
child 20856 9f7f0bf89e7d
--- a/src/Pure/Tools/nbe_codegen.ML	Mon Oct 02 23:01:09 2006 +0200
+++ b/src/Pure/Tools/nbe_codegen.ML	Mon Oct 02 23:01:11 2006 +0200
@@ -16,7 +16,7 @@
 
 signature NBE_CODEGEN =
 sig
-  val generate: (string -> bool) -> string * CodegenThingol.def -> string;
+  val generate: theory -> (string -> bool) -> (string * thm list) list -> string option;
   val nterm_to_term: theory -> NBE_Eval.nterm -> term;
 end
 
@@ -58,8 +58,11 @@
 
 fun eqns name ees =
   let fun eqn (es,e) = name ^ " " ^ space_implode " " es ^ " = " ^ e
-  in "fun " ^ space_implode "\n  | " (map eqn ees) ^ ";\n" end;
+  in space_implode "\n  | " (map eqn ees) end;
 
+fun eqnss (es :: ess) = prefix "fun " es :: map (prefix "and ") ess
+  |> space_implode "\n"
+  |> suffix "\n";
 
 fun Val v s = "val " ^ v ^ " = " ^ s;
 fun Let ds e = "let\n" ^ (space_implode "\n" ds) ^ " in " ^ e ^ " end"
@@ -88,83 +91,74 @@
 		   S.abs (S.tup []) (S.app Eval_C
 	(S.quote nm))]);
 
-open BasicCodegenThingol;
-
-fun mk_rexpr defined nm ar =
+fun mk_rexpr defined names ar =
   let
-    fun mk args = CodegenThingol.map_pure (mk' args)
-    and mk' args (IConst (c, _)) =
-          if c = nm then selfcall nm ar args
-          else if defined c then S.nbe_apps (MLname c) args
-          else S.app Eval_Constr (S.tup [S.quote c, S.list args])
-      | mk' args (IVar s) = S.nbe_apps (MLvname s) args
-      | mk' args (e1 `$ e2) = mk (args @ [mk [] e2]) e1
-      | mk' args ((nm, _) `|-> e) = S.nbe_apps (mk_nbeFUN (nm, mk [] e)) args;
+    fun mk args (Const (c, _)) = 
+          if member (op =) names c then selfcall c ar args
+            else if defined c then S.nbe_apps (MLname c) args
+            else S.app Eval_Constr (S.tup [S.quote c, S.list args])
+      | mk args (Free (v, _)) = S.nbe_apps (MLvname v) args
+      | mk args (t1 $ t2) = mk (args @ [mk [] t2]) t1
+      | mk args (Abs (v, _, t)) = S.nbe_apps (mk_nbeFUN (v, mk [] t)) args;
   in mk [] end;
 
 val mk_lexpr =
   let
-    fun mk args = CodegenThingol.map_pure (mk' args)
-    and mk' args (IConst (c, _)) =
+    fun mk args (Const (c, _)) =
           S.app Eval_Constr (S.tup [S.quote c, S.list args])
-      | mk' args (IVar s) = if args = [] then MLvname s else 
+      | mk args (Free (v, _)) = if null args then MLvname v else 
           sys_error "NBE mk_lexpr illegal higher order pattern"
-      | mk' args (e1 `$ e2) = mk (args @ [mk [] e2]) e1
-      | mk' args (_ `|-> _) =
+      | mk args (t1 $ t2) = mk (args @ [mk [] t2]) t1
+      | mk args (Abs _) =
           sys_error "NBE mk_lexpr illegal pattern";
   in mk [] end;
 
-fun mk_eqn defined nm ar (lhs,e) =
-  if has_duplicates (op =) (fold CodegenThingol.add_varnames lhs []) then [] else
-  [([S.list(map mk_lexpr (rev lhs))], mk_rexpr defined nm ar e)];
-
 fun lookup nm = S.Val (MLname nm) (tab_lookup (S.quote nm));
 
-fun generate defined (nm, CodegenThingol.Fun (eqns, _)) =
+fun generate thy defined [(_, [])] = NONE
+  | generate thy defined raw_eqnss =
       let
-        val ar = (length o fst o hd) eqns;
-        val params = (S.list o rev o MLparams) ar;
-        val funs =
+        val eqnss0 = map (fn (name, thms as thm :: _) =>
+          (name, ((length o snd o strip_comb o fst o Logic.dest_equals o prop_of) thm,
+            map (apfst (snd o strip_comb) o Logic.dest_equals o Logic.unvarify
+              o prop_of) thms)))
+          raw_eqnss;
+        val eqnss = (map o apsnd o apsnd o map) (fn (args, t) =>
+          (map (NBE_Eval.prep_term thy) args, NBE_Eval.prep_term thy t)) eqnss0
+        val names = map fst eqnss;
+        val used_funs =
           []
-          |> fold (fn (_, e) => CodegenThingol.add_constnames e) eqns
-          |> remove (op =) nm;
-        val globals = map lookup (filter defined funs);
-        val default_eqn = ([params], S.app Eval_Constr (S.tup[S.quote nm,params]));
-        val code = S.eqns (MLname nm)
-                          (maps (mk_eqn defined nm ar) eqns @ [default_eqn])
-        val register = tab_update
-            (S.app Eval_mk_Fun (S.tup[S.quote nm, MLname nm, string_of_int ar]))
-      in
-        S.Let (globals @ [code]) register
-      end
-  | generate _ _ = "";
+          |> fold (fold (fold_aterms (fn Const (c, _) => insert (op =) c
+                                      | _ => I)) o map snd o snd o snd) eqnss
+          |> subtract (op =) names;
+        fun mk_def (name, (ar, eqns)) =
+          let
+            fun mk_eqn (args, t) = ([S.list (map mk_lexpr (rev args))],
+              mk_rexpr defined names ar t);
+            val default_params = (S.list o rev o MLparams) ar;
+            val default_eqn = ([default_params], S.app Eval_Constr (S.tup [S.quote name, default_params]));
+          in S.eqns (MLname name) (map mk_eqn eqns @ [default_eqn]) end;
+        val globals = map lookup (filter defined used_funs);
+        fun register (name, (ar, _)) = tab_update
+            (S.app Eval_mk_Fun (S.tup [S.quote name, MLname name, string_of_int ar]))
+      in SOME (S.Let (globals @ [S.eqnss (map mk_def eqnss)]) (space_implode "; " (map register eqnss))) end;
 
 open NBE_Eval;
 
 val tcount = ref 0;
 
-(* FIXME get rid of TFree case!!! *)
 fun varifyT ty =
   let val ty' = map_type_tvar (fn ((s,i),S) => TypeInfer.param (!tcount + i) (s,S)) ty;
       val _ = (tcount := !tcount + maxidx_of_typ ty + 1);
-      val ty'' = map_type_tfree (TypeInfer.param (!tcount)) ty'
-  in tcount := !tcount+1; ty'' end;
+  in tcount := !tcount+1; ty' end;
 
 fun nterm_to_term thy t =
   let
-    fun consts_of (C s) = insert (op =) s
-      | consts_of (V _) = I
-      | consts_of (B _) = I
-      | consts_of (A (t1, t2)) = consts_of t1 #> consts_of t2
-      | consts_of (AbsN (_, t)) = consts_of t;
-    val consts = consts_of t [];
-    val ctab = AList.make (CodegenPackage.const_of_idf thy) consts;
-    val the_const = apsnd varifyT o the o AList.lookup (op =) ctab;
-    fun to_term bounds (C s) = Const (the_const s)
-      | to_term bounds (V s) = Free (s, dummyT)
-      | to_term bounds (B i) = Bound (find_index (fn j => i = j) bounds)
-      | to_term bounds (A (t1, t2)) = to_term bounds t1 $ to_term bounds t2
-      | to_term bounds (AbsN (i, t)) =
+   fun to_term bounds (C s) = Const ((apsnd varifyT o CodegenPackage.const_of_idf thy) s)
+     | to_term bounds (V s) = Free (s, dummyT)
+     | to_term bounds (B i) = Bound (find_index (fn j => i = j) bounds)
+     | to_term bounds (A (t1, t2)) = to_term bounds t1 $ to_term bounds t2
+     | to_term bounds (AbsN (i, t)) =
           Abs("u", dummyT, to_term (i::bounds) t);
   in tcount := 0; to_term [] t end;