src/Pure/codegen.ML
author wenzelm
Fri Mar 01 22:32:38 2002 +0100 (2002-03-01)
changeset 13003 3d5807d45439
parent 12824 cdf586d56b8a
child 13073 cc9d7f403a4b
permissions -rw-r--r--
clarified outer syntax;
     1 (*  Title:      Pure/codegen.ML
     2     ID:         $Id$
     3     Author:     Stefan Berghofer, TU Muenchen
     4     License:    GPL (GNU GENERAL PUBLIC LICENSE)
     5 
     6 Generic code generator.
     7 *)
     8 
     9 signature CODEGEN =
    10 sig
    11   val quiet_mode : bool ref
    12   val message : string -> unit
    13 
    14   datatype 'a mixfix =
    15       Arg
    16     | Ignore
    17     | Pretty of Pretty.T
    18     | Quote of 'a;
    19 
    20   type 'a codegen
    21 
    22   val add_codegen: string -> term codegen -> theory -> theory
    23   val add_tycodegen: string -> typ codegen -> theory -> theory
    24   val add_attribute: string -> theory attribute -> theory -> theory
    25   val print_codegens: theory -> unit
    26   val generate_code: theory -> (string * string) list -> string
    27   val generate_code_i: theory -> (string * term) list -> string
    28   val assoc_consts: (xstring * string option * term mixfix list) list -> theory -> theory
    29   val assoc_consts_i: (xstring * typ option * term mixfix list) list -> theory -> theory
    30   val assoc_types: (xstring * typ mixfix list) list -> theory -> theory
    31   val get_assoc_code: theory -> string -> typ -> term mixfix list option
    32   val get_assoc_type: theory -> string -> typ mixfix list option
    33   val invoke_codegen: theory -> string -> bool ->
    34     (exn option * string) Graph.T * term -> (exn option * string) Graph.T * Pretty.T
    35   val invoke_tycodegen: theory -> string -> bool ->
    36     (exn option * string) Graph.T * typ -> (exn option * string) Graph.T * Pretty.T
    37   val mk_const_id: Sign.sg -> string -> string
    38   val mk_type_id: Sign.sg -> string -> string
    39   val rename_term: term -> term
    40   val get_defn: theory -> string -> typ -> ((term list * term) * int option) option
    41   val is_instance: theory -> typ -> typ -> bool
    42   val parens: Pretty.T -> Pretty.T
    43   val mk_app: bool -> Pretty.T -> Pretty.T list -> Pretty.T
    44   val eta_expand: term -> term list -> int -> term
    45   val parse_mixfix: (string -> 'a) -> string -> 'a mixfix list
    46   val parsers: OuterSyntax.parser list
    47   val setup: (theory -> theory) list
    48 end;
    49 
    50 structure Codegen : CODEGEN =
    51 struct
    52 
    53 val quiet_mode = ref true;
    54 fun message s = if !quiet_mode then () else writeln s;
    55 
    56 (**** Mixfix syntax ****)
    57 
    58 datatype 'a mixfix =
    59     Arg
    60   | Ignore
    61   | Pretty of Pretty.T
    62   | Quote of 'a;
    63 
    64 fun is_arg Arg = true
    65   | is_arg Ignore = true
    66   | is_arg _ = false;
    67 
    68 fun quotes_of [] = []
    69   | quotes_of (Quote q :: ms) = q :: quotes_of ms
    70   | quotes_of (_ :: ms) = quotes_of ms;
    71 
    72 fun args_of [] xs = ([], xs)
    73   | args_of (Arg :: ms) (x :: xs) = apfst (cons x) (args_of ms xs)
    74   | args_of (Ignore :: ms) (_ :: xs) = args_of ms xs
    75   | args_of (_ :: ms) xs = args_of ms xs;
    76 
    77 fun num_args x = length (filter is_arg x);
    78 
    79 
    80 (**** theory data ****)
    81 
    82 (* data kind 'Pure/codegen' *)
    83 
    84 type 'a codegen = theory -> (exn option * string) Graph.T ->
    85   string -> bool -> 'a -> ((exn option * string) Graph.T * Pretty.T) option;
    86 
    87 structure CodegenArgs =
    88 struct
    89   val name = "Pure/codegen";
    90   type T =
    91     {codegens : (string * term codegen) list,
    92      tycodegens : (string * typ codegen) list,
    93      consts : ((string * typ) * term mixfix list) list,
    94      types : (string * typ mixfix list) list,
    95      attrs: (string * theory attribute) list};
    96 
    97   val empty =
    98     {codegens = [], tycodegens = [], consts = [], types = [], attrs = []};
    99   val copy = I;
   100   val prep_ext = I;
   101 
   102   fun merge
   103     ({codegens = codegens1, tycodegens = tycodegens1,
   104       consts = consts1, types = types1, attrs = attrs1},
   105      {codegens = codegens2, tycodegens = tycodegens2,
   106       consts = consts2, types = types2, attrs = attrs2}) =
   107     {codegens = rev (merge_alists (rev codegens1) (rev codegens2)),
   108      tycodegens = rev (merge_alists (rev tycodegens1) (rev tycodegens2)),
   109      consts = merge_alists consts1 consts2,
   110      types = merge_alists types1 types2,
   111      attrs = merge_alists attrs1 attrs2};
   112 
   113   fun print sg ({codegens, tycodegens, ...} : T) =
   114     Pretty.writeln (Pretty.chunks
   115       [Pretty.strs ("term code generators:" :: map fst codegens),
   116        Pretty.strs ("type code generators:" :: map fst tycodegens)]);
   117 end;
   118 
   119 structure CodegenData = TheoryDataFun(CodegenArgs);
   120 val print_codegens = CodegenData.print;
   121 
   122 
   123 (**** add new code generators to theory ****)
   124 
   125 fun add_codegen name f thy =
   126   let val {codegens, tycodegens, consts, types, attrs} = CodegenData.get thy
   127   in (case assoc (codegens, name) of
   128       None => CodegenData.put {codegens = (name, f) :: codegens,
   129         tycodegens = tycodegens, consts = consts, types = types,
   130         attrs = attrs} thy
   131     | Some _ => error ("Code generator " ^ name ^ " already declared"))
   132   end;
   133 
   134 fun add_tycodegen name f thy =
   135   let val {codegens, tycodegens, consts, types, attrs} = CodegenData.get thy
   136   in (case assoc (tycodegens, name) of
   137       None => CodegenData.put {tycodegens = (name, f) :: tycodegens,
   138         codegens = codegens, consts = consts, types = types,
   139         attrs = attrs} thy
   140     | Some _ => error ("Code generator " ^ name ^ " already declared"))
   141   end;
   142 
   143 
   144 (**** code attribute ****)
   145 
   146 fun add_attribute name att thy =
   147   let val {codegens, tycodegens, consts, types, attrs} = CodegenData.get thy
   148   in (case assoc (attrs, name) of
   149       None => CodegenData.put {tycodegens = tycodegens,
   150         codegens = codegens, consts = consts, types = types,
   151         attrs = (name, att) :: attrs} thy
   152     | Some _ => error ("Code attribute " ^ name ^ " already declared"))
   153   end;
   154 
   155 val code_attr =
   156   Attrib.syntax (Scan.depend (fn thy => Scan.optional Args.name "" >>
   157     (fn name => (thy, case assoc (#attrs (CodegenData.get thy), name) of
   158           None => error ("Unknown code attribute: " ^ quote name)
   159         | Some att => att)))); 
   160 
   161 
   162 (**** associate constants with target language code ****)
   163 
   164 fun gen_assoc_consts prep_type xs thy = foldl (fn (thy, (s, tyopt, syn)) =>
   165   let
   166     val sg = sign_of thy;
   167     val {codegens, tycodegens, consts, types, attrs} = CodegenData.get thy;
   168     val cname = Sign.intern_const sg s;
   169   in
   170     (case Sign.const_type sg cname of
   171        Some T =>
   172          let val T' = (case tyopt of
   173                 None => T
   174               | Some ty =>
   175                   let val U = prep_type sg ty
   176                   in if Type.typ_instance (Sign.tsig_of sg, U, T) then U
   177                     else error ("Illegal type constraint for constant " ^ cname)
   178                   end)
   179          in (case assoc (consts, (cname, T')) of
   180              None => CodegenData.put {codegens = codegens,
   181                tycodegens = tycodegens,
   182                consts = ((cname, T'), syn) :: consts,
   183                types = types, attrs = attrs} thy
   184            | Some _ => error ("Constant " ^ cname ^ " already associated with code"))
   185          end
   186      | _ => error ("Not a constant: " ^ s))
   187   end) (thy, xs);
   188 
   189 val assoc_consts_i = gen_assoc_consts (K I);
   190 val assoc_consts = gen_assoc_consts (fn sg => typ_of o read_ctyp sg);
   191 
   192 (**** associate types with target language types ****)
   193 
   194 fun assoc_types xs thy = foldl (fn (thy, (s, syn)) =>
   195   let
   196     val {codegens, tycodegens, consts, types, attrs} = CodegenData.get thy;
   197     val tc = Sign.intern_tycon (sign_of thy) s
   198   in
   199     (case assoc (types, tc) of
   200        None => CodegenData.put {codegens = codegens,
   201          tycodegens = tycodegens, consts = consts,
   202          types = (tc, syn) :: types, attrs = attrs} thy
   203      | Some _ => error ("Type " ^ tc ^ " already associated with code"))
   204   end) (thy, xs);
   205 
   206 fun get_assoc_type thy s = assoc (#types (CodegenData.get thy), s);
   207 
   208 
   209 (**** retrieve definition of constant ****)
   210 
   211 fun is_instance thy T1 T2 =
   212   Type.typ_instance (Sign.tsig_of (sign_of thy), T1, Type.varifyT T2);
   213 
   214 fun get_assoc_code thy s T = apsome snd (find_first (fn ((s', T'), _) =>
   215   s = s' andalso is_instance thy T T') (#consts (CodegenData.get thy)));
   216 
   217 fun get_defn thy s T =
   218   let
   219     val axms = flat (map (Symtab.dest o #axioms o Theory.rep_theory)
   220       (thy :: Theory.ancestors_of thy));
   221     val defs = mapfilter (fn (_, t) =>
   222       (let
   223          val (lhs, rhs) = Logic.dest_equals t;
   224          val (c, args) = strip_comb lhs;
   225          val (s', T') = dest_Const c
   226        in if s=s' then Some (T', (args, rhs)) else None end) handle TERM _ =>
   227          None) axms;
   228     val i = find_index (is_instance thy T o fst) defs
   229   in
   230     if i>=0 then Some (snd (nth_elem (i, defs)),
   231       if length defs = 1 then None else Some i)
   232     else None
   233   end;
   234 
   235 
   236 (**** make valid ML identifiers ****)
   237 
   238 fun gen_mk_id kind rename sg s =
   239   let
   240     val (xs as x::_) = explode (rename (space_implode "_"
   241       (NameSpace.unpack (Sign.cond_extern sg kind s))));
   242     fun check_str [] = ""
   243       | check_str (" " :: xs) = "_" ^ check_str xs
   244       | check_str (x :: xs) =
   245           (if Symbol.is_letdig x then x
   246            else "_" ^ string_of_int (ord x)) ^ check_str xs
   247   in
   248     (if not (Symbol.is_letter x) then "id" else "") ^ check_str xs
   249   end;
   250 
   251 val mk_const_id = gen_mk_id Sign.constK I;
   252 val mk_type_id = gen_mk_id Sign.typeK
   253   (fn s => if s mem ThmDatabase.ml_reserved then s ^ "_type" else s);
   254 
   255 fun rename_term t =
   256   let
   257     val names = add_term_names (t, map (fst o fst o dest_Var) (term_vars t));
   258     val clash = names inter ThmDatabase.ml_reserved;
   259     val ps = clash ~~ variantlist (clash, names);
   260 
   261     fun rename (Var ((a, i), T)) = Var ((if_none (assoc (ps, a)) a, i), T)
   262       | rename (Free (a, T)) = Free (if_none (assoc (ps, a)) a, T)
   263       | rename (Abs (s, T, t)) = Abs (s, T, rename t)
   264       | rename (t $ u) = rename t $ rename u
   265       | rename t = t;
   266   in
   267     rename t
   268   end;
   269 
   270 
   271 (**** invoke suitable code generator for term / type ****)
   272 
   273 fun invoke_codegen thy dep brack (gr, t) = (case get_first
   274    (fn (_, f) => f thy gr dep brack t) (#codegens (CodegenData.get thy)) of
   275       None => error ("Unable to generate code for term:\n" ^
   276         Sign.string_of_term (sign_of thy) t ^ "\nrequired by:\n" ^
   277         commas (Graph.all_succs gr [dep]))
   278     | Some x => x);
   279 
   280 fun invoke_tycodegen thy dep brack (gr, T) = (case get_first
   281    (fn (_, f) => f thy gr dep brack T) (#tycodegens (CodegenData.get thy)) of
   282       None => error ("Unable to generate code for type:\n" ^
   283         Sign.string_of_typ (sign_of thy) T ^ "\nrequired by:\n" ^
   284         commas (Graph.all_succs gr [dep]))
   285     | Some x => x);
   286 
   287 
   288 (**** code generator for mixfix expressions ****)
   289 
   290 fun parens p = Pretty.block [Pretty.str "(", p, Pretty.str ")"];
   291 
   292 fun pretty_fn [] p = [p]
   293   | pretty_fn (x::xs) p = Pretty.str ("fn " ^ x ^ " =>") ::
   294       Pretty.brk 1 :: pretty_fn xs p;
   295 
   296 fun pretty_mixfix [] [] _ = []
   297   | pretty_mixfix (Arg :: ms) (p :: ps) qs = p :: pretty_mixfix ms ps qs
   298   | pretty_mixfix (Ignore :: ms) ps qs = pretty_mixfix ms ps qs
   299   | pretty_mixfix (Pretty p :: ms) ps qs = p :: pretty_mixfix ms ps qs
   300   | pretty_mixfix (Quote _ :: ms) ps (q :: qs) = q :: pretty_mixfix ms ps qs;
   301 
   302 
   303 (**** default code generators ****)
   304 
   305 fun eta_expand t ts i =
   306   let
   307     val (Ts, _) = strip_type (fastype_of t);
   308     val j = i - length ts
   309   in
   310     foldr (fn (T, t) => Abs ("x", T, t))
   311       (take (j, Ts), list_comb (t, ts @ map Bound (j-1 downto 0)))
   312   end;
   313 
   314 fun mk_app _ p [] = p
   315   | mk_app brack p ps = if brack then
   316        Pretty.block (Pretty.str "(" ::
   317          separate (Pretty.brk 1) (p :: ps) @ [Pretty.str ")"])
   318      else Pretty.block (separate (Pretty.brk 1) (p :: ps));
   319 
   320 fun new_names t xs = variantlist (xs,
   321   map (fst o fst o dest_Var) (term_vars t) union
   322   add_term_names (t, ThmDatabase.ml_reserved));
   323 
   324 fun new_name t x = hd (new_names t [x]);
   325 
   326 fun default_codegen thy gr dep brack t =
   327   let
   328     val (u, ts) = strip_comb t;
   329     fun codegens brack = foldl_map (invoke_codegen thy dep brack)
   330   in (case u of
   331       Var ((s, i), _) =>
   332         let val (gr', ps) = codegens true (gr, ts)
   333         in Some (gr', mk_app brack (Pretty.str (s ^
   334            (if i=0 then "" else string_of_int i))) ps)
   335         end
   336 
   337     | Free (s, _) =>
   338         let val (gr', ps) = codegens true (gr, ts)
   339         in Some (gr', mk_app brack (Pretty.str s) ps) end
   340 
   341     | Const (s, T) =>
   342       (case get_assoc_code thy s T of
   343          Some ms =>
   344            let val i = num_args ms
   345            in if length ts < i then
   346                default_codegen thy gr dep brack (eta_expand u ts i)
   347              else
   348                let
   349                  val (ts1, ts2) = args_of ms ts;
   350                  val (gr1, ps1) = codegens false (gr, ts1);
   351                  val (gr2, ps2) = codegens true (gr1, ts2);
   352                  val (gr3, ps3) = codegens false (gr2, quotes_of ms);
   353                in
   354                  Some (gr3, mk_app brack (Pretty.block (pretty_mixfix ms ps1 ps3)) ps2)
   355                end
   356            end
   357        | None => (case get_defn thy s T of
   358            None => None
   359          | Some ((args, rhs), k) =>
   360              let
   361                val id = mk_const_id (sign_of thy) s ^ (case k of
   362                  None => "" | Some i => "_def" ^ string_of_int i);
   363                val (gr', ps) = codegens true (gr, ts);
   364              in
   365                Some (Graph.add_edge (id, dep) gr' handle Graph.UNDEF _ =>
   366                  let
   367                    val _ = message ("expanding definition of " ^ s);
   368                    val (Ts, _) = strip_type T;
   369                    val (args', rhs') =
   370                      if not (null args) orelse null Ts then (args, rhs) else
   371                        let val v = Free (new_name rhs "x", hd Ts)
   372                        in ([v], betapply (rhs, v)) end;
   373                    val (gr1, p) = invoke_codegen thy id false
   374                      (Graph.add_edge (id, dep)
   375                         (Graph.new_node (id, (None, "")) gr'), rhs');
   376                    val (gr2, xs) = codegens false (gr1, args');
   377                    val (gr3, ty) = invoke_tycodegen thy id false (gr2, T);
   378                  in Graph.map_node id (K (None, Pretty.string_of (Pretty.block
   379                    (separate (Pretty.brk 1) (if null args' then
   380                        [Pretty.str ("val " ^ id ^ " :"), ty]
   381                      else Pretty.str ("fun " ^ id) :: xs) @
   382                     [Pretty.str " =", Pretty.brk 1, p, Pretty.str ";"])) ^ "\n\n")) gr3
   383                  end, mk_app brack (Pretty.str id) ps)
   384              end))
   385 
   386     | Abs _ =>
   387       let
   388         val (bs, Ts) = ListPair.unzip (strip_abs_vars u);
   389         val t = strip_abs_body u
   390         val bs' = new_names t bs;
   391         val (gr1, ps) = codegens true (gr, ts);
   392         val (gr2, p) = invoke_codegen thy dep false
   393           (gr1, subst_bounds (map Free (rev (bs' ~~ Ts)), t));
   394       in
   395         Some (gr2, mk_app brack (Pretty.block (Pretty.str "(" :: pretty_fn bs' p @
   396           [Pretty.str ")"])) ps)
   397       end
   398 
   399     | _ => None)
   400   end;
   401 
   402 fun default_tycodegen thy gr dep brack (TVar ((s, i), _)) =
   403       Some (gr, Pretty.str (s ^ (if i = 0 then "" else string_of_int i)))
   404   | default_tycodegen thy gr dep brack (TFree (s, _)) = Some (gr, Pretty.str s)
   405   | default_tycodegen thy gr dep brack (Type (s, Ts)) =
   406       (case assoc (#types (CodegenData.get thy), s) of
   407          None => None
   408        | Some ms =>
   409            let
   410              val (gr', ps) = foldl_map
   411                (invoke_tycodegen thy dep false) (gr, fst (args_of ms Ts));
   412              val (gr'', qs) = foldl_map
   413                (invoke_tycodegen thy dep false) (gr', quotes_of ms)
   414            in Some (gr'', Pretty.block (pretty_mixfix ms ps qs)) end);
   415 
   416 
   417 fun output_code gr xs = implode (map (snd o Graph.get_node gr)
   418   (rev (Graph.all_preds gr xs)));
   419 
   420 fun gen_generate_code prep_term thy = Pretty.setmp_margin 80 (fn xs =>
   421   let
   422     val sg = sign_of thy;
   423     val gr = Graph.new_node ("<Top>", (None, "")) Graph.empty;
   424     val (gr', ps) = foldl_map (fn (gr, (s, t)) => apsnd (pair s)
   425       (invoke_codegen thy "<Top>" false (gr, t)))
   426         (gr, map (apsnd (prep_term sg)) xs)
   427   in
   428     "structure Generated =\nstruct\n\n" ^
   429     output_code gr' ["<Top>"] ^
   430     space_implode "\n\n" (map (fn (s', p) => Pretty.string_of (Pretty.block
   431       [Pretty.str ("val " ^ s' ^ " ="), Pretty.brk 1, p, Pretty.str ";"])) ps) ^
   432     "\n\nend;\n\nopen Generated;\n"
   433   end);
   434 
   435 val generate_code_i = gen_generate_code (K I);
   436 val generate_code = gen_generate_code
   437   (fn sg => term_of o read_cterm sg o rpair TypeInfer.logicT);
   438 
   439 
   440 (**** Interface ****)
   441 
   442 fun parse_mixfix rd s =
   443   (case Scan.finite Symbol.stopper (Scan.repeat
   444      (   $$ "_" >> K Arg
   445       || $$ "?" >> K Ignore
   446       || $$ "/" |-- Scan.repeat ($$ " ") >> (Pretty o Pretty.brk o length)
   447       || $$ "{" |-- $$ "*" |-- Scan.repeat1
   448            (   $$ "'" |-- Scan.one Symbol.not_eof
   449             || Scan.unless ($$ "*" -- $$ "}") (Scan.one Symbol.not_eof)) --|
   450          $$ "*" --| $$ "}" >> (Quote o rd o implode)
   451       || Scan.repeat1
   452            (   $$ "'" |-- Scan.one Symbol.not_eof
   453             || Scan.unless ($$ "_" || $$ "?" || $$ "/" || $$ "{" |-- $$ "*")
   454                  (Scan.one Symbol.not_eof)) >> (Pretty o Pretty.str o implode)))
   455        (Symbol.explode s) of
   456      (p, []) => p
   457    | _ => error ("Malformed annotation: " ^ quote s));
   458 
   459 structure P = OuterParse and K = OuterSyntax.Keyword;
   460 
   461 val assoc_typeP =
   462   OuterSyntax.command "types_code"
   463   "associate types with target language types" K.thy_decl
   464     (Scan.repeat1 (P.xname --| P.$$$ "(" -- P.string --| P.$$$ ")") >>
   465      (fn xs => Toplevel.theory (fn thy => assoc_types
   466        (map (fn (name, mfx) => (name, parse_mixfix
   467          (typ_of o read_ctyp (sign_of thy)) mfx)) xs) thy)));
   468 
   469 val assoc_constP =
   470   OuterSyntax.command "consts_code"
   471   "associate constants with target language code" K.thy_decl
   472     (Scan.repeat1
   473        (P.xname -- (Scan.option (P.$$$ "::" |-- P.typ)) --|
   474         P.$$$ "(" -- P.string --| P.$$$ ")") >>
   475      (fn xs => Toplevel.theory (fn thy => assoc_consts
   476        (map (fn ((name, optype), mfx) => (name, optype, parse_mixfix
   477          (term_of o read_cterm (sign_of thy) o rpair TypeInfer.logicT) mfx))
   478            xs) thy)));
   479 
   480 val generate_codeP =
   481   OuterSyntax.command "generate_code" "generates code for terms" K.thy_decl
   482     (Scan.option (P.$$$ "(" |-- P.name --| P.$$$ ")") --
   483      Scan.repeat1 (P.name --| P.$$$ "=" -- P.term) >>
   484      (fn (opt_fname, xs) => Toplevel.theory (fn thy =>
   485         ((case opt_fname of
   486             None => use_text Context.ml_output false
   487           | Some fname => File.write (Path.unpack fname))
   488               (generate_code thy xs); thy))));
   489 
   490 val parsers = [assoc_typeP, assoc_constP, generate_codeP];
   491 
   492 val setup =
   493   [CodegenData.init,
   494    add_codegen "default" default_codegen,
   495    add_tycodegen "default" default_tycodegen,
   496    assoc_types [("fun", parse_mixfix (K dummyT) "(_ ->/ _)")],
   497    Attrib.add_attributes [("code",
   498      (code_attr, K Attrib.undef_local_attribute),
   499      "declare theorems for code generation")]];
   500 
   501 end;
   502 
   503 OuterSyntax.add_parsers Codegen.parsers;