src/HOL/Tools/datatype_package/datatype_package.ML
changeset 31668 a616e56a5ec8
parent 31604 eb2f9d709296
child 31689 84a14d2dc868
equal deleted inserted replaced
31667:cc969090c204 31668:a616e56a5ec8
     4 Datatype package for Isabelle/HOL.
     4 Datatype package for Isabelle/HOL.
     5 *)
     5 *)
     6 
     6 
     7 signature DATATYPE_PACKAGE =
     7 signature DATATYPE_PACKAGE =
     8 sig
     8 sig
       
     9   type datatype_config = DatatypeAux.datatype_config
     9   type datatype_info = DatatypeAux.datatype_info
    10   type datatype_info = DatatypeAux.datatype_info
    10   type descr = DatatypeAux.descr
    11   type descr = DatatypeAux.descr
    11   val get_datatypes : theory -> datatype_info Symtab.table
    12   val get_datatypes : theory -> datatype_info Symtab.table
    12   val get_datatype : theory -> string -> datatype_info option
    13   val get_datatype : theory -> string -> datatype_info option
    13   val the_datatype : theory -> string -> datatype_info
    14   val the_datatype : theory -> string -> datatype_info
    22   val make_case :  Proof.context -> bool -> string list -> term ->
    23   val make_case :  Proof.context -> bool -> string list -> term ->
    23     (term * term) list -> term * (term * (int * bool)) list
    24     (term * term) list -> term * (term * (int * bool)) list
    24   val strip_case : Proof.context -> bool -> term -> (term * (term * term) list) option
    25   val strip_case : Proof.context -> bool -> term -> (term * (term * term) list) option
    25   val read_typ: theory ->
    26   val read_typ: theory ->
    26     (typ list * (string * sort) list) * string -> typ list * (string * sort) list
    27     (typ list * (string * sort) list) * string -> typ list * (string * sort) list
    27   val interpretation : (string list -> theory -> theory) -> theory -> theory
    28   val interpretation : (datatype_config -> string list -> theory -> theory) -> theory -> theory
    28   val rep_datatype : ({distinct : thm list list,
    29   type rules = {distinct : thm list list,
    29        inject : thm list list,
    30     inject : thm list list,
    30        exhaustion : thm list,
    31     exhaustion : thm list,
    31        rec_thms : thm list,
    32     rec_thms : thm list,
    32        case_thms : thm list list,
    33     case_thms : thm list list,
    33        split_thms : (thm * thm) list,
    34     split_thms : (thm * thm) list,
    34        induction : thm,
    35     induction : thm,
    35        simps : thm list} -> Proof.context -> Proof.context) -> string list option -> term list
    36     simps : thm list}
    36     -> theory -> Proof.state;
    37   val rep_datatype : datatype_config -> (rules -> Proof.context -> Proof.context)
    37   val rep_datatype_cmd : string list option -> string list -> theory -> Proof.state;
    38     -> string list option -> term list -> theory -> Proof.state;
    38   val add_datatype : bool -> bool -> string list -> (string list * binding * mixfix *
    39   val rep_datatype_cmd : string list option -> string list -> theory -> Proof.state
    39     (binding * typ list * mixfix) list) list -> theory ->
    40   val add_datatype : datatype_config -> string list -> (string list * binding * mixfix *
    40       {distinct : thm list list,
    41     (binding * typ list * mixfix) list) list -> theory -> rules * theory
    41        inject : thm list list,
    42   val add_datatype_cmd : string list -> (string list * binding * mixfix *
    42        exhaustion : thm list,
    43     (binding * string list * mixfix) list) list -> theory -> rules * theory
    43        rec_thms : thm list,
       
    44        case_thms : thm list list,
       
    45        split_thms : (thm * thm) list,
       
    46        induction : thm,
       
    47        simps : thm list} * theory
       
    48   val add_datatype_cmd : bool -> string list -> (string list * binding * mixfix *
       
    49     (binding * string list * mixfix) list) list -> theory ->
       
    50       {distinct : thm list list,
       
    51        inject : thm list list,
       
    52        exhaustion : thm list,
       
    53        rec_thms : thm list,
       
    54        case_thms : thm list list,
       
    55        split_thms : (thm * thm) list,
       
    56        induction : thm,
       
    57        simps : thm list} * theory
       
    58   val setup: theory -> theory
    44   val setup: theory -> theory
    59   val quiet_mode : bool ref
       
    60   val print_datatypes : theory -> unit
    45   val print_datatypes : theory -> unit
    61 end;
    46 end;
    62 
    47 
    63 structure DatatypePackage : DATATYPE_PACKAGE =
    48 structure DatatypePackage : DATATYPE_PACKAGE =
    64 struct
    49 struct
    65 
    50 
    66 open DatatypeAux;
    51 open DatatypeAux;
    67 
       
    68 val quiet_mode = quiet_mode;
       
    69 
    52 
    70 
    53 
    71 (* theory data *)
    54 (* theory data *)
    72 
    55 
    73 structure DatatypesData = TheoryDataFun
    56 structure DatatypesData = TheoryDataFun
   356     inject = inject,
   339     inject = inject,
   357     nchotomy = nchotomy,
   340     nchotomy = nchotomy,
   358     case_cong = case_cong,
   341     case_cong = case_cong,
   359     weak_case_cong = weak_case_cong});
   342     weak_case_cong = weak_case_cong});
   360 
   343 
   361 structure DatatypeInterpretation = InterpretationFun(type T = string list val eq = op =);
   344 type rules = {distinct : thm list list,
   362 val interpretation = DatatypeInterpretation.interpretation;
   345   inject : thm list list,
       
   346   exhaustion : thm list,
       
   347   rec_thms : thm list,
       
   348   case_thms : thm list list,
       
   349   split_thms : (thm * thm) list,
       
   350   induction : thm,
       
   351   simps : thm list}
       
   352 
       
   353 structure DatatypeInterpretation = InterpretationFun
       
   354   (type T = datatype_config * string list val eq = eq_snd op =);
       
   355 fun interpretation f = DatatypeInterpretation.interpretation (uncurry f);
   363 
   356 
   364 
   357 
   365 (******************* definitional introduction of datatypes *******************)
   358 (******************* definitional introduction of datatypes *******************)
   366 
   359 
   367 fun add_datatype_def flat_names new_type_names descr sorts types_syntax constr_syntax dt_info
   360 fun add_datatype_def (config : datatype_config) new_type_names descr sorts types_syntax constr_syntax dt_info
   368     case_names_induct case_names_exhausts thy =
   361     case_names_induct case_names_exhausts thy =
   369   let
   362   let
   370     val _ = message ("Proofs for datatype(s) " ^ commas_quote new_type_names);
   363     val _ = message config ("Proofs for datatype(s) " ^ commas_quote new_type_names);
   371 
   364 
   372     val ((inject, distinct, dist_rewrites, simproc_dists, induct), thy2) = thy |>
   365     val ((inject, distinct, dist_rewrites, simproc_dists, induct), thy2) = thy |>
   373       DatatypeRepProofs.representation_proofs flat_names dt_info new_type_names descr sorts
   366       DatatypeRepProofs.representation_proofs config dt_info new_type_names descr sorts
   374         types_syntax constr_syntax case_names_induct;
   367         types_syntax constr_syntax case_names_induct;
   375 
   368 
   376     val (casedist_thms, thy3) = DatatypeAbsProofs.prove_casedist_thms new_type_names descr
   369     val (casedist_thms, thy3) = DatatypeAbsProofs.prove_casedist_thms config new_type_names descr
   377       sorts induct case_names_exhausts thy2;
   370       sorts induct case_names_exhausts thy2;
   378     val ((reccomb_names, rec_thms), thy4) = DatatypeAbsProofs.prove_primrec_thms
   371     val ((reccomb_names, rec_thms), thy4) = DatatypeAbsProofs.prove_primrec_thms
   379       flat_names new_type_names descr sorts dt_info inject dist_rewrites
   372       config new_type_names descr sorts dt_info inject dist_rewrites
   380       (Simplifier.theory_context thy3 dist_ss) induct thy3;
   373       (Simplifier.theory_context thy3 dist_ss) induct thy3;
   381     val ((case_thms, case_names), thy6) = DatatypeAbsProofs.prove_case_thms
   374     val ((case_thms, case_names), thy6) = DatatypeAbsProofs.prove_case_thms
   382       flat_names new_type_names descr sorts reccomb_names rec_thms thy4;
   375       config new_type_names descr sorts reccomb_names rec_thms thy4;
   383     val (split_thms, thy7) = DatatypeAbsProofs.prove_split_thms new_type_names
   376     val (split_thms, thy7) = DatatypeAbsProofs.prove_split_thms config new_type_names
   384       descr sorts inject dist_rewrites casedist_thms case_thms thy6;
   377       descr sorts inject dist_rewrites casedist_thms case_thms thy6;
   385     val (nchotomys, thy8) = DatatypeAbsProofs.prove_nchotomys new_type_names
   378     val (nchotomys, thy8) = DatatypeAbsProofs.prove_nchotomys config new_type_names
   386       descr sorts casedist_thms thy7;
   379       descr sorts casedist_thms thy7;
   387     val (case_congs, thy9) = DatatypeAbsProofs.prove_case_congs new_type_names
   380     val (case_congs, thy9) = DatatypeAbsProofs.prove_case_congs new_type_names
   388       descr sorts nchotomys case_thms thy8;
   381       descr sorts nchotomys case_thms thy8;
   389     val (weak_case_congs, thy10) = DatatypeAbsProofs.prove_weak_case_congs new_type_names
   382     val (weak_case_congs, thy10) = DatatypeAbsProofs.prove_weak_case_congs new_type_names
   390       descr sorts thy9;
   383       descr sorts thy9;
   404           weak_case_congs (Simplifier.attrib (op addcongs))
   397           weak_case_congs (Simplifier.attrib (op addcongs))
   405       |> put_dt_infos dt_infos
   398       |> put_dt_infos dt_infos
   406       |> add_cases_induct dt_infos induct
   399       |> add_cases_induct dt_infos induct
   407       |> Sign.parent_path
   400       |> Sign.parent_path
   408       |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms) |> snd
   401       |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms) |> snd
   409       |> DatatypeInterpretation.data (map fst dt_infos);
   402       |> DatatypeInterpretation.data (config, map fst dt_infos);
   410   in
   403   in
   411     ({distinct = distinct,
   404     ({distinct = distinct,
   412       inject = inject,
   405       inject = inject,
   413       exhaustion = casedist_thms,
   406       exhaustion = casedist_thms,
   414       rec_thms = rec_thms,
   407       rec_thms = rec_thms,
   419   end;
   412   end;
   420 
   413 
   421 
   414 
   422 (*********************** declare existing type as datatype *********************)
   415 (*********************** declare existing type as datatype *********************)
   423 
   416 
   424 fun prove_rep_datatype alt_names new_type_names descr sorts induct inject half_distinct thy =
   417 fun prove_rep_datatype (config : datatype_config) alt_names new_type_names descr sorts induct inject half_distinct thy =
   425   let
   418   let
   426     val ((_, [induct']), _) =
   419     val ((_, [induct']), _) =
   427       Variable.importT_thms [induct] (Variable.thm_context induct);
   420       Variable.importT_thms [induct] (Variable.thm_context induct);
   428 
   421 
   429     fun err t = error ("Ill-formed predicate in induction rule: " ^
   422     fun err t = error ("Ill-formed predicate in induction rule: " ^
   438 
   431 
   439     val distinct = (map o maps) (fn thm => [thm, thm RS not_sym]) half_distinct;
   432     val distinct = (map o maps) (fn thm => [thm, thm RS not_sym]) half_distinct;
   440     val (case_names_induct, case_names_exhausts) =
   433     val (case_names_induct, case_names_exhausts) =
   441       (mk_case_names_induct descr, mk_case_names_exhausts descr (map #1 dtnames));
   434       (mk_case_names_induct descr, mk_case_names_exhausts descr (map #1 dtnames));
   442 
   435 
   443     val _ = message ("Proofs for datatype(s) " ^ commas_quote new_type_names);
   436     val _ = message config ("Proofs for datatype(s) " ^ commas_quote new_type_names);
   444 
   437 
   445     val (casedist_thms, thy2) = thy |>
   438     val (casedist_thms, thy2) = thy |>
   446       DatatypeAbsProofs.prove_casedist_thms new_type_names [descr] sorts induct
   439       DatatypeAbsProofs.prove_casedist_thms config new_type_names [descr] sorts induct
   447         case_names_exhausts;
   440         case_names_exhausts;
   448     val ((reccomb_names, rec_thms), thy3) = DatatypeAbsProofs.prove_primrec_thms
   441     val ((reccomb_names, rec_thms), thy3) = DatatypeAbsProofs.prove_primrec_thms
   449       false new_type_names [descr] sorts dt_info inject distinct
   442       config new_type_names [descr] sorts dt_info inject distinct
   450       (Simplifier.theory_context thy2 dist_ss) induct thy2;
   443       (Simplifier.theory_context thy2 dist_ss) induct thy2;
   451     val ((case_thms, case_names), thy4) = DatatypeAbsProofs.prove_case_thms false
   444     val ((case_thms, case_names), thy4) = DatatypeAbsProofs.prove_case_thms
   452       new_type_names [descr] sorts reccomb_names rec_thms thy3;
   445       config new_type_names [descr] sorts reccomb_names rec_thms thy3;
   453     val (split_thms, thy5) = DatatypeAbsProofs.prove_split_thms
   446     val (split_thms, thy5) = DatatypeAbsProofs.prove_split_thms
   454       new_type_names [descr] sorts inject distinct casedist_thms case_thms thy4;
   447       config new_type_names [descr] sorts inject distinct casedist_thms case_thms thy4;
   455     val (nchotomys, thy6) = DatatypeAbsProofs.prove_nchotomys new_type_names
   448     val (nchotomys, thy6) = DatatypeAbsProofs.prove_nchotomys config new_type_names
   456       [descr] sorts casedist_thms thy5;
   449       [descr] sorts casedist_thms thy5;
   457     val (case_congs, thy7) = DatatypeAbsProofs.prove_case_congs new_type_names
   450     val (case_congs, thy7) = DatatypeAbsProofs.prove_case_congs new_type_names
   458       [descr] sorts nchotomys case_thms thy6;
   451       [descr] sorts nchotomys case_thms thy6;
   459     val (weak_case_congs, thy8) = DatatypeAbsProofs.prove_weak_case_congs new_type_names
   452     val (weak_case_congs, thy8) = DatatypeAbsProofs.prove_weak_case_congs new_type_names
   460       [descr] sorts thy7;
   453       [descr] sorts thy7;
   480       |> put_dt_infos dt_infos
   473       |> put_dt_infos dt_infos
   481       |> add_cases_induct dt_infos induct'
   474       |> add_cases_induct dt_infos induct'
   482       |> Sign.parent_path
   475       |> Sign.parent_path
   483       |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms)
   476       |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms)
   484       |> snd
   477       |> snd
   485       |> DatatypeInterpretation.data (map fst dt_infos);
   478       |> DatatypeInterpretation.data (config, map fst dt_infos);
   486   in
   479   in
   487     ({distinct = distinct,
   480     ({distinct = distinct,
   488       inject = inject,
   481       inject = inject,
   489       exhaustion = casedist_thms,
   482       exhaustion = casedist_thms,
   490       rec_thms = rec_thms,
   483       rec_thms = rec_thms,
   492       split_thms = split_thms,
   485       split_thms = split_thms,
   493       induction = induct',
   486       induction = induct',
   494       simps = simps}, thy11)
   487       simps = simps}, thy11)
   495   end;
   488   end;
   496 
   489 
   497 fun gen_rep_datatype prep_term after_qed alt_names raw_ts thy =
   490 fun gen_rep_datatype prep_term (config : datatype_config) after_qed alt_names raw_ts thy =
   498   let
   491   let
   499     fun constr_of_term (Const (c, T)) = (c, T)
   492     fun constr_of_term (Const (c, T)) = (c, T)
   500       | constr_of_term t =
   493       | constr_of_term t =
   501           error ("Not a constant: " ^ Syntax.string_of_term_global thy t);
   494           error ("Not a constant: " ^ Syntax.string_of_term_global thy t);
   502     fun no_constr (c, T) = error ("Bad constructor: "
   495     fun no_constr (c, T) = error ("Bad constructor: "
   548         val [[[induct]], injs, half_distincts] =
   541         val [[[induct]], injs, half_distincts] =
   549           unflat rules (map Drule.zero_var_indexes_list raw_thms);
   542           unflat rules (map Drule.zero_var_indexes_list raw_thms);
   550             (*FIXME somehow dubious*)
   543             (*FIXME somehow dubious*)
   551       in
   544       in
   552         ProofContext.theory_result
   545         ProofContext.theory_result
   553           (prove_rep_datatype alt_names new_type_names descr vs induct injs half_distincts)
   546           (prove_rep_datatype config alt_names new_type_names descr vs induct injs half_distincts)
   554         #-> after_qed
   547         #-> after_qed
   555       end;
   548       end;
   556   in
   549   in
   557     thy
   550     thy
   558     |> ProofContext.init
   551     |> ProofContext.init
   559     |> Proof.theorem_i NONE after_qed' ((map o map) (rpair []) (flat rules))
   552     |> Proof.theorem_i NONE after_qed' ((map o map) (rpair []) (flat rules))
   560   end;
   553   end;
   561 
   554 
   562 val rep_datatype = gen_rep_datatype Sign.cert_term;
   555 val rep_datatype = gen_rep_datatype Sign.cert_term;
   563 val rep_datatype_cmd = gen_rep_datatype Syntax.read_term_global (K I);
   556 val rep_datatype_cmd = gen_rep_datatype Syntax.read_term_global default_datatype_config (K I);
   564 
   557 
   565 
   558 
   566 
   559 
   567 (******************************** add datatype ********************************)
   560 (******************************** add datatype ********************************)
   568 
   561 
   569 fun gen_add_datatype prep_typ err flat_names new_type_names dts thy =
   562 fun gen_add_datatype prep_typ (config : datatype_config) new_type_names dts thy =
   570   let
   563   let
   571     val _ = Theory.requires thy "Datatype" "datatype definitions";
   564     val _ = Theory.requires thy "Datatype" "datatype definitions";
   572 
   565 
   573     (* this theory is used just for parsing *)
   566     (* this theory is used just for parsing *)
   574 
   567 
   596           let
   589           let
   597             val (cargs', sorts'') = Library.foldl (prep_typ tmp_thy) (([], sorts'), cargs);
   590             val (cargs', sorts'') = Library.foldl (prep_typ tmp_thy) (([], sorts'), cargs);
   598             val _ = (case fold (curry OldTerm.add_typ_tfree_names) cargs' [] \\ tvs of
   591             val _ = (case fold (curry OldTerm.add_typ_tfree_names) cargs' [] \\ tvs of
   599                 [] => ()
   592                 [] => ()
   600               | vs => error ("Extra type variables on rhs: " ^ commas vs))
   593               | vs => error ("Extra type variables on rhs: " ^ commas vs))
   601           in (constrs @ [((if flat_names then Sign.full_name tmp_thy else
   594           in (constrs @ [((if #flat_names config then Sign.full_name tmp_thy else
   602                 Sign.full_name_path tmp_thy tname')
   595                 Sign.full_name_path tmp_thy tname')
   603                   (Binding.map_name (Syntax.const_name mx') cname),
   596                   (Binding.map_name (Syntax.const_name mx') cname),
   604                    map (dtyp_of_typ new_dts) cargs')],
   597                    map (dtyp_of_typ new_dts) cargs')],
   605               constr_syntax' @ [(cname, mx')], sorts'')
   598               constr_syntax' @ [(cname, mx')], sorts'')
   606           end handle ERROR msg => cat_error msg
   599           end handle ERROR msg => cat_error msg
   624       fold prep_dt_spec (dts ~~ new_type_names) ([], [], [], 0);
   617       fold prep_dt_spec (dts ~~ new_type_names) ([], [], [], 0);
   625     val sorts = sorts' @ (map (rpair (Sign.defaultS tmp_thy)) (tyvars \\ map fst sorts'));
   618     val sorts = sorts' @ (map (rpair (Sign.defaultS tmp_thy)) (tyvars \\ map fst sorts'));
   626     val dt_info = get_datatypes thy;
   619     val dt_info = get_datatypes thy;
   627     val (descr, _) = unfold_datatypes tmp_thy dts' sorts dt_info dts' i;
   620     val (descr, _) = unfold_datatypes tmp_thy dts' sorts dt_info dts' i;
   628     val _ = check_nonempty descr handle (exn as Datatype_Empty s) =>
   621     val _ = check_nonempty descr handle (exn as Datatype_Empty s) =>
   629       if err then error ("Nonemptiness check failed for datatype " ^ s)
   622       if #strict config then error ("Nonemptiness check failed for datatype " ^ s)
   630       else raise exn;
   623       else raise exn;
   631 
   624 
   632     val descr' = flat descr;
   625     val descr' = flat descr;
   633     val case_names_induct = mk_case_names_induct descr';
   626     val case_names_induct = mk_case_names_induct descr';
   634     val case_names_exhausts = mk_case_names_exhausts descr' (map #1 new_dts);
   627     val case_names_exhausts = mk_case_names_exhausts descr' (map #1 new_dts);
   635   in
   628   in
   636     add_datatype_def
   629     add_datatype_def
   637       flat_names new_type_names descr sorts types_syntax constr_syntax dt_info
   630       (config : datatype_config) new_type_names descr sorts types_syntax constr_syntax dt_info
   638       case_names_induct case_names_exhausts thy
   631       case_names_induct case_names_exhausts thy
   639   end;
   632   end;
   640 
   633 
   641 val add_datatype = gen_add_datatype cert_typ;
   634 val add_datatype = gen_add_datatype cert_typ;
   642 val add_datatype_cmd = gen_add_datatype read_typ true;
   635 val add_datatype_cmd = gen_add_datatype read_typ default_datatype_config;
   643 
   636 
   644 
   637 
   645 
   638 
   646 (** package setup **)
   639 (** package setup **)
   647 
   640 
   666   let
   659   let
   667     val names = map
   660     val names = map
   668       (fn ((((NONE, _), t), _), _) => Binding.name_of t | ((((SOME t, _), _), _), _) => t) args;
   661       (fn ((((NONE, _), t), _), _) => Binding.name_of t | ((((SOME t, _), _), _), _) => t) args;
   669     val specs = map (fn ((((_, vs), t), mx), cons) =>
   662     val specs = map (fn ((((_, vs), t), mx), cons) =>
   670       (vs, t, mx, map (fn ((x, y), z) => (x, y, z)) cons)) args;
   663       (vs, t, mx, map (fn ((x, y), z) => (x, y, z)) cons)) args;
   671   in snd o add_datatype_cmd false names specs end;
   664   in snd o add_datatype_cmd names specs end;
   672 
   665 
   673 val _ =
   666 val _ =
   674   OuterSyntax.command "datatype" "define inductive datatypes" K.thy_decl
   667   OuterSyntax.command "datatype" "define inductive datatypes" K.thy_decl
   675     (P.and_list1 datatype_decl >> (Toplevel.theory o mk_datatype));
   668     (P.and_list1 datatype_decl >> (Toplevel.theory o mk_datatype));
   676 
   669