src/HOL/Tools/Datatype/datatype_data.ML
author blanchet
Mon May 05 09:30:20 2014 +0200 (2014-05-05)
changeset 56858 0c3d0bc98abe
parent 56375 32e0da92c786
child 57983 6edc3529bb4e
permissions -rw-r--r--
simplify selectors in code views
     1 (*  Title:      HOL/Tools/Datatype/datatype_data.ML
     2     Author:     Stefan Berghofer, TU Muenchen
     3 
     4 Datatype package bookkeeping.
     5 *)
     6 
     7 signature DATATYPE_DATA =
     8 sig
     9   include DATATYPE_COMMON
    10   val get_all : theory -> info Symtab.table
    11   val get_info : theory -> string -> info option
    12   val the_info : theory -> string -> info
    13   val info_of_constr : theory -> string * typ -> info option
    14   val info_of_constr_permissive : theory -> string * typ -> info option
    15   val info_of_case : theory -> string -> info option
    16   val register: (string * info) list -> theory -> theory
    17   val the_spec : theory -> string -> (string * sort) list * (string * typ list) list
    18   val the_descr : theory -> string list ->
    19     descr * (string * sort) list * string list * string *
    20     (string list * string list) * (typ list * typ list)
    21   val all_distincts : theory -> typ list -> thm list list
    22   val get_constrs : theory -> string -> (string * typ) list option
    23   val mk_case_names_induct: descr -> attribute
    24   val mk_case_names_exhausts: descr -> string list -> attribute list
    25   val interpretation : (config -> string list -> theory -> theory) -> theory -> theory
    26   val interpretation_data : config * string list -> theory -> theory
    27   val setup: theory -> theory
    28 end;
    29 
    30 structure Datatype_Data: DATATYPE_DATA =
    31 struct
    32 
    33 (** theory data **)
    34 
    35 (* data management *)
    36 
    37 structure Data = Theory_Data
    38 (
    39   type T =
    40     {types: Datatype_Aux.info Symtab.table,
    41      constrs: (string * Datatype_Aux.info) list Symtab.table,
    42      cases: Datatype_Aux.info Symtab.table};
    43 
    44   val empty =
    45     {types = Symtab.empty, constrs = Symtab.empty, cases = Symtab.empty};
    46   val extend = I;
    47   fun merge
    48     ({types = types1, constrs = constrs1, cases = cases1},
    49      {types = types2, constrs = constrs2, cases = cases2}) : T =
    50     {types = Symtab.merge (K true) (types1, types2),
    51      constrs = Symtab.join (K (AList.merge (op =) (K true))) (constrs1, constrs2),
    52      cases = Symtab.merge (K true) (cases1, cases2)};
    53 );
    54 
    55 val get_all = #types o Data.get;
    56 val get_info = Symtab.lookup o get_all;
    57 
    58 fun the_info thy name =
    59   (case get_info thy name of
    60     SOME info => info
    61   | NONE => error ("Unknown datatype " ^ quote name));
    62 
    63 fun info_of_constr thy (c, T) =
    64   let
    65     val tab = Symtab.lookup_list (#constrs (Data.get thy)) c;
    66   in
    67     (case body_type T of
    68       Type (tyco, _) => AList.lookup (op =) tab tyco
    69     | _ => NONE)
    70   end;
    71 
    72 fun info_of_constr_permissive thy (c, T) =
    73   let
    74     val tab = Symtab.lookup_list (#constrs (Data.get thy)) c;
    75     val hint = (case body_type T of Type (tyco, _) => SOME tyco | _ => NONE);
    76     val default = if null tab then NONE else SOME (snd (List.last tab));
    77     (*conservative wrt. overloaded constructors*)
    78   in
    79     (case hint of
    80       NONE => default
    81     | SOME tyco =>
    82         (case AList.lookup (op =) tab tyco of
    83           NONE => default (*permissive*)
    84         | SOME info => SOME info))
    85   end;
    86 
    87 val info_of_case = Symtab.lookup o #cases o Data.get;
    88 
    89 fun ctrs_of_exhaust exhaust =
    90   Logic.strip_imp_prems (prop_of exhaust) |>
    91   map (head_of o snd o HOLogic.dest_eq o HOLogic.dest_Trueprop o the_single
    92     o Logic.strip_assums_hyp);
    93 
    94 fun case_of_case_rewrite case_rewrite =
    95   head_of (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop (prop_of case_rewrite))));
    96 
    97 fun ctr_sugar_of_info ({exhaust, nchotomy, inject, distinct, case_rewrites, case_cong,
    98     weak_case_cong, split, split_asm, ...} : Datatype_Aux.info) =
    99   {ctrs = ctrs_of_exhaust exhaust,
   100    casex = case_of_case_rewrite (hd case_rewrites),
   101    discs = [],
   102    selss = [],
   103    exhaust = exhaust,
   104    nchotomy = nchotomy,
   105    injects = inject,
   106    distincts = distinct,
   107    case_thms = case_rewrites,
   108    case_cong = case_cong,
   109    weak_case_cong = weak_case_cong,
   110    split = split,
   111    split_asm = split_asm,
   112    disc_defs = [],
   113    disc_thmss = [],
   114    discIs = [],
   115    sel_defs = [],
   116    sel_thmss = [],
   117    disc_excludesss = [],
   118    disc_exhausts = [],
   119    sel_exhausts = [],
   120    collapses = [],
   121    expands = [],
   122    sel_splits = [],
   123    sel_split_asms = [],
   124    case_eq_ifs = []};
   125 
   126 fun register dt_infos =
   127   Data.map (fn {types, constrs, cases} =>
   128     {types = types |> fold Symtab.update dt_infos,
   129      constrs = constrs |> fold (fn (constr, dtname_info) =>
   130          Symtab.map_default (constr, []) (cons dtname_info))
   131        (maps (fn (dtname, info as {descr, index, ...}) =>
   132           map (rpair (dtname, info) o fst) (#3 (the (AList.lookup op = descr index)))) dt_infos),
   133      cases = cases |> fold Symtab.update
   134        (map (fn (_, info as {case_name, ...}) => (case_name, info)) dt_infos)}) #>
   135   fold (fn (key, info) =>
   136     Ctr_Sugar.default_register_ctr_sugar_global key (ctr_sugar_of_info info)) dt_infos;
   137 
   138 
   139 (* complex queries *)
   140 
   141 fun the_spec thy dtco =
   142   let
   143     val {descr, index, ...} = the_info thy dtco;
   144     val (_, dtys, raw_cos) = the (AList.lookup (op =) descr index);
   145     val args = map Datatype_Aux.dest_DtTFree dtys;
   146     val cos = map (fn (co, tys) => (co, map (Datatype_Aux.typ_of_dtyp descr) tys)) raw_cos;
   147   in (args, cos) end;
   148 
   149 fun the_descr thy (raw_tycos as raw_tyco :: _) =
   150   let
   151     val info = the_info thy raw_tyco;
   152     val descr = #descr info;
   153 
   154     val (_, dtys, _) = the (AList.lookup (op =) descr (#index info));
   155     val vs = map Datatype_Aux.dest_DtTFree dtys;
   156 
   157     fun is_DtTFree (Datatype_Aux.DtTFree _) = true
   158       | is_DtTFree _ = false;
   159     val k = find_index (fn (_, (_, dTs, _)) => not (forall is_DtTFree dTs)) descr;
   160     val protoTs as (dataTs, _) =
   161       chop k descr
   162       |> (pairself o map)
   163         (fn (_, (tyco, dTs, _)) => (tyco, map (Datatype_Aux.typ_of_dtyp descr) dTs));
   164 
   165     val tycos = map fst dataTs;
   166     val _ =
   167       if eq_set (op =) (tycos, raw_tycos) then ()
   168       else
   169         error ("Type constructors " ^ commas_quote raw_tycos ^
   170           " do not belong exhaustively to one mutual recursive datatype");
   171 
   172     val (Ts, Us) = (pairself o map) Type protoTs;
   173 
   174     val names = map Long_Name.base_name tycos;
   175     val (auxnames, _) =
   176       Name.make_context names
   177       |> fold_map (Name.variant o Datatype_Aux.name_of_typ) Us;
   178     val prefix = space_implode "_" names;
   179 
   180   in (descr, vs, tycos, prefix, (names, auxnames), (Ts, Us)) end;
   181 
   182 fun all_distincts thy Ts =
   183   let
   184     fun add_tycos (Type (tyco, Ts)) = insert (op =) tyco #> fold add_tycos Ts
   185       | add_tycos _ = I;
   186     val tycos = fold add_tycos Ts [];
   187   in map_filter (Option.map #distinct o get_info thy) tycos end;
   188 
   189 fun get_constrs thy dtco =
   190   (case try (the_spec thy) dtco of
   191     SOME (args, cos) =>
   192       let
   193         fun subst (v, sort) = TVar ((v, 0), sort);
   194         fun subst_ty (TFree v) = subst v
   195           | subst_ty ty = ty;
   196         val dty = Type (dtco, map subst args);
   197         fun mk_co (co, tys) = (co, map (Term.map_atyps subst_ty) tys ---> dty);
   198       in SOME (map mk_co cos) end
   199   | NONE => NONE);
   200 
   201 
   202 
   203 (** various auxiliary **)
   204 
   205 (* case names *)
   206 
   207 local
   208 
   209 fun dt_recs (Datatype_Aux.DtTFree _) = []
   210   | dt_recs (Datatype_Aux.DtType (_, dts)) = maps dt_recs dts
   211   | dt_recs (Datatype_Aux.DtRec i) = [i];
   212 
   213 fun dt_cases (descr: Datatype_Aux.descr) (_, args, constrs) =
   214   let
   215     fun the_bname i = Long_Name.base_name (#1 (the (AList.lookup (op =) descr i)));
   216     val bnames = map the_bname (distinct (op =) (maps dt_recs args));
   217   in map (fn (c, _) => space_implode "_" (Long_Name.base_name c :: bnames)) constrs end;
   218 
   219 fun induct_cases descr =
   220   Datatype_Prop.indexify_names (maps (dt_cases descr) (map #2 descr));
   221 
   222 fun exhaust_cases descr i = dt_cases descr (the (AList.lookup (op =) descr i));
   223 
   224 in
   225 
   226 fun mk_case_names_induct descr = Rule_Cases.case_names (induct_cases descr);
   227 
   228 fun mk_case_names_exhausts descr new =
   229   map (Rule_Cases.case_names o exhaust_cases descr o #1)
   230     (filter (fn ((_, (name, _, _))) => member (op =) new name) descr);
   231 
   232 end;
   233 
   234 
   235 
   236 (** document antiquotation **)
   237 
   238 val antiq_setup =
   239   Thy_Output.antiquotation @{binding datatype} (Args.type_name {proper = true, strict = true})
   240     (fn {source = src, context = ctxt, ...} => fn dtco =>
   241       let
   242         val thy = Proof_Context.theory_of ctxt;
   243         val (vs, cos) = the_spec thy dtco;
   244         val ty = Type (dtco, map TFree vs);
   245         val pretty_typ_bracket = Syntax.pretty_typ (Config.put pretty_priority 1001 ctxt);
   246         fun pretty_constr (co, tys) =
   247           Pretty.block (Pretty.breaks
   248             (Syntax.pretty_term ctxt (Const (co, tys ---> ty)) ::
   249               map pretty_typ_bracket tys));
   250         val pretty_datatype =
   251           Pretty.block
   252            (Pretty.keyword1 "datatype" :: Pretty.brk 1 ::
   253             Syntax.pretty_typ ctxt ty ::
   254             Pretty.str " =" :: Pretty.brk 1 ::
   255             flat (separate [Pretty.brk 1, Pretty.str "| "] (map (single o pretty_constr) cos)));
   256       in
   257         Thy_Output.output ctxt
   258           (Thy_Output.maybe_pretty_source (K (K pretty_datatype)) ctxt src [()])
   259       end);
   260 
   261 
   262 
   263 (** abstract theory extensions relative to a datatype characterisation **)
   264 
   265 structure Datatype_Interpretation = Interpretation
   266 (
   267   type T = Datatype_Aux.config * string list;
   268   val eq: T * T -> bool = eq_snd (op =);
   269 );
   270 
   271 fun with_repaired_path f config (type_names as name :: _) thy =
   272   thy
   273   |> Sign.root_path
   274   |> Sign.add_path (Long_Name.qualifier name)
   275   |> f config type_names
   276   |> Sign.restore_naming thy;
   277 
   278 fun interpretation f = Datatype_Interpretation.interpretation (uncurry (with_repaired_path f));
   279 val interpretation_data = Datatype_Interpretation.data;
   280 
   281 
   282 
   283 (** setup theory **)
   284 
   285 val setup =
   286   antiq_setup #>
   287   Datatype_Interpretation.init;
   288 
   289 open Datatype_Aux;
   290 
   291 end;