code_datatype antiquotation; tuned
authorhaftmann
Wed, 22 Apr 2009 19:09:23 +0200
changeset 30962 f5fd07c558f9
parent 30961 541bfff659af
child 30963 f44736b9d804
code_datatype antiquotation; tuned
src/Tools/code/code_ml.ML
--- a/src/Tools/code/code_ml.ML	Wed Apr 22 19:09:22 2009 +0200
+++ b/src/Tools/code/code_ml.ML	Wed Apr 22 19:09:23 2009 +0200
@@ -911,36 +911,38 @@
   in (deresolver, nodes) end;
 
 fun serialize_ml target compile pr_module pr_stmt raw_module_name labelled_name reserved_names includes raw_module_alias
-  _ syntax_tyco syntax_const naming program cs destination =
+  _ syntax_tyco syntax_const naming program stmt_names destination =
   let
     val is_cons = Code_Thingol.is_cons program;
-    val stmt_names = Code_Target.stmt_names_of_destination destination;
-    val module_name = if null stmt_names then raw_module_name else SOME "Code";
+    val present_stmt_names = Code_Target.stmt_names_of_destination destination;
+    val is_present = not (null present_stmt_names);
+    val module_name = if is_present then SOME "Code" else raw_module_name;
     val (deresolver, nodes) = ml_node_of_program labelled_name module_name
       reserved_names raw_module_alias program;
     val reserved_names = Code_Printer.make_vars reserved_names;
     fun pr_node prefix (Dummy _) =
           NONE
-      | pr_node prefix (Stmt (_, stmt)) = if null stmt_names orelse
-          (not o null o filter (member (op =) stmt_names) o stmt_names_of) stmt then SOME
+      | pr_node prefix (Stmt (_, stmt)) = if is_present andalso
+          (null o filter (member (op =) present_stmt_names) o stmt_names_of) stmt
+          then NONE
+          else SOME
             (pr_stmt naming labelled_name syntax_tyco syntax_const reserved_names
               (deresolver prefix) is_cons stmt)
-          else NONE
       | pr_node prefix (Module (module_name, (_, nodes))) =
           separate (str "")
             ((map_filter (pr_node (prefix @ [module_name]) o Graph.get_node nodes)
               o rev o flat o Graph.strong_conn) nodes)
-          |> (if null stmt_names then pr_module module_name else Pretty.chunks)
+          |> (if is_present then Pretty.chunks else pr_module module_name)
           |> SOME;
-    val cs' = (map o try)
-      (deresolver (if is_some module_name then the_list module_name else [])) cs;
+    val stmt_names' = (map o try)
+      (deresolver (if is_some module_name then the_list module_name else [])) stmt_names;
     val p = Pretty.chunks (separate (str "") (map snd includes @ (map_filter
       (pr_node [] o Graph.get_node nodes) o rev o flat o Graph.strong_conn) nodes));
   in
     Code_Target.mk_serialization target
       (case compile of SOME compile => SOME (compile o Code_Target.code_of_pretty) | NONE => NONE)
       (fn NONE => Code_Target.code_writeln | SOME file => File.write file o Code_Target.code_of_pretty)
-      (rpair cs' o Code_Target.code_of_pretty) p destination
+      (rpair stmt_names' o Code_Target.code_of_pretty) p destination
   end;
 
 end; (*local*)
@@ -986,42 +988,69 @@
 
 structure CodeAntiqData = ProofDataFun
 (
-  type T = string list * (bool * (string * (string * (string * string) list) lazy));
-  fun init _ = ([], (true, ("", Lazy.value ("", []))));
+  type T = (string list * string list) * (bool * (string
+    * (string * ((string * string) list * (string * string) list)) lazy));
+  fun init _ = (([], []), (true, ("", Lazy.value ("", ([], [])))));
 );
 
 val is_first_occ = fst o snd o CodeAntiqData.get;
 
-fun delayed_code thy consts () =
+fun delayed_code thy tycos consts () =
   let
     val (consts', (naming, program)) = Code_Thingol.consts_program thy consts;
-    val (ml_code, consts'') = eval_code_of NONE thy naming program consts';
-    val const_tab = map2 (fn const => fn NONE =>
-      error ("Constant " ^ (quote o Code_Unit.string_of_const thy) const
-        ^ "\nhas a user-defined serialization")
-      | SOME const' => (const, const')) consts consts''
-  in (ml_code, const_tab) end;
+    val tycos' = map (the o Code_Thingol.lookup_tyco naming) tycos;
+    val (ml_code, target_names) = eval_code_of NONE thy naming program (consts' @ tycos');
+    val (consts'', tycos'') = chop (length consts') target_names;
+    val consts_map = map2 (fn const => fn NONE =>
+        error ("Constant " ^ (quote o Code_Unit.string_of_const thy) const
+          ^ "\nhas a user-defined serialization")
+      | SOME const'' => (const, const'')) consts consts''
+    val tycos_map = map2 (fn tyco => fn NONE =>
+        error ("Type " ^ (quote o Sign.extern_type thy) tyco
+          ^ "\nhas a user-defined serialization")
+      | SOME tyco'' => (tyco, tyco'')) tycos tycos'';
+  in (ml_code, (tycos_map, consts_map)) end;
 
-fun register_const const ctxt =
+fun register_code new_tycos new_consts ctxt =
   let
-    val (consts, (_, (struct_name, _))) = CodeAntiqData.get ctxt;
-    val consts' = insert (op =) const consts;
+    val ((tycos, consts), (_, (struct_name, _))) = CodeAntiqData.get ctxt;
+    val tycos' = fold (insert (op =)) new_tycos tycos;
+    val consts' = fold (insert (op =)) new_consts consts;
     val (struct_name', ctxt') = if struct_name = ""
       then ML_Antiquote.variant "Code" ctxt
       else (struct_name, ctxt);
-    val acc_code = Lazy.lazy (delayed_code (ProofContext.theory_of ctxt) consts');
-  in CodeAntiqData.put (consts', (false, (struct_name', acc_code))) ctxt' end;
+    val acc_code = Lazy.lazy (delayed_code (ProofContext.theory_of ctxt) tycos' consts');
+  in CodeAntiqData.put ((tycos', consts'), (false, (struct_name', acc_code))) ctxt' end;
+
+fun register_const const = register_code [] [const];
 
-fun print_code struct_name is_first const ctxt =
+fun register_datatype tyco constrs = register_code [tyco] constrs;
+
+fun print_const const all_struct_name tycos_map consts_map =
+  (Long_Name.append all_struct_name o the o AList.lookup (op =) consts_map) const;
+
+fun print_datatype tyco constrs all_struct_name tycos_map consts_map =
   let
-    val (consts, (_, (struct_code_name, acc_code))) = CodeAntiqData.get ctxt;
-    val (raw_ml_code, consts_map) = Lazy.force acc_code;
-    val const'' = Long_Name.append (Long_Name.append struct_name struct_code_name)
-      ((the o AList.lookup (op =) consts_map) const);
+    val upperize = implode o nth_map 0 Symbol.to_ascii_upper o explode;
+    fun check_base name name'' =
+      if upperize (Long_Name.base_name name) = upperize name''
+      then () else error ("Name as printed " ^ quote name''
+        ^ "\ndiffers from logical base name " ^ quote (Long_Name.base_name name) ^ "; sorry.");
+    val tyco'' = (the o AList.lookup (op =) tycos_map) tyco;
+    val constrs'' = map (the o AList.lookup (op =) consts_map) constrs;
+    val _ = check_base tyco tyco'';
+    val _ = map2 check_base constrs constrs'';
+  in "datatype " ^ tyco'' ^ " = datatype " ^ Long_Name.append all_struct_name tyco'' end;
+
+fun print_code struct_name is_first print_it ctxt =
+  let
+    val (_, (_, (struct_code_name, acc_code))) = CodeAntiqData.get ctxt;
+    val (raw_ml_code, (tycos_map, consts_map)) = Lazy.force acc_code;
     val ml_code = if is_first then "\nstructure " ^ struct_code_name
         ^ " =\nstruct\n\n" ^ raw_ml_code ^ "\nend;\n\n"
       else "";
-  in (ml_code, const'') end;
+    val all_struct_name = Long_Name.append struct_name struct_code_name;
+  in (ml_code, print_it all_struct_name tycos_map consts_map) end;
 
 in
 
@@ -1030,7 +1059,19 @@
     val const = Code_Unit.check_const (ProofContext.theory_of background) raw_const;
     val is_first = is_first_occ background;
     val background' = register_const const background;
-  in (print_code struct_name is_first const, background') end;
+  in (print_code struct_name is_first (print_const const), background') end;
+
+fun ml_code_datatype_antiq (raw_tyco, raw_constrs) {struct_name, background} =
+  let
+    val thy = ProofContext.theory_of background;
+    val tyco = Sign.intern_type thy raw_tyco;
+    val constrs = map (Code_Unit.check_const thy) raw_constrs;
+    val constrs' = (map fst o snd o Code.get_datatype thy) tyco;
+    val _ = if gen_eq_set (op =) (constrs, constrs') then ()
+      else error ("Type " ^ quote tyco ^ ": given constructors diverge from real constructors")
+    val is_first = is_first_occ background;
+    val background' = register_datatype tyco constrs background;
+  in (print_code struct_name is_first (print_datatype tyco constrs), background') end;
 
 end; (*local*)
 
@@ -1038,6 +1079,10 @@
 (** Isar setup **)
 
 val _ = ML_Context.add_antiq "code" (fn _ => Args.term >> ml_code_antiq);
+val _ = ML_Context.add_antiq "code_datatype" (fn _ =>
+  (Args.tyname --| Scan.lift (Args.$$$ "=")
+    -- (Args.term ::: Scan.repeat (Scan.lift (Args.$$$ "|") |-- Args.term)))
+      >> ml_code_datatype_antiq);
 
 fun isar_seri_sml module_name =
   Code_Target.parse_args (Scan.succeed ())