cleanup in datatype package
authorhaftmann
Thu, 06 Apr 2006 16:10:46 +0200
changeset 19346 c4c003abd830
parent 19345 73439b467e75
child 19347 e2e709f3f955
cleanup in datatype package
src/HOL/Tools/datatype_codegen.ML
src/HOL/Tools/datatype_package.ML
src/HOL/Tools/datatype_rep_proofs.ML
src/HOL/Tools/refute.ML
--- a/src/HOL/Tools/datatype_codegen.ML	Thu Apr 06 16:10:22 2006 +0200
+++ b/src/HOL/Tools/datatype_codegen.ML	Thu Apr 06 16:10:46 2006 +0200
@@ -7,6 +7,11 @@
 
 signature DATATYPE_CODEGEN =
 sig
+  val get_datatype_spec_thms: theory -> string
+    -> (((string * sort) list * (string * typ list) list) * tactic) option
+  val get_case_const_data: theory -> string -> (string * int) list option
+  val get_all_datatype_cons: theory -> (string * string) list
+  val get_datatype_case_consts: theory -> string list
   val setup: theory -> theory
 end;
 
@@ -297,19 +302,58 @@
   | datatype_tycodegen _ _ _ _ _ _ _ = NONE;
 
 
+(** code 2nd generation **)
+
+fun datatype_tac thy dtco =
+  let
+    val ctxt = Context.init_proof thy;
+    val inject = (#inject o DatatypePackage.the_datatype thy) dtco;
+    val simpset = Simplifier.context ctxt
+      (MetaSimplifier.empty_ss addsimprocs [distinct_simproc]);
+  in
+    (TRY o ALLGOALS o resolve_tac) [HOL.eq_reflection]
+    THEN (
+      (ALLGOALS o resolve_tac) (eqTrueI :: inject)
+      ORELSE (ALLGOALS o simp_tac) simpset
+    )
+    THEN (ALLGOALS o resolve_tac) [HOL.refl, Drule.reflexive_thm]
+  end;
+
+fun get_datatype_spec_thms thy dtco =
+  case DatatypePackage.get_datatype_spec thy dtco
+   of SOME vs_cos =>
+        SOME (vs_cos, datatype_tac thy dtco)
+    | NONE => NONE;
+
+fun get_all_datatype_cons thy =
+  Symtab.fold (fn (dtco, _) => fold
+    (fn (co, _) => cons (co, dtco))
+      ((snd o the oo DatatypePackage.get_datatype_spec) thy dtco))
+        (DatatypePackage.get_datatypes thy) [];
+
+fun get_case_const_data thy c =
+  case find_first (fn (_, {index, descr, case_name, ...}) =>
+      case_name = c
+    ) ((Symtab.dest o DatatypePackage.get_datatypes) thy)
+   of NONE => NONE
+    | SOME (_, {index, descr, ...}) =>
+        (SOME o map (apsnd length) o #3 o the o AList.lookup (op =) descr) index;
+
+fun get_datatype_case_consts thy =
+  Symtab.fold (fn (_, {case_name, ...}) => cons case_name)
+    (DatatypePackage.get_datatypes thy) [];
+
 val setup = 
   add_codegen "datatype" datatype_codegen #>
   add_tycodegen "datatype" datatype_tycodegen #>
+  CodegenTheorems.add_datatype_extr
+    get_datatype_spec_thms #>
   CodegenPackage.set_get_datatype
-    DatatypePackage.get_datatype #>
+    DatatypePackage.get_datatype_spec #>
   CodegenPackage.set_get_all_datatype_cons
-    DatatypePackage.get_all_datatype_cons #>
-  (fn thy => thy |> CodegenPackage.add_eqextr_default ("equality",
-    (CodegenPackage.eqextr_eq
-      DatatypePackage.get_eq_equations
-      (Sign.read_term thy "False")))) #>
+    get_all_datatype_cons #>
   CodegenPackage.ensure_datatype_case_consts
-    DatatypePackage.get_datatype_case_consts
-    DatatypePackage.get_case_const_data;
+    get_datatype_case_consts
+    get_case_const_data;
 
 end;
--- a/src/HOL/Tools/datatype_package.ML	Thu Apr 06 16:10:22 2006 +0200
+++ b/src/HOL/Tools/datatype_package.ML	Thu Apr 06 16:10:46 2006 +0200
@@ -63,17 +63,11 @@
        size : thm list,
        simps : thm list} * theory
   val get_datatypes : theory -> DatatypeAux.datatype_info Symtab.table
+  val get_datatype : theory -> string -> DatatypeAux.datatype_info option
+  val the_datatype : theory -> string -> DatatypeAux.datatype_info
+  val get_datatype_spec : theory -> string -> ((string * sort) list * (string * typ list) list) option
+  val get_datatype_constrs : theory -> string -> (string * typ) list option
   val print_datatypes : theory -> unit
-  val datatype_info : theory -> string -> DatatypeAux.datatype_info option
-  val datatype_info_err : theory -> string -> DatatypeAux.datatype_info
-  val get_datatype : theory -> string -> ((string * sort) list * (string * typ list) list) option
-  val get_datatype_case_consts : theory -> string list
-  val get_case_const_data : theory -> string -> (string * int) list option
-  val get_all_datatype_cons : theory -> (string * string) list
-  val get_eq_equations: theory -> string -> thm list
-  val constrs_of : theory -> string -> term list option
-  val case_const_of : theory -> string -> term option
-  val weak_case_congs_of : theory -> thm list
   val setup: theory -> theory
 end;
 
@@ -109,43 +103,41 @@
 
 (** theory information about datatypes **)
 
-val datatype_info = Symtab.lookup o get_datatypes;
+val get_datatype = Symtab.lookup o get_datatypes;
 
-fun datatype_info_err thy name = (case datatype_info thy name of
+fun the_datatype thy name = (case get_datatype thy name of
       SOME info => info
     | NONE => error ("Unknown datatype " ^ quote name));
 
-fun constrs_of thy tname = (case datatype_info thy tname of
-   SOME {index, descr, ...} =>
-     let val (_, _, constrs) = valOf (AList.lookup (op =) descr index)
-     in SOME (map (fn (cname, _) => Const (cname, Sign.the_const_type thy cname)) constrs)
-     end
- | _ => NONE);
+fun get_datatype_descr thy dtco =
+  get_datatype thy dtco
+  |> Option.map (fn info as { descr, index, ... } => 
+       (info, (((fn SOME (_, dtys, cos) => (dtys, cos)) o AList.lookup (op =) descr) index)));
 
-fun case_const_of thy tname = (case datatype_info thy tname of
-   SOME {case_name, ...} => SOME (Const (case_name, Sign.the_const_type thy case_name))
- | _ => NONE);
-
-val weak_case_congs_of = map (#weak_case_cong o #2) o Symtab.dest o get_datatypes;
-
-fun get_datatype thy dtco =
+fun get_datatype_spec thy dtco =
   let
-    fun get_cons descr vs =
-      apsnd (map (DatatypeAux.typ_of_dtyp descr
-        ((map (rpair []) o map DatatypeAux.dest_DtTFree) vs)));
-    fun get_info ({ sorts, descr, ... } : DatatypeAux.datatype_info) =
-      (sorts,
-        ((the oo get_first) (fn (_, (dtco', tys, cs)) =>
-            if dtco = dtco'
-            then SOME (map (get_cons descr tys) cs)
-            else NONE) descr));
-  in case Symtab.lookup (get_datatypes thy) dtco
-   of SOME info => (SOME o get_info) info
-    | NONE => NONE
-  end;
+    fun mk_cons typ_of_dtyp (co, tys) =
+      (co, map typ_of_dtyp tys);
+    fun mk_dtyp ({ sorts = raw_sorts, descr, ... } : DatatypeAux.datatype_info, (dtys, cos)) =
+      let
+        val sorts = map ((fn v => (v, (the o AList.lookup (op =) raw_sorts) v))
+          o DatatypeAux.dest_DtTFree) dtys;
+        val typ_of_dtyp = DatatypeAux.typ_of_dtyp descr sorts;
+        val tys = map typ_of_dtyp dtys;
+      in (sorts, map (mk_cons typ_of_dtyp) cos) end;
+  in Option.map mk_dtyp (get_datatype_descr thy dtco) end;
 
-fun get_datatype_case_consts thy =
-  Symtab.fold (fn (_, {case_name, ...}) => cons case_name) (get_datatypes thy) [];
+fun get_datatype_constrs thy dtco =
+  case get_datatype_spec thy dtco
+   of SOME (sorts, cos) =>
+        let
+          fun subst (v, sort) = TVar ((v, 0), sort);
+          fun subst_ty (TFree v) = subst v
+            | subst_ty ty = ty;
+          val dty = Type (dtco, map subst sorts);
+          fun mk_co (co, tys) = (co, map (Term.map_atyps subst_ty) tys ---> dty);
+        in SOME (map mk_co cos) end
+    | NONE => NONE;
 
 fun get_case_const_data thy c =
   case find_first (fn (_, {index, descr, case_name, ...}) =>
@@ -155,37 +147,6 @@
     | SOME (_, {index, descr, ...}) =>
         (SOME o map (apsnd length) o #3 o the o AList.lookup (op =) descr) index;
 
-fun get_all_datatype_cons thy =
-  Symtab.fold (fn (dtco, _) => fold
-    (fn (co, _) => cons (co, dtco))
-      ((snd o the oo get_datatype) thy dtco)) (get_datatypes thy) [];
-
-fun get_eq_equations thy dtco =
-  case get_datatype thy dtco
-   of SOME (vars, cos) =>
-        let
-          fun co_inject thm =
-            ((fst o dest_Const o fst o strip_comb o fst o HOLogic.dest_eq o fst
-              o HOLogic.dest_eq o HOLogic.dest_Trueprop o Thm.prop_of) thm, thm RS HOL.eq_reflection);
-          val inject = (map co_inject o #inject o the o datatype_info thy) dtco;
-          fun mk_refl co =
-            let
-              fun infer t =
-                (fst o Sign.infer_types (Sign.pp thy) thy (Sign.consts_of thy) (K NONE) (K NONE) [] true)
-                  ([t], Type (dtco, map (fn (v, sort) => TVar ((v, 0), sort)) vars))
-              val t = (Thm.cterm_of thy o infer) (Const (co, dummyT));
-            in
-              HOL.refl 
-              |> Drule.instantiate' [(SOME o Thm.ctyp_of_term) t] [SOME t]
-              |> (fn thm => thm RS Eq_TrueI)
-            end;
-          fun get_eq co =
-           case AList.lookup (op =) inject co
-            of SOME eq => eq
-             | NONE => mk_refl co;
-        in map (get_eq o fst) cos end
-   | NONE => [];
-
 fun find_tname var Bi =
   let val frees = map dest_Free (term_frees Bi)
       val params = rename_wrt_term Bi (Logic.strip_params Bi);
@@ -243,7 +204,7 @@
 	| NONE =>
 	    let val tn = find_tname (hd (List.mapPartial I (List.concat varss))) Bi
                 val {sign, ...} = Thm.rep_thm state
-	    in (#induction (datatype_info_err sign tn), "Induction rule for type " ^ tn) 
+	    in (#induction (the_datatype sign tn), "Induction rule for type " ^ tn) 
 	    end
     val concls = HOLogic.dest_concls (Thm.concl_of rule);
     val insts = List.concat (map prep_inst (concls ~~ varss)) handle UnequalLengths =>
@@ -276,7 +237,7 @@
       let val tn = infer_tname state i t in
         if tn = HOLogic.boolN then inst_tac [(("P", 0), t)] case_split_thm i state
         else case_inst_tac inst_tac t
-               (#exhaustion (datatype_info_err (Thm.sign_of_thm state) tn))
+               (#exhaustion (the_datatype (Thm.sign_of_thm state) tn))
                i state
       end handle THM _ => Seq.empty;
 
@@ -401,8 +362,8 @@
          (case (stripT (0, T1), stripT (0, T2)) of
             ((i', Type (tname1, _)), (j', Type (tname2, _))) =>
                 if tname1 = tname2 andalso not (cname1 = cname2) andalso i = i' andalso j = j' then
-                   (case (constrs_of sg tname1) of
-                      SOME constrs => let val cnames = map (fst o dest_Const) constrs
+                   (case (get_datatype_descr sg) tname1 of
+                      SOME (_, (_, constrs)) => let val cnames = map fst constrs
                         in if cname1 mem cnames andalso cname2 mem cnames then
                              let val eq_t = Logic.mk_equals (t, Const ("False", HOLogic.boolT));
                                  val eq_ct = cterm_of sg eq_t;
@@ -410,7 +371,7 @@
                                  val [In0_inject, In1_inject, In0_not_In1, In1_not_In0] =
                                    map (get_thm Datatype_thy o Name)
                                      ["In0_inject", "In1_inject", "In0_not_In1", "In1_not_In0"]
-                             in (case (#distinct (datatype_info_err sg tname1)) of
+                             in (case (#distinct (the_datatype sg tname1)) of
                                  QuickAndDirty => SOME (Thm.invoke_oracle
                                    Datatype_thy distinctN (sg, ConstrDistinct eq_t))
                                | FewConstrs thms => SOME (Goal.prove sg [] [] eq_t (K
--- a/src/HOL/Tools/datatype_rep_proofs.ML	Thu Apr 06 16:10:22 2006 +0200
+++ b/src/HOL/Tools/datatype_rep_proofs.ML	Thu Apr 06 16:10:46 2006 +0200
@@ -184,7 +184,7 @@
         (TypedefPackage.add_typedef_i false (SOME name') (name, tvs, mx) c NONE
           (rtac exI 1 THEN
             QUIET_BREADTH_FIRST (has_fewer_prems 1)
-            (resolve_tac rep_intrs 1))) thy |> #1)
+            (resolve_tac rep_intrs 1))) thy |> snd)
               (parent_path flat_names thy2, types_syntax ~~ tyvars ~~
                 (Library.take (length newTs, consts)) ~~ new_type_names));
 
--- a/src/HOL/Tools/refute.ML	Thu Apr 06 16:10:22 2006 +0200
+++ b/src/HOL/Tools/refute.ML	Thu Apr 06 16:10:46 2006 +0200
@@ -554,7 +554,7 @@
 						     | MATCH           => get_typedefn axms
 						     | Type.TYPE_MATCH => get_typedefn axms)
 				in
-					case DatatypePackage.datatype_info thy s of
+					case DatatypePackage.get_datatype thy s of
 					  SOME info =>  (* inductive datatype *)
 							(* only collect relevant type axioms for the argument types *)
 							Library.foldl collect_type_axioms (axs, Ts)
@@ -664,14 +664,10 @@
 					fun is_IDT_constructor () =
 						(case body_type T of
 						  Type (s', _) =>
-							(case DatatypePackage.constrs_of thy s' of
+							(case DatatypePackage.get_datatype_constrs thy s' of
 							  SOME constrs =>
-								Library.exists (fn c =>
-									(case c of
-									  Const (cname, ctype) =>
-										cname = s andalso Sign.typ_instance thy (T, ctype)
-									| _ =>
-										raise REFUTE ("collect_axioms", "IDT constructor is not a constant")))
+								Library.exists (fn (cname, cty) =>
+								cname = s andalso Sign.typ_instance thy (T, cty))
 									constrs
 							| NONE =>
 								false)
@@ -773,7 +769,7 @@
 				| Type ("prop", [])      => acc
 				| Type ("set", [T1])     => collect_types (T1, acc)
 				| Type (s, Ts)           =>
-					(case DatatypePackage.datatype_info thy s of
+					(case DatatypePackage.get_datatype thy s of
 					  SOME info =>  (* inductive datatype *)
 						let
 							val index               = #index info
@@ -944,7 +940,7 @@
 			(* TODO: no warning needed for /positive/ occurrences of IDTs       *)
 			val _ = if Library.exists (fn
 				  Type (s, _) =>
-					(case DatatypePackage.datatype_info thy s of
+					(case DatatypePackage.get_datatype thy s of
 					  SOME info =>  (* inductive datatype *)
 						let
 							val index           = #index info
@@ -1647,7 +1643,7 @@
 		val (typs, terms) = model
 		(* Term.typ -> (interpretation * model * arguments) option *)
 		fun interpret_term (Type (s, Ts)) =
-			(case DatatypePackage.datatype_info thy s of
+			(case DatatypePackage.get_datatype thy s of
 			  SOME info =>  (* inductive datatype *)
 				let
 					(* int option -- only recursive IDTs have an associated depth *)
@@ -1723,7 +1719,7 @@
 			  Const (s, T) =>
 				(case body_type T of
 				  Type (s', Ts') =>
-					(case DatatypePackage.datatype_info thy s' of
+					(case DatatypePackage.get_datatype thy s' of
 					  SOME info =>  (* body type is an inductive datatype *)
 						let
 							val index               = #index info
@@ -2511,7 +2507,7 @@
 	in
 		case typeof t of
 		  SOME (Type (s, Ts)) =>
-			(case DatatypePackage.datatype_info thy s of
+			(case DatatypePackage.get_datatype thy s of
 			  SOME info =>  (* inductive datatype *)
 				let
 					val (typs, _)           = model