src/HOL/Tools/datatype_codegen.ML
author wenzelm
Wed Apr 13 18:34:22 2005 +0200 (2005-04-13)
changeset 15703 727ef1b8b3ee
parent 15574 b1d1b5bfc464
child 16645 a152d6b21c31
permissions -rw-r--r--
*** empty log message ***
     1 (*  Title:      HOL/datatype_codegen.ML
     2     ID:         $Id$
     3     Author:     Stefan Berghofer, TU Muenchen
     4 
     5 Code generator for inductive datatypes.
     6 *)
     7 
     8 signature DATATYPE_CODEGEN =
     9 sig
    10   val setup: (theory -> theory) list
    11 end;
    12 
    13 structure DatatypeCodegen : DATATYPE_CODEGEN =
    14 struct
    15 
    16 open Codegen;
    17 
    18 fun mk_tuple [p] = p
    19   | mk_tuple ps = Pretty.block (Pretty.str "(" ::
    20       List.concat (separate [Pretty.str ",", Pretty.brk 1] (map single ps)) @
    21         [Pretty.str ")"]);
    22 
    23 (**** datatype definition ****)
    24 
    25 (* find shortest path to constructor with no recursive arguments *)
    26 
    27 fun find_nonempty (descr: DatatypeAux.descr) is i =
    28   let
    29     val (_, _, constrs) = valOf (assoc (descr, i));
    30     fun arg_nonempty (_, DatatypeAux.DtRec i) = if i mem is then NONE
    31           else Option.map (curry op + 1 o snd) (find_nonempty descr (i::is) i)
    32       | arg_nonempty _ = SOME 0;
    33     fun max xs = Library.foldl
    34       (fn (NONE, _) => NONE
    35         | (SOME i, SOME j) => SOME (Int.max (i, j))
    36         | (_, NONE) => NONE) (SOME 0, xs);
    37     val xs = sort (int_ord o pairself snd)
    38       (List.mapPartial (fn (s, dts) => Option.map (pair s)
    39         (max (map (arg_nonempty o DatatypeAux.strip_dtyp) dts))) constrs)
    40   in case xs of [] => NONE | x :: _ => SOME x end;
    41 
    42 fun add_dt_defs thy dep gr (descr: DatatypeAux.descr) =
    43   let
    44     val sg = sign_of thy;
    45     val tab = DatatypePackage.get_datatypes thy;
    46 
    47     val descr' = List.filter (can (map DatatypeAux.dest_DtTFree o #2 o snd)) descr;
    48     val rtnames = map (#1 o snd) (List.filter (fn (_, (_, _, cs)) =>
    49       exists (exists DatatypeAux.is_rec_type o snd) cs) descr');
    50 
    51     val (_, (_, _, (cname, _) :: _)) :: _ = descr';
    52     val dname = mk_const_id sg cname;
    53 
    54     fun mk_dtdef gr prfx [] = (gr, [])
    55       | mk_dtdef gr prfx ((_, (tname, dts, cs))::xs) =
    56           let
    57             val tvs = map DatatypeAux.dest_DtTFree dts;
    58             val sorts = map (rpair []) tvs;
    59             val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
    60             val (gr', ps) = foldl_map (fn (gr, (cname, cargs)) =>
    61               apsnd (pair cname) (foldl_map
    62                 (invoke_tycodegen thy dname false) (gr, cargs))) (gr, cs');
    63             val (gr'', rest) = mk_dtdef gr' "and " xs
    64           in
    65             (gr'',
    66              Pretty.block (Pretty.str prfx ::
    67                (if null tvs then [] else
    68                   [mk_tuple (map Pretty.str tvs), Pretty.str " "]) @
    69                [Pretty.str (mk_type_id sg tname ^ " ="), Pretty.brk 1] @
    70                List.concat (separate [Pretty.brk 1, Pretty.str "| "]
    71                  (map (fn (cname, ps') => [Pretty.block
    72                    (Pretty.str (mk_const_id sg cname) ::
    73                     (if null ps' then [] else
    74                      List.concat ([Pretty.str " of", Pretty.brk 1] ::
    75                        separate [Pretty.str " *", Pretty.brk 1]
    76                          (map single ps'))))]) ps))) :: rest)
    77           end;
    78 
    79     fun mk_term_of_def prfx [] = []
    80       | mk_term_of_def prfx ((_, (tname, dts, cs)) :: xs) =
    81           let
    82             val tvs = map DatatypeAux.dest_DtTFree dts;
    83             val sorts = map (rpair []) tvs;
    84             val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
    85             val dts' = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
    86             val T = Type (tname, dts');
    87             val rest = mk_term_of_def "and " xs;
    88             val (_, eqs) = foldl_map (fn (prfx, (cname, Ts)) =>
    89               let val args = map (fn i =>
    90                 Pretty.str ("x" ^ string_of_int i)) (1 upto length Ts)
    91               in ("  | ", Pretty.blk (4,
    92                 [Pretty.str prfx, mk_term_of sg false T, Pretty.brk 1,
    93                  if null Ts then Pretty.str (mk_const_id sg cname)
    94                  else parens (Pretty.block [Pretty.str (mk_const_id sg cname),
    95                     Pretty.brk 1, mk_tuple args]),
    96                  Pretty.str " =", Pretty.brk 1] @
    97                  List.concat (separate [Pretty.str " $", Pretty.brk 1]
    98                    ([Pretty.str ("Const (\"" ^ cname ^ "\","), Pretty.brk 1,
    99                      mk_type false (Ts ---> T), Pretty.str ")"] ::
   100                     map (fn (x, U) => [Pretty.block [mk_term_of sg false U,
   101                       Pretty.brk 1, x]]) (args ~~ Ts)))))
   102               end) (prfx, cs')
   103           in eqs @ rest end;
   104 
   105     fun mk_gen_of_def prfx [] = []
   106       | mk_gen_of_def prfx ((i, (tname, dts, cs)) :: xs) =
   107           let
   108             val tvs = map DatatypeAux.dest_DtTFree dts;
   109             val sorts = map (rpair []) tvs;
   110             val (cs1, cs2) =
   111               List.partition (exists DatatypeAux.is_rec_type o snd) cs;
   112             val SOME (cname, _) = find_nonempty descr [i] i;
   113 
   114             fun mk_delay p = Pretty.block
   115               [Pretty.str "fn () =>", Pretty.brk 1, p];
   116 
   117             fun mk_constr s b (cname, dts) =
   118               let
   119                 val gs = map (fn dt => mk_app false (mk_gen sg false rtnames s
   120                     (DatatypeAux.typ_of_dtyp descr sorts dt))
   121                   [Pretty.str (if b andalso DatatypeAux.is_rec_type dt then "0"
   122                      else "j")]) dts;
   123                 val id = mk_const_id sg cname
   124               in case gs of
   125                   _ :: _ :: _ => Pretty.block
   126                     [Pretty.str id, Pretty.brk 1, mk_tuple gs]
   127                 | _ => mk_app false (Pretty.str id) (map parens gs)
   128               end;
   129 
   130             fun mk_choice [c] = mk_constr "(i-1)" false c
   131               | mk_choice cs = Pretty.block [Pretty.str "one_of",
   132                   Pretty.brk 1, Pretty.blk (1, Pretty.str "[" ::
   133                   List.concat (separate [Pretty.str ",", Pretty.fbrk]
   134                     (map (single o mk_delay o mk_constr "(i-1)" false) cs)) @
   135                   [Pretty.str "]"]), Pretty.brk 1, Pretty.str "()"];
   136 
   137             val gs = map (Pretty.str o suffix "G" o strip_tname) tvs;
   138             val gen_name = "gen_" ^ mk_type_id sg tname
   139 
   140           in
   141             Pretty.blk (4, separate (Pretty.brk 1) 
   142                 (Pretty.str (prfx ^ gen_name ^
   143                    (if null cs1 then "" else "'")) :: gs @
   144                  (if null cs1 then [] else [Pretty.str "i"]) @
   145                  [Pretty.str "j"]) @
   146               [Pretty.str " =", Pretty.brk 1] @
   147               (if not (null cs1) andalso not (null cs2)
   148                then [Pretty.str "frequency", Pretty.brk 1,
   149                  Pretty.blk (1, [Pretty.str "[",
   150                    mk_tuple [Pretty.str "i", mk_delay (mk_choice cs1)],
   151                    Pretty.str ",", Pretty.fbrk,
   152                    mk_tuple [Pretty.str "1", mk_delay (mk_choice cs2)],
   153                    Pretty.str "]"]), Pretty.brk 1, Pretty.str "()"]
   154                else if null cs2 then
   155                  [Pretty.block [Pretty.str "(case", Pretty.brk 1,
   156                    Pretty.str "i", Pretty.brk 1, Pretty.str "of",
   157                    Pretty.brk 1, Pretty.str "0 =>", Pretty.brk 1,
   158                    mk_constr "0" true (cname, valOf (assoc (cs, cname))),
   159                    Pretty.brk 1, Pretty.str "| _ =>", Pretty.brk 1,
   160                    mk_choice cs1, Pretty.str ")"]]
   161                else [mk_choice cs2])) ::
   162             (if null cs1 then []
   163              else [Pretty.blk (4, separate (Pretty.brk 1) 
   164                  (Pretty.str ("and " ^ gen_name) :: gs @ [Pretty.str "i"]) @
   165                [Pretty.str " =", Pretty.brk 1] @
   166                separate (Pretty.brk 1) (Pretty.str (gen_name ^ "'") :: gs @
   167                  [Pretty.str "i", Pretty.str "i"]))]) @
   168             mk_gen_of_def "and " xs
   169           end
   170 
   171   in
   172     ((Graph.add_edge_acyclic (dname, dep) gr
   173         handle Graph.CYCLES _ => gr) handle Graph.UNDEF _ =>
   174          let
   175            val gr1 = Graph.add_edge (dname, dep)
   176              (Graph.new_node (dname, (NONE, "")) gr);
   177            val (gr2, dtdef) = mk_dtdef gr1 "datatype " descr';
   178          in
   179            Graph.map_node dname (K (NONE,
   180              Pretty.string_of (Pretty.blk (0, separate Pretty.fbrk dtdef @
   181                [Pretty.str ";"])) ^ "\n\n" ^
   182              (if "term_of" mem !mode then
   183                 Pretty.string_of (Pretty.blk (0, separate Pretty.fbrk
   184                   (mk_term_of_def "fun " descr') @ [Pretty.str ";"])) ^ "\n\n"
   185               else "") ^
   186              (if "test" mem !mode then
   187                 Pretty.string_of (Pretty.blk (0, separate Pretty.fbrk
   188                   (mk_gen_of_def "fun " descr') @ [Pretty.str ";"])) ^ "\n\n"
   189               else ""))) gr2
   190          end)
   191   end;
   192 
   193 
   194 (**** case expressions ****)
   195 
   196 fun pretty_case thy gr dep brack constrs (c as Const (_, T)) ts =
   197   let val i = length constrs
   198   in if length ts <= i then
   199        invoke_codegen thy dep brack (gr, eta_expand c ts (i+1))
   200     else
   201       let
   202         val ts1 = Library.take (i, ts);
   203         val t :: ts2 = Library.drop (i, ts);
   204         val names = foldr add_term_names
   205           (map (fst o fst o dest_Var) (foldr add_term_vars [] ts1)) ts1;
   206         val (Ts, dT) = split_last (Library.take (i+1, fst (strip_type T)));
   207 
   208         fun pcase gr [] [] [] = ([], gr)
   209           | pcase gr ((cname, cargs)::cs) (t::ts) (U::Us) =
   210               let
   211                 val j = length cargs;
   212                 val xs = variantlist (replicate j "x", names);
   213                 val Us' = Library.take (j, fst (strip_type U));
   214                 val frees = map Free (xs ~~ Us');
   215                 val (gr0, cp) = invoke_codegen thy dep false
   216                   (gr, list_comb (Const (cname, Us' ---> dT), frees));
   217                 val t' = Envir.beta_norm (list_comb (t, frees));
   218                 val (gr1, p) = invoke_codegen thy dep false (gr0, t');
   219                 val (ps, gr2) = pcase gr1 cs ts Us;
   220               in
   221                 ([Pretty.block [cp, Pretty.str " =>", Pretty.brk 1, p]] :: ps, gr2)
   222               end;
   223 
   224         val (ps1, gr1) = pcase gr constrs ts1 Ts;
   225         val ps = List.concat (separate [Pretty.brk 1, Pretty.str "| "] ps1);
   226         val (gr2, p) = invoke_codegen thy dep false (gr1, t);
   227         val (gr3, ps2) = foldl_map (invoke_codegen thy dep true) (gr2, ts2)
   228       in (gr3, (if not (null ts2) andalso brack then parens else I)
   229         (Pretty.block (separate (Pretty.brk 1)
   230           (Pretty.block ([Pretty.str "(case ", p, Pretty.str " of",
   231              Pretty.brk 1] @ ps @ [Pretty.str ")"]) :: ps2))))
   232       end
   233   end;
   234 
   235 
   236 (**** constructors ****)
   237 
   238 fun pretty_constr thy gr dep brack args (c as Const (s, _)) ts =
   239   let val i = length args
   240   in if i > 1 andalso length ts < i then
   241       invoke_codegen thy dep brack (gr, eta_expand c ts i)
   242      else
   243        let
   244          val id = mk_const_id (sign_of thy) s;
   245          val (gr', ps) = foldl_map (invoke_codegen thy dep (i = 1)) (gr, ts);
   246        in (case args of
   247           _ :: _ :: _ => (gr', (if brack then parens else I)
   248             (Pretty.block [Pretty.str id, Pretty.brk 1, mk_tuple ps]))
   249         | _ => (gr', mk_app brack (Pretty.str id) ps))
   250        end
   251   end;
   252 
   253 
   254 (**** code generators for terms and types ****)
   255 
   256 fun datatype_codegen thy gr dep brack t = (case strip_comb t of
   257    (c as Const (s, T), ts) =>
   258        (case find_first (fn (_, {index, descr, case_name, ...}) =>
   259          s = case_name orelse
   260            isSome (assoc (#3 (valOf (assoc (descr, index))), s)))
   261              (Symtab.dest (DatatypePackage.get_datatypes thy)) of
   262           NONE => NONE
   263         | SOME (tname, {index, descr, ...}) =>
   264            if isSome (get_assoc_code thy s T) then NONE else
   265            let val SOME (_, _, constrs) = assoc (descr, index)
   266            in (case (assoc (constrs, s), strip_type T) of
   267                (NONE, _) => SOME (pretty_case thy gr dep brack
   268                  (#3 (valOf (assoc (descr, index)))) c ts)
   269              | (SOME args, (_, Type _)) => SOME (pretty_constr thy
   270                  (fst (invoke_tycodegen thy dep false (gr, snd (strip_type T))))
   271                  dep brack args c ts)
   272              | _ => NONE)
   273            end)
   274  |  _ => NONE);
   275 
   276 fun datatype_tycodegen thy gr dep brack (Type (s, Ts)) =
   277       (case Symtab.lookup (DatatypePackage.get_datatypes thy, s) of
   278          NONE => NONE
   279        | SOME {descr, ...} =>
   280            if isSome (get_assoc_type thy s) then NONE else
   281            let val (gr', ps) = foldl_map
   282              (invoke_tycodegen thy dep false) (gr, Ts)
   283            in SOME (add_dt_defs thy dep gr' descr,
   284              Pretty.block ((if null Ts then [] else
   285                [mk_tuple ps, Pretty.str " "]) @
   286                [Pretty.str (mk_type_id (sign_of thy) s)]))
   287            end)
   288   | datatype_tycodegen _ _ _ _ _ = NONE;
   289 
   290 
   291 val setup =
   292   [add_codegen "datatype" datatype_codegen,
   293    add_tycodegen "datatype" datatype_tycodegen];
   294 
   295 end;