src/HOL/Tools/Datatype/datatype.ML
changeset 32717 0e787c721fa3
parent 32712 ec5976f4d3d8
child 32718 45929374f1bd
--- a/src/HOL/Tools/Datatype/datatype.ML	Sun Sep 27 19:58:24 2009 +0200
+++ b/src/HOL/Tools/Datatype/datatype.ML	Sun Sep 27 20:15:45 2009 +0200
@@ -39,8 +39,9 @@
 
 open DatatypeAux;
 
+(** theory data **)
 
-(* theory data *)
+(* data management *)
 
 structure DatatypesData = TheoryDataFun
 (
@@ -62,13 +63,16 @@
 );
 
 val get_all = #types o DatatypesData.get;
-val map_datatypes = DatatypesData.map;
-
+val get_info = Symtab.lookup o get_all;
+fun the_info thy name = (case get_info thy name of
+      SOME info => info
+    | NONE => error ("Unknown datatype " ^ quote name));
 
-(** theory information about datatypes **)
+val info_of_constr = Symtab.lookup o #constrs o DatatypesData.get;
+val info_of_case = Symtab.lookup o #cases o DatatypesData.get;
 
-fun put_dt_infos (dt_infos : (string * info) list) =
-  map_datatypes (fn {types, constrs, cases} =>
+fun register (dt_infos : (string * info) list) =
+  DatatypesData.map (fn {types, constrs, cases} =>
     {types = fold Symtab.update dt_infos types,
      constrs = fold Symtab.default (*conservative wrt. overloaded constructors*)
        (maps (fn (_, info as {descr, index, ...}) => map (rpair info o fst)
@@ -77,19 +81,7 @@
        (map (fn (_, info as {case_name, ...}) => (case_name, info)) dt_infos)
        cases});
 
-val get_info = Symtab.lookup o get_all;
-
-fun the_info thy name = (case get_info thy name of
-      SOME info => info
-    | NONE => error ("Unknown datatype " ^ quote name));
-
-val info_of_constr = Symtab.lookup o #constrs o DatatypesData.get;
-val info_of_case = Symtab.lookup o #cases o DatatypesData.get;
-
-fun get_info_descr thy dtco =
-  get_info thy dtco
-  |> Option.map (fn info as { descr, index, ... } =>
-       (info, (((fn SOME (_, dtys, cos) => (dtys, cos)) o AList.lookup (op =) descr) index)));
+(* complex queries *)
 
 fun the_spec thy dtco =
   let
@@ -143,69 +135,9 @@
     | NONE => NONE;
 
 
-(** induct method setup **)
-
-(* case names *)
-
-local
-
-fun dt_recs (DtTFree _) = []
-  | dt_recs (DtType (_, dts)) = maps dt_recs dts
-  | dt_recs (DtRec i) = [i];
-
-fun dt_cases (descr: descr) (_, args, constrs) =
-  let
-    fun the_bname i = Long_Name.base_name (#1 (the (AList.lookup (op =) descr i)));
-    val bnames = map the_bname (distinct (op =) (maps dt_recs args));
-  in map (fn (c, _) => space_implode "_" (Long_Name.base_name c :: bnames)) constrs end;
-
-
-fun induct_cases descr =
-  DatatypeProp.indexify_names (maps (dt_cases descr) (map #2 descr));
-
-fun exhaust_cases descr i = dt_cases descr (the (AList.lookup (op =) descr i));
-
-in
-
-fun mk_case_names_induct descr = RuleCases.case_names (induct_cases descr);
-
-fun mk_case_names_exhausts descr new =
-  map (RuleCases.case_names o exhaust_cases descr o #1)
-    (filter (fn ((_, (name, _, _))) => member (op =) new name) descr);
-
-end;
+(** various auxiliary **)
 
-fun add_rules simps case_thms rec_thms inject distinct
-                  weak_case_congs cong_att =
-  PureThy.add_thmss [((Binding.name "simps", simps), []),
-    ((Binding.empty, flat case_thms @
-          flat distinct @ rec_thms), [Simplifier.simp_add]),
-    ((Binding.empty, rec_thms), [Code.add_default_eqn_attribute]),
-    ((Binding.empty, flat inject), [iff_add]),
-    ((Binding.empty, map (fn th => th RS notE) (flat distinct)), [Classical.safe_elim NONE]),
-    ((Binding.empty, weak_case_congs), [cong_att])]
-  #> snd;
-
-
-(* add_cases_induct *)
-
-fun add_cases_induct infos inducts thy =
-  let
-    fun named_rules (name, {index, exhaust, ...}: info) =
-      [((Binding.empty, nth inducts index), [Induct.induct_type name]),
-       ((Binding.empty, exhaust), [Induct.cases_type name])];
-    fun unnamed_rule i =
-      ((Binding.empty, nth inducts i), [Thm.kind_internal, Induct.induct_type ""]);
-  in
-    thy |> PureThy.add_thms
-      (maps named_rules infos @
-        map unnamed_rule (length infos upto length inducts - 1)) |> snd
-    |> PureThy.add_thmss [((Binding.name "inducts", inducts), [])] |> snd
-  end;
-
-
-
-(**** simplification procedure for showing distinctness of constructors ****)
+(* simplification procedure for showing distinctness of constructors *)
 
 fun stripT (i, Type ("fun", [_, T])) = stripT (i + 1, T)
   | stripT p = p;
@@ -233,17 +165,21 @@
           etac FalseE 1]))
       end;
 
+fun get_constr thy dtco =
+  get_info thy dtco
+  |> Option.map (fn { descr, index, ... } => (#3 o the o AList.lookup (op =) descr) index);
+
 fun distinct_proc thy ss (t as Const ("op =", _) $ t1 $ t2) =
   (case (stripC (0, t1), stripC (0, t2)) of
      ((i, Const (cname1, T1)), (j, Const (cname2, T2))) =>
          (case (stripT (0, T1), stripT (0, T2)) of
             ((i', Type (tname1, _)), (j', Type (tname2, _))) =>
                 if tname1 = tname2 andalso not (cname1 = cname2) andalso i = i' andalso j = j' then
-                   (case (get_info_descr thy) tname1 of
-                      SOME (_, (_, constrs)) => let val cnames = map fst constrs
+                   (case get_constr thy tname1 of
+                      SOME constrs => let val cnames = map fst constrs
                         in if cname1 mem cnames andalso cname2 mem cnames then
                              SOME (distinct_rule thy ss tname1
-                               (Logic.mk_equals (t, Const ("False", HOLogic.boolT))))
+                               (Logic.mk_equals (t, HOLogic.false_const)))
                            else NONE
                         end
                     | NONE => NONE)
@@ -260,29 +196,7 @@
 val simproc_setup =
   Simplifier.map_simpset (fn ss => ss addsimprocs [distinct_simproc]);
 
-
-(**** translation rules for case ****)
-
-fun make_case ctxt = DatatypeCase.make_case
-  (info_of_constr (ProofContext.theory_of ctxt)) ctxt;
-
-fun strip_case ctxt = DatatypeCase.strip_case
-  (info_of_case (ProofContext.theory_of ctxt));
-
-fun add_case_tr' case_names thy =
-  Sign.add_advanced_trfuns ([], [],
-    map (fn case_name =>
-      let val case_name' = Sign.const_syntax_name thy case_name
-      in (case_name', DatatypeCase.case_tr' info_of_case case_name')
-      end) case_names, []) thy;
-
-val trfun_setup =
-  Sign.add_advanced_trfuns ([],
-    [("_case_syntax", DatatypeCase.case_tr true info_of_constr)],
-    [], []);
-
-
-(* prepare types *)
+(* prepare datatype specifications *)
 
 fun read_typ thy ((Ts, sorts), str) =
   let
@@ -302,8 +216,7 @@
        | dups => error ("Inconsistent sort constraints for " ^ commas dups))
   end;
 
-
-(**** make datatype info ****)
+(* arrange data entries *)
 
 fun make_dt_info alt_names descr sorts inducts reccomb_names rec_thms
     ((((((((((i, (_, (tname, _, _))), case_name), case_thms),
@@ -326,62 +239,113 @@
     case_cong = case_cong,
     weak_case_cong = weak_case_cong});
 
+(* case names *)
+
+local
+
+fun dt_recs (DtTFree _) = []
+  | dt_recs (DtType (_, dts)) = maps dt_recs dts
+  | dt_recs (DtRec i) = [i];
+
+fun dt_cases (descr: descr) (_, args, constrs) =
+  let
+    fun the_bname i = Long_Name.base_name (#1 (the (AList.lookup (op =) descr i)));
+    val bnames = map the_bname (distinct (op =) (maps dt_recs args));
+  in map (fn (c, _) => space_implode "_" (Long_Name.base_name c :: bnames)) constrs end;
+
+fun induct_cases descr =
+  DatatypeProp.indexify_names (maps (dt_cases descr) (map #2 descr));
+
+fun exhaust_cases descr i = dt_cases descr (the (AList.lookup (op =) descr i));
+
+in
+
+fun mk_case_names_induct descr = RuleCases.case_names (induct_cases descr);
+
+fun mk_case_names_exhausts descr new =
+  map (RuleCases.case_names o exhaust_cases descr o #1)
+    (filter (fn ((_, (name, _, _))) => member (op =) new name) descr);
+
+end;
+
+(* note rules and interpretations *)
+
+fun add_rules simps case_thms rec_thms inject distinct
+                  weak_case_congs cong_att =
+  PureThy.add_thmss [((Binding.name "simps", simps), []),
+    ((Binding.empty, flat case_thms @
+          flat distinct @ rec_thms), [Simplifier.simp_add]),
+    ((Binding.empty, rec_thms), [Code.add_default_eqn_attribute]),
+    ((Binding.empty, flat inject), [iff_add]),
+    ((Binding.empty, map (fn th => th RS notE) (flat distinct)), [Classical.safe_elim NONE]),
+    ((Binding.empty, weak_case_congs), [cong_att])]
+  #> snd;
+
+fun add_cases_induct infos inducts thy =
+  let
+    fun named_rules (name, {index, exhaust, ...}: info) =
+      [((Binding.empty, nth inducts index), [Induct.induct_type name]),
+       ((Binding.empty, exhaust), [Induct.cases_type name])];
+    fun unnamed_rule i =
+      ((Binding.empty, nth inducts i), [Thm.kind_internal, Induct.induct_type ""]);
+  in
+    thy |> PureThy.add_thms
+      (maps named_rules infos @
+        map unnamed_rule (length infos upto length inducts - 1)) |> snd
+    |> PureThy.add_thmss [((Binding.name "inducts", inducts), [])] |> snd
+  end;
+
 structure DatatypeInterpretation = InterpretationFun
   (type T = config * string list val eq: T * T -> bool = eq_snd op =);
 fun interpretation f = DatatypeInterpretation.interpretation (uncurry f);
 
+(* translation rules for case *)
 
-(******************* definitional introduction of datatypes *******************)
+fun make_case ctxt = DatatypeCase.make_case
+  (info_of_constr (ProofContext.theory_of ctxt)) ctxt;
+
+fun strip_case ctxt = DatatypeCase.strip_case
+  (info_of_case (ProofContext.theory_of ctxt));
 
-fun add_datatype_def (config : config) new_type_names descr sorts types_syntax constr_syntax dt_info
-    case_names_induct case_names_exhausts thy =
-  let
-    val _ = message config ("Proofs for datatype(s) " ^ commas_quote new_type_names);
+fun add_case_tr' case_names thy =
+  Sign.add_advanced_trfuns ([], [],
+    map (fn case_name =>
+      let val case_name' = Sign.const_syntax_name thy case_name
+      in (case_name', DatatypeCase.case_tr' info_of_case case_name')
+      end) case_names, []) thy;
 
-    val ((inject, distinct, dist_rewrites, simproc_dists, induct), thy2) = thy |>
-      DatatypeRepProofs.representation_proofs config dt_info new_type_names descr sorts
-        types_syntax constr_syntax case_names_induct;
-    val inducts = Project_Rule.projections (ProofContext.init thy2) induct;
+val trfun_setup =
+  Sign.add_advanced_trfuns ([],
+    [("_case_syntax", DatatypeCase.case_tr true info_of_constr)],
+    [], []);
+
+(* document antiquotation *)
 
-    val (casedist_thms, thy3) = DatatypeAbsProofs.prove_casedist_thms config new_type_names descr
-      sorts induct case_names_exhausts thy2;
-    val ((reccomb_names, rec_thms), thy4) = DatatypeAbsProofs.prove_primrec_thms
-      config new_type_names descr sorts dt_info inject dist_rewrites
-      (Simplifier.theory_context thy3 dist_ss) induct thy3;
-    val ((case_thms, case_names), thy6) = DatatypeAbsProofs.prove_case_thms
-      config new_type_names descr sorts reccomb_names rec_thms thy4;
-    val (split_thms, thy7) = DatatypeAbsProofs.prove_split_thms config new_type_names
-      descr sorts inject dist_rewrites casedist_thms case_thms thy6;
-    val (nchotomys, thy8) = DatatypeAbsProofs.prove_nchotomys config new_type_names
-      descr sorts casedist_thms thy7;
-    val (case_congs, thy9) = DatatypeAbsProofs.prove_case_congs new_type_names
-      descr sorts nchotomys case_thms thy8;
-    val (weak_case_congs, thy10) = DatatypeAbsProofs.prove_weak_case_congs new_type_names
-      descr sorts thy9;
-
-    val dt_infos = map
-      (make_dt_info (SOME new_type_names) (flat descr) sorts (inducts, induct) reccomb_names rec_thms)
-      ((0 upto length (hd descr) - 1) ~~ hd descr ~~ case_names ~~ case_thms ~~
-        casedist_thms ~~ simproc_dists ~~ inject ~~ split_thms ~~ nchotomys ~~ case_congs ~~ weak_case_congs);
-
-    val simps = flat (distinct @ inject @ case_thms) @ rec_thms;
-    val dt_names = map fst dt_infos;
-
-    val thy12 =
-      thy10
-      |> add_case_tr' case_names
-      |> Sign.add_path (space_implode "_" new_type_names)
-      |> add_rules simps case_thms rec_thms inject distinct
-          weak_case_congs (Simplifier.attrib (op addcongs))
-      |> put_dt_infos dt_infos
-      |> add_cases_induct dt_infos inducts
-      |> Sign.parent_path
-      |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms) |> snd
-      |> DatatypeInterpretation.data (config, map fst dt_infos);
-  in (dt_names, thy12) end;
+val _ = ThyOutput.antiquotation "datatype" Args.tyname
+  (fn {source = src, context = ctxt, ...} => fn dtco =>
+    let
+      val thy = ProofContext.theory_of ctxt;
+      val (vs, cos) = the_spec thy dtco;
+      val ty = Type (dtco, map TFree vs);
+      fun pretty_typ_bracket (ty as Type (_, _ :: _)) =
+            Pretty.enclose "(" ")" [Syntax.pretty_typ ctxt ty]
+        | pretty_typ_bracket ty =
+            Syntax.pretty_typ ctxt ty;
+      fun pretty_constr (co, tys) =
+        (Pretty.block o Pretty.breaks)
+          (Syntax.pretty_term ctxt (Const (co, tys ---> ty)) ::
+            map pretty_typ_bracket tys);
+      val pretty_datatype =
+        Pretty.block
+          (Pretty.command "datatype" :: Pretty.brk 1 ::
+           Syntax.pretty_typ ctxt ty ::
+           Pretty.str " =" :: Pretty.brk 1 ::
+           flat (separate [Pretty.brk 1, Pretty.str "| "]
+             (map (single o pretty_constr) cos)));
+    in ThyOutput.output (ThyOutput.maybe_pretty_source (K pretty_datatype) src [()]) end);
 
 
-(*********************** declare existing type as datatype *********************)
+(** declare existing type as datatype **)
 
 fun prove_rep_datatype (config : config) alt_names new_type_names descr sorts induct inject half_distinct thy =
   let
@@ -441,7 +405,7 @@
       |> add_case_tr' case_names
       |> add_rules simps case_thms rec_thms inject distinct
            weak_case_congs (Simplifier.attrib (op addcongs))
-      |> put_dt_infos dt_infos
+      |> register dt_infos
       |> add_cases_induct dt_infos inducts
       |> Sign.parent_path
       |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms)
@@ -518,8 +482,54 @@
 val rep_datatype_cmd = gen_rep_datatype Syntax.read_term_global default_config (K I);
 
 
+(** definitional introduction of datatypes **)
 
-(******************************** add datatype ********************************)
+fun add_datatype_def (config : config) new_type_names descr sorts types_syntax constr_syntax dt_info
+    case_names_induct case_names_exhausts thy =
+  let
+    val _ = message config ("Proofs for datatype(s) " ^ commas_quote new_type_names);
+
+    val ((inject, distinct, dist_rewrites, simproc_dists, induct), thy2) = thy |>
+      DatatypeRepProofs.representation_proofs config dt_info new_type_names descr sorts
+        types_syntax constr_syntax case_names_induct;
+    val inducts = Project_Rule.projections (ProofContext.init thy2) induct;
+
+    val (casedist_thms, thy3) = DatatypeAbsProofs.prove_casedist_thms config new_type_names descr
+      sorts induct case_names_exhausts thy2;
+    val ((reccomb_names, rec_thms), thy4) = DatatypeAbsProofs.prove_primrec_thms
+      config new_type_names descr sorts dt_info inject dist_rewrites
+      (Simplifier.theory_context thy3 dist_ss) induct thy3;
+    val ((case_thms, case_names), thy6) = DatatypeAbsProofs.prove_case_thms
+      config new_type_names descr sorts reccomb_names rec_thms thy4;
+    val (split_thms, thy7) = DatatypeAbsProofs.prove_split_thms config new_type_names
+      descr sorts inject dist_rewrites casedist_thms case_thms thy6;
+    val (nchotomys, thy8) = DatatypeAbsProofs.prove_nchotomys config new_type_names
+      descr sorts casedist_thms thy7;
+    val (case_congs, thy9) = DatatypeAbsProofs.prove_case_congs new_type_names
+      descr sorts nchotomys case_thms thy8;
+    val (weak_case_congs, thy10) = DatatypeAbsProofs.prove_weak_case_congs new_type_names
+      descr sorts thy9;
+
+    val dt_infos = map
+      (make_dt_info (SOME new_type_names) (flat descr) sorts (inducts, induct) reccomb_names rec_thms)
+      ((0 upto length (hd descr) - 1) ~~ hd descr ~~ case_names ~~ case_thms ~~
+        casedist_thms ~~ simproc_dists ~~ inject ~~ split_thms ~~ nchotomys ~~ case_congs ~~ weak_case_congs);
+
+    val simps = flat (distinct @ inject @ case_thms) @ rec_thms;
+    val dt_names = map fst dt_infos;
+
+    val thy12 =
+      thy10
+      |> add_case_tr' case_names
+      |> Sign.add_path (space_implode "_" new_type_names)
+      |> add_rules simps case_thms rec_thms inject distinct
+          weak_case_congs (Simplifier.attrib (op addcongs))
+      |> register dt_infos
+      |> add_cases_induct dt_infos inducts
+      |> Sign.parent_path
+      |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms) |> snd
+      |> DatatypeInterpretation.data (config, map fst dt_infos);
+  in (dt_names, thy12) end;
 
 fun gen_add_datatype prep_typ (config : config) new_type_names dts thy =
   let
@@ -596,7 +606,6 @@
 val datatype_cmd = snd ooo gen_add_datatype read_typ default_config;
 
 
-
 (** package setup **)
 
 (* setup theory *)
@@ -607,7 +616,6 @@
   trfun_setup #>
   DatatypeInterpretation.init;
 
-
 (* outer syntax *)
 
 local
@@ -642,31 +650,5 @@
 
 end;
 
-
-(* document antiquotation *)
-
-val _ = ThyOutput.antiquotation "datatype" Args.tyname
-  (fn {source = src, context = ctxt, ...} => fn dtco =>
-    let
-      val thy = ProofContext.theory_of ctxt;
-      val (vs, cos) = the_spec thy dtco;
-      val ty = Type (dtco, map TFree vs);
-      fun pretty_typ_bracket (ty as Type (_, _ :: _)) =
-            Pretty.enclose "(" ")" [Syntax.pretty_typ ctxt ty]
-        | pretty_typ_bracket ty =
-            Syntax.pretty_typ ctxt ty;
-      fun pretty_constr (co, tys) =
-        (Pretty.block o Pretty.breaks)
-          (Syntax.pretty_term ctxt (Const (co, tys ---> ty)) ::
-            map pretty_typ_bracket tys);
-      val pretty_datatype =
-        Pretty.block
-          (Pretty.command "datatype" :: Pretty.brk 1 ::
-           Syntax.pretty_typ ctxt ty ::
-           Pretty.str " =" :: Pretty.brk 1 ::
-           flat (separate [Pretty.brk 1, Pretty.str "| "]
-             (map (single o pretty_constr) cos)));
-    in ThyOutput.output (ThyOutput.maybe_pretty_source (K pretty_datatype) src [()]) end);
-
 end;