cleanup in datatype package
authorhaftmann
Thu Apr 06 16:10:46 2006 +0200 (2006-04-06)
changeset 19346c4c003abd830
parent 19345 73439b467e75
child 19347 e2e709f3f955
cleanup in datatype package
src/HOL/Tools/datatype_codegen.ML
src/HOL/Tools/datatype_package.ML
src/HOL/Tools/datatype_rep_proofs.ML
src/HOL/Tools/refute.ML
     1.1 --- a/src/HOL/Tools/datatype_codegen.ML	Thu Apr 06 16:10:22 2006 +0200
     1.2 +++ b/src/HOL/Tools/datatype_codegen.ML	Thu Apr 06 16:10:46 2006 +0200
     1.3 @@ -7,6 +7,11 @@
     1.4  
     1.5  signature DATATYPE_CODEGEN =
     1.6  sig
     1.7 +  val get_datatype_spec_thms: theory -> string
     1.8 +    -> (((string * sort) list * (string * typ list) list) * tactic) option
     1.9 +  val get_case_const_data: theory -> string -> (string * int) list option
    1.10 +  val get_all_datatype_cons: theory -> (string * string) list
    1.11 +  val get_datatype_case_consts: theory -> string list
    1.12    val setup: theory -> theory
    1.13  end;
    1.14  
    1.15 @@ -297,19 +302,58 @@
    1.16    | datatype_tycodegen _ _ _ _ _ _ _ = NONE;
    1.17  
    1.18  
    1.19 +(** code 2nd generation **)
    1.20 +
    1.21 +fun datatype_tac thy dtco =
    1.22 +  let
    1.23 +    val ctxt = Context.init_proof thy;
    1.24 +    val inject = (#inject o DatatypePackage.the_datatype thy) dtco;
    1.25 +    val simpset = Simplifier.context ctxt
    1.26 +      (MetaSimplifier.empty_ss addsimprocs [distinct_simproc]);
    1.27 +  in
    1.28 +    (TRY o ALLGOALS o resolve_tac) [HOL.eq_reflection]
    1.29 +    THEN (
    1.30 +      (ALLGOALS o resolve_tac) (eqTrueI :: inject)
    1.31 +      ORELSE (ALLGOALS o simp_tac) simpset
    1.32 +    )
    1.33 +    THEN (ALLGOALS o resolve_tac) [HOL.refl, Drule.reflexive_thm]
    1.34 +  end;
    1.35 +
    1.36 +fun get_datatype_spec_thms thy dtco =
    1.37 +  case DatatypePackage.get_datatype_spec thy dtco
    1.38 +   of SOME vs_cos =>
    1.39 +        SOME (vs_cos, datatype_tac thy dtco)
    1.40 +    | NONE => NONE;
    1.41 +
    1.42 +fun get_all_datatype_cons thy =
    1.43 +  Symtab.fold (fn (dtco, _) => fold
    1.44 +    (fn (co, _) => cons (co, dtco))
    1.45 +      ((snd o the oo DatatypePackage.get_datatype_spec) thy dtco))
    1.46 +        (DatatypePackage.get_datatypes thy) [];
    1.47 +
    1.48 +fun get_case_const_data thy c =
    1.49 +  case find_first (fn (_, {index, descr, case_name, ...}) =>
    1.50 +      case_name = c
    1.51 +    ) ((Symtab.dest o DatatypePackage.get_datatypes) thy)
    1.52 +   of NONE => NONE
    1.53 +    | SOME (_, {index, descr, ...}) =>
    1.54 +        (SOME o map (apsnd length) o #3 o the o AList.lookup (op =) descr) index;
    1.55 +
    1.56 +fun get_datatype_case_consts thy =
    1.57 +  Symtab.fold (fn (_, {case_name, ...}) => cons case_name)
    1.58 +    (DatatypePackage.get_datatypes thy) [];
    1.59 +
    1.60  val setup = 
    1.61    add_codegen "datatype" datatype_codegen #>
    1.62    add_tycodegen "datatype" datatype_tycodegen #>
    1.63 +  CodegenTheorems.add_datatype_extr
    1.64 +    get_datatype_spec_thms #>
    1.65    CodegenPackage.set_get_datatype
    1.66 -    DatatypePackage.get_datatype #>
    1.67 +    DatatypePackage.get_datatype_spec #>
    1.68    CodegenPackage.set_get_all_datatype_cons
    1.69 -    DatatypePackage.get_all_datatype_cons #>
    1.70 -  (fn thy => thy |> CodegenPackage.add_eqextr_default ("equality",
    1.71 -    (CodegenPackage.eqextr_eq
    1.72 -      DatatypePackage.get_eq_equations
    1.73 -      (Sign.read_term thy "False")))) #>
    1.74 +    get_all_datatype_cons #>
    1.75    CodegenPackage.ensure_datatype_case_consts
    1.76 -    DatatypePackage.get_datatype_case_consts
    1.77 -    DatatypePackage.get_case_const_data;
    1.78 +    get_datatype_case_consts
    1.79 +    get_case_const_data;
    1.80  
    1.81  end;
     2.1 --- a/src/HOL/Tools/datatype_package.ML	Thu Apr 06 16:10:22 2006 +0200
     2.2 +++ b/src/HOL/Tools/datatype_package.ML	Thu Apr 06 16:10:46 2006 +0200
     2.3 @@ -63,17 +63,11 @@
     2.4         size : thm list,
     2.5         simps : thm list} * theory
     2.6    val get_datatypes : theory -> DatatypeAux.datatype_info Symtab.table
     2.7 +  val get_datatype : theory -> string -> DatatypeAux.datatype_info option
     2.8 +  val the_datatype : theory -> string -> DatatypeAux.datatype_info
     2.9 +  val get_datatype_spec : theory -> string -> ((string * sort) list * (string * typ list) list) option
    2.10 +  val get_datatype_constrs : theory -> string -> (string * typ) list option
    2.11    val print_datatypes : theory -> unit
    2.12 -  val datatype_info : theory -> string -> DatatypeAux.datatype_info option
    2.13 -  val datatype_info_err : theory -> string -> DatatypeAux.datatype_info
    2.14 -  val get_datatype : theory -> string -> ((string * sort) list * (string * typ list) list) option
    2.15 -  val get_datatype_case_consts : theory -> string list
    2.16 -  val get_case_const_data : theory -> string -> (string * int) list option
    2.17 -  val get_all_datatype_cons : theory -> (string * string) list
    2.18 -  val get_eq_equations: theory -> string -> thm list
    2.19 -  val constrs_of : theory -> string -> term list option
    2.20 -  val case_const_of : theory -> string -> term option
    2.21 -  val weak_case_congs_of : theory -> thm list
    2.22    val setup: theory -> theory
    2.23  end;
    2.24  
    2.25 @@ -109,43 +103,41 @@
    2.26  
    2.27  (** theory information about datatypes **)
    2.28  
    2.29 -val datatype_info = Symtab.lookup o get_datatypes;
    2.30 +val get_datatype = Symtab.lookup o get_datatypes;
    2.31  
    2.32 -fun datatype_info_err thy name = (case datatype_info thy name of
    2.33 +fun the_datatype thy name = (case get_datatype thy name of
    2.34        SOME info => info
    2.35      | NONE => error ("Unknown datatype " ^ quote name));
    2.36  
    2.37 -fun constrs_of thy tname = (case datatype_info thy tname of
    2.38 -   SOME {index, descr, ...} =>
    2.39 -     let val (_, _, constrs) = valOf (AList.lookup (op =) descr index)
    2.40 -     in SOME (map (fn (cname, _) => Const (cname, Sign.the_const_type thy cname)) constrs)
    2.41 -     end
    2.42 - | _ => NONE);
    2.43 +fun get_datatype_descr thy dtco =
    2.44 +  get_datatype thy dtco
    2.45 +  |> Option.map (fn info as { descr, index, ... } => 
    2.46 +       (info, (((fn SOME (_, dtys, cos) => (dtys, cos)) o AList.lookup (op =) descr) index)));
    2.47  
    2.48 -fun case_const_of thy tname = (case datatype_info thy tname of
    2.49 -   SOME {case_name, ...} => SOME (Const (case_name, Sign.the_const_type thy case_name))
    2.50 - | _ => NONE);
    2.51 -
    2.52 -val weak_case_congs_of = map (#weak_case_cong o #2) o Symtab.dest o get_datatypes;
    2.53 -
    2.54 -fun get_datatype thy dtco =
    2.55 +fun get_datatype_spec thy dtco =
    2.56    let
    2.57 -    fun get_cons descr vs =
    2.58 -      apsnd (map (DatatypeAux.typ_of_dtyp descr
    2.59 -        ((map (rpair []) o map DatatypeAux.dest_DtTFree) vs)));
    2.60 -    fun get_info ({ sorts, descr, ... } : DatatypeAux.datatype_info) =
    2.61 -      (sorts,
    2.62 -        ((the oo get_first) (fn (_, (dtco', tys, cs)) =>
    2.63 -            if dtco = dtco'
    2.64 -            then SOME (map (get_cons descr tys) cs)
    2.65 -            else NONE) descr));
    2.66 -  in case Symtab.lookup (get_datatypes thy) dtco
    2.67 -   of SOME info => (SOME o get_info) info
    2.68 -    | NONE => NONE
    2.69 -  end;
    2.70 +    fun mk_cons typ_of_dtyp (co, tys) =
    2.71 +      (co, map typ_of_dtyp tys);
    2.72 +    fun mk_dtyp ({ sorts = raw_sorts, descr, ... } : DatatypeAux.datatype_info, (dtys, cos)) =
    2.73 +      let
    2.74 +        val sorts = map ((fn v => (v, (the o AList.lookup (op =) raw_sorts) v))
    2.75 +          o DatatypeAux.dest_DtTFree) dtys;
    2.76 +        val typ_of_dtyp = DatatypeAux.typ_of_dtyp descr sorts;
    2.77 +        val tys = map typ_of_dtyp dtys;
    2.78 +      in (sorts, map (mk_cons typ_of_dtyp) cos) end;
    2.79 +  in Option.map mk_dtyp (get_datatype_descr thy dtco) end;
    2.80  
    2.81 -fun get_datatype_case_consts thy =
    2.82 -  Symtab.fold (fn (_, {case_name, ...}) => cons case_name) (get_datatypes thy) [];
    2.83 +fun get_datatype_constrs thy dtco =
    2.84 +  case get_datatype_spec thy dtco
    2.85 +   of SOME (sorts, cos) =>
    2.86 +        let
    2.87 +          fun subst (v, sort) = TVar ((v, 0), sort);
    2.88 +          fun subst_ty (TFree v) = subst v
    2.89 +            | subst_ty ty = ty;
    2.90 +          val dty = Type (dtco, map subst sorts);
    2.91 +          fun mk_co (co, tys) = (co, map (Term.map_atyps subst_ty) tys ---> dty);
    2.92 +        in SOME (map mk_co cos) end
    2.93 +    | NONE => NONE;
    2.94  
    2.95  fun get_case_const_data thy c =
    2.96    case find_first (fn (_, {index, descr, case_name, ...}) =>
    2.97 @@ -155,37 +147,6 @@
    2.98      | SOME (_, {index, descr, ...}) =>
    2.99          (SOME o map (apsnd length) o #3 o the o AList.lookup (op =) descr) index;
   2.100  
   2.101 -fun get_all_datatype_cons thy =
   2.102 -  Symtab.fold (fn (dtco, _) => fold
   2.103 -    (fn (co, _) => cons (co, dtco))
   2.104 -      ((snd o the oo get_datatype) thy dtco)) (get_datatypes thy) [];
   2.105 -
   2.106 -fun get_eq_equations thy dtco =
   2.107 -  case get_datatype thy dtco
   2.108 -   of SOME (vars, cos) =>
   2.109 -        let
   2.110 -          fun co_inject thm =
   2.111 -            ((fst o dest_Const o fst o strip_comb o fst o HOLogic.dest_eq o fst
   2.112 -              o HOLogic.dest_eq o HOLogic.dest_Trueprop o Thm.prop_of) thm, thm RS HOL.eq_reflection);
   2.113 -          val inject = (map co_inject o #inject o the o datatype_info thy) dtco;
   2.114 -          fun mk_refl co =
   2.115 -            let
   2.116 -              fun infer t =
   2.117 -                (fst o Sign.infer_types (Sign.pp thy) thy (Sign.consts_of thy) (K NONE) (K NONE) [] true)
   2.118 -                  ([t], Type (dtco, map (fn (v, sort) => TVar ((v, 0), sort)) vars))
   2.119 -              val t = (Thm.cterm_of thy o infer) (Const (co, dummyT));
   2.120 -            in
   2.121 -              HOL.refl 
   2.122 -              |> Drule.instantiate' [(SOME o Thm.ctyp_of_term) t] [SOME t]
   2.123 -              |> (fn thm => thm RS Eq_TrueI)
   2.124 -            end;
   2.125 -          fun get_eq co =
   2.126 -           case AList.lookup (op =) inject co
   2.127 -            of SOME eq => eq
   2.128 -             | NONE => mk_refl co;
   2.129 -        in map (get_eq o fst) cos end
   2.130 -   | NONE => [];
   2.131 -
   2.132  fun find_tname var Bi =
   2.133    let val frees = map dest_Free (term_frees Bi)
   2.134        val params = rename_wrt_term Bi (Logic.strip_params Bi);
   2.135 @@ -243,7 +204,7 @@
   2.136  	| NONE =>
   2.137  	    let val tn = find_tname (hd (List.mapPartial I (List.concat varss))) Bi
   2.138                  val {sign, ...} = Thm.rep_thm state
   2.139 -	    in (#induction (datatype_info_err sign tn), "Induction rule for type " ^ tn) 
   2.140 +	    in (#induction (the_datatype sign tn), "Induction rule for type " ^ tn) 
   2.141  	    end
   2.142      val concls = HOLogic.dest_concls (Thm.concl_of rule);
   2.143      val insts = List.concat (map prep_inst (concls ~~ varss)) handle UnequalLengths =>
   2.144 @@ -276,7 +237,7 @@
   2.145        let val tn = infer_tname state i t in
   2.146          if tn = HOLogic.boolN then inst_tac [(("P", 0), t)] case_split_thm i state
   2.147          else case_inst_tac inst_tac t
   2.148 -               (#exhaustion (datatype_info_err (Thm.sign_of_thm state) tn))
   2.149 +               (#exhaustion (the_datatype (Thm.sign_of_thm state) tn))
   2.150                 i state
   2.151        end handle THM _ => Seq.empty;
   2.152  
   2.153 @@ -401,8 +362,8 @@
   2.154           (case (stripT (0, T1), stripT (0, T2)) of
   2.155              ((i', Type (tname1, _)), (j', Type (tname2, _))) =>
   2.156                  if tname1 = tname2 andalso not (cname1 = cname2) andalso i = i' andalso j = j' then
   2.157 -                   (case (constrs_of sg tname1) of
   2.158 -                      SOME constrs => let val cnames = map (fst o dest_Const) constrs
   2.159 +                   (case (get_datatype_descr sg) tname1 of
   2.160 +                      SOME (_, (_, constrs)) => let val cnames = map fst constrs
   2.161                          in if cname1 mem cnames andalso cname2 mem cnames then
   2.162                               let val eq_t = Logic.mk_equals (t, Const ("False", HOLogic.boolT));
   2.163                                   val eq_ct = cterm_of sg eq_t;
   2.164 @@ -410,7 +371,7 @@
   2.165                                   val [In0_inject, In1_inject, In0_not_In1, In1_not_In0] =
   2.166                                     map (get_thm Datatype_thy o Name)
   2.167                                       ["In0_inject", "In1_inject", "In0_not_In1", "In1_not_In0"]
   2.168 -                             in (case (#distinct (datatype_info_err sg tname1)) of
   2.169 +                             in (case (#distinct (the_datatype sg tname1)) of
   2.170                                   QuickAndDirty => SOME (Thm.invoke_oracle
   2.171                                     Datatype_thy distinctN (sg, ConstrDistinct eq_t))
   2.172                                 | FewConstrs thms => SOME (Goal.prove sg [] [] eq_t (K
     3.1 --- a/src/HOL/Tools/datatype_rep_proofs.ML	Thu Apr 06 16:10:22 2006 +0200
     3.2 +++ b/src/HOL/Tools/datatype_rep_proofs.ML	Thu Apr 06 16:10:46 2006 +0200
     3.3 @@ -184,7 +184,7 @@
     3.4          (TypedefPackage.add_typedef_i false (SOME name') (name, tvs, mx) c NONE
     3.5            (rtac exI 1 THEN
     3.6              QUIET_BREADTH_FIRST (has_fewer_prems 1)
     3.7 -            (resolve_tac rep_intrs 1))) thy |> #1)
     3.8 +            (resolve_tac rep_intrs 1))) thy |> snd)
     3.9                (parent_path flat_names thy2, types_syntax ~~ tyvars ~~
    3.10                  (Library.take (length newTs, consts)) ~~ new_type_names));
    3.11  
     4.1 --- a/src/HOL/Tools/refute.ML	Thu Apr 06 16:10:22 2006 +0200
     4.2 +++ b/src/HOL/Tools/refute.ML	Thu Apr 06 16:10:46 2006 +0200
     4.3 @@ -554,7 +554,7 @@
     4.4  						     | MATCH           => get_typedefn axms
     4.5  						     | Type.TYPE_MATCH => get_typedefn axms)
     4.6  				in
     4.7 -					case DatatypePackage.datatype_info thy s of
     4.8 +					case DatatypePackage.get_datatype thy s of
     4.9  					  SOME info =>  (* inductive datatype *)
    4.10  							(* only collect relevant type axioms for the argument types *)
    4.11  							Library.foldl collect_type_axioms (axs, Ts)
    4.12 @@ -664,14 +664,10 @@
    4.13  					fun is_IDT_constructor () =
    4.14  						(case body_type T of
    4.15  						  Type (s', _) =>
    4.16 -							(case DatatypePackage.constrs_of thy s' of
    4.17 +							(case DatatypePackage.get_datatype_constrs thy s' of
    4.18  							  SOME constrs =>
    4.19 -								Library.exists (fn c =>
    4.20 -									(case c of
    4.21 -									  Const (cname, ctype) =>
    4.22 -										cname = s andalso Sign.typ_instance thy (T, ctype)
    4.23 -									| _ =>
    4.24 -										raise REFUTE ("collect_axioms", "IDT constructor is not a constant")))
    4.25 +								Library.exists (fn (cname, cty) =>
    4.26 +								cname = s andalso Sign.typ_instance thy (T, cty))
    4.27  									constrs
    4.28  							| NONE =>
    4.29  								false)
    4.30 @@ -773,7 +769,7 @@
    4.31  				| Type ("prop", [])      => acc
    4.32  				| Type ("set", [T1])     => collect_types (T1, acc)
    4.33  				| Type (s, Ts)           =>
    4.34 -					(case DatatypePackage.datatype_info thy s of
    4.35 +					(case DatatypePackage.get_datatype thy s of
    4.36  					  SOME info =>  (* inductive datatype *)
    4.37  						let
    4.38  							val index               = #index info
    4.39 @@ -944,7 +940,7 @@
    4.40  			(* TODO: no warning needed for /positive/ occurrences of IDTs       *)
    4.41  			val _ = if Library.exists (fn
    4.42  				  Type (s, _) =>
    4.43 -					(case DatatypePackage.datatype_info thy s of
    4.44 +					(case DatatypePackage.get_datatype thy s of
    4.45  					  SOME info =>  (* inductive datatype *)
    4.46  						let
    4.47  							val index           = #index info
    4.48 @@ -1647,7 +1643,7 @@
    4.49  		val (typs, terms) = model
    4.50  		(* Term.typ -> (interpretation * model * arguments) option *)
    4.51  		fun interpret_term (Type (s, Ts)) =
    4.52 -			(case DatatypePackage.datatype_info thy s of
    4.53 +			(case DatatypePackage.get_datatype thy s of
    4.54  			  SOME info =>  (* inductive datatype *)
    4.55  				let
    4.56  					(* int option -- only recursive IDTs have an associated depth *)
    4.57 @@ -1723,7 +1719,7 @@
    4.58  			  Const (s, T) =>
    4.59  				(case body_type T of
    4.60  				  Type (s', Ts') =>
    4.61 -					(case DatatypePackage.datatype_info thy s' of
    4.62 +					(case DatatypePackage.get_datatype thy s' of
    4.63  					  SOME info =>  (* body type is an inductive datatype *)
    4.64  						let
    4.65  							val index               = #index info
    4.66 @@ -2511,7 +2507,7 @@
    4.67  	in
    4.68  		case typeof t of
    4.69  		  SOME (Type (s, Ts)) =>
    4.70 -			(case DatatypePackage.datatype_info thy s of
    4.71 +			(case DatatypePackage.get_datatype thy s of
    4.72  			  SOME info =>  (* inductive datatype *)
    4.73  				let
    4.74  					val (typs, _)           = model