src/HOL/Tools/Datatype/datatype.ML
author haftmann
Tue Jun 23 14:50:34 2009 +0200 (2009-06-23)
changeset 31781 861e675f01e6
parent 31775 2b04504fcb69
child 31783 cfbe9609ceb1
permissions -rw-r--r--
add_datatype interface yields type names and less rules
     1 (*  Title:      HOL/Tools/datatype.ML
     2     Author:     Stefan Berghofer, TU Muenchen
     3 
     4 Datatype package for Isabelle/HOL.
     5 *)
     6 
     7 signature DATATYPE =
     8 sig
     9   include DATATYPE_COMMON
    10   val add_datatype : config -> string list -> (string list * binding * mixfix *
    11     (binding * typ list * mixfix) list) list -> theory ->
    12      (string list * {inject : thm list list,
    13       rec_thms : thm list,
    14       case_thms : thm list list,
    15       induction : thm}) * theory
    16   val datatype_cmd : string list -> (string list * binding * mixfix *
    17     (binding * string list * mixfix) list) list -> theory -> theory
    18   val rep_datatype : config -> (string list -> Proof.context -> Proof.context)
    19     -> string list option -> term list -> theory -> Proof.state
    20   val rep_datatype_cmd : string list option -> string list -> theory -> Proof.state
    21   val get_datatypes : theory -> info Symtab.table
    22   val get_datatype : theory -> string -> info option
    23   val the_datatype : theory -> string -> info
    24   val datatype_of_constr : theory -> string -> info option
    25   val datatype_of_case : theory -> string -> info option
    26   val the_datatype_spec : theory -> string -> (string * sort) list * (string * typ list) list
    27   val the_datatype_descr : theory -> string list
    28     -> descr * (string * sort) list * string list
    29       * (string list * string list) * (typ list * typ list)
    30   val get_datatype_constrs : theory -> string -> (string * typ) list option
    31   val interpretation : (config -> string list -> theory -> theory) -> theory -> theory
    32   val distinct_simproc : simproc
    33   val make_case :  Proof.context -> bool -> string list -> term ->
    34     (term * term) list -> term * (term * (int * bool)) list
    35   val strip_case : Proof.context -> bool -> term -> (term * (term * term) list) option
    36   val read_typ: theory ->
    37     (typ list * (string * sort) list) * string -> typ list * (string * sort) list
    38   val setup: theory -> theory
    39 end;
    40 
    41 structure Datatype : DATATYPE =
    42 struct
    43 
    44 open DatatypeAux;
    45 
    46 
    47 (* theory data *)
    48 
    49 structure DatatypesData = TheoryDataFun
    50 (
    51   type T =
    52     {types: info Symtab.table,
    53      constrs: info Symtab.table,
    54      cases: info Symtab.table};
    55 
    56   val empty =
    57     {types = Symtab.empty, constrs = Symtab.empty, cases = Symtab.empty};
    58   val copy = I;
    59   val extend = I;
    60   fun merge _
    61     ({types = types1, constrs = constrs1, cases = cases1},
    62      {types = types2, constrs = constrs2, cases = cases2}) =
    63     {types = Symtab.merge (K true) (types1, types2),
    64      constrs = Symtab.merge (K true) (constrs1, constrs2),
    65      cases = Symtab.merge (K true) (cases1, cases2)};
    66 );
    67 
    68 val get_datatypes = #types o DatatypesData.get;
    69 val map_datatypes = DatatypesData.map;
    70 
    71 
    72 (** theory information about datatypes **)
    73 
    74 fun put_dt_infos (dt_infos : (string * info) list) =
    75   map_datatypes (fn {types, constrs, cases} =>
    76     {types = fold Symtab.update dt_infos types,
    77      constrs = fold Symtab.default (*conservative wrt. overloaded constructors*)
    78        (maps (fn (_, info as {descr, index, ...}) => map (rpair info o fst)
    79           (#3 (the (AList.lookup op = descr index)))) dt_infos) constrs,
    80      cases = fold Symtab.update
    81        (map (fn (_, info as {case_name, ...}) => (case_name, info)) dt_infos)
    82        cases});
    83 
    84 val get_datatype = Symtab.lookup o get_datatypes;
    85 
    86 fun the_datatype thy name = (case get_datatype thy name of
    87       SOME info => info
    88     | NONE => error ("Unknown datatype " ^ quote name));
    89 
    90 val datatype_of_constr = Symtab.lookup o #constrs o DatatypesData.get;
    91 val datatype_of_case = Symtab.lookup o #cases o DatatypesData.get;
    92 
    93 fun get_datatype_descr thy dtco =
    94   get_datatype thy dtco
    95   |> Option.map (fn info as { descr, index, ... } =>
    96        (info, (((fn SOME (_, dtys, cos) => (dtys, cos)) o AList.lookup (op =) descr) index)));
    97 
    98 fun the_datatype_spec thy dtco =
    99   let
   100     val info as { descr, index, sorts = raw_sorts, ... } = the_datatype thy dtco;
   101     val SOME (_, dtys, raw_cos) = AList.lookup (op =) descr index;
   102     val sorts = map ((fn v => (v, (the o AList.lookup (op =) raw_sorts) v))
   103       o DatatypeAux.dest_DtTFree) dtys;
   104     val cos = map
   105       (fn (co, tys) => (co, map (DatatypeAux.typ_of_dtyp descr sorts) tys)) raw_cos;
   106   in (sorts, cos) end;
   107 
   108 fun the_datatype_descr thy (raw_tycos as raw_tyco :: _) =
   109   let
   110     val info = the_datatype thy raw_tyco;
   111     val descr = #descr info;
   112 
   113     val SOME (_, dtys, raw_cos) = AList.lookup (op =) descr (#index info);
   114     val vs = map ((fn v => (v, (the o AList.lookup (op =) (#sorts info)) v))
   115       o dest_DtTFree) dtys;
   116 
   117     fun is_DtTFree (DtTFree _) = true
   118       | is_DtTFree _ = false
   119     val k = find_index (fn (_, (_, dTs, _)) => not (forall is_DtTFree dTs)) descr;
   120     val protoTs as (dataTs, _) = chop k descr
   121       |> (pairself o map) (fn (_, (tyco, dTs, _)) => (tyco, map (typ_of_dtyp descr vs) dTs));
   122     
   123     val tycos = map fst dataTs;
   124     val _ = if gen_eq_set (op =) (tycos, raw_tycos) then ()
   125       else error ("Type constructors " ^ commas (map quote raw_tycos)
   126         ^ "do not belong exhaustively to one mutual recursive datatype");
   127 
   128     val (Ts, Us) = (pairself o map) Type protoTs;
   129 
   130     val names = map Long_Name.base_name (the_default tycos (#alt_names info));
   131     val (auxnames, _) = Name.make_context names
   132       |> fold_map (yield_singleton Name.variants o name_of_typ) Us
   133 
   134   in (descr, vs, tycos, (names, auxnames), (Ts, Us)) end;
   135 
   136 fun get_datatype_constrs thy dtco =
   137   case try (the_datatype_spec thy) dtco
   138    of SOME (sorts, cos) =>
   139         let
   140           fun subst (v, sort) = TVar ((v, 0), sort);
   141           fun subst_ty (TFree v) = subst v
   142             | subst_ty ty = ty;
   143           val dty = Type (dtco, map subst sorts);
   144           fun mk_co (co, tys) = (co, map (Term.map_atyps subst_ty) tys ---> dty);
   145         in SOME (map mk_co cos) end
   146     | NONE => NONE;
   147 
   148 
   149 (** induct method setup **)
   150 
   151 (* case names *)
   152 
   153 local
   154 
   155 fun dt_recs (DtTFree _) = []
   156   | dt_recs (DtType (_, dts)) = maps dt_recs dts
   157   | dt_recs (DtRec i) = [i];
   158 
   159 fun dt_cases (descr: descr) (_, args, constrs) =
   160   let
   161     fun the_bname i = Long_Name.base_name (#1 (the (AList.lookup (op =) descr i)));
   162     val bnames = map the_bname (distinct (op =) (maps dt_recs args));
   163   in map (fn (c, _) => space_implode "_" (Long_Name.base_name c :: bnames)) constrs end;
   164 
   165 
   166 fun induct_cases descr =
   167   DatatypeProp.indexify_names (maps (dt_cases descr) (map #2 descr));
   168 
   169 fun exhaust_cases descr i = dt_cases descr (the (AList.lookup (op =) descr i));
   170 
   171 in
   172 
   173 fun mk_case_names_induct descr = RuleCases.case_names (induct_cases descr);
   174 
   175 fun mk_case_names_exhausts descr new =
   176   map (RuleCases.case_names o exhaust_cases descr o #1)
   177     (filter (fn ((_, (name, _, _))) => member (op =) new name) descr);
   178 
   179 end;
   180 
   181 fun add_rules simps case_thms rec_thms inject distinct
   182                   weak_case_congs cong_att =
   183   PureThy.add_thmss [((Binding.name "simps", simps), []),
   184     ((Binding.empty, flat case_thms @
   185           flat distinct @ rec_thms), [Simplifier.simp_add]),
   186     ((Binding.empty, rec_thms), [Code.add_default_eqn_attribute]),
   187     ((Binding.empty, flat inject), [iff_add]),
   188     ((Binding.empty, map (fn th => th RS notE) (flat distinct)), [Classical.safe_elim NONE]),
   189     ((Binding.empty, weak_case_congs), [cong_att])]
   190   #> snd;
   191 
   192 
   193 (* add_cases_induct *)
   194 
   195 fun add_cases_induct infos induction thy =
   196   let
   197     val inducts = ProjectRule.projections (ProofContext.init thy) induction;
   198 
   199     fun named_rules (name, {index, exhaustion, ...}: info) =
   200       [((Binding.empty, nth inducts index), [Induct.induct_type name]),
   201        ((Binding.empty, exhaustion), [Induct.cases_type name])];
   202     fun unnamed_rule i =
   203       ((Binding.empty, nth inducts i), [Thm.kind_internal, Induct.induct_type ""]);
   204   in
   205     thy |> PureThy.add_thms
   206       (maps named_rules infos @
   207         map unnamed_rule (length infos upto length inducts - 1)) |> snd
   208     |> PureThy.add_thmss [((Binding.name "inducts", inducts), [])] |> snd
   209   end;
   210 
   211 
   212 
   213 (**** simplification procedure for showing distinctness of constructors ****)
   214 
   215 fun stripT (i, Type ("fun", [_, T])) = stripT (i + 1, T)
   216   | stripT p = p;
   217 
   218 fun stripC (i, f $ x) = stripC (i + 1, f)
   219   | stripC p = p;
   220 
   221 val distinctN = "constr_distinct";
   222 
   223 fun distinct_rule thy ss tname eq_t = case #distinct (the_datatype thy tname) of
   224     FewConstrs thms => Goal.prove (Simplifier.the_context ss) [] [] eq_t (K
   225       (EVERY [rtac eq_reflection 1, rtac iffI 1, rtac notE 1,
   226         atac 2, resolve_tac thms 1, etac FalseE 1]))
   227   | ManyConstrs (thm, simpset) =>
   228       let
   229         val [In0_inject, In1_inject, In0_not_In1, In1_not_In0] =
   230           map (PureThy.get_thm (ThyInfo.the_theory "Datatype" thy))
   231             ["In0_inject", "In1_inject", "In0_not_In1", "In1_not_In0"];
   232       in
   233         Goal.prove (Simplifier.the_context ss) [] [] eq_t (K
   234         (EVERY [rtac eq_reflection 1, rtac iffI 1, dtac thm 1,
   235           full_simp_tac (Simplifier.inherit_context ss simpset) 1,
   236           REPEAT (dresolve_tac [In0_inject, In1_inject] 1),
   237           eresolve_tac [In0_not_In1 RS notE, In1_not_In0 RS notE] 1,
   238           etac FalseE 1]))
   239       end;
   240 
   241 fun distinct_proc thy ss (t as Const ("op =", _) $ t1 $ t2) =
   242   (case (stripC (0, t1), stripC (0, t2)) of
   243      ((i, Const (cname1, T1)), (j, Const (cname2, T2))) =>
   244          (case (stripT (0, T1), stripT (0, T2)) of
   245             ((i', Type (tname1, _)), (j', Type (tname2, _))) =>
   246                 if tname1 = tname2 andalso not (cname1 = cname2) andalso i = i' andalso j = j' then
   247                    (case (get_datatype_descr thy) tname1 of
   248                       SOME (_, (_, constrs)) => let val cnames = map fst constrs
   249                         in if cname1 mem cnames andalso cname2 mem cnames then
   250                              SOME (distinct_rule thy ss tname1
   251                                (Logic.mk_equals (t, Const ("False", HOLogic.boolT))))
   252                            else NONE
   253                         end
   254                     | NONE => NONE)
   255                 else NONE
   256           | _ => NONE)
   257    | _ => NONE)
   258   | distinct_proc _ _ _ = NONE;
   259 
   260 val distinct_simproc =
   261   Simplifier.simproc @{theory HOL} distinctN ["s = t"] distinct_proc;
   262 
   263 val dist_ss = HOL_ss addsimprocs [distinct_simproc];
   264 
   265 val simproc_setup =
   266   Simplifier.map_simpset (fn ss => ss addsimprocs [distinct_simproc]);
   267 
   268 
   269 (**** translation rules for case ****)
   270 
   271 fun make_case ctxt = DatatypeCase.make_case
   272   (datatype_of_constr (ProofContext.theory_of ctxt)) ctxt;
   273 
   274 fun strip_case ctxt = DatatypeCase.strip_case
   275   (datatype_of_case (ProofContext.theory_of ctxt));
   276 
   277 fun add_case_tr' case_names thy =
   278   Sign.add_advanced_trfuns ([], [],
   279     map (fn case_name =>
   280       let val case_name' = Sign.const_syntax_name thy case_name
   281       in (case_name', DatatypeCase.case_tr' datatype_of_case case_name')
   282       end) case_names, []) thy;
   283 
   284 val trfun_setup =
   285   Sign.add_advanced_trfuns ([],
   286     [("_case_syntax", DatatypeCase.case_tr true datatype_of_constr)],
   287     [], []);
   288 
   289 
   290 (* prepare types *)
   291 
   292 fun read_typ thy ((Ts, sorts), str) =
   293   let
   294     val ctxt = ProofContext.init thy
   295       |> fold (Variable.declare_typ o TFree) sorts;
   296     val T = Syntax.read_typ ctxt str;
   297   in (Ts @ [T], Term.add_tfreesT T sorts) end;
   298 
   299 fun cert_typ sign ((Ts, sorts), raw_T) =
   300   let
   301     val T = Type.no_tvars (Sign.certify_typ sign raw_T) handle
   302       TYPE (msg, _, _) => error msg;
   303     val sorts' = Term.add_tfreesT T sorts;
   304   in (Ts @ [T],
   305       case duplicates (op =) (map fst sorts') of
   306          [] => sorts'
   307        | dups => error ("Inconsistent sort constraints for " ^ commas dups))
   308   end;
   309 
   310 
   311 (**** make datatype info ****)
   312 
   313 fun make_dt_info alt_names descr sorts induct reccomb_names rec_thms
   314     (((((((((i, (_, (tname, _, _))), case_name), case_thms),
   315       exhaustion_thm), distinct_thm), inject), nchotomy), case_cong), weak_case_cong) =
   316   (tname,
   317    {index = i,
   318     alt_names = alt_names,
   319     descr = descr,
   320     sorts = sorts,
   321     rec_names = reccomb_names,
   322     rec_rewrites = rec_thms,
   323     case_name = case_name,
   324     case_rewrites = case_thms,
   325     induction = induct,
   326     exhaustion = exhaustion_thm,
   327     distinct = distinct_thm,
   328     inject = inject,
   329     nchotomy = nchotomy,
   330     case_cong = case_cong,
   331     weak_case_cong = weak_case_cong});
   332 
   333 structure DatatypeInterpretation = InterpretationFun
   334   (type T = config * string list val eq: T * T -> bool = eq_snd op =);
   335 fun interpretation f = DatatypeInterpretation.interpretation (uncurry f);
   336 
   337 
   338 (******************* definitional introduction of datatypes *******************)
   339 
   340 fun add_datatype_def (config : config) new_type_names descr sorts types_syntax constr_syntax dt_info
   341     case_names_induct case_names_exhausts thy =
   342   let
   343     val _ = message config ("Proofs for datatype(s) " ^ commas_quote new_type_names);
   344 
   345     val ((inject, distinct, dist_rewrites, simproc_dists, induct), thy2) = thy |>
   346       DatatypeRepProofs.representation_proofs config dt_info new_type_names descr sorts
   347         types_syntax constr_syntax case_names_induct;
   348 
   349     val (casedist_thms, thy3) = DatatypeAbsProofs.prove_casedist_thms config new_type_names descr
   350       sorts induct case_names_exhausts thy2;
   351     val ((reccomb_names, rec_thms), thy4) = DatatypeAbsProofs.prove_primrec_thms
   352       config new_type_names descr sorts dt_info inject dist_rewrites
   353       (Simplifier.theory_context thy3 dist_ss) induct thy3;
   354     val ((case_thms, case_names), thy6) = DatatypeAbsProofs.prove_case_thms
   355       config new_type_names descr sorts reccomb_names rec_thms thy4;
   356     val (split_thms, thy7) = DatatypeAbsProofs.prove_split_thms config new_type_names
   357       descr sorts inject dist_rewrites casedist_thms case_thms thy6;
   358     val (nchotomys, thy8) = DatatypeAbsProofs.prove_nchotomys config new_type_names
   359       descr sorts casedist_thms thy7;
   360     val (case_congs, thy9) = DatatypeAbsProofs.prove_case_congs new_type_names
   361       descr sorts nchotomys case_thms thy8;
   362     val (weak_case_congs, thy10) = DatatypeAbsProofs.prove_weak_case_congs new_type_names
   363       descr sorts thy9;
   364 
   365     val dt_infos = map
   366       (make_dt_info (SOME new_type_names) (flat descr) sorts induct reccomb_names rec_thms)
   367       ((0 upto length (hd descr) - 1) ~~ hd descr ~~ case_names ~~ case_thms ~~
   368         casedist_thms ~~ simproc_dists ~~ inject ~~ nchotomys ~~ case_congs ~~ weak_case_congs);
   369 
   370     val simps = flat (distinct @ inject @ case_thms) @ rec_thms;
   371     val dt_names = map fst dt_infos;
   372 
   373     val thy12 =
   374       thy10
   375       |> add_case_tr' case_names
   376       |> Sign.add_path (space_implode "_" new_type_names)
   377       |> add_rules simps case_thms rec_thms inject distinct
   378           weak_case_congs (Simplifier.attrib (op addcongs))
   379       |> put_dt_infos dt_infos
   380       |> add_cases_induct dt_infos induct
   381       |> Sign.parent_path
   382       |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms) |> snd
   383       |> DatatypeInterpretation.data (config, map fst dt_infos);
   384   in
   385     ((dt_names, {inject = inject,
   386       rec_thms = rec_thms,
   387       case_thms = case_thms,
   388       induction = induct}), thy12)
   389   end;
   390 
   391 
   392 (*********************** declare existing type as datatype *********************)
   393 
   394 fun prove_rep_datatype (config : config) alt_names new_type_names descr sorts induct inject half_distinct thy =
   395   let
   396     val ((_, [induct']), _) =
   397       Variable.importT_thms [induct] (Variable.thm_context induct);
   398 
   399     fun err t = error ("Ill-formed predicate in induction rule: " ^
   400       Syntax.string_of_term_global thy t);
   401 
   402     fun get_typ (t as _ $ Var (_, Type (tname, Ts))) =
   403           ((tname, map (fst o dest_TFree) Ts) handle TERM _ => err t)
   404       | get_typ t = err t;
   405     val dtnames = map get_typ (HOLogic.dest_conj (HOLogic.dest_Trueprop (Thm.concl_of induct')));
   406 
   407     val dt_info = get_datatypes thy;
   408 
   409     val distinct = (map o maps) (fn thm => [thm, thm RS not_sym]) half_distinct;
   410     val (case_names_induct, case_names_exhausts) =
   411       (mk_case_names_induct descr, mk_case_names_exhausts descr (map #1 dtnames));
   412 
   413     val _ = message config ("Proofs for datatype(s) " ^ commas_quote new_type_names);
   414 
   415     val (casedist_thms, thy2) = thy |>
   416       DatatypeAbsProofs.prove_casedist_thms config new_type_names [descr] sorts induct
   417         case_names_exhausts;
   418     val ((reccomb_names, rec_thms), thy3) = DatatypeAbsProofs.prove_primrec_thms
   419       config new_type_names [descr] sorts dt_info inject distinct
   420       (Simplifier.theory_context thy2 dist_ss) induct thy2;
   421     val ((case_thms, case_names), thy4) = DatatypeAbsProofs.prove_case_thms
   422       config new_type_names [descr] sorts reccomb_names rec_thms thy3;
   423     val (split_thms, thy5) = DatatypeAbsProofs.prove_split_thms
   424       config new_type_names [descr] sorts inject distinct casedist_thms case_thms thy4;
   425     val (nchotomys, thy6) = DatatypeAbsProofs.prove_nchotomys config new_type_names
   426       [descr] sorts casedist_thms thy5;
   427     val (case_congs, thy7) = DatatypeAbsProofs.prove_case_congs new_type_names
   428       [descr] sorts nchotomys case_thms thy6;
   429     val (weak_case_congs, thy8) = DatatypeAbsProofs.prove_weak_case_congs new_type_names
   430       [descr] sorts thy7;
   431 
   432     val ((_, [induct']), thy10) =
   433       thy8
   434       |> store_thmss "inject" new_type_names inject
   435       ||>> store_thmss "distinct" new_type_names distinct
   436       ||> Sign.add_path (space_implode "_" new_type_names)
   437       ||>> PureThy.add_thms [((Binding.name "induct", induct), [case_names_induct])];
   438 
   439     val dt_infos = map (make_dt_info alt_names descr sorts induct' reccomb_names rec_thms)
   440       ((0 upto length descr - 1) ~~ descr ~~ case_names ~~ case_thms ~~ casedist_thms ~~
   441         map FewConstrs distinct ~~ inject ~~ nchotomys ~~ case_congs ~~ weak_case_congs);
   442 
   443     val simps = flat (distinct @ inject @ case_thms) @ rec_thms;
   444     val dt_names = map fst dt_infos;
   445 
   446     val thy11 =
   447       thy10
   448       |> add_case_tr' case_names
   449       |> add_rules simps case_thms rec_thms inject distinct
   450            weak_case_congs (Simplifier.attrib (op addcongs))
   451       |> put_dt_infos dt_infos
   452       |> add_cases_induct dt_infos induct'
   453       |> Sign.parent_path
   454       |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms)
   455       |> snd
   456       |> DatatypeInterpretation.data (config, dt_names);
   457   in (dt_names, thy11) end;
   458 
   459 fun gen_rep_datatype prep_term (config : config) after_qed alt_names raw_ts thy =
   460   let
   461     fun constr_of_term (Const (c, T)) = (c, T)
   462       | constr_of_term t =
   463           error ("Not a constant: " ^ Syntax.string_of_term_global thy t);
   464     fun no_constr (c, T) = error ("Bad constructor: "
   465       ^ Sign.extern_const thy c ^ "::"
   466       ^ Syntax.string_of_typ_global thy T);
   467     fun type_of_constr (cT as (_, T)) =
   468       let
   469         val frees = OldTerm.typ_tfrees T;
   470         val (tyco, vs) = ((apsnd o map) (dest_TFree) o dest_Type o snd o strip_type) T
   471           handle TYPE _ => no_constr cT
   472         val _ = if has_duplicates (eq_fst (op =)) vs then no_constr cT else ();
   473         val _ = if length frees <> length vs then no_constr cT else ();
   474       in (tyco, (vs, cT)) end;
   475 
   476     val raw_cs = AList.group (op =) (map (type_of_constr o constr_of_term o prep_term thy) raw_ts);
   477     val _ = case map_filter (fn (tyco, _) =>
   478         if Symtab.defined (get_datatypes thy) tyco then SOME tyco else NONE) raw_cs
   479      of [] => ()
   480       | tycos => error ("Type(s) " ^ commas (map quote tycos)
   481           ^ " already represented inductivly");
   482     val raw_vss = maps (map (map snd o fst) o snd) raw_cs;
   483     val ms = case distinct (op =) (map length raw_vss)
   484      of [n] => 0 upto n - 1
   485       | _ => error ("Different types in given constructors");
   486     fun inter_sort m = map (fn xs => nth xs m) raw_vss
   487       |> Library.foldr1 (Sorts.inter_sort (Sign.classes_of thy))
   488     val sorts = map inter_sort ms;
   489     val vs = Name.names Name.context Name.aT sorts;
   490 
   491     fun norm_constr (raw_vs, (c, T)) = (c, map_atyps
   492       (TFree o (the o AList.lookup (op =) (map fst raw_vs ~~ vs)) o fst o dest_TFree) T);
   493 
   494     val cs = map (apsnd (map norm_constr)) raw_cs;
   495     val dtyps_of_typ = map (dtyp_of_typ (map (rpair (map fst vs) o fst) cs))
   496       o fst o strip_type;
   497     val new_type_names = map Long_Name.base_name (the_default (map fst cs) alt_names);
   498 
   499     fun mk_spec (i, (tyco, constr)) = (i, (tyco,
   500       map (DtTFree o fst) vs,
   501       (map o apsnd) dtyps_of_typ constr))
   502     val descr = map_index mk_spec cs;
   503     val injs = DatatypeProp.make_injs [descr] vs;
   504     val half_distincts = map snd (DatatypeProp.make_distincts [descr] vs);
   505     val ind = DatatypeProp.make_ind [descr] vs;
   506     val rules = (map o map o map) Logic.close_form [[[ind]], injs, half_distincts];
   507 
   508     fun after_qed' raw_thms =
   509       let
   510         val [[[induct]], injs, half_distincts] =
   511           unflat rules (map Drule.zero_var_indexes_list raw_thms);
   512             (*FIXME somehow dubious*)
   513       in
   514         ProofContext.theory_result
   515           (prove_rep_datatype config alt_names new_type_names descr vs induct injs half_distincts)
   516         #-> after_qed
   517       end;
   518   in
   519     thy
   520     |> ProofContext.init
   521     |> Proof.theorem_i NONE after_qed' ((map o map) (rpair []) (flat rules))
   522   end;
   523 
   524 val rep_datatype = gen_rep_datatype Sign.cert_term;
   525 val rep_datatype_cmd = gen_rep_datatype Syntax.read_term_global default_config (K I);
   526 
   527 
   528 
   529 (******************************** add datatype ********************************)
   530 
   531 fun gen_add_datatype prep_typ (config : config) new_type_names dts thy =
   532   let
   533     val _ = Theory.requires thy "Datatype" "datatype definitions";
   534 
   535     (* this theory is used just for parsing *)
   536 
   537     val tmp_thy = thy |>
   538       Theory.copy |>
   539       Sign.add_types (map (fn (tvs, tname, mx, _) =>
   540         (tname, length tvs, mx)) dts);
   541 
   542     val (tyvars, _, _, _)::_ = dts;
   543     val (new_dts, types_syntax) = ListPair.unzip (map (fn (tvs, tname, mx, _) =>
   544       let val full_tname = Sign.full_name tmp_thy (Binding.map_name (Syntax.type_name mx) tname)
   545       in (case duplicates (op =) tvs of
   546             [] => if eq_set (tyvars, tvs) then ((full_tname, tvs), (tname, mx))
   547                   else error ("Mutually recursive datatypes must have same type parameters")
   548           | dups => error ("Duplicate parameter(s) for datatype " ^ quote (Binding.str_of tname) ^
   549               " : " ^ commas dups))
   550       end) dts);
   551 
   552     val _ = (case duplicates (op =) (map fst new_dts) @ duplicates (op =) new_type_names of
   553       [] => () | dups => error ("Duplicate datatypes: " ^ commas dups));
   554 
   555     fun prep_dt_spec ((tvs, tname, mx, constrs), tname') (dts', constr_syntax, sorts, i) =
   556       let
   557         fun prep_constr (cname, cargs, mx') (constrs, constr_syntax', sorts') =
   558           let
   559             val (cargs', sorts'') = Library.foldl (prep_typ tmp_thy) (([], sorts'), cargs);
   560             val _ = (case fold (curry OldTerm.add_typ_tfree_names) cargs' [] \\ tvs of
   561                 [] => ()
   562               | vs => error ("Extra type variables on rhs: " ^ commas vs))
   563           in (constrs @ [((if #flat_names config then Sign.full_name tmp_thy else
   564                 Sign.full_name_path tmp_thy tname')
   565                   (Binding.map_name (Syntax.const_name mx') cname),
   566                    map (dtyp_of_typ new_dts) cargs')],
   567               constr_syntax' @ [(cname, mx')], sorts'')
   568           end handle ERROR msg => cat_error msg
   569            ("The error above occured in constructor " ^ quote (Binding.str_of cname) ^
   570             " of datatype " ^ quote (Binding.str_of tname));
   571 
   572         val (constrs', constr_syntax', sorts') =
   573           fold prep_constr constrs ([], [], sorts)
   574 
   575       in
   576         case duplicates (op =) (map fst constrs') of
   577            [] =>
   578              (dts' @ [(i, (Sign.full_name tmp_thy (Binding.map_name (Syntax.type_name mx) tname),
   579                 map DtTFree tvs, constrs'))],
   580               constr_syntax @ [constr_syntax'], sorts', i + 1)
   581          | dups => error ("Duplicate constructors " ^ commas dups ^
   582              " in datatype " ^ quote (Binding.str_of tname))
   583       end;
   584 
   585     val (dts', constr_syntax, sorts', i) =
   586       fold prep_dt_spec (dts ~~ new_type_names) ([], [], [], 0);
   587     val sorts = sorts' @ (map (rpair (Sign.defaultS tmp_thy)) (tyvars \\ map fst sorts'));
   588     val dt_info = get_datatypes thy;
   589     val (descr, _) = unfold_datatypes tmp_thy dts' sorts dt_info dts' i;
   590     val _ = check_nonempty descr handle (exn as Datatype_Empty s) =>
   591       if #strict config then error ("Nonemptiness check failed for datatype " ^ s)
   592       else raise exn;
   593 
   594     val descr' = flat descr;
   595     val case_names_induct = mk_case_names_induct descr';
   596     val case_names_exhausts = mk_case_names_exhausts descr' (map #1 new_dts);
   597   in
   598     add_datatype_def
   599       (config : config) new_type_names descr sorts types_syntax constr_syntax dt_info
   600       case_names_induct case_names_exhausts thy
   601   end;
   602 
   603 val add_datatype = gen_add_datatype cert_typ;
   604 val datatype_cmd = snd ooo gen_add_datatype read_typ default_config;
   605 
   606 
   607 
   608 (** package setup **)
   609 
   610 (* setup theory *)
   611 
   612 val setup =
   613   DatatypeRepProofs.distinctness_limit_setup #>
   614   simproc_setup #>
   615   trfun_setup #>
   616   DatatypeInterpretation.init;
   617 
   618 
   619 (* outer syntax *)
   620 
   621 local
   622 
   623 structure P = OuterParse and K = OuterKeyword
   624 
   625 fun prep_datatype_decls args =
   626   let
   627     val names = map
   628       (fn ((((NONE, _), t), _), _) => Binding.name_of t | ((((SOME t, _), _), _), _) => t) args;
   629     val specs = map (fn ((((_, vs), t), mx), cons) =>
   630       (vs, t, mx, map (fn ((x, y), z) => (x, y, z)) cons)) args;
   631   in (names, specs) end;
   632 
   633 val parse_datatype_decl =
   634   (Scan.option (P.$$$ "(" |-- P.name --| P.$$$ ")") -- P.type_args -- P.binding -- P.opt_infix --
   635     (P.$$$ "=" |-- P.enum1 "|" (P.binding -- Scan.repeat P.typ -- P.opt_mixfix)));
   636 
   637 val parse_datatype_decls = P.and_list1 parse_datatype_decl >> prep_datatype_decls;
   638 
   639 in
   640 
   641 val _ =
   642   OuterSyntax.command "datatype" "define inductive datatypes" K.thy_decl
   643     (parse_datatype_decls >> (fn (names, specs) => Toplevel.theory (datatype_cmd names specs)));
   644 
   645 val _ =
   646   OuterSyntax.command "rep_datatype" "represent existing types inductively" K.thy_goal
   647     (Scan.option (P.$$$ "(" |-- Scan.repeat1 P.name --| P.$$$ ")") -- Scan.repeat1 P.term
   648       >> (fn (alt_names, ts) => Toplevel.print
   649            o Toplevel.theory_to_proof (rep_datatype_cmd alt_names ts)));
   650 
   651 end;
   652 
   653 
   654 (* document antiquotation *)
   655 
   656 val _ = ThyOutput.antiquotation "datatype" Args.tyname
   657   (fn {source = src, context = ctxt, ...} => fn dtco =>
   658     let
   659       val thy = ProofContext.theory_of ctxt;
   660       val (vs, cos) = the_datatype_spec thy dtco;
   661       val ty = Type (dtco, map TFree vs);
   662       fun pretty_typ_bracket (ty as Type (_, _ :: _)) =
   663             Pretty.enclose "(" ")" [Syntax.pretty_typ ctxt ty]
   664         | pretty_typ_bracket ty =
   665             Syntax.pretty_typ ctxt ty;
   666       fun pretty_constr (co, tys) =
   667         (Pretty.block o Pretty.breaks)
   668           (Syntax.pretty_term ctxt (Const (co, tys ---> ty)) ::
   669             map pretty_typ_bracket tys);
   670       val pretty_datatype =
   671         Pretty.block
   672           (Pretty.command "datatype" :: Pretty.brk 1 ::
   673            Syntax.pretty_typ ctxt ty ::
   674            Pretty.str " =" :: Pretty.brk 1 ::
   675            flat (separate [Pretty.brk 1, Pretty.str "| "]
   676              (map (single o pretty_constr) cos)));
   677     in ThyOutput.output (ThyOutput.maybe_pretty_source (K pretty_datatype) src [()]) end);
   678 
   679 end;
   680