src/HOL/Tools/datatype_package.ML
author wenzelm
Wed Jun 18 18:55:10 2008 +0200 (2008-06-18)
changeset 27261 5b3101338f42
parent 27130 4ba366056426
child 27277 7b7ce2d7fafe
permissions -rw-r--r--
eliminated old Sign.read_term/Thm.read_cterm etc.;
     1 (*  Title:      HOL/Tools/datatype_package.ML
     2     ID:         $Id$
     3     Author:     Stefan Berghofer, TU Muenchen
     4 
     5 Datatype package for Isabelle/HOL.
     6 *)
     7 
     8 signature DATATYPE_PACKAGE =
     9 sig
    10   val quiet_mode : bool ref
    11   val get_datatypes : theory -> DatatypeAux.datatype_info Symtab.table
    12   val print_datatypes : theory -> unit
    13   val get_datatype : theory -> string -> DatatypeAux.datatype_info option
    14   val the_datatype : theory -> string -> DatatypeAux.datatype_info
    15   val datatype_of_constr : theory -> string -> DatatypeAux.datatype_info option
    16   val datatype_of_case : theory -> string -> DatatypeAux.datatype_info option
    17   val the_datatype_spec : theory -> string -> (string * sort) list * (string * typ list) list
    18   val get_datatype_constrs : theory -> string -> (string * typ) list option
    19   val construction_interpretation : theory
    20     -> {atom : typ -> 'a, dtyp : string -> 'a, rtyp : string -> 'a list -> 'a}
    21     -> (string * sort) list -> string list
    22     -> (string * (string * 'a list) list) list
    23   val distinct_simproc : simproc
    24   val make_case :  Proof.context -> bool -> string list -> term ->
    25     (term * term) list -> term * (term * (int * bool)) list
    26   val strip_case : Proof.context -> bool -> term -> (term * (term * term) list) option
    27   val interpretation : (string list -> theory -> theory) -> theory -> theory
    28   val rep_datatype : ({distinct : thm list list,
    29        inject : thm list list,
    30        exhaustion : thm list,
    31        rec_thms : thm list,
    32        case_thms : thm list list,
    33        split_thms : (thm * thm) list,
    34        induction : thm,
    35        simps : thm list} -> Proof.context -> Proof.context) -> string list option -> term list
    36     -> theory -> Proof.state;
    37   val rep_datatype_cmd : string list option -> string list -> theory -> Proof.state;
    38   val add_datatype : bool -> bool -> string list -> (string list * bstring * mixfix *
    39     (bstring * typ list * mixfix) list) list -> theory ->
    40       {distinct : thm list list,
    41        inject : thm list list,
    42        exhaustion : thm list,
    43        rec_thms : thm list,
    44        case_thms : thm list list,
    45        split_thms : (thm * thm) list,
    46        induction : thm,
    47        simps : thm list} * theory
    48   val add_datatype_cmd : bool -> string list -> (string list * bstring * mixfix *
    49     (bstring * string list * mixfix) list) list -> theory ->
    50       {distinct : thm list list,
    51        inject : thm list list,
    52        exhaustion : thm list,
    53        rec_thms : thm list,
    54        case_thms : thm list list,
    55        split_thms : (thm * thm) list,
    56        induction : thm,
    57        simps : thm list} * theory
    58   val setup: theory -> theory
    59 end;
    60 
    61 structure DatatypePackage : DATATYPE_PACKAGE =
    62 struct
    63 
    64 open DatatypeAux;
    65 
    66 val quiet_mode = quiet_mode;
    67 
    68 
    69 (* theory data *)
    70 
    71 structure DatatypesData = TheoryDataFun
    72 (
    73   type T =
    74     {types: datatype_info Symtab.table,
    75      constrs: datatype_info Symtab.table,
    76      cases: datatype_info Symtab.table};
    77 
    78   val empty =
    79     {types = Symtab.empty, constrs = Symtab.empty, cases = Symtab.empty};
    80   val copy = I;
    81   val extend = I;
    82   fun merge _
    83     ({types = types1, constrs = constrs1, cases = cases1},
    84      {types = types2, constrs = constrs2, cases = cases2}) =
    85     {types = Symtab.merge (K true) (types1, types2),
    86      constrs = Symtab.merge (K true) (constrs1, constrs2),
    87      cases = Symtab.merge (K true) (cases1, cases2)};
    88 );
    89 
    90 val get_datatypes = #types o DatatypesData.get;
    91 val map_datatypes = DatatypesData.map;
    92 
    93 fun print_datatypes thy =
    94   Pretty.writeln (Pretty.strs ("datatypes:" ::
    95     map #1 (NameSpace.extern_table (Sign.type_space thy, get_datatypes thy))));
    96 
    97 
    98 (** theory information about datatypes **)
    99 
   100 fun put_dt_infos (dt_infos : (string * datatype_info) list) =
   101   map_datatypes (fn {types, constrs, cases} =>
   102     {types = fold Symtab.update dt_infos types,
   103      constrs = fold Symtab.update
   104        (maps (fn (_, info as {descr, index, ...}) => map (rpair info o fst)
   105           (#3 (the (AList.lookup op = descr index)))) dt_infos) constrs,
   106      cases = fold Symtab.update
   107        (map (fn (_, info as {case_name, ...}) => (case_name, info)) dt_infos)
   108        cases});
   109 
   110 val get_datatype = Symtab.lookup o get_datatypes;
   111 
   112 fun the_datatype thy name = (case get_datatype thy name of
   113       SOME info => info
   114     | NONE => error ("Unknown datatype " ^ quote name));
   115 
   116 val datatype_of_constr = Symtab.lookup o #constrs o DatatypesData.get;
   117 val datatype_of_case = Symtab.lookup o #cases o DatatypesData.get;
   118 
   119 fun get_datatype_descr thy dtco =
   120   get_datatype thy dtco
   121   |> Option.map (fn info as { descr, index, ... } =>
   122        (info, (((fn SOME (_, dtys, cos) => (dtys, cos)) o AList.lookup (op =) descr) index)));
   123 
   124 fun the_datatype_spec thy dtco =
   125   let
   126     val info as { descr, index, sorts = raw_sorts, ... } = the_datatype thy dtco;
   127     val SOME (_, dtys, raw_cos) = AList.lookup (op =) descr index;
   128     val sorts = map ((fn v => (v, (the o AList.lookup (op =) raw_sorts) v))
   129       o DatatypeAux.dest_DtTFree) dtys;
   130     val cos = map
   131       (fn (co, tys) => (co, map (DatatypeAux.typ_of_dtyp descr sorts) tys)) raw_cos;
   132   in (sorts, cos) end;
   133 
   134 fun get_datatype_constrs thy dtco =
   135   case try (the_datatype_spec thy) dtco
   136    of SOME (sorts, cos) =>
   137         let
   138           fun subst (v, sort) = TVar ((v, 0), sort);
   139           fun subst_ty (TFree v) = subst v
   140             | subst_ty ty = ty;
   141           val dty = Type (dtco, map subst sorts);
   142           fun mk_co (co, tys) = (co, map (Term.map_atyps subst_ty) tys ---> dty);
   143         in SOME (map mk_co cos) end
   144     | NONE => NONE;
   145 
   146 fun construction_interpretation thy { atom, dtyp, rtyp } sorts tycos =
   147   let
   148     val descr = (#descr o the_datatype thy o hd) tycos;
   149     val k = length tycos;
   150     val descr_of = the o AList.lookup (op =) descr;
   151     fun interpT (T as DtTFree _) = atom (typ_of_dtyp descr sorts T)
   152       | interpT (T as DtType (tyco, Ts)) = if is_rec_type T
   153           then rtyp tyco (map interpT Ts)
   154           else atom (typ_of_dtyp descr sorts T)
   155       | interpT (DtRec l) = if l < k then (dtyp o #1 o descr_of) l
   156           else let val (tyco, Ts, _) = descr_of l
   157           in rtyp tyco (map interpT Ts) end;
   158     fun interpC (c, Ts) = (c, map interpT Ts);
   159     fun interpK (_, (tyco, _, cs)) = (tyco, map interpC cs);
   160   in map interpK (Library.take (k, descr)) end;
   161 
   162 
   163 
   164 (** induct method setup **)
   165 
   166 (* case names *)
   167 
   168 local
   169 
   170 fun dt_recs (DtTFree _) = []
   171   | dt_recs (DtType (_, dts)) = maps dt_recs dts
   172   | dt_recs (DtRec i) = [i];
   173 
   174 fun dt_cases (descr: descr) (_, args, constrs) =
   175   let
   176     fun the_bname i = Sign.base_name (#1 (the (AList.lookup (op =) descr i)));
   177     val bnames = map the_bname (distinct (op =) (maps dt_recs args));
   178   in map (fn (c, _) => space_implode "_" (Sign.base_name c :: bnames)) constrs end;
   179 
   180 
   181 fun induct_cases descr =
   182   DatatypeProp.indexify_names (maps (dt_cases descr) (map #2 descr));
   183 
   184 fun exhaust_cases descr i = dt_cases descr (the (AList.lookup (op =) descr i));
   185 
   186 in
   187 
   188 fun mk_case_names_induct descr = RuleCases.case_names (induct_cases descr);
   189 
   190 fun mk_case_names_exhausts descr new =
   191   map (RuleCases.case_names o exhaust_cases descr o #1)
   192     (filter (fn ((_, (name, _, _))) => member (op =) new name) descr);
   193 
   194 end;
   195 
   196 fun add_rules simps case_thms rec_thms inject distinct
   197                   weak_case_congs cong_att =
   198   PureThy.add_thmss [(("simps", simps), []),
   199     (("", flat case_thms @
   200           flat distinct @ rec_thms), [Simplifier.simp_add]),
   201     (("", rec_thms), [RecfunCodegen.add_default]),
   202     (("", flat inject), [iff_add]),
   203     (("", map (fn th => th RS notE) (flat distinct)), [Classical.safe_elim NONE]),
   204     (("", weak_case_congs), [cong_att])]
   205   #> snd;
   206 
   207 
   208 (* add_cases_induct *)
   209 
   210 fun add_cases_induct infos induction thy =
   211   let
   212     val inducts = ProjectRule.projections (ProofContext.init thy) induction;
   213 
   214     fun named_rules (name, {index, exhaustion, ...}: datatype_info) =
   215       [(("", nth inducts index), [Induct.induct_type name]),
   216        (("", exhaustion), [Induct.cases_type name])];
   217     fun unnamed_rule i =
   218       (("", nth inducts i), [PureThy.kind_internal, Induct.induct_type ""]);
   219   in
   220     thy |> PureThy.add_thms
   221       (maps named_rules infos @
   222         map unnamed_rule (length infos upto length inducts - 1)) |> snd
   223     |> PureThy.add_thmss [(("inducts", inducts), [])] |> snd
   224   end;
   225 
   226 
   227 
   228 (**** simplification procedure for showing distinctness of constructors ****)
   229 
   230 fun stripT (i, Type ("fun", [_, T])) = stripT (i + 1, T)
   231   | stripT p = p;
   232 
   233 fun stripC (i, f $ x) = stripC (i + 1, f)
   234   | stripC p = p;
   235 
   236 val distinctN = "constr_distinct";
   237 
   238 fun distinct_rule thy ss tname eq_t = case #distinct (the_datatype thy tname) of
   239     FewConstrs thms => Goal.prove (Simplifier.the_context ss) [] [] eq_t (K
   240       (EVERY [rtac eq_reflection 1, rtac iffI 1, rtac notE 1,
   241         atac 2, resolve_tac thms 1, etac FalseE 1]))
   242   | ManyConstrs (thm, simpset) =>
   243       let
   244         val [In0_inject, In1_inject, In0_not_In1, In1_not_In0] =
   245           map (PureThy.get_thm (ThyInfo.the_theory "Datatype" thy))
   246             ["In0_inject", "In1_inject", "In0_not_In1", "In1_not_In0"];
   247       in
   248         Goal.prove (Simplifier.the_context ss) [] [] eq_t (K
   249         (EVERY [rtac eq_reflection 1, rtac iffI 1, dtac thm 1,
   250           full_simp_tac (Simplifier.inherit_context ss simpset) 1,
   251           REPEAT (dresolve_tac [In0_inject, In1_inject] 1),
   252           eresolve_tac [In0_not_In1 RS notE, In1_not_In0 RS notE] 1,
   253           etac FalseE 1]))
   254       end;
   255 
   256 fun distinct_proc thy ss (t as Const ("op =", _) $ t1 $ t2) =
   257   (case (stripC (0, t1), stripC (0, t2)) of
   258      ((i, Const (cname1, T1)), (j, Const (cname2, T2))) =>
   259          (case (stripT (0, T1), stripT (0, T2)) of
   260             ((i', Type (tname1, _)), (j', Type (tname2, _))) =>
   261                 if tname1 = tname2 andalso not (cname1 = cname2) andalso i = i' andalso j = j' then
   262                    (case (get_datatype_descr thy) tname1 of
   263                       SOME (_, (_, constrs)) => let val cnames = map fst constrs
   264                         in if cname1 mem cnames andalso cname2 mem cnames then
   265                              SOME (distinct_rule thy ss tname1
   266                                (Logic.mk_equals (t, Const ("False", HOLogic.boolT))))
   267                            else NONE
   268                         end
   269                     | NONE => NONE)
   270                 else NONE
   271           | _ => NONE)
   272    | _ => NONE)
   273   | distinct_proc _ _ _ = NONE;
   274 
   275 val distinct_simproc =
   276   Simplifier.simproc HOL.thy distinctN ["s = t"] distinct_proc;
   277 
   278 val dist_ss = HOL_ss addsimprocs [distinct_simproc];
   279 
   280 val simproc_setup =
   281   Simplifier.map_simpset (fn ss => ss addsimprocs [distinct_simproc]);
   282 
   283 
   284 (**** translation rules for case ****)
   285 
   286 fun make_case ctxt = DatatypeCase.make_case
   287   (datatype_of_constr (ProofContext.theory_of ctxt)) ctxt;
   288 
   289 fun strip_case ctxt = DatatypeCase.strip_case
   290   (datatype_of_case (ProofContext.theory_of ctxt));
   291 
   292 fun add_case_tr' case_names thy =
   293   Sign.add_advanced_trfuns ([], [],
   294     map (fn case_name =>
   295       let val case_name' = Sign.const_syntax_name thy case_name
   296       in (case_name', DatatypeCase.case_tr' datatype_of_case case_name')
   297       end) case_names, []) thy;
   298 
   299 val trfun_setup =
   300   Sign.add_advanced_trfuns ([],
   301     [("_case_syntax", DatatypeCase.case_tr true datatype_of_constr)],
   302     [], []);
   303 
   304 
   305 (* prepare types *)
   306 
   307 fun read_typ sign ((Ts, sorts), str) =
   308   let
   309     val T = Type.no_tvars (Sign.read_def_typ (sign, AList.lookup (op =)
   310       (map (apfst (rpair ~1)) sorts)) str) handle TYPE (msg, _, _) => error msg
   311   in (Ts @ [T], add_typ_tfrees (T, sorts)) end;
   312 
   313 fun cert_typ sign ((Ts, sorts), raw_T) =
   314   let
   315     val T = Type.no_tvars (Sign.certify_typ sign raw_T) handle
   316       TYPE (msg, _, _) => error msg;
   317     val sorts' = add_typ_tfrees (T, sorts)
   318   in (Ts @ [T],
   319       case duplicates (op =) (map fst sorts') of
   320          [] => sorts'
   321        | dups => error ("Inconsistent sort constraints for " ^ commas dups))
   322   end;
   323 
   324 
   325 (**** make datatype info ****)
   326 
   327 fun make_dt_info alt_names descr sorts induct reccomb_names rec_thms
   328     (((((((((i, (_, (tname, _, _))), case_name), case_thms),
   329       exhaustion_thm), distinct_thm), inject), nchotomy), case_cong), weak_case_cong) =
   330   (tname,
   331    {index = i,
   332     alt_names = alt_names,
   333     descr = descr,
   334     sorts = sorts,
   335     rec_names = reccomb_names,
   336     rec_rewrites = rec_thms,
   337     case_name = case_name,
   338     case_rewrites = case_thms,
   339     induction = induct,
   340     exhaustion = exhaustion_thm,
   341     distinct = distinct_thm,
   342     inject = inject,
   343     nchotomy = nchotomy,
   344     case_cong = case_cong,
   345     weak_case_cong = weak_case_cong});
   346 
   347 structure DatatypeInterpretation = InterpretationFun(type T = string list val eq = op =);
   348 val interpretation = DatatypeInterpretation.interpretation;
   349 
   350 
   351 (******************* definitional introduction of datatypes *******************)
   352 
   353 fun add_datatype_def flat_names new_type_names descr sorts types_syntax constr_syntax dt_info
   354     case_names_induct case_names_exhausts thy =
   355   let
   356     val _ = message ("Proofs for datatype(s) " ^ commas_quote new_type_names);
   357 
   358     val ((inject, distinct, dist_rewrites, simproc_dists, induct), thy2) = thy |>
   359       DatatypeRepProofs.representation_proofs flat_names dt_info new_type_names descr sorts
   360         types_syntax constr_syntax case_names_induct;
   361 
   362     val (casedist_thms, thy3) = DatatypeAbsProofs.prove_casedist_thms new_type_names descr
   363       sorts induct case_names_exhausts thy2;
   364     val ((reccomb_names, rec_thms), thy4) = DatatypeAbsProofs.prove_primrec_thms
   365       flat_names new_type_names descr sorts dt_info inject dist_rewrites
   366       (Simplifier.theory_context thy3 dist_ss) induct thy3;
   367     val ((case_thms, case_names), thy6) = DatatypeAbsProofs.prove_case_thms
   368       flat_names new_type_names descr sorts reccomb_names rec_thms thy4;
   369     val (split_thms, thy7) = DatatypeAbsProofs.prove_split_thms new_type_names
   370       descr sorts inject dist_rewrites casedist_thms case_thms thy6;
   371     val (nchotomys, thy8) = DatatypeAbsProofs.prove_nchotomys new_type_names
   372       descr sorts casedist_thms thy7;
   373     val (case_congs, thy9) = DatatypeAbsProofs.prove_case_congs new_type_names
   374       descr sorts nchotomys case_thms thy8;
   375     val (weak_case_congs, thy10) = DatatypeAbsProofs.prove_weak_case_congs new_type_names
   376       descr sorts thy9;
   377 
   378     val dt_infos = map (make_dt_info NONE (flat descr) sorts induct reccomb_names rec_thms)
   379       ((0 upto length (hd descr) - 1) ~~ (hd descr) ~~ case_names ~~ case_thms ~~
   380         casedist_thms ~~ simproc_dists ~~ inject ~~ nchotomys ~~ case_congs ~~ weak_case_congs);
   381 
   382     val simps = flat (distinct @ inject @ case_thms) @ rec_thms;
   383 
   384     val thy12 =
   385       thy10
   386       |> add_case_tr' case_names
   387       |> Sign.add_path (space_implode "_" new_type_names)
   388       |> add_rules simps case_thms rec_thms inject distinct
   389           weak_case_congs (Simplifier.attrib (op addcongs))
   390       |> put_dt_infos dt_infos
   391       |> add_cases_induct dt_infos induct
   392       |> Sign.parent_path
   393       |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms) |> snd
   394       |> DatatypeInterpretation.data (map fst dt_infos);
   395   in
   396     ({distinct = distinct,
   397       inject = inject,
   398       exhaustion = casedist_thms,
   399       rec_thms = rec_thms,
   400       case_thms = case_thms,
   401       split_thms = split_thms,
   402       induction = induct,
   403       simps = simps}, thy12)
   404   end;
   405 
   406 
   407 (*********************** declare existing type as datatype *********************)
   408 
   409 fun prove_rep_datatype alt_names new_type_names descr sorts induct inject distinct thy =
   410   let
   411     val ((_, [induct']), _) =
   412       Variable.importT_thms [induct] (Variable.thm_context induct);
   413 
   414     fun err t = error ("Ill-formed predicate in induction rule: " ^
   415       Syntax.string_of_term_global thy t);
   416 
   417     fun get_typ (t as _ $ Var (_, Type (tname, Ts))) =
   418           ((tname, map (fst o dest_TFree) Ts) handle TERM _ => err t)
   419       | get_typ t = err t;
   420     val dtnames = map get_typ (HOLogic.dest_conj (HOLogic.dest_Trueprop (Thm.concl_of induct')));
   421 
   422     val dt_info = get_datatypes thy;
   423 
   424     val (case_names_induct, case_names_exhausts) =
   425       (mk_case_names_induct descr, mk_case_names_exhausts descr (map #1 dtnames));
   426 
   427     val _ = message ("Proofs for datatype(s) " ^ commas_quote new_type_names);
   428 
   429     val (casedist_thms, thy2) = thy |>
   430       DatatypeAbsProofs.prove_casedist_thms new_type_names [descr] sorts induct
   431         case_names_exhausts;
   432     val ((reccomb_names, rec_thms), thy3) = DatatypeAbsProofs.prove_primrec_thms
   433       false new_type_names [descr] sorts dt_info inject distinct
   434       (Simplifier.theory_context thy2 dist_ss) induct thy2;
   435     val ((case_thms, case_names), thy4) = DatatypeAbsProofs.prove_case_thms false
   436       new_type_names [descr] sorts reccomb_names rec_thms thy3;
   437     val (split_thms, thy5) = DatatypeAbsProofs.prove_split_thms
   438       new_type_names [descr] sorts inject distinct casedist_thms case_thms thy4;
   439     val (nchotomys, thy6) = DatatypeAbsProofs.prove_nchotomys new_type_names
   440       [descr] sorts casedist_thms thy5;
   441     val (case_congs, thy7) = DatatypeAbsProofs.prove_case_congs new_type_names
   442       [descr] sorts nchotomys case_thms thy6;
   443     val (weak_case_congs, thy8) = DatatypeAbsProofs.prove_weak_case_congs new_type_names
   444       [descr] sorts thy7;
   445 
   446     val ((_, [induct']), thy10) =
   447       thy8
   448       |> store_thmss "inject" new_type_names inject
   449       ||>> store_thmss "distinct" new_type_names distinct
   450       ||> Sign.add_path (space_implode "_" new_type_names)
   451       ||>> PureThy.add_thms [(("induct", induct), [case_names_induct])];
   452 
   453     val dt_infos = map (make_dt_info alt_names descr sorts induct' reccomb_names rec_thms)
   454       ((0 upto length descr - 1) ~~ descr ~~ case_names ~~ case_thms ~~ casedist_thms ~~
   455         map FewConstrs distinct ~~ inject ~~ nchotomys ~~ case_congs ~~ weak_case_congs);
   456 
   457     val simps = flat (distinct @ inject @ case_thms) @ rec_thms;
   458 
   459     val thy11 =
   460       thy10
   461       |> add_case_tr' case_names
   462       |> add_rules simps case_thms rec_thms inject distinct
   463            weak_case_congs (Simplifier.attrib (op addcongs))
   464       |> put_dt_infos dt_infos
   465       |> add_cases_induct dt_infos induct'
   466       |> Sign.parent_path
   467       |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms)
   468       |> snd
   469       |> DatatypeInterpretation.data (map fst dt_infos);
   470   in
   471     ({distinct = distinct,
   472       inject = inject,
   473       exhaustion = casedist_thms,
   474       rec_thms = rec_thms,
   475       case_thms = case_thms,
   476       split_thms = split_thms,
   477       induction = induct',
   478       simps = simps}, thy11)
   479   end;
   480 
   481 fun gen_rep_datatype prep_term after_qed alt_names raw_ts thy =
   482   let
   483     fun constr_of_term (Const (c, T)) = (c, T)
   484       | constr_of_term t =
   485           error ("Not a constant: " ^ Syntax.string_of_term_global thy t);
   486     fun no_constr (c, T) = error ("Bad constructor: "
   487       ^ Sign.extern_const thy c ^ "::"
   488       ^ Syntax.string_of_typ_global thy T);
   489     fun type_of_constr (cT as (_, T)) =
   490       let
   491         val frees = typ_tfrees T;
   492         val (tyco, vs) = ((apsnd o map) (dest_TFree) o dest_Type o snd o strip_type) T
   493           handle TYPE _ => no_constr cT
   494         val _ = if has_duplicates (eq_fst (op =)) vs then no_constr cT else ();
   495         val _ = if length frees <> length vs then no_constr cT else ();
   496       in (tyco, (vs, cT)) end;
   497 
   498     val raw_cs = AList.group (op =) (map (type_of_constr o constr_of_term o prep_term thy) raw_ts);
   499     val _ = case map_filter (fn (tyco, _) =>
   500         if Symtab.defined (get_datatypes thy) tyco then SOME tyco else NONE) raw_cs
   501      of [] => ()
   502       | tycos => error ("Type(s) " ^ commas (map quote tycos)
   503           ^ " already represented inductivly");
   504     val raw_vss = maps (map (map snd o fst) o snd) raw_cs;
   505     val ms = case distinct (op =) (map length raw_vss)
   506      of [n] => 0 upto n - 1
   507       | _ => error ("Different types in given constructors");
   508     fun inter_sort m = map (fn xs => nth xs m) raw_vss
   509       |> Library.foldr1 (Sorts.inter_sort (Sign.classes_of thy))
   510     val sorts = map inter_sort ms;
   511     val vs = Name.names Name.context Name.aT sorts;
   512 
   513     fun norm_constr (raw_vs, (c, T)) = (c, map_atyps
   514       (TFree o (the o AList.lookup (op =) (map fst raw_vs ~~ vs)) o fst o dest_TFree) T);
   515 
   516     val cs = map (apsnd (map norm_constr)) raw_cs;
   517     val dtyps_of_typ = map (dtyp_of_typ (map (rpair (map fst vs) o fst) cs))
   518       o fst o strip_type;
   519     val new_type_names = map NameSpace.base (the_default (map fst cs) alt_names);
   520 
   521     fun mk_spec (i, (tyco, constr)) = (i, (tyco,
   522       map (DtTFree o fst) vs,
   523       (map o apsnd) dtyps_of_typ constr))
   524     val descr = map_index mk_spec cs;
   525     val injs = DatatypeProp.make_injs [descr] vs;
   526     val distincts = map snd (DatatypeProp.make_distincts [descr] vs);
   527     val ind = DatatypeProp.make_ind [descr] vs;
   528     val rules = (map o map o map) Logic.close_form [[[ind]], injs, distincts];
   529 
   530     fun after_qed' raw_thms =
   531       let
   532         val [[[induct]], injs, distincts] =
   533           unflat rules (map Drule.zero_var_indexes_list raw_thms);
   534             (*FIXME somehow dubious*)
   535       in
   536         ProofContext.theory_result
   537           (prove_rep_datatype alt_names new_type_names descr vs induct injs distincts)
   538         #-> after_qed
   539       end;
   540   in
   541     thy
   542     |> ProofContext.init
   543     |> Proof.theorem_i NONE after_qed' ((map o map) (rpair []) (flat rules))
   544   end;
   545 
   546 val rep_datatype = gen_rep_datatype Sign.cert_term;
   547 val rep_datatype_cmd = gen_rep_datatype Syntax.read_term_global (K I);
   548 
   549 
   550 
   551 (******************************** add datatype ********************************)
   552 
   553 fun gen_add_datatype prep_typ err flat_names new_type_names dts thy =
   554   let
   555     val _ = Theory.requires thy "Datatype" "datatype definitions";
   556 
   557     (* this theory is used just for parsing *)
   558 
   559     val tmp_thy = thy |>
   560       Theory.copy |>
   561       Sign.add_types (map (fn (tvs, tname, mx, _) =>
   562         (tname, length tvs, mx)) dts);
   563 
   564     val (tyvars, _, _, _)::_ = dts;
   565     val (new_dts, types_syntax) = ListPair.unzip (map (fn (tvs, tname, mx, _) =>
   566       let val full_tname = Sign.full_name tmp_thy (Syntax.type_name tname mx)
   567       in (case duplicates (op =) tvs of
   568             [] => if eq_set (tyvars, tvs) then ((full_tname, tvs), (tname, mx))
   569                   else error ("Mutually recursive datatypes must have same type parameters")
   570           | dups => error ("Duplicate parameter(s) for datatype " ^ full_tname ^
   571               " : " ^ commas dups))
   572       end) dts);
   573 
   574     val _ = (case duplicates (op =) (map fst new_dts) @ duplicates (op =) new_type_names of
   575       [] => () | dups => error ("Duplicate datatypes: " ^ commas dups));
   576 
   577     fun prep_dt_spec (tvs, tname, mx, constrs) (dts', constr_syntax, sorts, i) =
   578       let
   579         fun prep_constr (cname, cargs, mx') (constrs, constr_syntax', sorts') =
   580           let
   581             val (cargs', sorts'') = Library.foldl (prep_typ tmp_thy) (([], sorts'), cargs);
   582             val _ = (case fold (curry add_typ_tfree_names) cargs' [] \\ tvs of
   583                 [] => ()
   584               | vs => error ("Extra type variables on rhs: " ^ commas vs))
   585           in (constrs @ [((if flat_names then Sign.full_name tmp_thy else
   586                 Sign.full_name_path tmp_thy tname) (Syntax.const_name cname mx'),
   587                    map (dtyp_of_typ new_dts) cargs')],
   588               constr_syntax' @ [(cname, mx')], sorts'')
   589           end handle ERROR msg =>
   590             cat_error msg ("The error above occured in constructor " ^ cname ^
   591               " of datatype " ^ tname);
   592 
   593         val (constrs', constr_syntax', sorts') =
   594           fold prep_constr constrs ([], [], sorts)
   595 
   596       in
   597         case duplicates (op =) (map fst constrs') of
   598            [] =>
   599              (dts' @ [(i, (Sign.full_name tmp_thy (Syntax.type_name tname mx),
   600                 map DtTFree tvs, constrs'))],
   601               constr_syntax @ [constr_syntax'], sorts', i + 1)
   602          | dups => error ("Duplicate constructors " ^ commas dups ^
   603              " in datatype " ^ tname)
   604       end;
   605 
   606     val (dts', constr_syntax, sorts', i) = fold prep_dt_spec dts ([], [], [], 0);
   607     val sorts = sorts' @ (map (rpair (Sign.defaultS tmp_thy)) (tyvars \\ map fst sorts'));
   608     val dt_info = get_datatypes thy;
   609     val (descr, _) = unfold_datatypes tmp_thy dts' sorts dt_info dts' i;
   610     val _ = check_nonempty descr handle (exn as Datatype_Empty s) =>
   611       if err then error ("Nonemptiness check failed for datatype " ^ s)
   612       else raise exn;
   613 
   614     val descr' = flat descr;
   615     val case_names_induct = mk_case_names_induct descr';
   616     val case_names_exhausts = mk_case_names_exhausts descr' (map #1 new_dts);
   617   in
   618     add_datatype_def
   619       flat_names new_type_names descr sorts types_syntax constr_syntax dt_info
   620       case_names_induct case_names_exhausts thy
   621   end;
   622 
   623 val add_datatype = gen_add_datatype cert_typ;
   624 val add_datatype_cmd = gen_add_datatype read_typ true;
   625 
   626 
   627 (** a datatype antiquotation **)
   628 
   629 local
   630 
   631 val sym_datatype = Pretty.str "\\isacommand{datatype}";
   632 val sym_binder = Pretty.str "{\\isacharequal}";
   633 val sym_of = Pretty.str "of";
   634 val sym_sep = Pretty.str "{\\isacharbar}";
   635 
   636 in
   637 
   638 fun args_datatype (ctxt, args) =
   639   let
   640     val (tyco, (ctxt', args')) = Args.tyname (ctxt, args);
   641     val thy = Context.theory_of ctxt';
   642     val spec = the_datatype_spec thy tyco;
   643   in ((tyco, spec), (ctxt', args')) end;
   644 
   645 fun pretty_datatype ctxt (dtco, (vs, cos)) =
   646   let
   647     val ty = Type (dtco, map TFree vs);
   648     fun pretty_typ_br ty =
   649       let
   650         val p = Syntax.pretty_typ ctxt ty;
   651         val s = explode (Pretty.str_of p);
   652       in if member (op =) s " " then Pretty.enclose "(" ")" [p]
   653         else p
   654       end;
   655     fun pretty_constr (co, []) =
   656           Syntax.pretty_term ctxt (Const (co, ty))
   657       | pretty_constr (co, [ty']) =
   658           (Pretty.block o Pretty.breaks)
   659             [Syntax.pretty_term ctxt (Const (co, ty' --> ty)),
   660               sym_of, Syntax.pretty_typ ctxt ty']
   661       | pretty_constr (co, tys) =
   662           (Pretty.block o Pretty.breaks)
   663             (Syntax.pretty_term ctxt (Const (co, tys ---> ty)) ::
   664               sym_of :: map pretty_typ_br tys);
   665   in (Pretty.block o Pretty.breaks) (
   666     sym_datatype
   667     :: Syntax.pretty_typ ctxt ty
   668     :: sym_binder
   669     :: separate sym_sep (map pretty_constr cos)
   670   ) end
   671 
   672 end;
   673 
   674 (** package setup **)
   675 
   676 (* setup theory *)
   677 
   678 val setup =
   679   DatatypeRepProofs.distinctness_limit_setup #>
   680   simproc_setup #>
   681   trfun_setup #>
   682   DatatypeInterpretation.init;
   683 
   684 
   685 (* outer syntax *)
   686 
   687 local structure P = OuterParse and K = OuterKeyword in
   688 
   689 val datatype_decl =
   690   Scan.option (P.$$$ "(" |-- P.name --| P.$$$ ")") -- P.type_args -- P.name -- P.opt_infix --
   691     (P.$$$ "=" |-- P.enum1 "|" (P.name -- Scan.repeat P.typ -- P.opt_mixfix));
   692 
   693 fun mk_datatype args =
   694   let
   695     val names = map (fn ((((NONE, _), t), _), _) => t | ((((SOME t, _), _), _), _) => t) args;
   696     val specs = map (fn ((((_, vs), t), mx), cons) =>
   697       (vs, t, mx, map (fn ((x, y), z) => (x, y, z)) cons)) args;
   698   in snd o add_datatype_cmd false names specs end;
   699 
   700 val _ =
   701   OuterSyntax.command "datatype" "define inductive datatypes" K.thy_decl
   702     (P.and_list1 datatype_decl >> (Toplevel.theory o mk_datatype));
   703 
   704 val _ =
   705   OuterSyntax.command "rep_datatype" "represent existing types inductively" K.thy_goal
   706     (Scan.option (P.$$$ "(" |-- Scan.repeat1 P.name --| P.$$$ ")") -- Scan.repeat1 P.term
   707       >> (fn (alt_names, ts) => Toplevel.print
   708            o Toplevel.theory_to_proof (rep_datatype_cmd alt_names ts)));
   709 
   710 val _ =
   711   ThyOutput.add_commands [("datatype",
   712     ThyOutput.args args_datatype (ThyOutput.output pretty_datatype))];
   713 
   714 end;
   715 
   716 end;
   717