code_datatype antiquotation; tuned
authorhaftmann
Wed Apr 22 19:09:23 2009 +0200 (2009-04-22)
changeset 30962f5fd07c558f9
parent 30961 541bfff659af
child 30963 f44736b9d804
code_datatype antiquotation; tuned
src/Tools/code/code_ml.ML
     1.1 --- a/src/Tools/code/code_ml.ML	Wed Apr 22 19:09:22 2009 +0200
     1.2 +++ b/src/Tools/code/code_ml.ML	Wed Apr 22 19:09:23 2009 +0200
     1.3 @@ -911,36 +911,38 @@
     1.4    in (deresolver, nodes) end;
     1.5  
     1.6  fun serialize_ml target compile pr_module pr_stmt raw_module_name labelled_name reserved_names includes raw_module_alias
     1.7 -  _ syntax_tyco syntax_const naming program cs destination =
     1.8 +  _ syntax_tyco syntax_const naming program stmt_names destination =
     1.9    let
    1.10      val is_cons = Code_Thingol.is_cons program;
    1.11 -    val stmt_names = Code_Target.stmt_names_of_destination destination;
    1.12 -    val module_name = if null stmt_names then raw_module_name else SOME "Code";
    1.13 +    val present_stmt_names = Code_Target.stmt_names_of_destination destination;
    1.14 +    val is_present = not (null present_stmt_names);
    1.15 +    val module_name = if is_present then SOME "Code" else raw_module_name;
    1.16      val (deresolver, nodes) = ml_node_of_program labelled_name module_name
    1.17        reserved_names raw_module_alias program;
    1.18      val reserved_names = Code_Printer.make_vars reserved_names;
    1.19      fun pr_node prefix (Dummy _) =
    1.20            NONE
    1.21 -      | pr_node prefix (Stmt (_, stmt)) = if null stmt_names orelse
    1.22 -          (not o null o filter (member (op =) stmt_names) o stmt_names_of) stmt then SOME
    1.23 +      | pr_node prefix (Stmt (_, stmt)) = if is_present andalso
    1.24 +          (null o filter (member (op =) present_stmt_names) o stmt_names_of) stmt
    1.25 +          then NONE
    1.26 +          else SOME
    1.27              (pr_stmt naming labelled_name syntax_tyco syntax_const reserved_names
    1.28                (deresolver prefix) is_cons stmt)
    1.29 -          else NONE
    1.30        | pr_node prefix (Module (module_name, (_, nodes))) =
    1.31            separate (str "")
    1.32              ((map_filter (pr_node (prefix @ [module_name]) o Graph.get_node nodes)
    1.33                o rev o flat o Graph.strong_conn) nodes)
    1.34 -          |> (if null stmt_names then pr_module module_name else Pretty.chunks)
    1.35 +          |> (if is_present then Pretty.chunks else pr_module module_name)
    1.36            |> SOME;
    1.37 -    val cs' = (map o try)
    1.38 -      (deresolver (if is_some module_name then the_list module_name else [])) cs;
    1.39 +    val stmt_names' = (map o try)
    1.40 +      (deresolver (if is_some module_name then the_list module_name else [])) stmt_names;
    1.41      val p = Pretty.chunks (separate (str "") (map snd includes @ (map_filter
    1.42        (pr_node [] o Graph.get_node nodes) o rev o flat o Graph.strong_conn) nodes));
    1.43    in
    1.44      Code_Target.mk_serialization target
    1.45        (case compile of SOME compile => SOME (compile o Code_Target.code_of_pretty) | NONE => NONE)
    1.46        (fn NONE => Code_Target.code_writeln | SOME file => File.write file o Code_Target.code_of_pretty)
    1.47 -      (rpair cs' o Code_Target.code_of_pretty) p destination
    1.48 +      (rpair stmt_names' o Code_Target.code_of_pretty) p destination
    1.49    end;
    1.50  
    1.51  end; (*local*)
    1.52 @@ -986,42 +988,69 @@
    1.53  
    1.54  structure CodeAntiqData = ProofDataFun
    1.55  (
    1.56 -  type T = string list * (bool * (string * (string * (string * string) list) lazy));
    1.57 -  fun init _ = ([], (true, ("", Lazy.value ("", []))));
    1.58 +  type T = (string list * string list) * (bool * (string
    1.59 +    * (string * ((string * string) list * (string * string) list)) lazy));
    1.60 +  fun init _ = (([], []), (true, ("", Lazy.value ("", ([], [])))));
    1.61  );
    1.62  
    1.63  val is_first_occ = fst o snd o CodeAntiqData.get;
    1.64  
    1.65 -fun delayed_code thy consts () =
    1.66 +fun delayed_code thy tycos consts () =
    1.67    let
    1.68      val (consts', (naming, program)) = Code_Thingol.consts_program thy consts;
    1.69 -    val (ml_code, consts'') = eval_code_of NONE thy naming program consts';
    1.70 -    val const_tab = map2 (fn const => fn NONE =>
    1.71 -      error ("Constant " ^ (quote o Code_Unit.string_of_const thy) const
    1.72 -        ^ "\nhas a user-defined serialization")
    1.73 -      | SOME const' => (const, const')) consts consts''
    1.74 -  in (ml_code, const_tab) end;
    1.75 +    val tycos' = map (the o Code_Thingol.lookup_tyco naming) tycos;
    1.76 +    val (ml_code, target_names) = eval_code_of NONE thy naming program (consts' @ tycos');
    1.77 +    val (consts'', tycos'') = chop (length consts') target_names;
    1.78 +    val consts_map = map2 (fn const => fn NONE =>
    1.79 +        error ("Constant " ^ (quote o Code_Unit.string_of_const thy) const
    1.80 +          ^ "\nhas a user-defined serialization")
    1.81 +      | SOME const'' => (const, const'')) consts consts''
    1.82 +    val tycos_map = map2 (fn tyco => fn NONE =>
    1.83 +        error ("Type " ^ (quote o Sign.extern_type thy) tyco
    1.84 +          ^ "\nhas a user-defined serialization")
    1.85 +      | SOME tyco'' => (tyco, tyco'')) tycos tycos'';
    1.86 +  in (ml_code, (tycos_map, consts_map)) end;
    1.87  
    1.88 -fun register_const const ctxt =
    1.89 +fun register_code new_tycos new_consts ctxt =
    1.90    let
    1.91 -    val (consts, (_, (struct_name, _))) = CodeAntiqData.get ctxt;
    1.92 -    val consts' = insert (op =) const consts;
    1.93 +    val ((tycos, consts), (_, (struct_name, _))) = CodeAntiqData.get ctxt;
    1.94 +    val tycos' = fold (insert (op =)) new_tycos tycos;
    1.95 +    val consts' = fold (insert (op =)) new_consts consts;
    1.96      val (struct_name', ctxt') = if struct_name = ""
    1.97        then ML_Antiquote.variant "Code" ctxt
    1.98        else (struct_name, ctxt);
    1.99 -    val acc_code = Lazy.lazy (delayed_code (ProofContext.theory_of ctxt) consts');
   1.100 -  in CodeAntiqData.put (consts', (false, (struct_name', acc_code))) ctxt' end;
   1.101 +    val acc_code = Lazy.lazy (delayed_code (ProofContext.theory_of ctxt) tycos' consts');
   1.102 +  in CodeAntiqData.put ((tycos', consts'), (false, (struct_name', acc_code))) ctxt' end;
   1.103 +
   1.104 +fun register_const const = register_code [] [const];
   1.105  
   1.106 -fun print_code struct_name is_first const ctxt =
   1.107 +fun register_datatype tyco constrs = register_code [tyco] constrs;
   1.108 +
   1.109 +fun print_const const all_struct_name tycos_map consts_map =
   1.110 +  (Long_Name.append all_struct_name o the o AList.lookup (op =) consts_map) const;
   1.111 +
   1.112 +fun print_datatype tyco constrs all_struct_name tycos_map consts_map =
   1.113    let
   1.114 -    val (consts, (_, (struct_code_name, acc_code))) = CodeAntiqData.get ctxt;
   1.115 -    val (raw_ml_code, consts_map) = Lazy.force acc_code;
   1.116 -    val const'' = Long_Name.append (Long_Name.append struct_name struct_code_name)
   1.117 -      ((the o AList.lookup (op =) consts_map) const);
   1.118 +    val upperize = implode o nth_map 0 Symbol.to_ascii_upper o explode;
   1.119 +    fun check_base name name'' =
   1.120 +      if upperize (Long_Name.base_name name) = upperize name''
   1.121 +      then () else error ("Name as printed " ^ quote name''
   1.122 +        ^ "\ndiffers from logical base name " ^ quote (Long_Name.base_name name) ^ "; sorry.");
   1.123 +    val tyco'' = (the o AList.lookup (op =) tycos_map) tyco;
   1.124 +    val constrs'' = map (the o AList.lookup (op =) consts_map) constrs;
   1.125 +    val _ = check_base tyco tyco'';
   1.126 +    val _ = map2 check_base constrs constrs'';
   1.127 +  in "datatype " ^ tyco'' ^ " = datatype " ^ Long_Name.append all_struct_name tyco'' end;
   1.128 +
   1.129 +fun print_code struct_name is_first print_it ctxt =
   1.130 +  let
   1.131 +    val (_, (_, (struct_code_name, acc_code))) = CodeAntiqData.get ctxt;
   1.132 +    val (raw_ml_code, (tycos_map, consts_map)) = Lazy.force acc_code;
   1.133      val ml_code = if is_first then "\nstructure " ^ struct_code_name
   1.134          ^ " =\nstruct\n\n" ^ raw_ml_code ^ "\nend;\n\n"
   1.135        else "";
   1.136 -  in (ml_code, const'') end;
   1.137 +    val all_struct_name = Long_Name.append struct_name struct_code_name;
   1.138 +  in (ml_code, print_it all_struct_name tycos_map consts_map) end;
   1.139  
   1.140  in
   1.141  
   1.142 @@ -1030,7 +1059,19 @@
   1.143      val const = Code_Unit.check_const (ProofContext.theory_of background) raw_const;
   1.144      val is_first = is_first_occ background;
   1.145      val background' = register_const const background;
   1.146 -  in (print_code struct_name is_first const, background') end;
   1.147 +  in (print_code struct_name is_first (print_const const), background') end;
   1.148 +
   1.149 +fun ml_code_datatype_antiq (raw_tyco, raw_constrs) {struct_name, background} =
   1.150 +  let
   1.151 +    val thy = ProofContext.theory_of background;
   1.152 +    val tyco = Sign.intern_type thy raw_tyco;
   1.153 +    val constrs = map (Code_Unit.check_const thy) raw_constrs;
   1.154 +    val constrs' = (map fst o snd o Code.get_datatype thy) tyco;
   1.155 +    val _ = if gen_eq_set (op =) (constrs, constrs') then ()
   1.156 +      else error ("Type " ^ quote tyco ^ ": given constructors diverge from real constructors")
   1.157 +    val is_first = is_first_occ background;
   1.158 +    val background' = register_datatype tyco constrs background;
   1.159 +  in (print_code struct_name is_first (print_datatype tyco constrs), background') end;
   1.160  
   1.161  end; (*local*)
   1.162  
   1.163 @@ -1038,6 +1079,10 @@
   1.164  (** Isar setup **)
   1.165  
   1.166  val _ = ML_Context.add_antiq "code" (fn _ => Args.term >> ml_code_antiq);
   1.167 +val _ = ML_Context.add_antiq "code_datatype" (fn _ =>
   1.168 +  (Args.tyname --| Scan.lift (Args.$$$ "=")
   1.169 +    -- (Args.term ::: Scan.repeat (Scan.lift (Args.$$$ "|") |-- Args.term)))
   1.170 +      >> ml_code_datatype_antiq);
   1.171  
   1.172  fun isar_seri_sml module_name =
   1.173    Code_Target.parse_args (Scan.succeed ())