src/HOL/Tools/Datatype/datatype_rep_proofs.ML
changeset 33959 2afc55e8ed27
parent 33957 e9afca2118d4
child 33963 977b94b64905
equal deleted inserted replaced
33958:a57f4c9d0a19 33959:2afc55e8ed27
     1 (*  Title:      HOL/Tools/datatype_rep_proofs.ML
     1 (*  Title:      HOL/Tools/datatype_rep_proofs.ML
     2     Author:     Stefan Berghofer, TU Muenchen
     2     Author:     Stefan Berghofer, TU Muenchen
     3 
     3 
     4 Definitional introduction of datatypes
     4 Definitional introduction of datatypes with proof of characteristic theorems:
     5 Proof of characteristic theorems:
       
     6 
     5 
     7  - injectivity of constructors
     6  - injectivity of constructors
     8  - distinctness of constructors
     7  - distinctness of constructors
     9  - induction theorem
     8  - induction theorem
    10 *)
     9 *)
    11 
    10 
    12 signature DATATYPE_REP_PROOFS =
    11 signature DATATYPE_REP_PROOFS =
    13 sig
    12 sig
    14   include DATATYPE_COMMON
    13   include DATATYPE_COMMON
    15   val representation_proofs : config -> info Symtab.table ->
    14   val add_datatype : config -> string list -> (string list * binding * mixfix *
    16     string list -> descr list -> (string * sort) list ->
    15     (binding * typ list * mixfix) list) list -> theory -> string list * theory
    17       (binding * mixfix) list -> (binding * mixfix) list list -> attribute
    16   val datatype_cmd : string list -> (string list * binding * mixfix *
    18         -> theory -> (thm list list * thm list list * thm) * theory
    17     (binding * string list * mixfix) list) list -> theory -> theory
    19 end;
    18 end;
    20 
    19 
    21 structure DatatypeRepProofs : DATATYPE_REP_PROOFS =
    20 structure DatatypeRepProofs : DATATYPE_REP_PROOFS =
    22 struct
    21 struct
    23 
    22 
       
    23 (** auxiliary **)
       
    24 
    24 open DatatypeAux;
    25 open DatatypeAux;
    25 
    26 
    26 val (_ $ (_ $ (_ $ (distinct_f $ _) $ _))) = hd (prems_of distinct_lemma);
    27 val (_ $ (_ $ (_ $ (distinct_f $ _) $ _))) = hd (prems_of distinct_lemma);
    27 
    28 
    28 val collect_simp = rewrite_rule [mk_meta_eq mem_Collect_eq];
    29 val collect_simp = rewrite_rule [mk_meta_eq mem_Collect_eq];
    29 
       
    30 
       
    31 (** theory context references **)
       
    32 
    30 
    33 fun exh_thm_of (dt_info : info Symtab.table) tname =
    31 fun exh_thm_of (dt_info : info Symtab.table) tname =
    34   #exhaust (the (Symtab.lookup dt_info tname));
    32   #exhaust (the (Symtab.lookup dt_info tname));
    35 
    33 
    36 (******************************************************************************)
    34 val node_name = @{type_name "Datatype.node"};
       
    35 val In0_name = @{const_name "Datatype.In0"};
       
    36 val In1_name = @{const_name "Datatype.In1"};
       
    37 val Scons_name = @{const_name "Datatype.Scons"};
       
    38 val Leaf_name = @{const_name "Datatype.Leaf"};
       
    39 val Numb_name = @{const_name "Datatype.Numb"};
       
    40 val Lim_name = @{const_name "Datatype.Lim"};
       
    41 val Suml_name = @{const_name "Sum_Type.Suml"};
       
    42 val Sumr_name = @{const_name "Sum_Type.Sumr"};
       
    43 
       
    44 val In0_inject = @{thm In0_inject};
       
    45 val In1_inject = @{thm In1_inject};
       
    46 val Scons_inject = @{thm Scons_inject};
       
    47 val Leaf_inject = @{thm Leaf_inject};
       
    48 val In0_eq = @{thm In0_eq};
       
    49 val In1_eq = @{thm In1_eq};
       
    50 val In0_not_In1 = @{thm In0_not_In1};
       
    51 val In1_not_In0 = @{thm In1_not_In0};
       
    52 val Lim_inject = @{thm Lim_inject};
       
    53 val Suml_inject = @{thm Suml_inject};
       
    54 val Sumr_inject = @{thm Sumr_inject};
       
    55 
       
    56 
       
    57 
       
    58 (** proof of characteristic theorems **)
    37 
    59 
    38 fun representation_proofs (config : config) (dt_info : info Symtab.table)
    60 fun representation_proofs (config : config) (dt_info : info Symtab.table)
    39       new_type_names descr sorts types_syntax constr_syntax case_names_induct thy =
    61       new_type_names descr sorts types_syntax constr_syntax case_names_induct thy =
    40   let
    62   let
    41     val Datatype_thy = ThyInfo.the_theory "Datatype" thy;
       
    42     val node_name = "Datatype.node";
       
    43     val In0_name = "Datatype.In0";
       
    44     val In1_name = "Datatype.In1";
       
    45     val Scons_name = "Datatype.Scons";
       
    46     val Leaf_name = "Datatype.Leaf";
       
    47     val Numb_name = "Datatype.Numb";
       
    48     val Lim_name = "Datatype.Lim";
       
    49     val Suml_name = "Datatype.Suml";
       
    50     val Sumr_name = "Datatype.Sumr";
       
    51 
       
    52     val [In0_inject, In1_inject, Scons_inject, Leaf_inject,
       
    53          In0_eq, In1_eq, In0_not_In1, In1_not_In0,
       
    54          Lim_inject, Suml_inject, Sumr_inject] = map (PureThy.get_thm Datatype_thy)
       
    55           ["In0_inject", "In1_inject", "Scons_inject", "Leaf_inject",
       
    56            "In0_eq", "In1_eq", "In0_not_In1", "In1_not_In0",
       
    57            "Lim_inject", "Suml_inject", "Sumr_inject"];
       
    58 
       
    59     val descr' = flat descr;
    63     val descr' = flat descr;
    60 
       
    61     val big_name = space_implode "_" new_type_names;
    64     val big_name = space_implode "_" new_type_names;
    62     val thy1 = Sign.add_path big_name thy;
    65     val thy1 = Sign.add_path big_name thy;
    63     val big_rec_name = big_name ^ "_rep_set";
    66     val big_rec_name = big_name ^ "_rep_set";
    64     val rep_set_names' =
    67     val rep_set_names' =
    65       (if length descr' = 1 then [big_rec_name] else
    68       (if length descr' = 1 then [big_rec_name] else
    81     val sumT = if null leafTs then HOLogic.unitT
    84     val sumT = if null leafTs then HOLogic.unitT
    82       else Balanced_Tree.make (fn (T, U) => Type ("+", [T, U])) leafTs;
    85       else Balanced_Tree.make (fn (T, U) => Type ("+", [T, U])) leafTs;
    83     val Univ_elT = HOLogic.mk_setT (Type (node_name, [sumT, branchT]));
    86     val Univ_elT = HOLogic.mk_setT (Type (node_name, [sumT, branchT]));
    84     val UnivT = HOLogic.mk_setT Univ_elT;
    87     val UnivT = HOLogic.mk_setT Univ_elT;
    85     val UnivT' = Univ_elT --> HOLogic.boolT;
    88     val UnivT' = Univ_elT --> HOLogic.boolT;
    86     val Collect = Const ("Collect", UnivT' --> UnivT);
    89     val Collect = Const (@{const_name Collect}, UnivT' --> UnivT);
    87 
    90 
    88     val In0 = Const (In0_name, Univ_elT --> Univ_elT);
    91     val In0 = Const (In0_name, Univ_elT --> Univ_elT);
    89     val In1 = Const (In1_name, Univ_elT --> Univ_elT);
    92     val In1 = Const (In1_name, Univ_elT --> Univ_elT);
    90     val Leaf = Const (Leaf_name, sumT --> Univ_elT);
    93     val Leaf = Const (Leaf_name, sumT --> Univ_elT);
    91     val Lim = Const (Lim_name, (branchT --> Univ_elT) --> Univ_elT);
    94     val Lim = Const (Lim_name, (branchT --> Univ_elT) --> Univ_elT);
    98           if n = 1 then x else
   101           if n = 1 then x else
    99           let val n2 = n div 2;
   102           let val n2 = n div 2;
   100               val Type (_, [T1, T2]) = T
   103               val Type (_, [T1, T2]) = T
   101           in
   104           in
   102             if i <= n2 then
   105             if i <= n2 then
   103               Const ("Sum_Type.Inl", T1 --> T) $ (mk_inj' T1 n2 i)
   106               Const (@{const_name "Sum_Type.Inl"}, T1 --> T) $ (mk_inj' T1 n2 i)
   104             else
   107             else
   105               Const ("Sum_Type.Inr", T2 --> T) $ (mk_inj' T2 (n - n2) (i - n2))
   108               Const (@{const_name "Sum_Type.Inr"}, T2 --> T) $ (mk_inj' T2 (n - n2) (i - n2))
   106           end
   109           end
   107       in mk_inj' sumT (length leafTs) (1 + find_index (fn T'' => T'' = T') leafTs)
   110       in mk_inj' sumT (length leafTs) (1 + find_index (fn T'' => T'' = T') leafTs)
   108       end;
   111       end;
   109 
   112 
   110     (* make injections for constructors *)
   113     (* make injections for constructors *)
   627 
   630 
   628   in
   631   in
   629     ((constr_inject', distinct_thms', dt_induct'), thy7)
   632     ((constr_inject', distinct_thms', dt_induct'), thy7)
   630   end;
   633   end;
   631 
   634 
       
   635 
       
   636 
       
   637 (** definitional introduction of datatypes **)
       
   638 
       
   639 fun gen_add_datatype prep_typ config new_type_names dts thy =
       
   640   let
       
   641     val _ = Theory.requires thy "Datatype" "datatype definitions";
       
   642 
       
   643     (* this theory is used just for parsing *)
       
   644     val tmp_thy = thy |>
       
   645       Theory.copy |>
       
   646       Sign.add_types (map (fn (tvs, tname, mx, _) =>
       
   647         (tname, length tvs, mx)) dts);
       
   648 
       
   649     val (tyvars, _, _, _)::_ = dts;
       
   650     val (new_dts, types_syntax) = ListPair.unzip (map (fn (tvs, tname, mx, _) =>
       
   651       let val full_tname = Sign.full_name tmp_thy (Binding.map_name (Syntax.type_name mx) tname)
       
   652       in
       
   653         (case duplicates (op =) tvs of
       
   654           [] =>
       
   655             if eq_set (op =) (tyvars, tvs) then ((full_tname, tvs), (tname, mx))
       
   656             else error ("Mutually recursive datatypes must have same type parameters")
       
   657         | dups => error ("Duplicate parameter(s) for datatype " ^ quote (Binding.str_of tname) ^
       
   658             " : " ^ commas dups))
       
   659       end) dts);
       
   660     val dt_names = map fst new_dts;
       
   661 
       
   662     val _ =
       
   663       (case duplicates (op =) (map fst new_dts) @ duplicates (op =) new_type_names of
       
   664         [] => ()
       
   665       | dups => error ("Duplicate datatypes: " ^ commas dups));
       
   666 
       
   667     fun prep_dt_spec (tvs, tname, mx, constrs) tname' (dts', constr_syntax, sorts, i) =
       
   668       let
       
   669         fun prep_constr (cname, cargs, mx') (constrs, constr_syntax', sorts') =
       
   670           let
       
   671             val (cargs', sorts'') = fold_map (prep_typ tmp_thy) cargs sorts';
       
   672             val _ =
       
   673               (case subtract (op =) tvs (fold (curry OldTerm.add_typ_tfree_names) cargs' []) of
       
   674                 [] => ()
       
   675               | vs => error ("Extra type variables on rhs: " ^ commas vs))
       
   676           in (constrs @ [(Sign.full_name_path tmp_thy tname'
       
   677                   (Binding.map_name (Syntax.const_name mx') cname),
       
   678                    map (dtyp_of_typ new_dts) cargs')],
       
   679               constr_syntax' @ [(cname, mx')], sorts'')
       
   680           end handle ERROR msg => cat_error msg
       
   681            ("The error above occured in constructor " ^ quote (Binding.str_of cname) ^
       
   682             " of datatype " ^ quote (Binding.str_of tname));
       
   683 
       
   684         val (constrs', constr_syntax', sorts') =
       
   685           fold prep_constr constrs ([], [], sorts)
       
   686 
       
   687       in
       
   688         case duplicates (op =) (map fst constrs') of
       
   689            [] =>
       
   690              (dts' @ [(i, (Sign.full_name tmp_thy (Binding.map_name (Syntax.type_name mx) tname),
       
   691                 map DtTFree tvs, constrs'))],
       
   692               constr_syntax @ [constr_syntax'], sorts', i + 1)
       
   693          | dups => error ("Duplicate constructors " ^ commas dups ^
       
   694              " in datatype " ^ quote (Binding.str_of tname))
       
   695       end;
       
   696 
       
   697     val (dts', constr_syntax, sorts', i) =
       
   698       fold2 prep_dt_spec dts new_type_names ([], [], [], 0);
       
   699     val sorts = sorts' @ map (rpair (Sign.defaultS tmp_thy)) (subtract (op =) (map fst sorts') tyvars);
       
   700     val dt_info = Datatype.get_all thy;
       
   701     val (descr, _) = unfold_datatypes tmp_thy dts' sorts dt_info dts' i;
       
   702     val _ = check_nonempty descr handle (exn as Datatype_Empty s) =>
       
   703       if #strict config then error ("Nonemptiness check failed for datatype " ^ s)
       
   704       else raise exn;
       
   705 
       
   706     val _ = message config ("Constructing datatype(s) " ^ commas_quote new_type_names);
       
   707 
       
   708   in
       
   709     thy
       
   710     |> representation_proofs config dt_info new_type_names descr sorts
       
   711         types_syntax constr_syntax (Datatype.mk_case_names_induct (flat descr))
       
   712     |-> (fn (inject, distinct, induct) => Datatype.derive_datatype_props
       
   713         config dt_names (SOME new_type_names) descr sorts
       
   714         induct inject distinct)
       
   715   end;
       
   716 
       
   717 val add_datatype = gen_add_datatype Datatype.cert_typ;
       
   718 val datatype_cmd = snd ooo gen_add_datatype Datatype.read_typ default_config;
       
   719 
       
   720 local
       
   721 
       
   722 structure P = OuterParse and K = OuterKeyword
       
   723 
       
   724 fun prep_datatype_decls args =
       
   725   let
       
   726     val names = map
       
   727       (fn ((((NONE, _), t), _), _) => Binding.name_of t | ((((SOME t, _), _), _), _) => t) args;
       
   728     val specs = map (fn ((((_, vs), t), mx), cons) =>
       
   729       (vs, t, mx, map (fn ((x, y), z) => (x, y, z)) cons)) args;
       
   730   in (names, specs) end;
       
   731 
       
   732 val parse_datatype_decl =
       
   733   (Scan.option (P.$$$ "(" |-- P.name --| P.$$$ ")") -- P.type_args -- P.binding -- P.opt_infix --
       
   734     (P.$$$ "=" |-- P.enum1 "|" (P.binding -- Scan.repeat P.typ -- P.opt_mixfix)));
       
   735 
       
   736 val parse_datatype_decls = P.and_list1 parse_datatype_decl >> prep_datatype_decls;
       
   737 
       
   738 in
       
   739 
       
   740 val _ =
       
   741   OuterSyntax.command "datatype" "define inductive datatypes" K.thy_decl
       
   742     (parse_datatype_decls >> (fn (names, specs) => Toplevel.theory (datatype_cmd names specs)));
       
   743 
   632 end;
   744 end;
       
   745 
       
   746 end;
       
   747 
       
   748 structure Datatype =
       
   749 struct
       
   750 
       
   751 open Datatype;
       
   752 open DatatypeRepProofs;
       
   753 
       
   754 end;