|
1 (* Title: HOL/Tools/datatype_case.ML |
|
2 ID: $Id$ |
|
3 Author: Konrad Slind, Cambridge University Computer Laboratory |
|
4 Stefan Berghofer, TU Muenchen |
|
5 |
|
6 Nested case expressions on datatypes. |
|
7 *) |
|
8 |
|
9 signature DATATYPE_CASE = |
|
10 sig |
|
11 val make_case: (string -> DatatypeAux.datatype_info option) -> |
|
12 Proof.context -> bool -> string list -> term -> (term * term) list -> |
|
13 term * (term * (int * bool)) list |
|
14 val dest_case: (string -> DatatypeAux.datatype_info option) -> bool -> |
|
15 string list -> term -> (term * (term * term) list) option |
|
16 val strip_case: (string -> DatatypeAux.datatype_info option) -> bool -> |
|
17 term -> (term * (term * term) list) option |
|
18 val case_tr: (theory -> string -> DatatypeAux.datatype_info option) -> |
|
19 Proof.context -> term list -> term |
|
20 val case_tr': (theory -> string -> DatatypeAux.datatype_info option) -> |
|
21 string -> Proof.context -> term list -> term |
|
22 end; |
|
23 |
|
24 structure DatatypeCase : DATATYPE_CASE = |
|
25 struct |
|
26 |
|
27 exception CASE_ERROR of string * int; |
|
28 |
|
29 fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty; |
|
30 |
|
31 (*--------------------------------------------------------------------------- |
|
32 * Get information about datatypes |
|
33 *---------------------------------------------------------------------------*) |
|
34 |
|
35 fun ty_info (tab : string -> DatatypeAux.datatype_info option) s = |
|
36 case tab s of |
|
37 SOME {descr, case_name, index, sorts, ...} => |
|
38 let |
|
39 val (_, (tname, dts, constrs)) = nth descr index; |
|
40 val mk_ty = DatatypeAux.typ_of_dtyp descr sorts; |
|
41 val T = Type (tname, map mk_ty dts) |
|
42 in |
|
43 SOME {case_name = case_name, |
|
44 constructors = map (fn (cname, dts') => |
|
45 Const (cname, Logic.varifyT (map mk_ty dts' ---> T))) constrs} |
|
46 end |
|
47 | NONE => NONE; |
|
48 |
|
49 |
|
50 (*--------------------------------------------------------------------------- |
|
51 * Each pattern carries with it a tag (i,b) where |
|
52 * i is the clause it came from and |
|
53 * b=true indicates that clause was given by the user |
|
54 * (or is an instantiation of a user supplied pattern) |
|
55 * b=false --> i = ~1 |
|
56 *---------------------------------------------------------------------------*) |
|
57 |
|
58 fun pattern_map f (tm,x) = (f tm, x); |
|
59 |
|
60 fun pattern_subst theta = pattern_map (subst_free theta); |
|
61 |
|
62 fun row_of_pat x = fst (snd x); |
|
63 |
|
64 fun add_row_used ((prfx, pats), (tm, tag)) used = |
|
65 foldl add_term_free_names (foldl add_term_free_names |
|
66 (add_term_free_names (tm, used)) pats) prfx; |
|
67 |
|
68 (* try to preserve names given by user *) |
|
69 fun default_names names ts = |
|
70 map (fn ("", Free (name', _)) => name' | (name, _) => name) (names ~~ ts); |
|
71 |
|
72 fun strip_constraints (Const ("_constrain", _) $ t $ tT) = |
|
73 strip_constraints t ||> cons tT |
|
74 | strip_constraints t = (t, []); |
|
75 |
|
76 fun mk_fun_constrain tT t = Syntax.const "_constrain" $ t $ |
|
77 (Syntax.free "fun" $ tT $ Syntax.free "dummy"); |
|
78 |
|
79 |
|
80 (*--------------------------------------------------------------------------- |
|
81 * Produce an instance of a constructor, plus genvars for its arguments. |
|
82 *---------------------------------------------------------------------------*) |
|
83 fun fresh_constr ty_match ty_inst colty used c = |
|
84 let |
|
85 val (_, Ty) = dest_Const c |
|
86 val Ts = binder_types Ty; |
|
87 val names = Name.variant_list used |
|
88 (DatatypeProp.make_tnames (map Logic.unvarifyT Ts)); |
|
89 val ty = body_type Ty; |
|
90 val ty_theta = ty_match ty colty handle Type.TYPE_MATCH => |
|
91 raise CASE_ERROR ("type mismatch", ~1) |
|
92 val c' = ty_inst ty_theta c |
|
93 val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts) |
|
94 in (c', gvars) |
|
95 end; |
|
96 |
|
97 |
|
98 (*--------------------------------------------------------------------------- |
|
99 * Goes through a list of rows and picks out the ones beginning with a |
|
100 * pattern with constructor = name. |
|
101 *---------------------------------------------------------------------------*) |
|
102 fun mk_group (name, T) rows = |
|
103 let val k = length (binder_types T) |
|
104 in fold (fn (row as ((prfx, p :: rst), rhs as (_, (i, _)))) => |
|
105 fn ((in_group, not_in_group), (names, cnstrts)) => (case strip_comb p of |
|
106 (Const (name', _), args) => |
|
107 if name = name' then |
|
108 if length args = k then |
|
109 let val (args', cnstrts') = split_list (map strip_constraints args) |
|
110 in |
|
111 ((((prfx, args' @ rst), rhs) :: in_group, not_in_group), |
|
112 (default_names names args', map2 append cnstrts cnstrts')) |
|
113 end |
|
114 else raise CASE_ERROR |
|
115 ("Wrong number of arguments for constructor " ^ name, i) |
|
116 else ((in_group, row :: not_in_group), (names, cnstrts)) |
|
117 | _ => raise CASE_ERROR ("Not a constructor pattern", i))) |
|
118 rows (([], []), (replicate k "", replicate k [])) |>> pairself rev |
|
119 end; |
|
120 |
|
121 (*--------------------------------------------------------------------------- |
|
122 * Partition the rows. Not efficient: we should use hashing. |
|
123 *---------------------------------------------------------------------------*) |
|
124 fun partition _ _ _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1) |
|
125 | partition ty_match ty_inst type_of used constructors colty res_ty |
|
126 (rows as (((prfx, _ :: rstp), _) :: _)) = |
|
127 let |
|
128 fun part {constrs = [], rows = [], A} = rev A |
|
129 | part {constrs = [], rows = (_, (_, (i, _))) :: _, A} = |
|
130 raise CASE_ERROR ("Not a constructor pattern", i) |
|
131 | part {constrs = c :: crst, rows, A} = |
|
132 let |
|
133 val ((in_group, not_in_group), (names, cnstrts)) = |
|
134 mk_group (dest_Const c) rows; |
|
135 val used' = fold add_row_used in_group used; |
|
136 val (c', gvars) = fresh_constr ty_match ty_inst colty used' c; |
|
137 val in_group' = |
|
138 if null in_group (* Constructor not given *) |
|
139 then |
|
140 let |
|
141 val Ts = map type_of rstp; |
|
142 val xs = Name.variant_list |
|
143 (foldl add_term_free_names used' gvars) |
|
144 (replicate (length rstp) "x") |
|
145 in |
|
146 [((prfx, gvars @ map Free (xs ~~ Ts)), |
|
147 (Const ("HOL.undefined", res_ty), (~1, false)))] |
|
148 end |
|
149 else in_group |
|
150 in |
|
151 part{constrs = crst, |
|
152 rows = not_in_group, |
|
153 A = {constructor = c', |
|
154 new_formals = gvars, |
|
155 names = names, |
|
156 constraints = cnstrts, |
|
157 group = in_group'} :: A} |
|
158 end |
|
159 in part {constrs = constructors, rows = rows, A = []} |
|
160 end; |
|
161 |
|
162 (*--------------------------------------------------------------------------- |
|
163 * Misc. routines used in mk_case |
|
164 *---------------------------------------------------------------------------*) |
|
165 |
|
166 fun mk_pat ((c, c'), l) = |
|
167 let |
|
168 val L = length (binder_types (fastype_of c)) |
|
169 fun build (prfx, tag, plist) = |
|
170 let val (args, plist') = chop L plist |
|
171 in (prfx, tag, list_comb (c', args) :: plist') end |
|
172 in map build l end; |
|
173 |
|
174 fun v_to_prfx (prfx, v::pats) = (v::prfx,pats) |
|
175 | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1); |
|
176 |
|
177 fun v_to_pats (v::prfx,tag, pats) = (prfx, tag, v::pats) |
|
178 | v_to_pats _ = raise CASE_ERROR ("mk_case: v_to_pats", ~1); |
|
179 |
|
180 |
|
181 (*---------------------------------------------------------------------------- |
|
182 * Translation of pattern terms into nested case expressions. |
|
183 * |
|
184 * This performs the translation and also builds the full set of patterns. |
|
185 * Thus it supports the construction of induction theorems even when an |
|
186 * incomplete set of patterns is given. |
|
187 *---------------------------------------------------------------------------*) |
|
188 |
|
189 fun mk_case tab ctxt ty_match ty_inst type_of used range_ty = |
|
190 let |
|
191 val name = Name.variant used "a"; |
|
192 fun expand constructors used ty ((_, []), _) = |
|
193 raise CASE_ERROR ("mk_case: expand_var_row", ~1) |
|
194 | expand constructors used ty (row as ((prfx, p :: rst), rhs)) = |
|
195 if is_Free p then |
|
196 let |
|
197 val used' = add_row_used row used; |
|
198 fun expnd c = |
|
199 let val capp = |
|
200 list_comb (fresh_constr ty_match ty_inst ty used' c) |
|
201 in ((prfx, capp :: rst), pattern_subst [(p, capp)] rhs) |
|
202 end |
|
203 in map expnd constructors end |
|
204 else [row] |
|
205 fun mk {rows = [], ...} = raise CASE_ERROR ("no rows", ~1) |
|
206 | mk {path = [], rows = ((prfx, []), (tm, tag)) :: _} = (* Done *) |
|
207 ([(prfx, tag, [])], tm) |
|
208 | mk {path, rows as ((row as ((_, [Free _]), _)) :: _ :: _)} = |
|
209 mk {path = path, rows = [row]} |
|
210 | mk {path = u :: rstp, rows as ((_, _ :: _), _) :: _} = |
|
211 let val col0 = map (fn ((_, p :: _), (_, (i, _))) => (p, i)) rows |
|
212 in case Option.map (apfst head_of) |
|
213 (find_first (not o is_Free o fst) col0) of |
|
214 NONE => |
|
215 let |
|
216 val rows' = map (fn ((v, _), row) => row ||> |
|
217 pattern_subst [(v, u)] |>> v_to_prfx) (col0 ~~ rows); |
|
218 val (pref_patl, tm) = mk {path = rstp, rows = rows'} |
|
219 in (map v_to_pats pref_patl, tm) end |
|
220 | SOME (Const (cname, cT), i) => (case ty_info tab cname of |
|
221 NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i) |
|
222 | SOME {case_name, constructors} => |
|
223 let |
|
224 val pty = body_type cT; |
|
225 val used' = foldl add_term_free_names used rstp; |
|
226 val nrows = maps (expand constructors used' pty) rows; |
|
227 val subproblems = partition ty_match ty_inst type_of used' |
|
228 constructors pty range_ty nrows; |
|
229 val new_formals = map #new_formals subproblems |
|
230 val constructors' = map #constructor subproblems |
|
231 val news = map (fn {new_formals, group, ...} => |
|
232 {path = new_formals @ rstp, rows = group}) subproblems; |
|
233 val (pat_rect, dtrees) = split_list (map mk news); |
|
234 val case_functions = map2 |
|
235 (fn {new_formals, names, constraints, ...} => |
|
236 fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t => |
|
237 Abs (if s = "" then name else s, T, |
|
238 abstract_over (x, t)) |> |
|
239 fold mk_fun_constrain cnstrts) |
|
240 (new_formals ~~ names ~~ constraints)) |
|
241 subproblems dtrees; |
|
242 val types = map type_of (case_functions @ [u]); |
|
243 val case_const = Const (case_name, types ---> range_ty) |
|
244 val tree = list_comb (case_const, case_functions @ [u]) |
|
245 val pat_rect1 = flat (map mk_pat |
|
246 (constructors ~~ constructors' ~~ pat_rect)) |
|
247 in (pat_rect1, tree) |
|
248 end) |
|
249 | SOME (t, i) => raise CASE_ERROR ("Not a datatype constructor: " ^ |
|
250 ProofContext.string_of_term ctxt t, i) |
|
251 end |
|
252 | mk _ = raise CASE_ERROR ("Malformed row matrix", ~1) |
|
253 in mk |
|
254 end; |
|
255 |
|
256 fun case_error s = error ("Error in case expression:\n" ^ s); |
|
257 |
|
258 (* Repeated variable occurrences in a pattern are not allowed. *) |
|
259 fun no_repeat_vars ctxt pat = fold_aterms |
|
260 (fn x as Free (s, _) => (fn xs => |
|
261 if member op aconv xs x then |
|
262 case_error (quote s ^ " occurs repeatedly in the pattern " ^ |
|
263 quote (ProofContext.string_of_term ctxt pat)) |
|
264 else x :: xs) |
|
265 | _ => I) pat []; |
|
266 |
|
267 fun gen_make_case ty_match ty_inst type_of tab ctxt err used x clauses = |
|
268 let |
|
269 fun string_of_clause (pat, rhs) = ProofContext.string_of_term ctxt |
|
270 (Syntax.const "_case1" $ pat $ rhs); |
|
271 val _ = map (no_repeat_vars ctxt o fst) clauses; |
|
272 val rows = map_index (fn (i, (pat, rhs)) => |
|
273 (([], [pat]), (rhs, (i, true)))) clauses; |
|
274 val rangeT = (case distinct op = (map (type_of o snd) clauses) of |
|
275 [] => case_error "no clauses given" |
|
276 | [T] => T |
|
277 | _ => case_error "all cases must have the same result type"); |
|
278 val used' = fold add_row_used rows used; |
|
279 val (patts, case_tm) = mk_case tab ctxt ty_match ty_inst type_of |
|
280 used' rangeT {path = [x], rows = rows} |
|
281 handle CASE_ERROR (msg, i) => case_error (msg ^ |
|
282 (if i < 0 then "" |
|
283 else "\nIn clause\n" ^ string_of_clause (nth clauses i))); |
|
284 val patts1 = map |
|
285 (fn (_, tag, [pat]) => (pat, tag) |
|
286 | _ => case_error "error in pattern-match translation") patts; |
|
287 val patts2 = Library.sort (Library.int_ord o Library.pairself row_of_pat) patts1 |
|
288 val finals = map row_of_pat patts2 |
|
289 val originals = map (row_of_pat o #2) rows |
|
290 val _ = case originals \\ finals of |
|
291 [] => () |
|
292 | is => (if err then case_error else warning) |
|
293 ("The following clauses are redundant (covered by preceding clauses):\n" ^ |
|
294 space_implode "\n" (map (string_of_clause o nth clauses) is)); |
|
295 in |
|
296 (case_tm, patts2) |
|
297 end; |
|
298 |
|
299 fun make_case tab ctxt = gen_make_case |
|
300 (match_type (ProofContext.theory_of ctxt)) Envir.subst_TVars fastype_of tab ctxt; |
|
301 val make_case_untyped = gen_make_case (K (K Vartab.empty)) |
|
302 (K (Term.map_types (K dummyT))) (K dummyT); |
|
303 |
|
304 |
|
305 (* parse translation *) |
|
306 |
|
307 fun case_tr tab_of ctxt [t, u] = |
|
308 let |
|
309 val thy = ProofContext.theory_of ctxt; |
|
310 (* replace occurrences of dummy_pattern by distinct variables *) |
|
311 (* internalize constant names *) |
|
312 fun prep_pat ((c as Const ("_constrain", _)) $ t $ tT) used = |
|
313 let val (t', used') = prep_pat t used |
|
314 in (c $ t' $ tT, used') end |
|
315 | prep_pat (Const ("dummy_pattern", T)) used = |
|
316 let val x = Name.variant used "x" |
|
317 in (Free (x, T), x :: used) end |
|
318 | prep_pat (Const (s, T)) used = |
|
319 (case try (unprefix Syntax.constN) s of |
|
320 SOME c => (Const (c, T), used) |
|
321 | NONE => (Const (Sign.intern_const thy s, T), used)) |
|
322 | prep_pat (v as Free (s, T)) used = |
|
323 let val s' = Sign.intern_const thy s |
|
324 in |
|
325 if Sign.declared_const thy s' then |
|
326 (Const (s', T), used) |
|
327 else (v, used) |
|
328 end |
|
329 | prep_pat (t $ u) used = |
|
330 let |
|
331 val (t', used') = prep_pat t used; |
|
332 val (u', used'') = prep_pat u used' |
|
333 in |
|
334 (t' $ u', used'') |
|
335 end |
|
336 | prep_pat t used = case_error ("Bad pattern: " ^ |
|
337 ProofContext.string_of_term ctxt t); |
|
338 fun dest_case1 (t as Const ("_case1", _) $ l $ r) = |
|
339 let val (l', cnstrts) = strip_constraints l |
|
340 in ((fst (prep_pat l' (add_term_free_names (t, []))), r), cnstrts) |
|
341 end |
|
342 | dest_case1 t = case_error "dest_case1"; |
|
343 fun dest_case2 (Const ("_case2", _) $ t $ u) = t :: dest_case2 u |
|
344 | dest_case2 t = [t]; |
|
345 val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u)); |
|
346 val (case_tm, _) = make_case_untyped (tab_of thy) ctxt true [] |
|
347 (fold (fn tT => fn t => Syntax.const "_constrain" $ t $ tT) |
|
348 (flat cnstrts) t) cases; |
|
349 in case_tm end |
|
350 | case_tr _ _ ts = case_error "case_tr"; |
|
351 |
|
352 |
|
353 (*--------------------------------------------------------------------------- |
|
354 * Pretty printing of nested case expressions |
|
355 *---------------------------------------------------------------------------*) |
|
356 |
|
357 (* destruct one level of pattern matching *) |
|
358 |
|
359 fun gen_dest_case name_of type_of tab d used t = |
|
360 case apfst name_of (strip_comb t) of |
|
361 (SOME cname, ts as _ :: _) => |
|
362 let |
|
363 val (fs, x) = split_last ts; |
|
364 fun strip_abs i t = |
|
365 let |
|
366 val zs = strip_abs_vars t; |
|
367 val _ = if length zs < i then raise CASE_ERROR ("", 0) else (); |
|
368 val (xs, ys) = chop i zs; |
|
369 val u = list_abs (ys, strip_abs_body t); |
|
370 val xs' = map Free (Name.variant_list (add_term_names (u, used)) |
|
371 (map fst xs) ~~ map snd xs) |
|
372 in (xs', subst_bounds (rev xs', u)) end; |
|
373 fun is_dependent i t = |
|
374 let val k = length (strip_abs_vars t) - i |
|
375 in k < 0 orelse exists (fn j => j >= k) |
|
376 (loose_bnos (strip_abs_body t)) |
|
377 end; |
|
378 fun count_cases (_, _, true) = I |
|
379 | count_cases (c, (_, body), false) = |
|
380 AList.map_default op aconv (body, []) (cons c); |
|
381 val is_undefined = name_of #> equal (SOME "HOL.undefined"); |
|
382 fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body) |
|
383 in case ty_info tab cname of |
|
384 SOME {constructors, case_name} => |
|
385 if length fs = length constructors then |
|
386 let |
|
387 val cases = map (fn (Const (s, U), t) => |
|
388 let |
|
389 val k = length (binder_types U); |
|
390 val p as (xs, _) = strip_abs k t |
|
391 in |
|
392 (Const (s, map type_of xs ---> type_of x), |
|
393 p, is_dependent k t) |
|
394 end) (constructors ~~ fs); |
|
395 val cases' = sort (int_ord o swap o pairself (length o snd)) |
|
396 (fold_rev count_cases cases []); |
|
397 val R = type_of t; |
|
398 val dummy = if d then Const ("dummy_pattern", R) |
|
399 else Free (Name.variant used "x", R) |
|
400 in |
|
401 SOME (x, map mk_case (case find_first (is_undefined o fst) cases' of |
|
402 SOME (_, cs) => |
|
403 if length cs = length constructors then [hd cases] |
|
404 else filter_out (fn (_, (_, body), _) => is_undefined body) cases |
|
405 | NONE => case cases' of |
|
406 [] => cases |
|
407 | (default, cs) :: _ => |
|
408 if length cs = 1 then cases |
|
409 else if length cs = length constructors then |
|
410 [hd cases, (dummy, ([], default), false)] |
|
411 else |
|
412 filter_out (fn (c, _, _) => member op aconv cs c) cases @ |
|
413 [(dummy, ([], default), false)])) |
|
414 end handle CASE_ERROR _ => NONE |
|
415 else NONE |
|
416 | _ => NONE |
|
417 end |
|
418 | _ => NONE; |
|
419 |
|
420 val dest_case = gen_dest_case (try (dest_Const #> fst)) fastype_of; |
|
421 val dest_case' = gen_dest_case |
|
422 (try (dest_Const #> fst #> unprefix Syntax.constN)) (K dummyT); |
|
423 |
|
424 |
|
425 (* destruct nested patterns *) |
|
426 |
|
427 fun strip_case' dest (pat, rhs) = |
|
428 case dest (add_term_free_names (pat, [])) rhs of |
|
429 SOME (exp as Free _, clauses) => |
|
430 if member op aconv (term_frees pat) exp andalso |
|
431 not (exists (fn (_, rhs') => |
|
432 member op aconv (term_frees rhs') exp) clauses) |
|
433 then |
|
434 maps (strip_case' dest) (map (fn (pat', rhs') => |
|
435 (subst_free [(exp, pat')] pat, rhs')) clauses) |
|
436 else [(pat, rhs)] |
|
437 | _ => [(pat, rhs)]; |
|
438 |
|
439 fun gen_strip_case dest t = case dest [] t of |
|
440 SOME (x, clauses) => |
|
441 SOME (x, maps (strip_case' dest) clauses) |
|
442 | NONE => NONE; |
|
443 |
|
444 val strip_case = gen_strip_case oo dest_case; |
|
445 val strip_case' = gen_strip_case oo dest_case'; |
|
446 |
|
447 |
|
448 (* print translation *) |
|
449 |
|
450 fun case_tr' tab_of cname ctxt ts = |
|
451 let |
|
452 val thy = ProofContext.theory_of ctxt; |
|
453 val consts = ProofContext.consts_of ctxt; |
|
454 fun mk_clause (pat, rhs) = |
|
455 let val xs = term_frees pat |
|
456 in |
|
457 Syntax.const "_case1" $ |
|
458 map_aterms |
|
459 (fn Free p => Syntax.mark_boundT p |
|
460 | Const (s, _) => Const (Consts.extern_early consts s, dummyT) |
|
461 | t => t) pat $ |
|
462 map_aterms |
|
463 (fn x as Free (s, _) => |
|
464 if member op aconv xs x then Syntax.mark_bound s else x |
|
465 | t => t) rhs |
|
466 end |
|
467 in case strip_case' (tab_of thy) true (list_comb (Syntax.const cname, ts)) of |
|
468 SOME (x, clauses) => Syntax.const "_case_syntax" $ x $ |
|
469 foldr1 (fn (t, u) => Syntax.const "_case2" $ t $ u) |
|
470 (map mk_clause clauses) |
|
471 | NONE => raise Match |
|
472 end; |
|
473 |
|
474 end; |