src/HOL/Tools/datatype_codegen.ML
changeset 31610 2b47e8e37c11
parent 31609 8d353e3214d0
child 31611 a577f77af93f
equal deleted inserted replaced
31609:8d353e3214d0 31610:2b47e8e37c11
     1 (*  Title:      HOL/Tools/datatype_codegen.ML
       
     2     Author:     Stefan Berghofer and Florian Haftmann, TU Muenchen
       
     3 
       
     4 Code generator facilities for inductive datatypes.
       
     5 *)
       
     6 
       
     7 signature DATATYPE_CODEGEN =
       
     8 sig
       
     9   val find_shortest_path: DatatypeAux.descr -> int -> (string * int) option
       
    10   val mk_eq_eqns: theory -> string -> (thm * bool) list
       
    11   val mk_case_cert: theory -> string -> thm
       
    12   val setup: theory -> theory
       
    13 end;
       
    14 
       
    15 structure DatatypeCodegen : DATATYPE_CODEGEN =
       
    16 struct
       
    17 
       
    18 (** SML code generator **)
       
    19 
       
    20 open Codegen;
       
    21 
       
    22 (**** datatype definition ****)
       
    23 
       
    24 (* find shortest path to constructor with no recursive arguments *)
       
    25 
       
    26 fun find_nonempty (descr: DatatypeAux.descr) is i =
       
    27   let
       
    28     val (_, _, constrs) = the (AList.lookup (op =) descr i);
       
    29     fun arg_nonempty (_, DatatypeAux.DtRec i) = if member (op =) is i
       
    30           then NONE
       
    31           else Option.map (curry op + 1 o snd) (find_nonempty descr (i::is) i)
       
    32       | arg_nonempty _ = SOME 0;
       
    33     fun max xs = Library.foldl
       
    34       (fn (NONE, _) => NONE
       
    35         | (SOME i, SOME j) => SOME (Int.max (i, j))
       
    36         | (_, NONE) => NONE) (SOME 0, xs);
       
    37     val xs = sort (int_ord o pairself snd)
       
    38       (map_filter (fn (s, dts) => Option.map (pair s)
       
    39         (max (map (arg_nonempty o DatatypeAux.strip_dtyp) dts))) constrs)
       
    40   in case xs of [] => NONE | x :: _ => SOME x end;
       
    41 
       
    42 fun find_shortest_path descr i = find_nonempty descr [i] i;
       
    43 
       
    44 fun add_dt_defs thy defs dep module (descr: DatatypeAux.descr) sorts gr =
       
    45   let
       
    46     val descr' = List.filter (can (map DatatypeAux.dest_DtTFree o #2 o snd)) descr;
       
    47     val rtnames = map (#1 o snd) (List.filter (fn (_, (_, _, cs)) =>
       
    48       exists (exists DatatypeAux.is_rec_type o snd) cs) descr');
       
    49 
       
    50     val (_, (tname, _, _)) :: _ = descr';
       
    51     val node_id = tname ^ " (type)";
       
    52     val module' = if_library (thyname_of_type thy tname) module;
       
    53 
       
    54     fun mk_dtdef prfx [] gr = ([], gr)
       
    55       | mk_dtdef prfx ((_, (tname, dts, cs))::xs) gr =
       
    56           let
       
    57             val tvs = map DatatypeAux.dest_DtTFree dts;
       
    58             val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
       
    59             val ((_, type_id), gr') = mk_type_id module' tname gr;
       
    60             val (ps, gr'') = gr' |>
       
    61               fold_map (fn (cname, cargs) =>
       
    62                 fold_map (invoke_tycodegen thy defs node_id module' false)
       
    63                   cargs ##>>
       
    64                 mk_const_id module' cname) cs';
       
    65             val (rest, gr''') = mk_dtdef "and " xs gr''
       
    66           in
       
    67             (Pretty.block (str prfx ::
       
    68                (if null tvs then [] else
       
    69                   [mk_tuple (map str tvs), str " "]) @
       
    70                [str (type_id ^ " ="), Pretty.brk 1] @
       
    71                List.concat (separate [Pretty.brk 1, str "| "]
       
    72                  (map (fn (ps', (_, cname)) => [Pretty.block
       
    73                    (str cname ::
       
    74                     (if null ps' then [] else
       
    75                      List.concat ([str " of", Pretty.brk 1] ::
       
    76                        separate [str " *", Pretty.brk 1]
       
    77                          (map single ps'))))]) ps))) :: rest, gr''')
       
    78           end;
       
    79 
       
    80     fun mk_constr_term cname Ts T ps =
       
    81       List.concat (separate [str " $", Pretty.brk 1]
       
    82         ([str ("Const (\"" ^ cname ^ "\","), Pretty.brk 1,
       
    83           mk_type false (Ts ---> T), str ")"] :: ps));
       
    84 
       
    85     fun mk_term_of_def gr prfx [] = []
       
    86       | mk_term_of_def gr prfx ((_, (tname, dts, cs)) :: xs) =
       
    87           let
       
    88             val cs' = map (apsnd (map (DatatypeAux.typ_of_dtyp descr sorts))) cs;
       
    89             val dts' = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
       
    90             val T = Type (tname, dts');
       
    91             val rest = mk_term_of_def gr "and " xs;
       
    92             val (eqs, _) = fold_map (fn (cname, Ts) => fn prfx =>
       
    93               let val args = map (fn i =>
       
    94                 str ("x" ^ string_of_int i)) (1 upto length Ts)
       
    95               in (Pretty.blk (4,
       
    96                 [str prfx, mk_term_of gr module' false T, Pretty.brk 1,
       
    97                  if null Ts then str (snd (get_const_id gr cname))
       
    98                  else parens (Pretty.block
       
    99                    [str (snd (get_const_id gr cname)),
       
   100                     Pretty.brk 1, mk_tuple args]),
       
   101                  str " =", Pretty.brk 1] @
       
   102                  mk_constr_term cname Ts T
       
   103                    (map2 (fn x => fn U => [Pretty.block [mk_term_of gr module' false U,
       
   104                       Pretty.brk 1, x]]) args Ts)), "  | ")
       
   105               end) cs' prfx
       
   106           in eqs @ rest end;
       
   107 
       
   108     fun mk_gen_of_def gr prfx [] = []
       
   109       | mk_gen_of_def gr prfx ((i, (tname, dts, cs)) :: xs) =
       
   110           let
       
   111             val tvs = map DatatypeAux.dest_DtTFree dts;
       
   112             val Us = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
       
   113             val T = Type (tname, Us);
       
   114             val (cs1, cs2) =
       
   115               List.partition (exists DatatypeAux.is_rec_type o snd) cs;
       
   116             val SOME (cname, _) = find_shortest_path descr i;
       
   117 
       
   118             fun mk_delay p = Pretty.block
       
   119               [str "fn () =>", Pretty.brk 1, p];
       
   120 
       
   121             fun mk_force p = Pretty.block [p, Pretty.brk 1, str "()"];
       
   122 
       
   123             fun mk_constr s b (cname, dts) =
       
   124               let
       
   125                 val gs = map (fn dt => mk_app false (mk_gen gr module' false rtnames s
       
   126                     (DatatypeAux.typ_of_dtyp descr sorts dt))
       
   127                   [str (if b andalso DatatypeAux.is_rec_type dt then "0"
       
   128                      else "j")]) dts;
       
   129                 val Ts = map (DatatypeAux.typ_of_dtyp descr sorts) dts;
       
   130                 val xs = map str
       
   131                   (DatatypeProp.indexify_names (replicate (length dts) "x"));
       
   132                 val ts = map str
       
   133                   (DatatypeProp.indexify_names (replicate (length dts) "t"));
       
   134                 val (_, id) = get_const_id gr cname
       
   135               in
       
   136                 mk_let
       
   137                   (map2 (fn p => fn q => mk_tuple [p, q]) xs ts ~~ gs)
       
   138                   (mk_tuple
       
   139                     [case xs of
       
   140                        _ :: _ :: _ => Pretty.block
       
   141                          [str id, Pretty.brk 1, mk_tuple xs]
       
   142                      | _ => mk_app false (str id) xs,
       
   143                      mk_delay (Pretty.block (mk_constr_term cname Ts T
       
   144                        (map (single o mk_force) ts)))])
       
   145               end;
       
   146 
       
   147             fun mk_choice [c] = mk_constr "(i-1)" false c
       
   148               | mk_choice cs = Pretty.block [str "one_of",
       
   149                   Pretty.brk 1, Pretty.blk (1, str "[" ::
       
   150                   List.concat (separate [str ",", Pretty.fbrk]
       
   151                     (map (single o mk_delay o mk_constr "(i-1)" false) cs)) @
       
   152                   [str "]"]), Pretty.brk 1, str "()"];
       
   153 
       
   154             val gs = maps (fn s =>
       
   155               let val s' = strip_tname s
       
   156               in [str (s' ^ "G"), str (s' ^ "T")] end) tvs;
       
   157             val gen_name = "gen_" ^ snd (get_type_id gr tname)
       
   158 
       
   159           in
       
   160             Pretty.blk (4, separate (Pretty.brk 1) 
       
   161                 (str (prfx ^ gen_name ^
       
   162                    (if null cs1 then "" else "'")) :: gs @
       
   163                  (if null cs1 then [] else [str "i"]) @
       
   164                  [str "j"]) @
       
   165               [str " =", Pretty.brk 1] @
       
   166               (if not (null cs1) andalso not (null cs2)
       
   167                then [str "frequency", Pretty.brk 1,
       
   168                  Pretty.blk (1, [str "[",
       
   169                    mk_tuple [str "i", mk_delay (mk_choice cs1)],
       
   170                    str ",", Pretty.fbrk,
       
   171                    mk_tuple [str "1", mk_delay (mk_choice cs2)],
       
   172                    str "]"]), Pretty.brk 1, str "()"]
       
   173                else if null cs2 then
       
   174                  [Pretty.block [str "(case", Pretty.brk 1,
       
   175                    str "i", Pretty.brk 1, str "of",
       
   176                    Pretty.brk 1, str "0 =>", Pretty.brk 1,
       
   177                    mk_constr "0" true (cname, valOf (AList.lookup (op =) cs cname)),
       
   178                    Pretty.brk 1, str "| _ =>", Pretty.brk 1,
       
   179                    mk_choice cs1, str ")"]]
       
   180                else [mk_choice cs2])) ::
       
   181             (if null cs1 then []
       
   182              else [Pretty.blk (4, separate (Pretty.brk 1) 
       
   183                  (str ("and " ^ gen_name) :: gs @ [str "i"]) @
       
   184                [str " =", Pretty.brk 1] @
       
   185                separate (Pretty.brk 1) (str (gen_name ^ "'") :: gs @
       
   186                  [str "i", str "i"]))]) @
       
   187             mk_gen_of_def gr "and " xs
       
   188           end
       
   189 
       
   190   in
       
   191     (module', (add_edge_acyclic (node_id, dep) gr
       
   192         handle Graph.CYCLES _ => gr) handle Graph.UNDEF _ =>
       
   193          let
       
   194            val gr1 = add_edge (node_id, dep)
       
   195              (new_node (node_id, (NONE, "", "")) gr);
       
   196            val (dtdef, gr2) = mk_dtdef "datatype " descr' gr1 ;
       
   197          in
       
   198            map_node node_id (K (NONE, module',
       
   199              string_of (Pretty.blk (0, separate Pretty.fbrk dtdef @
       
   200                [str ";"])) ^ "\n\n" ^
       
   201              (if "term_of" mem !mode then
       
   202                 string_of (Pretty.blk (0, separate Pretty.fbrk
       
   203                   (mk_term_of_def gr2 "fun " descr') @ [str ";"])) ^ "\n\n"
       
   204               else "") ^
       
   205              (if "test" mem !mode then
       
   206                 string_of (Pretty.blk (0, separate Pretty.fbrk
       
   207                   (mk_gen_of_def gr2 "fun " descr') @ [str ";"])) ^ "\n\n"
       
   208               else ""))) gr2
       
   209          end)
       
   210   end;
       
   211 
       
   212 
       
   213 (**** case expressions ****)
       
   214 
       
   215 fun pretty_case thy defs dep module brack constrs (c as Const (_, T)) ts gr =
       
   216   let val i = length constrs
       
   217   in if length ts <= i then
       
   218        invoke_codegen thy defs dep module brack (eta_expand c ts (i+1)) gr
       
   219     else
       
   220       let
       
   221         val ts1 = Library.take (i, ts);
       
   222         val t :: ts2 = Library.drop (i, ts);
       
   223         val names = List.foldr OldTerm.add_term_names
       
   224           (map (fst o fst o dest_Var) (List.foldr OldTerm.add_term_vars [] ts1)) ts1;
       
   225         val (Ts, dT) = split_last (Library.take (i+1, fst (strip_type T)));
       
   226 
       
   227         fun pcase [] [] [] gr = ([], gr)
       
   228           | pcase ((cname, cargs)::cs) (t::ts) (U::Us) gr =
       
   229               let
       
   230                 val j = length cargs;
       
   231                 val xs = Name.variant_list names (replicate j "x");
       
   232                 val Us' = Library.take (j, fst (strip_type U));
       
   233                 val frees = map Free (xs ~~ Us');
       
   234                 val (cp, gr0) = invoke_codegen thy defs dep module false
       
   235                   (list_comb (Const (cname, Us' ---> dT), frees)) gr;
       
   236                 val t' = Envir.beta_norm (list_comb (t, frees));
       
   237                 val (p, gr1) = invoke_codegen thy defs dep module false t' gr0;
       
   238                 val (ps, gr2) = pcase cs ts Us gr1;
       
   239               in
       
   240                 ([Pretty.block [cp, str " =>", Pretty.brk 1, p]] :: ps, gr2)
       
   241               end;
       
   242 
       
   243         val (ps1, gr1) = pcase constrs ts1 Ts gr ;
       
   244         val ps = List.concat (separate [Pretty.brk 1, str "| "] ps1);
       
   245         val (p, gr2) = invoke_codegen thy defs dep module false t gr1;
       
   246         val (ps2, gr3) = fold_map (invoke_codegen thy defs dep module true) ts2 gr2;
       
   247       in ((if not (null ts2) andalso brack then parens else I)
       
   248         (Pretty.block (separate (Pretty.brk 1)
       
   249           (Pretty.block ([str "(case ", p, str " of",
       
   250              Pretty.brk 1] @ ps @ [str ")"]) :: ps2))), gr3)
       
   251       end
       
   252   end;
       
   253 
       
   254 
       
   255 (**** constructors ****)
       
   256 
       
   257 fun pretty_constr thy defs dep module brack args (c as Const (s, T)) ts gr =
       
   258   let val i = length args
       
   259   in if i > 1 andalso length ts < i then
       
   260       invoke_codegen thy defs dep module brack (eta_expand c ts i) gr
       
   261      else
       
   262        let
       
   263          val id = mk_qual_id module (get_const_id gr s);
       
   264          val (ps, gr') = fold_map
       
   265            (invoke_codegen thy defs dep module (i = 1)) ts gr;
       
   266        in (case args of
       
   267           _ :: _ :: _ => (if brack then parens else I)
       
   268             (Pretty.block [str id, Pretty.brk 1, mk_tuple ps])
       
   269         | _ => (mk_app brack (str id) ps), gr')
       
   270        end
       
   271   end;
       
   272 
       
   273 
       
   274 (**** code generators for terms and types ****)
       
   275 
       
   276 fun datatype_codegen thy defs dep module brack t gr = (case strip_comb t of
       
   277    (c as Const (s, T), ts) =>
       
   278      (case DatatypePackage.datatype_of_case thy s of
       
   279         SOME {index, descr, ...} =>
       
   280           if is_some (get_assoc_code thy (s, T)) then NONE else
       
   281           SOME (pretty_case thy defs dep module brack
       
   282             (#3 (the (AList.lookup op = descr index))) c ts gr )
       
   283       | NONE => case (DatatypePackage.datatype_of_constr thy s, strip_type T) of
       
   284         (SOME {index, descr, ...}, (_, U as Type (tyname, _))) =>
       
   285           if is_some (get_assoc_code thy (s, T)) then NONE else
       
   286           let
       
   287             val SOME (tyname', _, constrs) = AList.lookup op = descr index;
       
   288             val SOME args = AList.lookup op = constrs s
       
   289           in
       
   290             if tyname <> tyname' then NONE
       
   291             else SOME (pretty_constr thy defs
       
   292               dep module brack args c ts (snd (invoke_tycodegen thy defs dep module false U gr)))
       
   293           end
       
   294       | _ => NONE)
       
   295  | _ => NONE);
       
   296 
       
   297 fun datatype_tycodegen thy defs dep module brack (Type (s, Ts)) gr =
       
   298       (case DatatypePackage.get_datatype thy s of
       
   299          NONE => NONE
       
   300        | SOME {descr, sorts, ...} =>
       
   301            if is_some (get_assoc_type thy s) then NONE else
       
   302            let
       
   303              val (ps, gr') = fold_map
       
   304                (invoke_tycodegen thy defs dep module false) Ts gr;
       
   305              val (module', gr'') = add_dt_defs thy defs dep module descr sorts gr' ;
       
   306              val (tyid, gr''') = mk_type_id module' s gr''
       
   307            in SOME (Pretty.block ((if null Ts then [] else
       
   308                [mk_tuple ps, str " "]) @
       
   309                [str (mk_qual_id module tyid)]), gr''')
       
   310            end)
       
   311   | datatype_tycodegen _ _ _ _ _ _ _ = NONE;
       
   312 
       
   313 
       
   314 (** generic code generator **)
       
   315 
       
   316 (* case certificates *)
       
   317 
       
   318 fun mk_case_cert thy tyco =
       
   319   let
       
   320     val raw_thms =
       
   321       (#case_rewrites o DatatypePackage.the_datatype thy) tyco;
       
   322     val thms as hd_thm :: _ = raw_thms
       
   323       |> Conjunction.intr_balanced
       
   324       |> Thm.unvarify
       
   325       |> Conjunction.elim_balanced (length raw_thms)
       
   326       |> map Simpdata.mk_meta_eq
       
   327       |> map Drule.zero_var_indexes
       
   328     val params = fold_aterms (fn (Free (v, _)) => insert (op =) v
       
   329       | _ => I) (Thm.prop_of hd_thm) [];
       
   330     val rhs = hd_thm
       
   331       |> Thm.prop_of
       
   332       |> Logic.dest_equals
       
   333       |> fst
       
   334       |> Term.strip_comb
       
   335       |> apsnd (fst o split_last)
       
   336       |> list_comb;
       
   337     val lhs = Free (Name.variant params "case", Term.fastype_of rhs);
       
   338     val asm = (Thm.cterm_of thy o Logic.mk_equals) (lhs, rhs);
       
   339   in
       
   340     thms
       
   341     |> Conjunction.intr_balanced
       
   342     |> MetaSimplifier.rewrite_rule [(Thm.symmetric o Thm.assume) asm]
       
   343     |> Thm.implies_intr asm
       
   344     |> Thm.generalize ([], params) 0
       
   345     |> AxClass.unoverload thy
       
   346     |> Thm.varifyT
       
   347   end;
       
   348 
       
   349 
       
   350 (* equality *)
       
   351 
       
   352 fun mk_eq_eqns thy dtco =
       
   353   let
       
   354     val (vs, cos) = DatatypePackage.the_datatype_spec thy dtco;
       
   355     val { descr, index, inject = inject_thms, ... } = DatatypePackage.the_datatype thy dtco;
       
   356     val ty = Type (dtco, map TFree vs);
       
   357     fun mk_eq (t1, t2) = Const (@{const_name eq_class.eq}, ty --> ty --> HOLogic.boolT)
       
   358       $ t1 $ t2;
       
   359     fun true_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.true_const);
       
   360     fun false_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.false_const);
       
   361     val triv_injects = map_filter
       
   362      (fn (c, []) => SOME (HOLogic.mk_Trueprop (true_eq (Const (c, ty), Const (c, ty))))
       
   363        | _ => NONE) cos;
       
   364     fun prep_inject (trueprop $ (equiv $ (_ $ t1 $ t2) $ rhs)) =
       
   365       trueprop $ (equiv $ mk_eq (t1, t2) $ rhs);
       
   366     val injects = map prep_inject (nth (DatatypeProp.make_injs [descr] vs) index);
       
   367     fun prep_distinct (trueprop $ (not $ (_ $ t1 $ t2))) =
       
   368       [trueprop $ false_eq (t1, t2), trueprop $ false_eq (t2, t1)];
       
   369     val distincts = maps prep_distinct (snd (nth (DatatypeProp.make_distincts [descr] vs) index));
       
   370     val refl = HOLogic.mk_Trueprop (true_eq (Free ("x", ty), Free ("x", ty)));
       
   371     val simpset = Simplifier.context (ProofContext.init thy) (HOL_basic_ss
       
   372       addsimps (map Simpdata.mk_eq (@{thm eq} :: @{thm eq_True} :: inject_thms))
       
   373       addsimprocs [DatatypePackage.distinct_simproc]);
       
   374     fun prove prop = Goal.prove_global thy [] [] prop (K (ALLGOALS (simp_tac simpset)))
       
   375       |> Simpdata.mk_eq;
       
   376   in map (rpair true o prove) (triv_injects @ injects @ distincts) @ [(prove refl, false)] end;
       
   377 
       
   378 fun add_equality vs dtcos thy =
       
   379   let
       
   380     fun add_def dtco lthy =
       
   381       let
       
   382         val ty = Type (dtco, map TFree vs);
       
   383         fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT)
       
   384           $ Free ("x", ty) $ Free ("y", ty);
       
   385         val def = HOLogic.mk_Trueprop (HOLogic.mk_eq
       
   386           (mk_side @{const_name eq_class.eq}, mk_side @{const_name "op ="}));
       
   387         val def' = Syntax.check_term lthy def;
       
   388         val ((_, (_, thm)), lthy') = Specification.definition
       
   389           (NONE, (Attrib.empty_binding, def')) lthy;
       
   390         val ctxt_thy = ProofContext.init (ProofContext.theory_of lthy);
       
   391         val thm' = singleton (ProofContext.export lthy' ctxt_thy) thm;
       
   392       in (thm', lthy') end;
       
   393     fun tac thms = Class.intro_classes_tac []
       
   394       THEN ALLGOALS (ProofContext.fact_tac thms);
       
   395     fun add_eq_thms dtco thy =
       
   396       let
       
   397         val const = AxClass.param_of_inst thy (@{const_name eq_class.eq}, dtco);
       
   398         val thy_ref = Theory.check_thy thy;
       
   399         fun mk_thms () = rev ((mk_eq_eqns (Theory.deref thy_ref) dtco));
       
   400       in
       
   401         Code.add_eqnl (const, Lazy.lazy mk_thms) thy
       
   402       end;
       
   403   in
       
   404     thy
       
   405     |> TheoryTarget.instantiation (dtcos, vs, [HOLogic.class_eq])
       
   406     |> fold_map add_def dtcos
       
   407     |-> (fn def_thms => Class.prove_instantiation_exit_result (map o Morphism.thm)
       
   408          (fn _ => fn def_thms => tac def_thms) def_thms)
       
   409     |-> (fn def_thms => fold Code.del_eqn def_thms)
       
   410     |> fold add_eq_thms dtcos
       
   411   end;
       
   412 
       
   413 
       
   414 (* liberal addition of code data for datatypes *)
       
   415 
       
   416 fun mk_constr_consts thy vs dtco cos =
       
   417   let
       
   418     val cs = map (fn (c, tys) => (c, tys ---> Type (dtco, map TFree vs))) cos;
       
   419     val cs' = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) cs;
       
   420   in if is_some (try (Code.constrset_of_consts thy) cs')
       
   421     then SOME cs
       
   422     else NONE
       
   423   end;
       
   424 
       
   425 fun add_all_code dtcos thy =
       
   426   let
       
   427     val (vs :: _, coss) = (split_list o map (DatatypePackage.the_datatype_spec thy)) dtcos;
       
   428     val any_css = map2 (mk_constr_consts thy vs) dtcos coss;
       
   429     val css = if exists is_none any_css then []
       
   430       else map_filter I any_css;
       
   431     val case_rewrites = maps (#case_rewrites o DatatypePackage.the_datatype thy) dtcos;
       
   432     val certs = map (mk_case_cert thy) dtcos;
       
   433   in
       
   434     if null css then thy
       
   435     else thy
       
   436       |> fold Code.add_datatype css
       
   437       |> fold_rev Code.add_default_eqn case_rewrites
       
   438       |> fold Code.add_case certs
       
   439       |> add_equality vs dtcos
       
   440    end;
       
   441 
       
   442 
       
   443 
       
   444 (** theory setup **)
       
   445 
       
   446 val setup = 
       
   447   add_codegen "datatype" datatype_codegen
       
   448   #> add_tycodegen "datatype" datatype_tycodegen
       
   449   #> DatatypePackage.interpretation add_all_code
       
   450 
       
   451 end;