src/Pure/codegen.ML
author berghofe
Mon Dec 10 15:39:34 2001 +0100 (2001-12-10)
changeset 12452 68493b92e7a6
parent 12311 ce5f9e61c037
child 12490 d2a2c479b3cb
permissions -rw-r--r--
- Added code generator interface for types
- Changed type of invoke_codegen
     1 (*  Title:      Pure/codegen.ML
     2     ID:         $Id$
     3     Author:     Stefan Berghofer, TU Muenchen
     4     License:    GPL (GNU GENERAL PUBLIC LICENSE)
     5 
     6 Generic code generator.
     7 *)
     8 
     9 signature CODEGEN =
    10 sig
    11   val quiet_mode : bool ref
    12   val message : string -> unit
    13 
    14   datatype 'a mixfix =
    15       Arg
    16     | Ignore
    17     | Pretty of Pretty.T
    18     | Quote of 'a;
    19 
    20   type 'a codegen
    21 
    22   val add_codegen: string -> term codegen -> theory -> theory
    23   val add_tycodegen: string -> typ codegen -> theory -> theory
    24   val print_codegens: theory -> unit
    25   val generate_code: theory -> (string * string) list -> string
    26   val generate_code_i: theory -> (string * term) list -> string
    27   val assoc_consts: (xstring * string option * term mixfix list) list -> theory -> theory
    28   val assoc_consts_i: (xstring * typ option * term mixfix list) list -> theory -> theory
    29   val assoc_types: (xstring * typ mixfix list) list -> theory -> theory
    30   val get_assoc_code: theory -> string -> typ -> term mixfix list option
    31   val get_assoc_type: theory -> string -> typ mixfix list option
    32   val invoke_codegen: theory -> string -> bool ->
    33     (exn option * string) Graph.T * term -> (exn option * string) Graph.T * Pretty.T
    34   val invoke_tycodegen: theory -> string -> bool ->
    35     (exn option * string) Graph.T * typ -> (exn option * string) Graph.T * Pretty.T
    36   val mk_const_id: Sign.sg -> string -> string
    37   val mk_type_id: Sign.sg -> string -> string
    38   val rename_term: term -> term
    39   val get_defn: theory -> string -> typ -> ((term list * term) * int option) option
    40   val is_instance: theory -> typ -> typ -> bool
    41   val parens: Pretty.T -> Pretty.T
    42   val mk_app: bool -> Pretty.T -> Pretty.T list -> Pretty.T
    43   val eta_expand: term -> term list -> int -> term
    44   val parse_mixfix: (string -> 'a) -> string -> 'a mixfix list
    45   val parsers: OuterSyntax.parser list
    46   val setup: (theory -> theory) list
    47 end;
    48 
    49 structure Codegen : CODEGEN =
    50 struct
    51 
    52 val quiet_mode = ref true;
    53 fun message s = if !quiet_mode then () else writeln s;
    54 
    55 (**** Mixfix syntax ****)
    56 
    57 datatype 'a mixfix =
    58     Arg
    59   | Ignore
    60   | Pretty of Pretty.T
    61   | Quote of 'a;
    62 
    63 fun is_arg Arg = true
    64   | is_arg Ignore = true
    65   | is_arg _ = false;
    66 
    67 fun quotes_of [] = []
    68   | quotes_of (Quote q :: ms) = q :: quotes_of ms
    69   | quotes_of (_ :: ms) = quotes_of ms;
    70 
    71 fun args_of [] xs = ([], xs)
    72   | args_of (Arg :: ms) (x :: xs) = apfst (cons x) (args_of ms xs)
    73   | args_of (Ignore :: ms) (_ :: xs) = args_of ms xs
    74   | args_of (_ :: ms) xs = args_of ms xs;
    75 
    76 val num_args = length o filter is_arg;
    77 
    78 
    79 (**** theory data ****)
    80 
    81 (* data kind 'Pure/codegen' *)
    82 
    83 type 'a codegen = theory -> (exn option * string) Graph.T ->
    84   string -> bool -> 'a -> ((exn option * string) Graph.T * Pretty.T) option;
    85 
    86 structure CodegenArgs =
    87 struct
    88   val name = "Pure/codegen";
    89   type T =
    90     {codegens : (string * term codegen) list,
    91      tycodegens : (string * typ codegen) list,
    92      consts : ((string * typ) * term mixfix list) list,
    93      types : (string * typ mixfix list) list};
    94 
    95   val empty = {codegens = [], tycodegens = [], consts = [], types = []};
    96   val copy = I;
    97   val prep_ext = I;
    98 
    99   fun merge ({codegens = codegens1, tycodegens = tycodegens1, consts = consts1, types = types1},
   100              {codegens = codegens2, tycodegens = tycodegens2, consts = consts2, types = types2}) =
   101     {codegens = rev (merge_alists (rev codegens1) (rev codegens2)),
   102      tycodegens = rev (merge_alists (rev tycodegens1) (rev tycodegens2)),
   103      consts   = merge_alists consts1 consts2,
   104      types    = merge_alists types1 types2};
   105 
   106   fun print sg ({codegens, tycodegens, ...} : T) =
   107     Pretty.writeln (Pretty.chunks
   108       [Pretty.strs ("term code generators:" :: map fst codegens),
   109        Pretty.strs ("type code generators:" :: map fst tycodegens)]);
   110 end;
   111 
   112 structure CodegenData = TheoryDataFun(CodegenArgs);
   113 val print_codegens = CodegenData.print;
   114 
   115 
   116 (**** add new code generators to theory ****)
   117 
   118 fun add_codegen name f thy =
   119   let val {codegens, tycodegens, consts, types} = CodegenData.get thy
   120   in (case assoc (codegens, name) of
   121       None => CodegenData.put {codegens = (name, f) :: codegens,
   122         tycodegens = tycodegens, consts = consts, types = types} thy
   123     | Some _ => error ("Code generator " ^ name ^ " already declared"))
   124   end;
   125 
   126 fun add_tycodegen name f thy =
   127   let val {codegens, tycodegens, consts, types} = CodegenData.get thy
   128   in (case assoc (tycodegens, name) of
   129       None => CodegenData.put {tycodegens = (name, f) :: tycodegens,
   130         codegens = codegens, consts = consts, types = types} thy
   131     | Some _ => error ("Code generator " ^ name ^ " already declared"))
   132   end;
   133 
   134 
   135 (**** associate constants with target language code ****)
   136 
   137 fun gen_assoc_consts prep_type xs thy = foldl (fn (thy, (s, tyopt, syn)) =>
   138   let
   139     val sg = sign_of thy;
   140     val {codegens, tycodegens, consts, types} = CodegenData.get thy;
   141     val cname = Sign.intern_const sg s;
   142   in
   143     (case Sign.const_type sg cname of
   144        Some T =>
   145          let val T' = (case tyopt of
   146                 None => T
   147               | Some ty =>
   148                   let val U = prep_type sg ty
   149                   in if Type.typ_instance (Sign.tsig_of sg, U, T) then U
   150                     else error ("Illegal type constraint for constant " ^ cname)
   151                   end)
   152          in (case assoc (consts, (cname, T')) of
   153              None => CodegenData.put {codegens = codegens,
   154                tycodegens = tycodegens,
   155                consts = ((cname, T'), syn) :: consts, types = types} thy
   156            | Some _ => error ("Constant " ^ cname ^ " already associated with code"))
   157          end
   158      | _ => error ("Not a constant: " ^ s))
   159   end) (thy, xs);
   160 
   161 val assoc_consts_i = gen_assoc_consts (K I);
   162 val assoc_consts = gen_assoc_consts (fn sg => typ_of o read_ctyp sg);
   163 
   164 (**** associate types with target language types ****)
   165 
   166 fun assoc_types xs thy = foldl (fn (thy, (s, syn)) =>
   167   let
   168     val {codegens, tycodegens, consts, types} = CodegenData.get thy;
   169     val tc = Sign.intern_tycon (sign_of thy) s
   170   in
   171     (case assoc (types, tc) of
   172        None => CodegenData.put {codegens = codegens,
   173          tycodegens = tycodegens, consts = consts,
   174          types = (tc, syn) :: types} thy
   175      | Some _ => error ("Type " ^ tc ^ " already associated with code"))
   176   end) (thy, xs);
   177 
   178 fun get_assoc_type thy s = assoc (#types (CodegenData.get thy), s);
   179 
   180 
   181 (**** retrieve definition of constant ****)
   182 
   183 fun is_instance thy T1 T2 =
   184   Type.typ_instance (Sign.tsig_of (sign_of thy), T1, Type.varifyT T2);
   185 
   186 fun get_assoc_code thy s T = apsome snd (find_first (fn ((s', T'), _) =>
   187   s = s' andalso is_instance thy T T') (#consts (CodegenData.get thy)));
   188 
   189 fun get_defn thy s T =
   190   let
   191     val axms = flat (map (Symtab.dest o #axioms o Theory.rep_theory)
   192       (thy :: Theory.ancestors_of thy));
   193     val defs = mapfilter (fn (_, t) =>
   194       (let
   195          val (lhs, rhs) = Logic.dest_equals t;
   196          val (c, args) = strip_comb lhs;
   197          val (s', T') = dest_Const c
   198        in if s=s' then Some (T', (args, rhs)) else None end) handle TERM _ =>
   199          None) axms;
   200     val i = find_index (is_instance thy T o fst) defs
   201   in
   202     if i>=0 then Some (snd (nth_elem (i, defs)),
   203       if length defs = 1 then None else Some i)
   204     else None
   205   end;
   206 
   207 
   208 (**** make valid ML identifiers ****)
   209 
   210 fun gen_mk_id kind rename sg s =
   211   let
   212     val (xs as x::_) = explode (rename (space_implode "_"
   213       (NameSpace.unpack (Sign.cond_extern sg kind s))));
   214     fun check_str [] = ""
   215       | check_str (x::xs) =
   216           (if Symbol.is_letter x orelse Symbol.is_digit x orelse x="_" then x
   217            else "_" ^ string_of_int (ord x)) ^ check_str xs
   218   in
   219     (if not (Symbol.is_letter x) then "id" else "") ^ check_str xs
   220   end;
   221 
   222 val mk_const_id = gen_mk_id Sign.constK I;
   223 val mk_type_id = gen_mk_id Sign.typeK
   224   (fn s => if s mem ThmDatabase.ml_reserved then s ^ "_type" else s);
   225 
   226 fun rename_term t =
   227   let
   228     val names = add_term_names (t, map (fst o fst o dest_Var) (term_vars t));
   229     val clash = names inter ThmDatabase.ml_reserved;
   230     val ps = clash ~~ variantlist (clash, names);
   231 
   232     fun rename (Var ((a, i), T)) = Var ((if_none (assoc (ps, a)) a, i), T)
   233       | rename (Free (a, T)) = Free (if_none (assoc (ps, a)) a, T)
   234       | rename (Abs (s, T, t)) = Abs (s, T, rename t)
   235       | rename (t $ u) = rename t $ rename u
   236       | rename t = t;
   237   in
   238     rename t
   239   end;
   240 
   241 
   242 (**** invoke suitable code generator for term / type ****)
   243 
   244 fun invoke_codegen thy dep brack (gr, t) = (case get_first
   245    (fn (_, f) => f thy gr dep brack t) (#codegens (CodegenData.get thy)) of
   246       None => error ("Unable to generate code for term:\n" ^
   247         Sign.string_of_term (sign_of thy) t ^ "\nrequired by:\n" ^
   248         commas (Graph.all_succs gr [dep]))
   249     | Some x => x);
   250 
   251 fun invoke_tycodegen thy dep brack (gr, T) = (case get_first
   252    (fn (_, f) => f thy gr dep brack T) (#tycodegens (CodegenData.get thy)) of
   253       None => error ("Unable to generate code for type:\n" ^
   254         Sign.string_of_typ (sign_of thy) T ^ "\nrequired by:\n" ^
   255         commas (Graph.all_succs gr [dep]))
   256     | Some x => x);
   257 
   258 
   259 (**** code generator for mixfix expressions ****)
   260 
   261 fun parens p = Pretty.block [Pretty.str "(", p, Pretty.str ")"];
   262 
   263 fun pretty_fn [] p = [p]
   264   | pretty_fn (x::xs) p = Pretty.str ("fn " ^ x ^ " =>") ::
   265       Pretty.brk 1 :: pretty_fn xs p;
   266 
   267 fun pretty_mixfix [] [] _ = []
   268   | pretty_mixfix (Arg :: ms) (p :: ps) qs = p :: pretty_mixfix ms ps qs
   269   | pretty_mixfix (Ignore :: ms) ps qs = pretty_mixfix ms ps qs
   270   | pretty_mixfix (Pretty p :: ms) ps qs = p :: pretty_mixfix ms ps qs
   271   | pretty_mixfix (Quote _ :: ms) ps (q :: qs) = q :: pretty_mixfix ms ps qs;
   272 
   273 
   274 (**** default code generators ****)
   275 
   276 fun eta_expand t ts i =
   277   let
   278     val (Ts, _) = strip_type (fastype_of t);
   279     val j = i - length ts
   280   in
   281     foldr (fn (T, t) => Abs ("x", T, t))
   282       (take (j, Ts), list_comb (t, ts @ map Bound (j-1 downto 0)))
   283   end;
   284 
   285 fun mk_app _ p [] = p
   286   | mk_app brack p ps = if brack then
   287        Pretty.block (Pretty.str "(" ::
   288          separate (Pretty.brk 1) (p :: ps) @ [Pretty.str ")"])
   289      else Pretty.block (separate (Pretty.brk 1) (p :: ps));
   290 
   291 fun new_names t xs = variantlist (xs,
   292   map (fst o fst o dest_Var) (term_vars t) union
   293   add_term_names (t, ThmDatabase.ml_reserved));
   294 
   295 fun new_name t x = hd (new_names t [x]);
   296 
   297 fun default_codegen thy gr dep brack t =
   298   let
   299     val (u, ts) = strip_comb t;
   300     fun codegens brack = foldl_map (invoke_codegen thy dep brack)
   301   in (case u of
   302       Var ((s, i), _) =>
   303         let val (gr', ps) = codegens true (gr, ts)
   304         in Some (gr', mk_app brack (Pretty.str (s ^
   305            (if i=0 then "" else string_of_int i))) ps)
   306         end
   307 
   308     | Free (s, _) =>
   309         let val (gr', ps) = codegens true (gr, ts)
   310         in Some (gr', mk_app brack (Pretty.str s) ps) end
   311 
   312     | Const (s, T) =>
   313       (case get_assoc_code thy s T of
   314          Some ms =>
   315            let val i = num_args ms
   316            in if length ts < i then
   317                default_codegen thy gr dep brack (eta_expand u ts i)
   318              else
   319                let
   320                  val (ts1, ts2) = args_of ms ts;
   321                  val (gr1, ps1) = codegens false (gr, ts1);
   322                  val (gr2, ps2) = codegens true (gr1, ts2);
   323                  val (gr3, ps3) = codegens false (gr2, quotes_of ms);
   324                in
   325                  Some (gr3, mk_app brack (Pretty.block (pretty_mixfix ms ps1 ps3)) ps2)
   326                end
   327            end
   328        | None => (case get_defn thy s T of
   329            None => None
   330          | Some ((args, rhs), k) =>
   331              let
   332                val id = mk_const_id (sign_of thy) s ^ (case k of
   333                  None => "" | Some i => "_def" ^ string_of_int i);
   334                val (gr', ps) = codegens true (gr, ts);
   335              in
   336                Some (Graph.add_edge (id, dep) gr' handle Graph.UNDEF _ =>
   337                  let
   338                    val _ = message ("expanding definition of " ^ s);
   339                    val (Ts, _) = strip_type T;
   340                    val (args', rhs') =
   341                      if not (null args) orelse null Ts then (args, rhs) else
   342                        let val v = Free (new_name rhs "x", hd Ts)
   343                        in ([v], betapply (rhs, v)) end;
   344                    val (gr1, p) = invoke_codegen thy id false
   345                      (Graph.add_edge (id, dep)
   346                         (Graph.new_node (id, (None, "")) gr'), rhs');
   347                    val (gr2, xs) = codegens false (gr1, args');
   348                  in Graph.map_node id (K (None, Pretty.string_of (Pretty.block
   349                    (Pretty.str (if null args' then "val " else "fun ") ::
   350                     separate (Pretty.brk 1) (Pretty.str id :: xs) @
   351                     [Pretty.str " =", Pretty.brk 1, p, Pretty.str ";"])) ^ "\n\n")) gr2
   352                  end, mk_app brack (Pretty.str id) ps)
   353              end))
   354 
   355     | Abs _ =>
   356       let
   357         val (bs, Ts) = ListPair.unzip (strip_abs_vars u);
   358         val t = strip_abs_body u
   359         val bs' = new_names t bs;
   360         val (gr1, ps) = codegens true (gr, ts);
   361         val (gr2, p) = invoke_codegen thy dep false
   362           (gr1, subst_bounds (map Free (rev (bs' ~~ Ts)), t));
   363       in
   364         Some (gr2, mk_app brack (Pretty.block (Pretty.str "(" :: pretty_fn bs' p @
   365           [Pretty.str ")"])) ps)
   366       end
   367 
   368     | _ => None)
   369   end;
   370 
   371 fun default_tycodegen thy gr dep brack (TVar ((s, i), _)) =
   372       Some (gr, Pretty.str (s ^ (if i = 0 then "" else string_of_int i)))
   373   | default_tycodegen thy gr dep brack (TFree (s, _)) = Some (gr, Pretty.str s)
   374   | default_tycodegen thy gr dep brack (Type (s, Ts)) =
   375       (case assoc (#types (CodegenData.get thy), s) of
   376          None => None
   377        | Some ms =>
   378            let
   379              val (gr', ps) = foldl_map
   380                (invoke_tycodegen thy dep false) (gr, fst (args_of ms Ts));
   381              val (gr'', qs) = foldl_map
   382                (invoke_tycodegen thy dep false) (gr', quotes_of ms)
   383            in Some (gr'', Pretty.block (pretty_mixfix ms ps qs)) end);
   384 
   385 
   386 fun output_code gr xs = implode (map (snd o Graph.get_node gr)
   387   (rev (Graph.all_preds gr xs)));
   388 
   389 fun gen_generate_code prep_term thy = Pretty.setmp_margin 80 (fn xs =>
   390   let
   391     val sg = sign_of thy;
   392     val gr = Graph.new_node ("<Top>", (None, "")) Graph.empty;
   393     val (gr', ps) = foldl_map (fn (gr, (s, t)) => apsnd (pair s)
   394       (invoke_codegen thy "<Top>" false (gr, t)))
   395         (gr, map (apsnd (prep_term sg)) xs)
   396   in
   397     "structure Generated =\nstruct\n\n" ^
   398     output_code gr' ["<Top>"] ^
   399     space_implode "\n\n" (map (fn (s', p) => Pretty.string_of (Pretty.block
   400       [Pretty.str ("val " ^ s' ^ " ="), Pretty.brk 1, p, Pretty.str ";"])) ps) ^
   401     "\n\nend;\n\nopen Generated;\n"
   402   end);
   403 
   404 val generate_code_i = gen_generate_code (K I);
   405 val generate_code = gen_generate_code
   406   (fn sg => term_of o read_cterm sg o rpair TypeInfer.logicT);
   407 
   408 
   409 (**** Interface ****)
   410 
   411 fun parse_mixfix rd s =
   412   (case Scan.finite Symbol.stopper (Scan.repeat
   413      (   $$ "_" >> K Arg
   414       || $$ "?" >> K Ignore
   415       || $$ "/" |-- Scan.repeat ($$ " ") >> (Pretty o Pretty.brk o length)
   416       || $$ "{" |-- $$ "*" |-- Scan.repeat1
   417            (   $$ "'" |-- Scan.one Symbol.not_eof
   418             || Scan.unless ($$ "*" -- $$ "}") (Scan.one Symbol.not_eof)) --|
   419          $$ "*" --| $$ "}" >> (Quote o rd o implode)
   420       || Scan.repeat1
   421            (   $$ "'" |-- Scan.one Symbol.not_eof
   422             || Scan.unless ($$ "_" || $$ "?" || $$ "/" || $$ "{" |-- $$ "*")
   423                  (Scan.one Symbol.not_eof)) >> (Pretty o Pretty.str o implode)))
   424        (Symbol.explode s) of
   425      (p, []) => p
   426    | _ => error ("Malformed annotation: " ^ quote s));
   427 
   428 structure P = OuterParse and K = OuterSyntax.Keyword;
   429 
   430 val assoc_typeP =
   431   OuterSyntax.command "types_code"
   432   "associate types with target language types" K.thy_decl
   433     (Scan.repeat1 (P.xname --| P.$$$ "(" -- P.string --| P.$$$ ")") >>
   434      (fn xs => Toplevel.theory (fn thy => assoc_types
   435        (map (fn (name, mfx) => (name, parse_mixfix
   436          (typ_of o read_ctyp (sign_of thy)) mfx)) xs) thy)));
   437 
   438 val assoc_constP =
   439   OuterSyntax.command "consts_code"
   440   "associate constants with target language code" K.thy_decl
   441     (Scan.repeat1
   442        (P.xname -- (Scan.option (P.$$$ "::" |-- P.string)) --|
   443         P.$$$ "(" -- P.string --| P.$$$ ")") >>
   444      (fn xs => Toplevel.theory (fn thy => assoc_consts
   445        (map (fn ((name, optype), mfx) => (name, optype, parse_mixfix
   446          (term_of o read_cterm (sign_of thy) o rpair TypeInfer.logicT) mfx))
   447            xs) thy)));
   448 
   449 val generate_codeP =
   450   OuterSyntax.command "generate_code" "generates code for terms" K.thy_decl
   451     (Scan.option (P.$$$ "(" |-- P.string --| P.$$$ ")") --
   452      Scan.repeat1 (P.name --| P.$$$ "=" -- P.string) >>
   453      (fn (opt_fname, xs) => Toplevel.theory (fn thy =>
   454         ((case opt_fname of
   455             None => use_text Context.ml_output false
   456           | Some fname => File.write (Path.unpack fname))
   457               (generate_code thy xs); thy))));
   458 
   459 val parsers = [assoc_typeP, assoc_constP, generate_codeP];
   460 
   461 val setup =
   462   [CodegenData.init,
   463    add_codegen "default" default_codegen,
   464    add_tycodegen "default" default_tycodegen,
   465    assoc_types [("fun", parse_mixfix (K dummyT) "(_ ->/ _)")]];
   466 
   467 end;
   468 
   469 OuterSyntax.add_parsers Codegen.parsers;