src/HOL/Tools/Datatype/datatype_codegen.ML
author wenzelm
Tue Apr 19 23:57:28 2011 +0200 (2011-04-19)
changeset 42411 ff997038e8eb
parent 42361 23f352990944
child 43324 2b47822868e4
permissions -rw-r--r--
eliminated Codegen.mode in favour of explicit argument;
     1 (*  Title:      HOL/Tools/Datatype/datatype_codegen.ML
     2     Author:     Stefan Berghofer and Florian Haftmann, TU Muenchen
     3 
     4 Code generator facilities for inductive datatypes.
     5 *)
     6 
     7 signature DATATYPE_CODEGEN =
     8 sig
     9   val setup: theory -> theory
    10 end;
    11 
    12 structure Datatype_Codegen : DATATYPE_CODEGEN =
    13 struct
    14 
    15 (** generic code generator **)
    16 
    17 (* liberal addition of code data for datatypes *)
    18 
    19 fun mk_constr_consts thy vs tyco cos =
    20   let
    21     val cs = map (fn (c, tys) => (c, tys ---> Type (tyco, map TFree vs))) cos;
    22     val cs' = map (fn c_ty as (_, ty) => (AxClass.unoverload_const thy c_ty, ty)) cs;
    23   in if is_some (try (Code.constrset_of_consts thy) cs')
    24     then SOME cs
    25     else NONE
    26   end;
    27 
    28 
    29 (* case certificates *)
    30 
    31 fun mk_case_cert thy tyco =
    32   let
    33     val raw_thms =
    34       (#case_rewrites o Datatype_Data.the_info thy) tyco;
    35     val thms as hd_thm :: _ = raw_thms
    36       |> Conjunction.intr_balanced
    37       |> Thm.unvarify_global
    38       |> Conjunction.elim_balanced (length raw_thms)
    39       |> map Simpdata.mk_meta_eq
    40       |> map Drule.zero_var_indexes
    41     val params = fold_aterms (fn (Free (v, _)) => insert (op =) v
    42       | _ => I) (Thm.prop_of hd_thm) [];
    43     val rhs = hd_thm
    44       |> Thm.prop_of
    45       |> Logic.dest_equals
    46       |> fst
    47       |> Term.strip_comb
    48       |> apsnd (fst o split_last)
    49       |> list_comb;
    50     val lhs = Free (Name.variant params "case", Term.fastype_of rhs);
    51     val asm = (Thm.cterm_of thy o Logic.mk_equals) (lhs, rhs);
    52   in
    53     thms
    54     |> Conjunction.intr_balanced
    55     |> Raw_Simplifier.rewrite_rule [(Thm.symmetric o Thm.assume) asm]
    56     |> Thm.implies_intr asm
    57     |> Thm.generalize ([], params) 0
    58     |> AxClass.unoverload thy
    59     |> Thm.varifyT_global
    60   end;
    61 
    62 
    63 (* equality *)
    64 
    65 fun mk_eq_eqns thy tyco =
    66   let
    67     val (vs, cos) = Datatype_Data.the_spec thy tyco;
    68     val { descr, index, inject = inject_thms, distinct = distinct_thms, ... } =
    69       Datatype_Data.the_info thy tyco;
    70     val ty = Type (tyco, map TFree vs);
    71     fun mk_eq (t1, t2) = Const (@{const_name HOL.equal}, ty --> ty --> HOLogic.boolT)
    72       $ t1 $ t2;
    73     fun true_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.true_const);
    74     fun false_eq t12 = HOLogic.mk_eq (mk_eq t12, HOLogic.false_const);
    75     val triv_injects = map_filter
    76      (fn (c, []) => SOME (HOLogic.mk_Trueprop (true_eq (Const (c, ty), Const (c, ty))))
    77        | _ => NONE) cos;
    78     fun prep_inject (trueprop $ (equiv $ (_ $ t1 $ t2) $ rhs)) =
    79       trueprop $ (equiv $ mk_eq (t1, t2) $ rhs);
    80     val injects = map prep_inject (nth (Datatype_Prop.make_injs [descr] vs) index);
    81     fun prep_distinct (trueprop $ (not $ (_ $ t1 $ t2))) =
    82       [trueprop $ false_eq (t1, t2), trueprop $ false_eq (t2, t1)];
    83     val distincts = maps prep_distinct (snd (nth (Datatype_Prop.make_distincts [descr] vs) index));
    84     val refl = HOLogic.mk_Trueprop (true_eq (Free ("x", ty), Free ("x", ty)));
    85     val simpset = Simplifier.global_context thy (HOL_basic_ss addsimps 
    86       (map Simpdata.mk_eq (@{thm equal} :: @{thm eq_True} :: inject_thms @ distinct_thms)));
    87     fun prove prop = Skip_Proof.prove_global thy [] [] prop (K (ALLGOALS (simp_tac simpset)))
    88       |> Simpdata.mk_eq;
    89   in (map prove (triv_injects @ injects @ distincts), prove refl) end;
    90 
    91 fun add_equality vs tycos thy =
    92   let
    93     fun add_def tyco lthy =
    94       let
    95         val ty = Type (tyco, map TFree vs);
    96         fun mk_side const_name = Const (const_name, ty --> ty --> HOLogic.boolT)
    97           $ Free ("x", ty) $ Free ("y", ty);
    98         val def = HOLogic.mk_Trueprop (HOLogic.mk_eq
    99           (mk_side @{const_name HOL.equal}, mk_side @{const_name HOL.eq}));
   100         val def' = Syntax.check_term lthy def;
   101         val ((_, (_, thm)), lthy') = Specification.definition
   102           (NONE, (Attrib.empty_binding, def')) lthy;
   103         val ctxt_thy = Proof_Context.init_global (Proof_Context.theory_of lthy);
   104         val thm' = singleton (Proof_Context.export lthy' ctxt_thy) thm;
   105       in (thm', lthy') end;
   106     fun tac thms = Class.intro_classes_tac []
   107       THEN ALLGOALS (Proof_Context.fact_tac thms);
   108     fun prefix tyco = Binding.qualify true (Long_Name.base_name tyco) o Binding.qualify true "eq" o Binding.name;
   109     fun add_eq_thms tyco =
   110       Theory.checkpoint
   111       #> `(fn thy => mk_eq_eqns thy tyco)
   112       #-> (fn (thms, thm) => Global_Theory.note_thmss Thm.lemmaK
   113           [((prefix tyco "refl", [Code.add_nbe_default_eqn_attribute]), [([thm], [])]),
   114             ((prefix tyco "simps", [Code.add_default_eqn_attribute]), [(rev thms, [])])])
   115       #> snd
   116   in
   117     thy
   118     |> Class.instantiation (tycos, vs, [HOLogic.class_equal])
   119     |> fold_map add_def tycos
   120     |-> (fn def_thms => Class.prove_instantiation_exit_result (map o Morphism.thm)
   121          (fn _ => fn def_thms => tac def_thms) def_thms)
   122     |-> (fn def_thms => fold Code.del_eqn def_thms)
   123     |> fold add_eq_thms tycos
   124   end;
   125 
   126 
   127 (* register a datatype etc. *)
   128 
   129 fun add_all_code config tycos thy =
   130   let
   131     val (vs :: _, coss) = (split_list o map (Datatype_Data.the_spec thy)) tycos;
   132     val any_css = map2 (mk_constr_consts thy vs) tycos coss;
   133     val css = if exists is_none any_css then []
   134       else map_filter I any_css;
   135     val case_rewrites = maps (#case_rewrites o Datatype_Data.the_info thy) tycos;
   136     val certs = map (mk_case_cert thy) tycos;
   137     val tycos_eq = filter_out
   138       (fn tyco => can (Sorts.mg_domain (Sign.classes_of thy) tyco) [HOLogic.class_equal]) tycos;
   139   in
   140     if null css then thy
   141     else thy
   142       |> tap (fn _ => Datatype_Aux.message config "Registering datatype for code generator ...")
   143       |> fold Code.add_datatype css
   144       |> fold_rev Code.add_default_eqn case_rewrites
   145       |> fold Code.add_case certs
   146       |> not (null tycos_eq) ? add_equality vs tycos_eq
   147    end;
   148 
   149 
   150 (** SML code generator **)
   151 
   152 (* datatype definition *)
   153 
   154 fun add_dt_defs thy mode defs dep module descr sorts gr =
   155   let
   156     val descr' = filter (can (map Datatype_Aux.dest_DtTFree o #2 o snd)) descr;
   157     val rtnames = map (#1 o snd) (filter (fn (_, (_, _, cs)) =>
   158       exists (exists Datatype_Aux.is_rec_type o snd) cs) descr');
   159 
   160     val (_, (tname, _, _)) :: _ = descr';
   161     val node_id = tname ^ " (type)";
   162     val module' = Codegen.if_library mode (Codegen.thyname_of_type thy tname) module;
   163 
   164     fun mk_dtdef prfx [] gr = ([], gr)
   165       | mk_dtdef prfx ((_, (tname, dts, cs))::xs) gr =
   166           let
   167             val tvs = map Datatype_Aux.dest_DtTFree dts;
   168             val cs' = map (apsnd (map (Datatype_Aux.typ_of_dtyp descr sorts))) cs;
   169             val ((_, type_id), gr') = Codegen.mk_type_id module' tname gr;
   170             val (ps, gr'') = gr' |>
   171               fold_map (fn (cname, cargs) =>
   172                 fold_map (Codegen.invoke_tycodegen thy mode defs node_id module' false)
   173                   cargs ##>>
   174                 Codegen.mk_const_id module' cname) cs';
   175             val (rest, gr''') = mk_dtdef "and " xs gr''
   176           in
   177             (Pretty.block (Codegen.str prfx ::
   178                (if null tvs then [] else
   179                   [Codegen.mk_tuple (map Codegen.str tvs), Codegen.str " "]) @
   180                [Codegen.str (type_id ^ " ="), Pretty.brk 1] @
   181                flat (separate [Pretty.brk 1, Codegen.str "| "]
   182                  (map (fn (ps', (_, cname)) => [Pretty.block
   183                    (Codegen.str cname ::
   184                     (if null ps' then [] else
   185                      flat ([Codegen.str " of", Pretty.brk 1] ::
   186                        separate [Codegen.str " *", Pretty.brk 1]
   187                          (map single ps'))))]) ps))) :: rest, gr''')
   188           end;
   189 
   190     fun mk_constr_term cname Ts T ps =
   191       flat (separate [Codegen.str " $", Pretty.brk 1]
   192         ([Codegen.str ("Const (\"" ^ cname ^ "\","), Pretty.brk 1,
   193           Codegen.mk_type false (Ts ---> T), Codegen.str ")"] :: ps));
   194 
   195     fun mk_term_of_def gr prfx [] = []
   196       | mk_term_of_def gr prfx ((_, (tname, dts, cs)) :: xs) =
   197           let
   198             val cs' = map (apsnd (map (Datatype_Aux.typ_of_dtyp descr sorts))) cs;
   199             val dts' = map (Datatype_Aux.typ_of_dtyp descr sorts) dts;
   200             val T = Type (tname, dts');
   201             val rest = mk_term_of_def gr "and " xs;
   202             val (eqs, _) = fold_map (fn (cname, Ts) => fn prfx =>
   203               let val args = map (fn i =>
   204                 Codegen.str ("x" ^ string_of_int i)) (1 upto length Ts)
   205               in (Pretty.blk (4,
   206                 [Codegen.str prfx, Codegen.mk_term_of gr module' false T, Pretty.brk 1,
   207                  if null Ts then Codegen.str (snd (Codegen.get_const_id gr cname))
   208                  else Codegen.parens (Pretty.block
   209                    [Codegen.str (snd (Codegen.get_const_id gr cname)),
   210                     Pretty.brk 1, Codegen.mk_tuple args]),
   211                  Codegen.str " =", Pretty.brk 1] @
   212                  mk_constr_term cname Ts T
   213                    (map2 (fn x => fn U => [Pretty.block [Codegen.mk_term_of gr module' false U,
   214                       Pretty.brk 1, x]]) args Ts)), "  | ")
   215               end) cs' prfx
   216           in eqs @ rest end;
   217 
   218     fun mk_gen_of_def gr prfx [] = []
   219       | mk_gen_of_def gr prfx ((i, (tname, dts, cs)) :: xs) =
   220           let
   221             val tvs = map Datatype_Aux.dest_DtTFree dts;
   222             val Us = map (Datatype_Aux.typ_of_dtyp descr sorts) dts;
   223             val T = Type (tname, Us);
   224             val (cs1, cs2) =
   225               List.partition (exists Datatype_Aux.is_rec_type o snd) cs;
   226             val SOME (cname, _) = Datatype_Aux.find_shortest_path descr i;
   227 
   228             fun mk_delay p = Pretty.block
   229               [Codegen.str "fn () =>", Pretty.brk 1, p];
   230 
   231             fun mk_force p = Pretty.block [p, Pretty.brk 1, Codegen.str "()"];
   232 
   233             fun mk_constr s b (cname, dts) =
   234               let
   235                 val gs = map (fn dt => Codegen.mk_app false
   236                     (Codegen.mk_gen gr module' false rtnames s
   237                       (Datatype_Aux.typ_of_dtyp descr sorts dt))
   238                   [Codegen.str (if b andalso Datatype_Aux.is_rec_type dt then "0"
   239                      else "j")]) dts;
   240                 val Ts = map (Datatype_Aux.typ_of_dtyp descr sorts) dts;
   241                 val xs = map Codegen.str
   242                   (Datatype_Prop.indexify_names (replicate (length dts) "x"));
   243                 val ts = map Codegen.str
   244                   (Datatype_Prop.indexify_names (replicate (length dts) "t"));
   245                 val (_, id) = Codegen.get_const_id gr cname;
   246               in
   247                 Codegen.mk_let
   248                   (map2 (fn p => fn q => Codegen.mk_tuple [p, q]) xs ts ~~ gs)
   249                   (Codegen.mk_tuple
   250                     [case xs of
   251                        _ :: _ :: _ => Pretty.block
   252                          [Codegen.str id, Pretty.brk 1, Codegen.mk_tuple xs]
   253                      | _ => Codegen.mk_app false (Codegen.str id) xs,
   254                      mk_delay (Pretty.block (mk_constr_term cname Ts T
   255                        (map (single o mk_force) ts)))])
   256               end;
   257 
   258             fun mk_choice [c] = mk_constr "(i-1)" false c
   259               | mk_choice cs = Pretty.block [Codegen.str "one_of",
   260                   Pretty.brk 1, Pretty.blk (1, Codegen.str "[" ::
   261                   flat (separate [Codegen.str ",", Pretty.fbrk]
   262                     (map (single o mk_delay o mk_constr "(i-1)" false) cs)) @
   263                   [Codegen.str "]"]), Pretty.brk 1, Codegen.str "()"];
   264 
   265             val gs = maps (fn s =>
   266               let val s' = Codegen.strip_tname s
   267               in [Codegen.str (s' ^ "G"), Codegen.str (s' ^ "T")] end) tvs;
   268             val gen_name = "gen_" ^ snd (Codegen.get_type_id gr tname)
   269 
   270           in
   271             Pretty.blk (4, separate (Pretty.brk 1) 
   272                 (Codegen.str (prfx ^ gen_name ^
   273                    (if null cs1 then "" else "'")) :: gs @
   274                  (if null cs1 then [] else [Codegen.str "i"]) @
   275                  [Codegen.str "j"]) @
   276               [Codegen.str " =", Pretty.brk 1] @
   277               (if not (null cs1) andalso not (null cs2)
   278                then [Codegen.str "frequency", Pretty.brk 1,
   279                  Pretty.blk (1, [Codegen.str "[",
   280                    Codegen.mk_tuple [Codegen.str "i", mk_delay (mk_choice cs1)],
   281                    Codegen.str ",", Pretty.fbrk,
   282                    Codegen.mk_tuple [Codegen.str "1", mk_delay (mk_choice cs2)],
   283                    Codegen.str "]"]), Pretty.brk 1, Codegen.str "()"]
   284                else if null cs2 then
   285                  [Pretty.block [Codegen.str "(case", Pretty.brk 1,
   286                    Codegen.str "i", Pretty.brk 1, Codegen.str "of",
   287                    Pretty.brk 1, Codegen.str "0 =>", Pretty.brk 1,
   288                    mk_constr "0" true (cname, the (AList.lookup (op =) cs cname)),
   289                    Pretty.brk 1, Codegen.str "| _ =>", Pretty.brk 1,
   290                    mk_choice cs1, Codegen.str ")"]]
   291                else [mk_choice cs2])) ::
   292             (if null cs1 then []
   293              else [Pretty.blk (4, separate (Pretty.brk 1) 
   294                  (Codegen.str ("and " ^ gen_name) :: gs @ [Codegen.str "i"]) @
   295                [Codegen.str " =", Pretty.brk 1] @
   296                separate (Pretty.brk 1) (Codegen.str (gen_name ^ "'") :: gs @
   297                  [Codegen.str "i", Codegen.str "i"]))]) @
   298             mk_gen_of_def gr "and " xs
   299           end
   300 
   301   in
   302     (module', (Codegen.add_edge_acyclic (node_id, dep) gr
   303         handle Graph.CYCLES _ => gr) handle Graph.UNDEF _ =>
   304          let
   305            val gr1 = Codegen.add_edge (node_id, dep)
   306              (Codegen.new_node (node_id, (NONE, "", "")) gr);
   307            val (dtdef, gr2) = mk_dtdef "datatype " descr' gr1 ;
   308          in
   309            Codegen.map_node node_id (K (NONE, module',
   310              Codegen.string_of (Pretty.blk (0, separate Pretty.fbrk dtdef @
   311                [Codegen.str ";"])) ^ "\n\n" ^
   312              (if member (op =) mode "term_of" then
   313                 Codegen.string_of (Pretty.blk (0, separate Pretty.fbrk
   314                   (mk_term_of_def gr2 "fun " descr') @ [Codegen.str ";"])) ^ "\n\n"
   315               else "") ^
   316              (if member (op =) mode "test" then
   317                 Codegen.string_of (Pretty.blk (0, separate Pretty.fbrk
   318                   (mk_gen_of_def gr2 "fun " descr') @ [Codegen.str ";"])) ^ "\n\n"
   319               else ""))) gr2
   320          end)
   321   end;
   322 
   323 
   324 (* case expressions *)
   325 
   326 fun pretty_case thy mode defs dep module brack constrs (c as Const (_, T)) ts gr =
   327   let val i = length constrs
   328   in if length ts <= i then
   329        Codegen.invoke_codegen thy mode defs dep module brack (Codegen.eta_expand c ts (i+1)) gr
   330     else
   331       let
   332         val ts1 = take i ts;
   333         val t :: ts2 = drop i ts;
   334         val names = List.foldr OldTerm.add_term_names
   335           (map (fst o fst o dest_Var) (List.foldr OldTerm.add_term_vars [] ts1)) ts1;
   336         val (Ts, dT) = split_last (take (i+1) (binder_types T));
   337 
   338         fun pcase [] [] [] gr = ([], gr)
   339           | pcase ((cname, cargs)::cs) (t::ts) (U::Us) gr =
   340               let
   341                 val j = length cargs;
   342                 val xs = Name.variant_list names (replicate j "x");
   343                 val Us' = take j (binder_types U);
   344                 val frees = map2 (curry Free) xs Us';
   345                 val (cp, gr0) = Codegen.invoke_codegen thy mode defs dep module false
   346                   (list_comb (Const (cname, Us' ---> dT), frees)) gr;
   347                 val t' = Envir.beta_norm (list_comb (t, frees));
   348                 val (p, gr1) = Codegen.invoke_codegen thy mode defs dep module false t' gr0;
   349                 val (ps, gr2) = pcase cs ts Us gr1;
   350               in
   351                 ([Pretty.block [cp, Codegen.str " =>", Pretty.brk 1, p]] :: ps, gr2)
   352               end;
   353 
   354         val (ps1, gr1) = pcase constrs ts1 Ts gr ;
   355         val ps = flat (separate [Pretty.brk 1, Codegen.str "| "] ps1);
   356         val (p, gr2) = Codegen.invoke_codegen thy mode defs dep module false t gr1;
   357         val (ps2, gr3) = fold_map (Codegen.invoke_codegen thy mode defs dep module true) ts2 gr2;
   358       in ((if not (null ts2) andalso brack then Codegen.parens else I)
   359         (Pretty.block (separate (Pretty.brk 1)
   360           (Pretty.block ([Codegen.str "(case ", p, Codegen.str " of",
   361              Pretty.brk 1] @ ps @ [Codegen.str ")"]) :: ps2))), gr3)
   362       end
   363   end;
   364 
   365 
   366 (* constructors *)
   367 
   368 fun pretty_constr thy mode defs dep module brack args (c as Const (s, T)) ts gr =
   369   let val i = length args
   370   in if i > 1 andalso length ts < i then
   371       Codegen.invoke_codegen thy mode defs dep module brack (Codegen.eta_expand c ts i) gr
   372      else
   373        let
   374          val id = Codegen.mk_qual_id module (Codegen.get_const_id gr s);
   375          val (ps, gr') = fold_map
   376            (Codegen.invoke_codegen thy mode defs dep module (i = 1)) ts gr;
   377        in
   378         (case args of
   379           _ :: _ :: _ => (if brack then Codegen.parens else I)
   380             (Pretty.block [Codegen.str id, Pretty.brk 1, Codegen.mk_tuple ps])
   381         | _ => (Codegen.mk_app brack (Codegen.str id) ps), gr')
   382        end
   383   end;
   384 
   385 
   386 (* code generators for terms and types *)
   387 
   388 fun datatype_codegen thy mode defs dep module brack t gr =
   389   (case strip_comb t of
   390     (c as Const (s, T), ts) =>
   391       (case Datatype_Data.info_of_case thy s of
   392         SOME {index, descr, ...} =>
   393           if is_some (Codegen.get_assoc_code thy (s, T)) then NONE
   394           else
   395             SOME (pretty_case thy mode defs dep module brack
   396               (#3 (the (AList.lookup op = descr index))) c ts gr)
   397       | NONE =>
   398           (case (Datatype_Data.info_of_constr thy (s, T), body_type T) of
   399             (SOME {index, descr, ...}, U as Type (tyname, _)) =>
   400               if is_some (Codegen.get_assoc_code thy (s, T)) then NONE
   401               else
   402                 let
   403                   val SOME (tyname', _, constrs) = AList.lookup op = descr index;
   404                   val SOME args = AList.lookup op = constrs s;
   405                 in
   406                   if tyname <> tyname' then NONE
   407                   else
   408                     SOME
   409                       (pretty_constr thy mode defs
   410                         dep module brack args c ts
   411                         (snd (Codegen.invoke_tycodegen thy mode defs dep module false U gr)))
   412                 end
   413           | _ => NONE))
   414   | _ => NONE);
   415 
   416 fun datatype_tycodegen thy mode defs dep module brack (Type (s, Ts)) gr =
   417       (case Datatype_Data.get_info thy s of
   418          NONE => NONE
   419        | SOME {descr, sorts, ...} =>
   420            if is_some (Codegen.get_assoc_type thy s) then NONE else
   421            let
   422              val (ps, gr') = fold_map
   423                (Codegen.invoke_tycodegen thy mode defs dep module false) Ts gr;
   424              val (module', gr'') = add_dt_defs thy mode defs dep module descr sorts gr' ;
   425              val (tyid, gr''') = Codegen.mk_type_id module' s gr''
   426            in SOME (Pretty.block ((if null Ts then [] else
   427                [Codegen.mk_tuple ps, Codegen.str " "]) @
   428                [Codegen.str (Codegen.mk_qual_id module tyid)]), gr''')
   429            end)
   430   | datatype_tycodegen _ _ _ _ _ _ _ _ = NONE;
   431 
   432 
   433 (** theory setup **)
   434 
   435 val setup = 
   436   Datatype_Data.interpretation add_all_code
   437   #> Codegen.add_codegen "datatype" datatype_codegen
   438   #> Codegen.add_tycodegen "datatype" datatype_tycodegen;
   439 
   440 end;