Haskell uses generic flat_program combinator
authorhaftmann
Tue, 07 Sep 2010 16:05:18 +0200
changeset 39204 3d30f501b7c2
parent 39203 b2f9a6f4b84b
child 39205 13c6e91efcb6
Haskell uses generic flat_program combinator
src/Tools/Code/code_haskell.ML
--- a/src/Tools/Code/code_haskell.ML	Tue Sep 07 11:08:58 2010 +0200
+++ b/src/Tools/Code/code_haskell.ML	Tue Sep 07 16:05:18 2010 +0200
@@ -261,7 +261,7 @@
           end;
   in print_stmt end;
 
-type flat_program = ((string * Code_Thingol.stmt) Graph.T * ((string * (string list * string list)) list)) Graph.T;
+type flat_program = ((string * Code_Thingol.stmt option) Graph.T * string list) Graph.T;
 
 fun flat_program labelled_name { module_alias, module_prefix, reserved,
       empty_nsp, namify_stmt, modify_stmt } program =
@@ -277,11 +277,9 @@
     fun add_stmt name stmt =
       let
         val (module_name, base) = dest_name name;
-      in case modify_stmt stmt
-       of SOME stmt' => 
-            Graph.default_node (module_name, (Graph.empty, []))
-            #> (Graph.map_node module_name o apfst) (Graph.new_node (name, (base, stmt')))
-        | NONE => I
+      in
+        Graph.default_node (module_name, (Graph.empty, []))
+        #> (Graph.map_node module_name o apfst) (Graph.new_node (name, (base, stmt)))
       end;
     fun add_dependency name name' =
       let
@@ -289,14 +287,13 @@
         val (module_name', base') = dest_name name';
       in if module_name = module_name'
         then (Graph.map_node module_name o apfst) (Graph.add_edge (name, name'))
-        else (Graph.map_node module_name o apsnd)
-          (AList.map_default (op =) (module_name', []) (insert (op =) name'))
+        else (Graph.map_node module_name o apsnd) (AList.map_default (op =) (module_name', []) (insert (op =) name'))
       end;
     val proto_program = Graph.empty
       |> Graph.fold (fn (name, (stmt, _)) => add_stmt name stmt) program
       |> Graph.fold (fn (name, (_, (_, names))) => fold (add_dependency name) names) program;
 
-    (* name declarations *)
+    (* name declarations and statement modifications *)
     fun declare name (base, stmt) (gr, nsp) = 
       let
         val (base', nsp') = namify_stmt stmt base nsp;
@@ -304,45 +301,36 @@
       in (gr', nsp') end;
     fun declarations gr = (gr, empty_nsp)
       |> fold (fn name => declare name (Graph.get_node gr name)) (Graph.keys gr) 
-      |> fst;
-    val intermediate_program = proto_program
-      |> Graph.map ((K o apfst) declarations);
+      |> fst
+      |> (Graph.map o K o apsnd) modify_stmt;
+    val flat_program = proto_program
+      |> (Graph.map o K o apfst) declarations;
 
     (* qualified and unqualified imports, deresolving *)
     fun base_deresolver name = fst (Graph.get_node
-      (fst (Graph.get_node intermediate_program (fst (dest_name name)))) name);
-    fun classify_imports gr imports =
+      (fst (Graph.get_node flat_program (fst (dest_name name)))) name);
+    fun classify_names gr imports =
       let
         val import_tab = maps
           (fn (module_name, names) => map (rpair module_name) names) imports;
         val imported_names = map fst import_tab;
         val here_names = Graph.keys gr;
-        val qualified_names = []
-          |> fold (fn name => AList.map_default (op =) (base_deresolver name, [])
-               (insert (op =) name)) (here_names @ imported_names)
-          |> filter (fn (_, names) => length names > 1)
-          |> maps snd;
-        val name_tab = Symtab.empty
-          |> fold (fn name => Symtab.update (name, base_deresolver name)) here_names
-          |> fold (fn name => Symtab.update (name,
-               if member (op =) qualified_names name
-               then Long_Name.append (the (AList.lookup (op =) import_tab name))
-                 (base_deresolver name)
-               else base_deresolver name)) imported_names;
-        val imports' = (map o apsnd) (List.partition (member (op =) qualified_names))
-          imports;
-      in (name_tab, imports') end;
-    val classified = AList.make (uncurry classify_imports o Graph.get_node intermediate_program)
-      (Graph.keys intermediate_program);
-    val flat_program = Graph.map (apsnd o K o snd o the o AList.lookup (op =) classified)
-      intermediate_program;
+      in
+        Symtab.empty
+        |> fold (fn name => Symtab.update (name, base_deresolver name)) here_names
+        |> fold (fn name => Symtab.update (name,
+            Long_Name.append (the (AList.lookup (op =) import_tab name))
+              (base_deresolver name))) imported_names
+      end;
+    val name_tabs = AList.make (uncurry classify_names o Graph.get_node flat_program)
+      (Graph.keys flat_program);
     val deresolver_tab = Symtab.empty
-      |> fold (fn (module_name, (name_tab, _)) => Symtab.update (module_name, name_tab)) classified;
+      |> fold (fn (module_name, name_tab) => Symtab.update (module_name, name_tab)) name_tabs;
     fun deresolver module_name name =
       the (Symtab.lookup (the (Symtab.lookup deresolver_tab module_name)) name)
       handle Option => error ("Unknown statement name: " ^ labelled_name name);
 
-  in (deresolver, flat_program) end;
+  in { deresolver = deresolver, flat_program = flat_program } end;
 
 fun haskell_program_of_program labelled_name module_alias module_prefix reserved =
   let
@@ -379,70 +367,16 @@
         modify_stmt = fn stmt => if select_stmt stmt then SOME stmt else NONE }
   end;
 
-fun mk_name_module reserved module_prefix module_alias program =
-  let
-    val fragments_tab = Code_Namespace.build_module_namespace { module_alias = module_alias,
-      module_prefix = module_prefix, reserved = reserved } program;
-  in Long_Name.implode o the o Symtab.lookup fragments_tab end;
-
-fun haskell_program_of_program labelled_name module_prefix reserved module_alias program =
-  let
-    val reserved = Name.make_context reserved;
-    val mk_name_module = mk_name_module reserved module_prefix module_alias program;
-    fun add_stmt (name, (stmt, deps)) =
-      let
-        val (module_name, base) = Code_Namespace.dest_name name;
-        val module_name' = mk_name_module module_name;
-        val mk_name_stmt = yield_singleton Name.variants;
-        fun add_fun upper (nsp_fun, nsp_typ) =
-          let
-            val (base', nsp_fun') =
-              mk_name_stmt (if upper then first_upper base else base) nsp_fun
-          in (base', (nsp_fun', nsp_typ)) end;
-        fun add_typ (nsp_fun, nsp_typ) =
-          let
-            val (base', nsp_typ') = mk_name_stmt (first_upper base) nsp_typ
-          in (base', (nsp_fun, nsp_typ')) end;
-        val add_name = case stmt
-         of Code_Thingol.Fun (_, (_, SOME _)) => pair base
-          | Code_Thingol.Fun _ => add_fun false
-          | Code_Thingol.Datatype _ => add_typ
-          | Code_Thingol.Datatypecons _ => add_fun true
-          | Code_Thingol.Class _ => add_typ
-          | Code_Thingol.Classrel _ => pair base
-          | Code_Thingol.Classparam _ => add_fun false
-          | Code_Thingol.Classinst _ => pair base;
-        fun add_stmt' base' = case stmt
-         of Code_Thingol.Fun (_, (_, SOME _)) =>
-              I
-          | Code_Thingol.Datatypecons _ =>
-              cons (name, (Long_Name.append module_name' base', NONE))
-          | Code_Thingol.Classrel _ => I
-          | Code_Thingol.Classparam _ =>
-              cons (name, (Long_Name.append module_name' base', NONE))
-          | _ => cons (name, (Long_Name.append module_name' base', SOME stmt));
-      in
-        Symtab.map_default (module_name', ([], ([], (reserved, reserved))))
-              (apfst (fold (insert (op = : string * string -> bool)) deps))
-        #> `(fn program => add_name ((snd o snd o the o Symtab.lookup program) module_name'))
-        #-> (fn (base', names) =>
-              (Symtab.map_entry module_name' o apsnd) (fn (stmts, _) =>
-              (add_stmt' base' stmts, names)))
-      end;
-    val hs_program = fold add_stmt (AList.make (fn name =>
-      (Graph.get_node program name, Graph.imm_succs program name))
-      (Graph.strong_conn program |> flat)) Symtab.empty;
-    fun deresolver name = (fst o the o AList.lookup (op =) ((fst o snd o the
-      o Symtab.lookup hs_program) ((mk_name_module o fst o Code_Namespace.dest_name) name))) name
-      handle Option => error ("Unknown statement name: " ^ labelled_name name);
-  in (deresolver, hs_program) end;
-
 fun serialize_haskell module_prefix string_classes { labelled_name, reserved_syms,
     includes, module_alias, class_syntax, tyco_syntax, const_syntax, program } =
   let
+
+    (* build program *)
     val reserved = fold (insert (op =) o fst) includes reserved_syms;
-    val (deresolver, hs_program) = haskell_program_of_program labelled_name
-      module_prefix reserved module_alias program;
+    val { deresolver, flat_program = haskell_program } = haskell_program_of_program
+      labelled_name module_alias module_prefix (Name.make_context reserved) program;
+
+    (* print statements *)
     val contr_classparam_typs = Code_Thingol.contr_classparam_typs program;
     fun deriving_show tyco =
       let
@@ -457,58 +391,52 @@
               andalso forall (deriv' tycos) tys
           | deriv' _ (ITyVar _) = true
       in deriv [] tyco end;
-    val reserved = make_vars reserved;
-    fun print_stmt qualified = print_haskell_stmt labelled_name
-      class_syntax tyco_syntax const_syntax reserved
-      (if qualified then deresolver else Long_Name.base_name o deresolver)
-      contr_classparam_typs
+    fun print_stmt deresolve = print_haskell_stmt labelled_name
+      class_syntax tyco_syntax const_syntax (make_vars reserved)
+      deresolve contr_classparam_typs
       (if string_classes then deriving_show else K false);
-    fun print_module name content =
-      (name, Pretty.chunks2 [
-        str ("module " ^ name ^ " where {"),
-        content,
-        str "}"
-      ]);
-    fun serialize_module (module_name', (deps, (stmts, _))) =
+
+    (* print modules *)
+    val import_includes_ps =
+      map (fn (name, _) => str ("import qualified " ^ name ^ ";")) includes;
+    fun print_module_frame module_name ps =
+      (module_name, Pretty.chunks2 (
+        str "{-# OPTIONS_GHC -fglasgow-exts #-}"
+        :: str ("module " ^ module_name ^ " where {")
+        :: ps
+        @| str "}"
+      ));
+    fun print_module module_name (gr, imports) =
       let
-        val stmt_names = map fst stmts;
-        val qualified = true;
-        val imports = subtract (op =) stmt_names deps
-          |> distinct (op =)
-          |> map_filter (try deresolver)
-          |> map Long_Name.qualifier
-          |> distinct (op =);
-        fun print_import_include (name, _) = str ("import qualified " ^ name ^ ";");
-        fun print_import_module name = str ((if qualified
-          then "import qualified "
-          else "import ") ^ name ^ ";");
-        val import_ps = map print_import_include includes @ map print_import_module imports
-        val content = Pretty.chunks2 ((if null import_ps then [] else [Pretty.chunks import_ps])
-            @ map_filter
-              (fn (name, (_, SOME stmt)) => SOME (markup_stmt name (print_stmt qualified (name, stmt)))
-                | (_, (_, NONE)) => NONE) stmts
-          );
-      in print_module module_name' content end;
-    fun write_module width (SOME destination) (modlname, content) =
+        val deresolve = deresolver module_name
+        fun print_import module_name = (semicolon o map str) ["import qualified", module_name];
+        val import_ps = import_includes_ps @ map (print_import o fst) imports;
+        fun print_stmt' gr name = case Graph.get_node gr name
+         of (_, NONE) => NONE
+          | (_, SOME stmt) => SOME (markup_stmt name (print_stmt deresolve (name, stmt)));
+        val body_ps = map_filter (print_stmt' gr) ((flat o rev o Graph.strong_conn) gr);
+      in
+        print_module_frame module_name
+          ((if null import_ps then [] else [Pretty.chunks import_ps]) @ body_ps)
+      end;
+
+    (*serialization*)
+    fun write_module width (SOME destination) (module_name, content) =
           let
             val _ = File.check destination;
-            val filename = case modlname
-             of "" => Path.explode "Main.hs"
-              | _ => (Path.ext "hs" o Path.explode o implode o separate "/"
-                    o Long_Name.explode) modlname;
-            val pathname = Path.append destination filename;
-            val _ = File.mkdir_leaf (Path.dir pathname);
-          in File.write pathname
-            ("{-# OPTIONS_GHC -fglasgow-exts #-}\n\n"
-              ^ format [] width content)
-          end
+            val filepath = (Path.append destination o Path.ext "hs" o Path.explode o implode
+              o separate "/" o Long_Name.explode) module_name;
+            val _ = File.mkdir_leaf (Path.dir filepath);
+          in File.write filepath (format [] width content) end
       | write_module width NONE (_, content) = writeln (format [] width content);
   in
     Code_Target.serialization
       (fn width => fn destination => K () o map (write_module width destination))
-      (fn present => fn width => rpair (fn _ => error "no deresolving") o format present width o Pretty.chunks o map snd)
-      (map (uncurry print_module) includes
-        @ map serialize_module (Symtab.dest hs_program))
+      (fn present => fn width => rpair (fn _ => error "no deresolving")
+        o format present width o Pretty.chunks o map snd)
+      (map (uncurry print_module_frame o apsnd single) includes
+        @ map (fn module_name => print_module module_name (Graph.get_node haskell_program module_name))
+          ((flat o rev o Graph.strong_conn) haskell_program))
   end;
 
 val serializer : Code_Target.serializer =