1 (* Title: HOL/Tools/datatype_aux.ML
3 Author: Stefan Berghofer, TU Muenchen
5 Auxiliary functions for defining datatypes.
8 signature DATATYPE_AUX =
10 val quiet_mode : bool ref
11 val message : string -> unit
13 val foldl1 : ('a * 'a -> 'a) -> 'a list -> 'a
15 val add_path : bool -> string -> theory -> theory
16 val parent_path : bool -> theory -> theory
18 val store_thmss_atts : string -> string list -> attribute list list -> thm list list
19 -> theory -> thm list list * theory
20 val store_thmss : string -> string list -> thm list list -> theory -> thm list list * theory
21 val store_thms_atts : string -> string list -> attribute list list -> thm list
22 -> theory -> thm list * theory
23 val store_thms : string -> string list -> thm list -> theory -> thm list * theory
25 val split_conj_thm : thm -> thm list
26 val mk_conj : term list -> term
27 val mk_disj : term list -> term
29 val app_bnds : term -> int -> term
31 val cong_tac : int -> tactic
32 val indtac : thm -> int -> tactic
33 val exh_tac : (string -> thm) -> int -> tactic
35 datatype simproc_dist = QuickAndDirty
36 | FewConstrs of thm list
37 | ManyConstrs of thm * simpset;
41 | DtType of string * (dtyp list)
47 exception Datatype_Empty of string
48 val name_of_typ : typ -> string
49 val dtyp_of_typ : (string * string list) list -> typ -> dtyp
50 val mk_Free : string -> typ -> int -> term
51 val is_rec_type : dtyp -> bool
52 val typ_of_dtyp : descr -> (string * sort) list -> dtyp -> typ
53 val dest_DtTFree : dtyp -> string
54 val dest_DtRec : dtyp -> int
55 val strip_dtyp : dtyp -> dtyp list * dtyp
56 val body_index : dtyp -> int
57 val mk_fun_dtyp : dtyp list -> dtyp -> dtyp
58 val dest_TFree : typ -> string
59 val get_nonrec_types : descr -> (string * sort) list -> typ list
60 val get_branching_types : descr -> (string * sort) list -> typ list
61 val get_arities : descr -> int list
62 val get_rec_types : descr -> (string * sort) list -> typ list
63 val check_nonempty : descr list -> unit
64 val unfold_datatypes :
65 Sign.sg -> descr -> (string * sort) list -> datatype_info Symtab.table ->
66 descr -> int -> descr list * int
69 structure DatatypeAux : DATATYPE_AUX =
72 val quiet_mode = ref false;
73 fun message s = if !quiet_mode then () else writeln s;
75 (* FIXME: move to library ? *)
76 fun foldl1 f (x::xs) = Library.foldl f (x, xs);
78 fun add_path flat_names s = if flat_names then I else Theory.add_path s;
79 fun parent_path flat_names = if flat_names then I else Theory.parent_path;
82 (* store theorems in theory *)
84 fun store_thmss_atts label tnames attss thmss =
85 fold_map (fn ((tname, atts), thms) =>
87 #> PureThy.add_thmss [((label, thms), atts)]
88 #-> (fn thm::_ => Theory.parent_path #> pair thm)
89 ) (tnames ~~ attss ~~ thmss);
91 fun store_thmss label tnames = store_thmss_atts label tnames (replicate (length tnames) []);
93 fun store_thms_atts label tnames attss thmss =
94 fold_map (fn ((tname, atts), thms) =>
96 #> PureThy.add_thms [((label, thms), atts)]
97 #-> (fn thm::_ => Theory.parent_path #> pair thm)
98 ) (tnames ~~ attss ~~ thmss);
100 fun store_thms label tnames = store_thms_atts label tnames (replicate (length tnames) []);
103 (* split theorem thm_1 & ... & thm_n into n theorems *)
105 fun split_conj_thm th =
106 ((th RS conjunct1)::(split_conj_thm (th RS conjunct2))) handle THM _ => [th];
108 val mk_conj = foldr1 (HOLogic.mk_binop "op &");
109 val mk_disj = foldr1 (HOLogic.mk_binop "op |");
111 fun app_bnds t i = list_comb (t, map Bound (i - 1 downto 0));
114 fun cong_tac i st = (case Logic.strip_assums_concl
115 (List.nth (prems_of st, i - 1)) of
116 _ $ (_ $ (f $ x) $ (g $ y)) =>
118 val cong' = Thm.lift_rule (Thm.cprem_of st i) cong;
119 val _ $ (_ $ (f' $ x') $ (g' $ y')) =
120 Logic.strip_assums_concl (prop_of cong');
121 val insts = map (pairself (cterm_of (#sign (rep_thm st))) o
122 apsnd (curry list_abs (Logic.strip_params (concl_of cong'))) o
123 apfst head_of) [(f', f), (g', g), (x', x), (y', y)]
124 in compose_tac (false, cterm_instantiate insts cong', 2) i st
125 handle THM _ => no_tac st
129 (* instantiate induction rule *)
131 fun indtac indrule i st =
133 val ts = HOLogic.dest_conj (HOLogic.dest_Trueprop (concl_of indrule));
134 val ts' = HOLogic.dest_conj (HOLogic.dest_Trueprop
135 (Logic.strip_imp_concl (List.nth (prems_of st, i - 1))));
136 val getP = if can HOLogic.dest_imp (hd ts) then
137 (apfst SOME) o HOLogic.dest_imp else pair NONE;
138 fun abstr (t1, t2) = (case t1 of
139 NONE => let val [Free (s, T)] = add_term_frees (t2, [])
140 in absfree (s, T, t2) end
141 | SOME (_ $ t' $ _) => Abs ("x", fastype_of t', abstract_over (t', t2)))
142 val cert = cterm_of (Thm.sign_of_thm st);
143 val Ps = map (cert o head_of o snd o getP) ts;
144 val indrule' = cterm_instantiate (Ps ~~
145 (map (cert o abstr o getP) ts')) indrule
150 (* perform exhaustive case analysis on last parameter of subgoal i *)
152 fun exh_tac exh_thm_of i state =
154 val sg = Thm.sign_of_thm state;
155 val prem = List.nth (prems_of state, i - 1);
156 val params = Logic.strip_params prem;
157 val (_, Type (tname, _)) = hd (rev params);
158 val exhaustion = Thm.lift_rule (Thm.cprem_of state i) (exh_thm_of tname);
159 val prem' = hd (prems_of exhaustion);
160 val _ $ (_ $ lhs $ _) = hd (rev (Logic.strip_assums_hyp prem'));
161 val exhaustion' = cterm_instantiate [(cterm_of sg (head_of lhs),
162 cterm_of sg (foldr (fn ((_, T), t) => Abs ("z", T, t))
163 (Bound 0) params))] exhaustion
164 in compose_tac (false, exhaustion', nprems_of exhaustion) i state
167 (* handling of distinctness theorems *)
169 datatype simproc_dist = QuickAndDirty
170 | FewConstrs of thm list
171 | ManyConstrs of thm * simpset;
173 (********************** Internal description of datatypes *********************)
177 | DtType of string * (dtyp list)
180 (* information about datatypes *)
182 (* index, datatype name, type arguments, constructor name, types of constructor's arguments *)
183 type descr = (int * (string * dtyp list * (string * dtyp list) list)) list;
188 sorts : (string * sort) list,
189 rec_names : string list,
190 rec_rewrites : thm list,
192 case_rewrites : thm list,
195 distinct : simproc_dist,
199 weak_case_cong : thm};
201 fun mk_Free s T i = Free (s ^ (string_of_int i), T);
203 fun subst_DtTFree _ substs (T as (DtTFree name)) =
204 AList.lookup (op =) substs name |> the_default T
205 | subst_DtTFree i substs (DtType (name, ts)) =
206 DtType (name, map (subst_DtTFree i substs) ts)
207 | subst_DtTFree i _ (DtRec j) = DtRec (i + j);
210 exception Datatype_Empty of string;
212 fun dest_DtTFree (DtTFree a) = a
213 | dest_DtTFree _ = raise Datatype;
215 fun dest_DtRec (DtRec i) = i
216 | dest_DtRec _ = raise Datatype;
218 fun is_rec_type (DtType (_, dts)) = exists is_rec_type dts
219 | is_rec_type (DtRec _) = true
220 | is_rec_type _ = false;
222 fun strip_dtyp (DtType ("fun", [T, U])) = apfst (cons T) (strip_dtyp U)
223 | strip_dtyp T = ([], T);
225 val body_index = dest_DtRec o snd o strip_dtyp;
227 fun mk_fun_dtyp [] U = U
228 | mk_fun_dtyp (T :: Ts) U = DtType ("fun", [T, mk_fun_dtyp Ts U]);
230 fun dest_TFree (TFree (n, _)) = n;
232 fun name_of_typ (Type (s, Ts)) =
233 let val s' = Sign.base_name s
234 in space_implode "_" (List.filter (not o equal "") (map name_of_typ Ts) @
235 [if Syntax.is_identifier s' then s' else "x"])
237 | name_of_typ _ = "";
239 fun dtyp_of_typ _ (TFree (n, _)) = DtTFree n
240 | dtyp_of_typ _ (TVar _) = error "Illegal schematic type variable(s)"
241 | dtyp_of_typ new_dts (Type (tname, Ts)) =
242 (case AList.lookup (op =) new_dts tname of
243 NONE => DtType (tname, map (dtyp_of_typ new_dts) Ts)
244 | SOME vs => if map (try dest_TFree) Ts = map SOME vs then
245 DtRec (find_index (curry op = tname o fst) new_dts)
246 else error ("Illegal occurrence of recursive type " ^ tname));
248 fun typ_of_dtyp descr sorts (DtTFree a) = TFree (a, (the o AList.lookup (op =) sorts) a)
249 | typ_of_dtyp descr sorts (DtRec i) =
250 let val (s, ds, _) = (the o AList.lookup (op =) descr) i
251 in Type (s, map (typ_of_dtyp descr sorts) ds) end
252 | typ_of_dtyp descr sorts (DtType (s, ds)) =
253 Type (s, map (typ_of_dtyp descr sorts) ds);
255 (* find all non-recursive types in datatype description *)
257 fun get_nonrec_types descr sorts =
258 map (typ_of_dtyp descr sorts) (Library.foldl (fn (Ts, (_, (_, _, constrs))) =>
259 Library.foldl (fn (Ts', (_, cargs)) =>
260 filter_out is_rec_type cargs union Ts') (Ts, constrs)) ([], descr));
262 (* get all recursive types in datatype description *)
264 fun get_rec_types descr sorts = map (fn (_ , (s, ds, _)) =>
265 Type (s, map (typ_of_dtyp descr sorts) ds)) descr;
267 (* get all branching types *)
269 fun get_branching_types descr sorts =
270 map (typ_of_dtyp descr sorts) (Library.foldl (fn (Ts, (_, (_, _, constrs))) =>
271 Library.foldl (fn (Ts', (_, cargs)) => foldr op union Ts' (map (fst o strip_dtyp)
272 cargs)) (Ts, constrs)) ([], descr));
274 fun get_arities descr = Library.foldl (fn (is, (_, (_, _, constrs))) =>
275 Library.foldl (fn (is', (_, cargs)) => map (length o fst o strip_dtyp)
276 (List.filter is_rec_type cargs) union is') (is, constrs)) ([], descr);
278 (* nonemptiness check for datatypes *)
280 fun check_nonempty descr =
282 val descr' = List.concat descr;
283 fun is_nonempty_dt is i =
285 val (_, _, constrs) = (the o AList.lookup (op =) descr') i;
286 fun arg_nonempty (_, DtRec i) = if i mem is then false
287 else is_nonempty_dt (i::is) i
288 | arg_nonempty _ = true;
289 in exists ((forall (arg_nonempty o strip_dtyp)) o snd) constrs
291 in assert_all (fn (i, _) => is_nonempty_dt [i] i) (hd descr)
292 (fn (_, (s, _, _)) => raise Datatype_Empty s)
295 (* unfold a list of mutually recursive datatype specifications *)
296 (* all types of the form DtType (dt_name, [..., DtRec _, ...]) *)
297 (* need to be unfolded *)
299 fun unfold_datatypes sign orig_descr sorts (dt_info : datatype_info Symtab.table) descr i =
301 fun typ_error T msg = error ("Non-admissible type expression\n" ^
302 Sign.string_of_typ sign (typ_of_dtyp (orig_descr @ descr) sorts T) ^ "\n" ^ msg);
304 fun get_dt_descr T i tname dts =
305 (case Symtab.lookup dt_info tname of
306 NONE => typ_error T (tname ^ " is not a datatype - can't use it in\
308 | (SOME {index, descr, ...}) =>
309 let val (_, vars, _) = (the o AList.lookup (op =) descr) index;
310 val subst = ((map dest_DtTFree vars) ~~ dts) handle UnequalLengths =>
311 typ_error T ("Type constructor " ^ tname ^ " used with wrong\
312 \ number of arguments")
313 in (i + index, map (fn (j, (tn, args, cs)) => (i + j,
314 (tn, map (subst_DtTFree i subst) args,
315 map (apsnd (map (subst_DtTFree i subst))) cs))) descr)
318 (* unfold a single constructor argument *)
320 fun unfold_arg ((i, Ts, descrs), T) =
321 if is_rec_type T then
322 let val (Us, U) = strip_dtyp T
323 in if exists is_rec_type Us then
324 typ_error T "Non-strictly positive recursive occurrence of type"
326 DtType (tname, dts) =>
328 val (index, descr) = get_dt_descr T i tname dts;
329 val (descr', i') = unfold_datatypes sign orig_descr sorts
330 dt_info descr (i + length descr)
331 in (i', Ts @ [mk_fun_dtyp Us (DtRec index)], descrs @ descr') end
332 | _ => (i, Ts @ [T], descrs))
334 else (i, Ts @ [T], descrs);
336 (* unfold a constructor *)
338 fun unfold_constr ((i, constrs, descrs), (cname, cargs)) =
339 let val (i', cargs', descrs') = Library.foldl unfold_arg ((i, [], descrs), cargs)
340 in (i', constrs @ [(cname, cargs')], descrs') end;
342 (* unfold a single datatype *)
344 fun unfold_datatype ((i, dtypes, descrs), (j, (tname, tvars, constrs))) =
345 let val (i', constrs', descrs') =
346 Library.foldl unfold_constr ((i, [], descrs), constrs)
347 in (i', dtypes @ [(j, (tname, tvars, constrs'))], descrs')
350 val (i', descr', descrs) = Library.foldl unfold_datatype ((i, [],[]), descr);
352 in (descr' :: descrs, i') end;