eliminated fun/val confusion
authorhaftmann
Mon, 29 Dec 2008 13:23:53 +0100
changeset 29189 ee8572f3bb57
parent 29188 ff41885a1234
child 29193 4410739f97a6
child 29197 6d4cb27ed19c
eliminated fun/val confusion
src/Tools/code/code_ml.ML
--- a/src/Tools/code/code_ml.ML	Sun Dec 28 14:41:47 2008 -0800
+++ b/src/Tools/code/code_ml.ML	Mon Dec 29 13:23:53 2008 +0100
@@ -1,5 +1,4 @@
 (*  Title:      Tools/code/code_ml.ML
-    ID:         $Id$
     Author:     Florian Haftmann, TU Muenchen
 
 Serializer for SML and OCaml.
@@ -25,17 +24,21 @@
 val target_OCaml = "OCaml";
 
 datatype ml_stmt =
-    MLFuns of ((string * (typscheme * ((iterm list * iterm) * (thm * bool)) list)) * bool (*val flag*)) list
+    MLExc of string * int
+  | MLVal of string * ((typscheme * iterm) * (thm * bool))
+  | MLFuns of (string * (typscheme * ((iterm list * iterm) * (thm * bool)) list)) list * string list
   | MLDatas of (string * ((vname * sort) list * (string * itype list) list)) list
   | MLClass of string * (vname * ((class * string) list * (string * itype) list))
   | MLClassinst of string * ((class * (string * (vname * sort) list))
         * ((class * (string * (string * dict list list))) list
       * ((string * const) * (thm * bool)) list));
 
-fun stmt_names_of (MLFuns fs) = map (fst o fst) fs
+fun stmt_names_of (MLExc (name, _)) = [name]
+  | stmt_names_of (MLVal (name, _)) = [name]
+  | stmt_names_of (MLFuns (fs, _)) = map fst fs
   | stmt_names_of (MLDatas ds) = map fst ds
-  | stmt_names_of (MLClass (c, _)) = [c]
-  | stmt_names_of (MLClassinst (i, _)) = [i];
+  | stmt_names_of (MLClass (name, _)) = [name]
+  | stmt_names_of (MLClassinst (name, _)) = [name];
 
 
 (** SML serailizer **)
@@ -157,73 +160,83 @@
             )
           end
       | pr_case is_closure thm vars fxy ((_, []), _) = str "raise Fail \"empty case\"";
-    fun pr_stmt (MLFuns (funns as (funn :: funns'))) =
+    fun pr_stmt (MLExc (name, n)) =
+          let
+            val exc_str =
+              (ML_Syntax.print_string o NameSpace.base o NameSpace.qualifier) name;
+          in
+            concat (
+              str (if n = 0 then "val" else "fun")
+              :: (str o deresolve) name
+              :: map str (replicate n "_")
+              @ str "="
+              :: str "raise"
+              :: str "(Fail"
+              @@ str (exc_str ^ ")")
+            )
+          end
+      | pr_stmt (MLVal (name, (((vs, ty), t), (thm, _)))) =
           let
-            val definer =
+            val consts = map_filter
+              (fn c => if (is_some o syntax_const) c
+                then NONE else (SOME o NameSpace.base o deresolve) c)
+                (Code_Thingol.fold_constnames (insert (op =)) t []);
+            val vars = reserved_names
+              |> Code_Name.intro_vars consts;
+          in
+            concat [
+              str "val",
+              (str o deresolve) name,
+              str ":",
+              pr_typ NOBR ty,
+              str "=",
+              pr_term (K false) thm vars NOBR t
+            ]
+          end
+      | pr_stmt (MLFuns (funn :: funns, pseudo_funs)) =
+          let
+            fun pr_funn definer (name, ((vs, ty), eqs as eq :: eqs')) =
               let
-                fun no_args _ (((ts, _), _) :: _) = length ts
-                  | no_args ty [] = (length o fst o Code_Thingol.unfold_fun) ty;
-                fun mk 0 [] = "val"
-                  | mk 0 vs = if (null o filter_out (null o snd)) vs
-                      then "val" else "fun"
-                  | mk k _ = "fun";
-                fun chk ((_, ((vs, ty), eqs)), _) NONE = SOME (mk (no_args ty eqs) vs)
-                  | chk ((_, ((vs, ty), eqs)), _) (SOME defi) =
-                      if defi = mk (no_args ty eqs) vs then SOME defi
-                      else error ("Mixing simultaneous vals and funs not implemented: "
-                        ^ commas (map (labelled_name o fst o fst) funns));
-              in the (fold chk funns NONE) end;
-            fun pr_funn definer ((name, ((vs, ty), [])), _) =
+                val vs_dict = filter_out (null o snd) vs;
+                val shift = if null eqs' then I else
+                  map (Pretty.block o single o Pretty.block o single);
+                fun pr_eq definer ((ts, t), (thm, _)) =
                   let
-                    val vs_dict = filter_out (null o snd) vs;
-                    val n = length vs_dict + (length o fst o Code_Thingol.unfold_fun) ty;
-                    val exc_str =
-                      (ML_Syntax.print_string o NameSpace.base o NameSpace.qualifier) name;
+                    val consts = map_filter
+                      (fn c => if (is_some o syntax_const) c
+                        then NONE else (SOME o NameSpace.base o deresolve) c)
+                        ((fold o Code_Thingol.fold_constnames) (insert (op =)) (t :: ts) []);
+                    val vars = reserved_names
+                      |> Code_Name.intro_vars consts
+                      |> Code_Name.intro_vars ((fold o Code_Thingol.fold_unbound_varnames)
+                           (insert (op =)) ts []);
                   in
                     concat (
                       str definer
                       :: (str o deresolve) name
-                      :: map str (replicate n "_")
+                      :: (if member (op =) pseudo_funs name then [str "()"]
+                          else pr_tyvar_dicts vs_dict
+                            @ map (pr_term (member (op =) pseudo_funs) thm vars BR) ts)
                       @ str "="
-                      :: str "raise"
-                      :: str "(Fail"
-                      @@ str (exc_str ^ ")")
+                      @@ pr_term (member (op =) pseudo_funs) thm vars NOBR t
                     )
                   end
-              | pr_funn definer ((name, ((vs, ty), eqs as eq :: eqs')), _) =
-                  let
-                    val vs_dict = filter_out (null o snd) vs;
-                    val shift = if null eqs' then I else
-                      map (Pretty.block o single o Pretty.block o single);
-                    fun pr_eq definer ((ts, t), (thm, _)) =
-                      let
-                        val consts = map_filter
-                          (fn c => if (is_some o syntax_const) c
-                            then NONE else (SOME o NameSpace.base o deresolve) c)
-                            ((fold o Code_Thingol.fold_constnames) (insert (op =)) (t :: ts) []);
-                        val vars = reserved_names
-                          |> Code_Name.intro_vars consts
-                          |> Code_Name.intro_vars ((fold o Code_Thingol.fold_unbound_varnames)
-                               (insert (op =)) ts []);
-                      in
-                        concat (
-                          [str definer, (str o deresolve) name]
-                          @ (if null ts andalso null vs_dict
-                             then [str ":", pr_typ NOBR ty]
-                             else
-                               pr_tyvar_dicts vs_dict
-                               @ map (pr_term (K false) thm vars BR) ts)
-                       @ [str "=", pr_term (K false) thm vars NOBR t]
-                        )
-                      end
-                  in
-                    (Pretty.block o Pretty.fbreaks o shift) (
-                      pr_eq definer eq
-                      :: map (pr_eq "|") eqs'
-                    )
-                  end;
-            val (ps, p) = split_last (pr_funn definer funn :: map (pr_funn "and") funns');
-          in Pretty.chunks (ps @ [Pretty.block ([p, str ";"])]) end
+              in
+                (Pretty.block o Pretty.fbreaks o shift) (
+                  pr_eq definer eq
+                  :: map (pr_eq "|") eqs'
+                )
+              end;
+            fun pr_pseudo_fun name = concat [
+                str "val",
+                (str o deresolve) name,
+                str "=",
+                (str o deresolve) name,
+                str "();"
+              ];
+            val (ps, p) = split_last (pr_funn "fun" funn :: map (pr_funn "and") funns);
+            val pseudo_ps = map pr_pseudo_fun pseudo_funs;
+          in Pretty.chunks (ps @ Pretty.block ([p, str ";"]) :: pseudo_ps) end
      | pr_stmt (MLDatas (datas as (data :: datas'))) =
           let
             fun pr_co (co, []) =
@@ -250,7 +263,7 @@
                   );
             val (ps, p) = split_last
               (pr_data "datatype" data :: map (pr_data "and") datas');
-          in Pretty.chunks (ps @ [Pretty.block ([p, str ";"])]) end
+          in Pretty.chunks (ps @| Pretty.block ([p, str ";"])) end
      | pr_stmt (MLClass (class, (v, (superclasses, classparams)))) =
           let
             val w = Code_Name.first_upper v ^ "_";
@@ -457,7 +470,39 @@
         val (fished3, _) = Name.variants fished2 Name.context;
         val vars' = Code_Name.intro_vars fished3 vars;
       in map (Code_Name.lookup_var vars') fished3 end;
-    fun pr_stmt (MLFuns (funns as funn :: funns')) =
+    fun pr_stmt (MLExc (name, n)) =
+          let
+            val exc_str =
+              (ML_Syntax.print_string o NameSpace.base o NameSpace.qualifier) name;
+          in
+            concat (
+              str "let"
+              :: (str o deresolve) name
+              :: map str (replicate n "_")
+              @ str "="
+              :: str "failwith"
+              @@ str exc_str
+            )
+          end
+      | pr_stmt (MLVal (name, (((vs, ty), t), (thm, _)))) =
+          let
+            val consts = map_filter
+              (fn c => if (is_some o syntax_const) c
+                then NONE else (SOME o NameSpace.base o deresolve) c)
+                (Code_Thingol.fold_constnames (insert (op =)) t []);
+            val vars = reserved_names
+              |> Code_Name.intro_vars consts;
+          in
+            concat [
+              str "let",
+              (str o deresolve) name,
+              str ":",
+              pr_typ NOBR ty,
+              str "=",
+              pr_term (K false) thm vars NOBR t
+            ]
+          end
+      | pr_stmt (MLFuns (funn :: funns, pseudo_funs)) =
           let
             fun pr_eq ((ts, t), (thm, _)) =
               let
@@ -470,24 +515,12 @@
                   |> Code_Name.intro_vars ((fold o Code_Thingol.fold_unbound_varnames)
                       (insert (op =)) ts []);
               in concat [
-                (Pretty.block o Pretty.commas) (map (pr_term (K false) thm vars NOBR) ts),
+                (Pretty.block o Pretty.commas)
+                  (map (pr_term (member (op =) pseudo_funs) thm vars NOBR) ts),
                 str "->",
-                pr_term (K false) thm vars NOBR t
+                pr_term (member (op =) pseudo_funs) thm vars NOBR t
               ] end;
-            fun pr_eqs name ty [] =
-                  let
-                    val n = (length o fst o Code_Thingol.unfold_fun) ty;
-                    val exc_str =
-                      (ML_Syntax.print_string o NameSpace.base o NameSpace.qualifier) name;
-                  in
-                    concat (
-                      map str (replicate n "_")
-                      @ str "="
-                      :: str "failwith"
-                      @@ str exc_str
-                    )
-                  end
-              | pr_eqs _ _ [((ts, t), (thm, _))] =
+            fun pr_eqs is_pseudo [((ts, t), (thm, _))] =
                   let
                     val consts = map_filter
                       (fn c => if (is_some o syntax_const) c
@@ -499,12 +532,13 @@
                           (insert (op =)) ts []);
                   in
                     concat (
-                      map (pr_term (K false) thm vars BR) ts
+                      (if is_pseudo then [str "()"]
+                        else map (pr_term (member (op =) pseudo_funs) thm vars BR) ts)
                       @ str "="
-                      @@ pr_term (K false) thm vars NOBR t
+                      @@ pr_term (member (op =) pseudo_funs) thm vars NOBR t
                     )
                   end
-              | pr_eqs _ _ (eqs as (eq as (([_], _), _)) :: eqs') =
+              | pr_eqs _ (eqs as (eq as (([_], _), _)) :: eqs') =
                   Pretty.block (
                     str "="
                     :: Pretty.brk 1
@@ -514,7 +548,7 @@
                     :: maps (append [Pretty.fbrk, str "|", Pretty.brk 1]
                           o single o pr_eq) eqs'
                   )
-              | pr_eqs _ _ (eqs as eq :: eqs') =
+              | pr_eqs _ (eqs as eq :: eqs') =
                   let
                     val consts = map_filter
                       (fn c => if (is_some o syntax_const) c
@@ -541,16 +575,25 @@
                            o single o pr_eq) eqs'
                     )
                   end;
-            fun pr_funn definer ((name, ((vs, ty), eqs)), _) =
+            fun pr_funn definer (name, ((vs, ty), eqs)) =
               concat (
                 str definer
                 :: (str o deresolve) name
                 :: pr_tyvar_dicts (filter_out (null o snd) vs)
-                @| pr_eqs name ty eqs
+                @| pr_eqs (member (op =) pseudo_funs name) eqs
               );
+            fun pr_pseudo_fun name = concat [
+                str "let",
+                (str o deresolve) name,
+                str "=",
+                (str o deresolve) name,
+                str "();;"
+              ];
+            val (ps, p) = split_last (pr_funn "fun" funn :: map (pr_funn "and") funns);
             val (ps, p) = split_last
-              (pr_funn "let rec" funn :: map (pr_funn "and") funns');
-          in Pretty.chunks (ps @ [Pretty.block ([p, str ";;"])]) end
+              (pr_funn "let rec" funn :: map (pr_funn "and") funns);
+            val pseudo_ps = map pr_pseudo_fun pseudo_funs;
+          in Pretty.chunks (ps @ Pretty.block ([p, str ";;"]) :: pseudo_ps) end
      | pr_stmt (MLDatas (datas as (data :: datas'))) =
           let
             fun pr_co (co, []) =
@@ -577,7 +620,7 @@
                   );
             val (ps, p) = split_last
               (pr_data "type" data :: map (pr_data "and") datas');
-          in Pretty.chunks (ps @ [Pretty.block ([p, str ";;"])]) end
+          in Pretty.chunks (ps @| Pretty.block ([p, str ";;"])) end
      | pr_stmt (MLClass (class, (v, (superclasses, classparams)))) =
           let
             val w = "_" ^ Code_Name.first_upper v;
@@ -729,15 +772,33 @@
         val base' = if upper then Code_Name.first_upper base else base;
         val ([base''], nsp') = Name.variants [base'] nsp;
       in (base'', nsp') end;
-    fun add_funs stmts =
-      fold_map
+    fun rearrange_fun name (tysm as (vs, ty), raw_eqs) =
+      let
+        val eqs = filter (snd o snd) raw_eqs;
+        val (eqs', is_value) = if null (filter_out (null o snd) vs) then case eqs
+         of [(([], t), thm)] => if (not o null o fst o Code_Thingol.unfold_fun) ty
+            then ([(([IVar "x"], t `$ IVar "x"), thm)], false)
+            else (eqs, not (Code_Thingol.fold_constnames
+              (fn name' => fn b => b orelse name = name') t false))
+          | _ => (eqs, false)
+          else (eqs, false)
+      in ((name, (tysm, eqs')), is_value) end;
+    fun check_kind [((name, (tysm, [(([], t), thm)])), true)] = MLVal (name, ((tysm, t), thm))
+      | check_kind [((name, ((vs, ty), [])), _)] =
+          MLExc (name, (length o filter_out (null o snd)) vs + (length o fst o Code_Thingol.unfold_fun) ty)
+      | check_kind funns =
+          MLFuns (map fst funns, map_filter
+            (fn ((name, ((vs, _), [(([], _), _)])), _) =>
+                  if null (filter_out (null o snd) vs) then SOME name else NONE
+              | _ => NONE) funns);
+    fun add_funs stmts = fold_map
         (fn (name, Code_Thingol.Fun (_, stmt)) =>
-              map_nsp_fun_yield (mk_name_stmt false name) #>>
-                rpair ((name, stmt |> apsnd (filter (snd o snd))), false)
+              map_nsp_fun_yield (mk_name_stmt false name)
+              #>> rpair (rearrange_fun name stmt)
           | (name, _) =>
               error ("Function block containing illegal statement: " ^ labelled_name name)
         ) stmts
-      #>> (split_list #> apsnd MLFuns);
+      #>> (split_list #> apsnd check_kind);
     fun add_datatypes stmts =
       fold_map
         (fn (name, Code_Thingol.Datatype (_, stmt)) =>