|
1 (* Title: Tools/case_translation.ML |
|
2 Author: Konrad Slind, Cambridge University Computer Laboratory |
|
3 Author: Stefan Berghofer, TU Muenchen |
|
4 Author: Dmitriy Traytel, TU Muenchen |
|
5 |
|
6 Nested case expressions via a generic data slot for case combinators and constructors. |
|
7 *) |
|
8 |
|
9 signature CASE_TRANSLATION = |
|
10 sig |
|
11 datatype config = Error | Warning | Quiet |
|
12 val case_tr: Proof.context -> term list -> term |
|
13 val lookup_by_constr: Proof.context -> string * typ -> (term * term list) option |
|
14 val lookup_by_constr_permissive: Proof.context -> string * typ -> (term * term list) option |
|
15 val lookup_by_case: Proof.context -> string -> (term * term list) option |
|
16 val make_case: Proof.context -> config -> Name.context -> term -> (term * term) list -> term |
|
17 val print_case_translations: Proof.context -> unit |
|
18 val strip_case: Proof.context -> bool -> term -> term |
|
19 val show_cases: bool Config.T |
|
20 val setup: theory -> theory |
|
21 val register: term -> term list -> Context.generic -> Context.generic |
|
22 end; |
|
23 |
|
24 structure Case_Translation: CASE_TRANSLATION = |
|
25 struct |
|
26 |
|
27 (** data management **) |
|
28 |
|
29 datatype data = Data of |
|
30 {constrs: (string * (term * term list)) list Symtab.table, |
|
31 cases: (term * term list) Symtab.table}; |
|
32 |
|
33 fun make_data (constrs, cases) = Data {constrs = constrs, cases = cases}; |
|
34 |
|
35 structure Data = Generic_Data |
|
36 ( |
|
37 type T = data; |
|
38 val empty = make_data (Symtab.empty, Symtab.empty); |
|
39 val extend = I; |
|
40 fun merge |
|
41 (Data {constrs = constrs1, cases = cases1}, |
|
42 Data {constrs = constrs2, cases = cases2}) = |
|
43 make_data |
|
44 (Symtab.join (K (AList.merge (op =) (K true))) (constrs1, constrs2), |
|
45 Symtab.merge (K true) (cases1, cases2)); |
|
46 ); |
|
47 |
|
48 fun map_data f = |
|
49 Data.map (fn Data {constrs, cases} => make_data (f (constrs, cases))); |
|
50 fun map_constrs f = map_data (fn (constrs, cases) => (f constrs, cases)); |
|
51 fun map_cases f = map_data (fn (constrs, cases) => (constrs, f cases)); |
|
52 |
|
53 val rep_data = (fn Data args => args) o Data.get o Context.Proof; |
|
54 |
|
55 fun T_of_data (comb, constrs) = |
|
56 fastype_of comb |
|
57 |> funpow (length constrs) range_type |
|
58 |> domain_type; |
|
59 |
|
60 val Tname_of_data = fst o dest_Type o T_of_data; |
|
61 |
|
62 val constrs_of = #constrs o rep_data; |
|
63 val cases_of = #cases o rep_data; |
|
64 |
|
65 fun lookup_by_constr ctxt (c, T) = |
|
66 let |
|
67 val tab = Symtab.lookup_list (constrs_of ctxt) c; |
|
68 in |
|
69 (case body_type T of |
|
70 Type (tyco, _) => AList.lookup (op =) tab tyco |
|
71 | _ => NONE) |
|
72 end; |
|
73 |
|
74 fun lookup_by_constr_permissive ctxt (c, T) = |
|
75 let |
|
76 val tab = Symtab.lookup_list (constrs_of ctxt) c; |
|
77 val hint = (case body_type T of Type (tyco, _) => SOME tyco | _ => NONE); |
|
78 val default = if null tab then NONE else SOME (snd (List.last tab)); |
|
79 (*conservative wrt. overloaded constructors*) |
|
80 in |
|
81 (case hint of |
|
82 NONE => default |
|
83 | SOME tyco => |
|
84 (case AList.lookup (op =) tab tyco of |
|
85 NONE => default (*permissive*) |
|
86 | SOME info => SOME info)) |
|
87 end; |
|
88 |
|
89 val lookup_by_case = Symtab.lookup o cases_of; |
|
90 |
|
91 |
|
92 (** installation **) |
|
93 |
|
94 fun case_error s = error ("Error in case expression:\n" ^ s); |
|
95 |
|
96 val name_of = try (dest_Const #> fst); |
|
97 |
|
98 (* parse translation *) |
|
99 |
|
100 fun constrain_Abs tT t = Syntax.const @{syntax_const "_constrainAbs"} $ t $ tT; |
|
101 |
|
102 fun case_tr ctxt [t, u] = |
|
103 let |
|
104 val thy = Proof_Context.theory_of ctxt; |
|
105 |
|
106 fun is_const s = |
|
107 Sign.declared_const thy (Proof_Context.intern_const ctxt s); |
|
108 |
|
109 fun abs p tTs t = Syntax.const @{const_syntax case_abs} $ |
|
110 fold constrain_Abs tTs (absfree p t); |
|
111 |
|
112 fun abs_pat (Const ("_constrain", _) $ t $ tT) tTs = abs_pat t (tT :: tTs) |
|
113 | abs_pat (Free (p as (x, _))) tTs = |
|
114 if is_const x then I else abs p tTs |
|
115 | abs_pat (t $ u) _ = abs_pat u [] #> abs_pat t [] |
|
116 | abs_pat _ _ = I; |
|
117 |
|
118 fun dest_case1 (Const (@{syntax_const "_case1"}, _) $ l $ r) = |
|
119 abs_pat l [] |
|
120 (Syntax.const @{const_syntax case_elem} $ Term_Position.strip_positions l $ r) |
|
121 | dest_case1 _ = case_error "dest_case1"; |
|
122 |
|
123 fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u |
|
124 | dest_case2 t = [t]; |
|
125 in |
|
126 fold_rev |
|
127 (fn t => fn u => |
|
128 Syntax.const @{const_syntax case_cons} $ dest_case1 t $ u) |
|
129 (dest_case2 u) |
|
130 (Syntax.const @{const_syntax case_nil}) $ t |
|
131 end |
|
132 | case_tr _ _ = case_error "case_tr"; |
|
133 |
|
134 val trfun_setup = |
|
135 Sign.add_advanced_trfuns ([], |
|
136 [(@{syntax_const "_case_syntax"}, case_tr)], |
|
137 [], []); |
|
138 |
|
139 |
|
140 (* print translation *) |
|
141 |
|
142 fun case_tr' [t, u, x] = |
|
143 let |
|
144 fun mk_clause (Const (@{const_syntax case_abs}, _) $ Abs (s, T, t)) xs used = |
|
145 let val (s', used') = Name.variant s used |
|
146 in mk_clause t ((s', T) :: xs) used' end |
|
147 | mk_clause (Const (@{const_syntax case_elem}, _) $ pat $ rhs) xs _ = |
|
148 Syntax.const @{syntax_const "_case1"} $ |
|
149 subst_bounds (map Syntax_Trans.mark_bound_abs xs, pat) $ |
|
150 subst_bounds (map Syntax_Trans.mark_bound_body xs, rhs); |
|
151 |
|
152 fun mk_clauses (Const (@{const_syntax case_nil}, _)) = [] |
|
153 | mk_clauses (Const (@{const_syntax case_cons}, _) $ t $ u) = |
|
154 mk_clauses' t u |
|
155 and mk_clauses' t u = |
|
156 mk_clause t [] (Term.declare_term_frees t Name.context) :: |
|
157 mk_clauses u |
|
158 in |
|
159 Syntax.const @{syntax_const "_case_syntax"} $ x $ |
|
160 foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u) |
|
161 (mk_clauses' t u) |
|
162 end; |
|
163 |
|
164 val trfun_setup' = Sign.add_trfuns |
|
165 ([], [], [(@{const_syntax "case_cons"}, case_tr')], []); |
|
166 |
|
167 |
|
168 (* declarations *) |
|
169 |
|
170 fun register raw_case_comb raw_constrs context = |
|
171 let |
|
172 val ctxt = Context.proof_of context; |
|
173 val case_comb = singleton (Variable.polymorphic ctxt) raw_case_comb; |
|
174 val constrs = Variable.polymorphic ctxt raw_constrs; |
|
175 val case_key = case_comb |> dest_Const |> fst; |
|
176 val constr_keys = map (fst o dest_Const) constrs; |
|
177 val data = (case_comb, constrs); |
|
178 val Tname = Tname_of_data data; |
|
179 val update_constrs = fold (fn key => Symtab.cons_list (key, (Tname, data))) constr_keys; |
|
180 val update_cases = Symtab.update (case_key, data); |
|
181 in |
|
182 context |
|
183 |> map_constrs update_constrs |
|
184 |> map_cases update_cases |
|
185 end; |
|
186 |
|
187 |
|
188 (* (Un)check phases *) |
|
189 |
|
190 datatype config = Error | Warning | Quiet; |
|
191 |
|
192 exception CASE_ERROR of string * int; |
|
193 |
|
194 fun match_type ctxt pat ob = |
|
195 Sign.typ_match (Proof_Context.theory_of ctxt) (pat, ob) Vartab.empty; |
|
196 |
|
197 |
|
198 (*Each pattern carries with it a tag i, which denotes the clause it |
|
199 came from. i = ~1 indicates that the clause was added by pattern |
|
200 completion.*) |
|
201 |
|
202 fun add_row_used ((prfx, pats), (tm, tag)) = |
|
203 fold Term.declare_term_frees (tm :: pats @ map Free prfx); |
|
204 |
|
205 (* try to preserve names given by user *) |
|
206 fun default_name "" (Free (name', _)) = name' |
|
207 | default_name name _ = name; |
|
208 |
|
209 |
|
210 (*Produce an instance of a constructor, plus fresh variables for its arguments.*) |
|
211 fun fresh_constr ctxt colty used c = |
|
212 let |
|
213 val (_, T) = dest_Const c; |
|
214 val Ts = binder_types T; |
|
215 val (names, _) = fold_map Name.variant |
|
216 (Datatype_Prop.make_tnames (map Logic.unvarifyT_global Ts)) used; |
|
217 val ty = body_type T; |
|
218 val ty_theta = match_type ctxt ty colty |
|
219 handle Type.TYPE_MATCH => raise CASE_ERROR ("type mismatch", ~1); |
|
220 val c' = Envir.subst_term_types ty_theta c; |
|
221 val gvars = map (Envir.subst_term_types ty_theta o Free) (names ~~ Ts); |
|
222 in (c', gvars) end; |
|
223 |
|
224 (*Go through a list of rows and pick out the ones beginning with a |
|
225 pattern with constructor = name.*) |
|
226 fun mk_group (name, T) rows = |
|
227 let val k = length (binder_types T) in |
|
228 fold (fn (row as ((prfx, p :: ps), rhs as (_, i))) => |
|
229 fn ((in_group, not_in_group), names) => |
|
230 (case strip_comb p of |
|
231 (Const (name', _), args) => |
|
232 if name = name' then |
|
233 if length args = k then |
|
234 ((((prfx, args @ ps), rhs) :: in_group, not_in_group), |
|
235 map2 default_name names args) |
|
236 else raise CASE_ERROR ("Wrong number of arguments for constructor " ^ quote name, i) |
|
237 else ((in_group, row :: not_in_group), names) |
|
238 | _ => raise CASE_ERROR ("Not a constructor pattern", i))) |
|
239 rows (([], []), replicate k "") |>> pairself rev |
|
240 end; |
|
241 |
|
242 |
|
243 (* Partitioning *) |
|
244 |
|
245 fun partition _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1) |
|
246 | partition ctxt used constructors colty res_ty |
|
247 (rows as (((prfx, _ :: ps), _) :: _)) = |
|
248 let |
|
249 fun part [] [] = [] |
|
250 | part [] ((_, (_, i)) :: _) = raise CASE_ERROR ("Not a constructor pattern", i) |
|
251 | part (c :: cs) rows = |
|
252 let |
|
253 val ((in_group, not_in_group), names) = mk_group (dest_Const c) rows; |
|
254 val used' = fold add_row_used in_group used; |
|
255 val (c', gvars) = fresh_constr ctxt colty used' c; |
|
256 val in_group' = |
|
257 if null in_group (* Constructor not given *) |
|
258 then |
|
259 let |
|
260 val Ts = map fastype_of ps; |
|
261 val (xs, _) = |
|
262 fold_map Name.variant |
|
263 (replicate (length ps) "x") |
|
264 (fold Term.declare_term_frees gvars used'); |
|
265 in |
|
266 [((prfx, gvars @ map Free (xs ~~ Ts)), |
|
267 (Const (@{const_name undefined}, res_ty), ~1))] |
|
268 end |
|
269 else in_group; |
|
270 in |
|
271 {constructor = c', |
|
272 new_formals = gvars, |
|
273 names = names, |
|
274 group = in_group'} :: part cs not_in_group |
|
275 end; |
|
276 in part constructors rows end; |
|
277 |
|
278 fun v_to_prfx (prfx, Free v :: pats) = (v :: prfx, pats) |
|
279 | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1); |
|
280 |
|
281 |
|
282 (* Translation of pattern terms into nested case expressions. *) |
|
283 |
|
284 fun mk_case ctxt used range_ty = |
|
285 let |
|
286 val get_info = lookup_by_constr_permissive ctxt; |
|
287 |
|
288 fun expand constructors used ty ((_, []), _) = raise CASE_ERROR ("mk_case: expand", ~1) |
|
289 | expand constructors used ty (row as ((prfx, p :: ps), (rhs, tag))) = |
|
290 if is_Free p then |
|
291 let |
|
292 val used' = add_row_used row used; |
|
293 fun expnd c = |
|
294 let val capp = list_comb (fresh_constr ctxt ty used' c) |
|
295 in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag)) end; |
|
296 in map expnd constructors end |
|
297 else [row]; |
|
298 |
|
299 val (name, _) = Name.variant "a" used; |
|
300 |
|
301 fun mk _ [] = raise CASE_ERROR ("no rows", ~1) |
|
302 | mk [] (((_, []), (tm, tag)) :: _) = ([tag], tm) (* Done *) |
|
303 | mk path (rows as ((row as ((_, [Free _]), _)) :: _ :: _)) = mk path [row] |
|
304 | mk (u :: us) (rows as ((_, _ :: _), _) :: _) = |
|
305 let val col0 = map (fn ((_, p :: _), (_, i)) => (p, i)) rows in |
|
306 (case Option.map (apfst head_of) |
|
307 (find_first (not o is_Free o fst) col0) of |
|
308 NONE => |
|
309 let |
|
310 val rows' = map (fn ((v, _), row) => row ||> |
|
311 apfst (subst_free [(v, u)]) |>> v_to_prfx) (col0 ~~ rows); |
|
312 in mk us rows' end |
|
313 | SOME (Const (cname, cT), i) => |
|
314 (case get_info (cname, cT) of |
|
315 NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ quote cname, i) |
|
316 | SOME (case_comb, constructors) => |
|
317 let |
|
318 val pty = body_type cT; |
|
319 val used' = fold Term.declare_term_frees us used; |
|
320 val nrows = maps (expand constructors used' pty) rows; |
|
321 val subproblems = |
|
322 partition ctxt used' constructors pty range_ty nrows; |
|
323 val (pat_rect, dtrees) = |
|
324 split_list (map (fn {new_formals, group, ...} => |
|
325 mk (new_formals @ us) group) subproblems); |
|
326 val case_functions = |
|
327 map2 (fn {new_formals, names, ...} => |
|
328 fold_rev (fn (x as Free (_, T), s) => fn t => |
|
329 Abs (if s = "" then name else s, T, abstract_over (x, t))) |
|
330 (new_formals ~~ names)) |
|
331 subproblems dtrees; |
|
332 val types = map fastype_of (case_functions @ [u]); |
|
333 val case_const = Const (name_of case_comb |> the, types ---> range_ty); |
|
334 val tree = list_comb (case_const, case_functions @ [u]); |
|
335 in (flat pat_rect, tree) end) |
|
336 | SOME (t, i) => |
|
337 raise CASE_ERROR ("Not a datatype constructor: " ^ Syntax.string_of_term ctxt t, i)) |
|
338 end |
|
339 | mk _ _ = raise CASE_ERROR ("Malformed row matrix", ~1) |
|
340 in mk end; |
|
341 |
|
342 |
|
343 (* replace occurrences of dummy_pattern by distinct variables *) |
|
344 fun replace_dummies (Const (@{const_name dummy_pattern}, T)) used = |
|
345 let val (x, used') = Name.variant "x" used |
|
346 in (Free (x, T), used') end |
|
347 | replace_dummies (t $ u) used = |
|
348 let |
|
349 val (t', used') = replace_dummies t used; |
|
350 val (u', used'') = replace_dummies u used'; |
|
351 in (t' $ u', used'') end |
|
352 | replace_dummies t used = (t, used); |
|
353 |
|
354 (*Repeated variable occurrences in a pattern are not allowed.*) |
|
355 fun no_repeat_vars ctxt pat = fold_aterms |
|
356 (fn x as Free (s, _) => |
|
357 (fn xs => |
|
358 if member op aconv xs x then |
|
359 case_error (quote s ^ " occurs repeatedly in the pattern " ^ |
|
360 quote (Syntax.string_of_term ctxt pat)) |
|
361 else x :: xs) |
|
362 | _ => I) pat []; |
|
363 |
|
364 fun make_case ctxt config used x clauses = |
|
365 let |
|
366 fun string_of_clause (pat, rhs) = |
|
367 Syntax.string_of_term ctxt (Syntax.const @{syntax_const "_case1"} $ pat $ rhs); |
|
368 val _ = map (no_repeat_vars ctxt o fst) clauses; |
|
369 val (rows, used') = used |> |
|
370 fold (fn (pat, rhs) => |
|
371 Term.declare_term_frees pat #> Term.declare_term_frees rhs) clauses |> |
|
372 fold_map (fn (i, (pat, rhs)) => fn used => |
|
373 let val (pat', used') = replace_dummies pat used |
|
374 in ((([], [pat']), (rhs, i)), used') end) |
|
375 (map_index I clauses); |
|
376 val rangeT = |
|
377 (case distinct (op =) (map (fastype_of o snd) clauses) of |
|
378 [] => case_error "no clauses given" |
|
379 | [T] => T |
|
380 | _ => case_error "all cases must have the same result type"); |
|
381 val used' = fold add_row_used rows used; |
|
382 val (tags, case_tm) = |
|
383 mk_case ctxt used' rangeT [x] rows |
|
384 handle CASE_ERROR (msg, i) => |
|
385 case_error |
|
386 (msg ^ (if i < 0 then "" else "\nIn clause\n" ^ string_of_clause (nth clauses i))); |
|
387 val _ = |
|
388 (case subtract (op =) tags (map (snd o snd) rows) of |
|
389 [] => () |
|
390 | is => |
|
391 (case config of Error => case_error | Warning => warning | Quiet => fn _ => ()) |
|
392 ("The following clauses are redundant (covered by preceding clauses):\n" ^ |
|
393 cat_lines (map (string_of_clause o nth clauses) is))); |
|
394 in |
|
395 case_tm |
|
396 end; |
|
397 |
|
398 |
|
399 (* term check *) |
|
400 |
|
401 fun decode_clause (Const (@{const_name case_abs}, _) $ Abs (s, T, t)) xs used = |
|
402 let val (s', used') = Name.variant s used |
|
403 in decode_clause t (Free (s', T) :: xs) used' end |
|
404 | decode_clause (Const (@{const_name case_elem}, _) $ t $ u) xs _ = |
|
405 (subst_bounds (xs, t), subst_bounds (xs, u)) |
|
406 | decode_clause _ _ _ = case_error "decode_clause"; |
|
407 |
|
408 fun decode_cases (Const (@{const_name case_nil}, _)) = [] |
|
409 | decode_cases (Const (@{const_name case_cons}, _) $ t $ u) = |
|
410 decode_clause t [] (Term.declare_term_frees t Name.context) :: |
|
411 decode_cases u |
|
412 | decode_cases _ = case_error "decode_cases"; |
|
413 |
|
414 fun check_case ctxt = |
|
415 let |
|
416 fun decode_case ((t as Const (@{const_name case_cons}, _) $ _ $ _) $ u) = |
|
417 make_case ctxt Error Name.context (decode_case u) (decode_cases t) |
|
418 | decode_case (t $ u) = decode_case t $ decode_case u |
|
419 | decode_case (Abs (x, T, u)) = |
|
420 let val (x', u') = Term.dest_abs (x, T, u); |
|
421 in Term.absfree (x', T) (decode_case u') end |
|
422 | decode_case t = t; |
|
423 in |
|
424 map decode_case |
|
425 end; |
|
426 |
|
427 val term_check_setup = |
|
428 Context.theory_map (Syntax_Phases.term_check 1 "case" check_case); |
|
429 |
|
430 |
|
431 (* Pretty printing of nested case expressions *) |
|
432 |
|
433 (* destruct one level of pattern matching *) |
|
434 |
|
435 fun dest_case ctxt d used t = |
|
436 (case apfst name_of (strip_comb t) of |
|
437 (SOME cname, ts as _ :: _) => |
|
438 let |
|
439 val (fs, x) = split_last ts; |
|
440 fun strip_abs i Us t = |
|
441 let |
|
442 val zs = strip_abs_vars t; |
|
443 val j = length zs; |
|
444 val (xs, ys) = |
|
445 if j < i then (zs @ map (pair "x") (drop j Us), []) |
|
446 else chop i zs; |
|
447 val u = fold_rev Term.abs ys (strip_abs_body t); |
|
448 val xs' = map Free |
|
449 ((fold_map Name.variant (map fst xs) |
|
450 (Term.declare_term_names u used) |> fst) ~~ |
|
451 map snd xs); |
|
452 val (xs1, xs2) = chop j xs' |
|
453 in (xs', list_comb (subst_bounds (rev xs1, u), xs2)) end; |
|
454 fun is_dependent i t = |
|
455 let val k = length (strip_abs_vars t) - i |
|
456 in k < 0 orelse exists (fn j => j >= k) (loose_bnos (strip_abs_body t)) end; |
|
457 fun count_cases (_, _, true) = I |
|
458 | count_cases (c, (_, body), false) = AList.map_default op aconv (body, []) (cons c); |
|
459 val is_undefined = name_of #> equal (SOME @{const_name undefined}); |
|
460 fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body); |
|
461 val get_info = lookup_by_case ctxt; |
|
462 in |
|
463 (case get_info cname of |
|
464 SOME (_, constructors) => |
|
465 if length fs = length constructors then |
|
466 let |
|
467 val cases = map (fn (Const (s, U), t) => |
|
468 let |
|
469 val Us = binder_types U; |
|
470 val k = length Us; |
|
471 val p as (xs, _) = strip_abs k Us t; |
|
472 in |
|
473 (Const (s, map fastype_of xs ---> fastype_of x), p, is_dependent k t) |
|
474 end) (constructors ~~ fs); |
|
475 val cases' = |
|
476 sort (int_ord o swap o pairself (length o snd)) |
|
477 (fold_rev count_cases cases []); |
|
478 val R = fastype_of t; |
|
479 val dummy = |
|
480 if d then Term.dummy_pattern R |
|
481 else Free (Name.variant "x" used |> fst, R); |
|
482 in |
|
483 SOME (x, |
|
484 map mk_case |
|
485 (case find_first (is_undefined o fst) cases' of |
|
486 SOME (_, cs) => |
|
487 if length cs = length constructors then [hd cases] |
|
488 else filter_out (fn (_, (_, body), _) => is_undefined body) cases |
|
489 | NONE => |
|
490 (case cases' of |
|
491 [] => cases |
|
492 | (default, cs) :: _ => |
|
493 if length cs = 1 then cases |
|
494 else if length cs = length constructors then |
|
495 [hd cases, (dummy, ([], default), false)] |
|
496 else |
|
497 filter_out (fn (c, _, _) => member op aconv cs c) cases @ |
|
498 [(dummy, ([], default), false)]))) |
|
499 end |
|
500 else NONE |
|
501 | _ => NONE) |
|
502 end |
|
503 | _ => NONE); |
|
504 |
|
505 |
|
506 (* destruct nested patterns *) |
|
507 |
|
508 fun encode_clause S T (pat, rhs) = |
|
509 fold (fn x as (_, U) => fn t => |
|
510 Const (@{const_name case_abs}, (U --> T) --> T) $ Term.absfree x t) |
|
511 (Term.add_frees pat []) |
|
512 (Const (@{const_name case_elem}, S --> T --> S --> T) $ pat $ rhs); |
|
513 |
|
514 fun encode_cases S T [] = Const (@{const_name case_nil}, S --> T) |
|
515 | encode_cases S T (p :: ps) = |
|
516 Const (@{const_name case_cons}, (S --> T) --> (S --> T) --> S --> T) $ |
|
517 encode_clause S T p $ encode_cases S T ps; |
|
518 |
|
519 fun encode_case (t, ps as (pat, rhs) :: _) = |
|
520 encode_cases (fastype_of pat) (fastype_of rhs) ps $ t |
|
521 | encode_case _ = case_error "encode_case"; |
|
522 |
|
523 fun strip_case' ctxt d (pat, rhs) = |
|
524 (case dest_case ctxt d (Term.declare_term_frees pat Name.context) rhs of |
|
525 SOME (exp as Free _, clauses) => |
|
526 if Term.exists_subterm (curry (op aconv) exp) pat andalso |
|
527 not (exists (fn (_, rhs') => |
|
528 Term.exists_subterm (curry (op aconv) exp) rhs') clauses) |
|
529 then |
|
530 maps (strip_case' ctxt d) (map (fn (pat', rhs') => |
|
531 (subst_free [(exp, pat')] pat, rhs')) clauses) |
|
532 else [(pat, rhs)] |
|
533 | _ => [(pat, rhs)]); |
|
534 |
|
535 fun strip_case ctxt d t = |
|
536 (case dest_case ctxt d Name.context t of |
|
537 SOME (x, clauses) => encode_case (x, maps (strip_case' ctxt d) clauses) |
|
538 | NONE => |
|
539 (case t of |
|
540 (t $ u) => strip_case ctxt d t $ strip_case ctxt d u |
|
541 | (Abs (x, T, u)) => |
|
542 let val (x', u') = Term.dest_abs (x, T, u); |
|
543 in Term.absfree (x', T) (strip_case ctxt d u') end |
|
544 | _ => t)); |
|
545 |
|
546 |
|
547 (* term uncheck *) |
|
548 |
|
549 val show_cases = Attrib.setup_config_bool @{binding show_cases} (K true); |
|
550 |
|
551 fun uncheck_case ctxt ts = |
|
552 if Config.get ctxt show_cases then map (strip_case ctxt true) ts else ts; |
|
553 |
|
554 val term_uncheck_setup = |
|
555 Context.theory_map (Syntax_Phases.term_uncheck 1 "case" uncheck_case); |
|
556 |
|
557 |
|
558 (* theory setup *) |
|
559 |
|
560 val setup = |
|
561 trfun_setup #> |
|
562 trfun_setup' #> |
|
563 term_check_setup #> |
|
564 term_uncheck_setup; |
|
565 |
|
566 |
|
567 (* outer syntax commands *) |
|
568 |
|
569 fun print_case_translations ctxt = |
|
570 let |
|
571 val cases = Symtab.dest (cases_of ctxt); |
|
572 fun show_case (_, data as (comb, ctrs)) = |
|
573 Pretty.big_list |
|
574 (Pretty.string_of (Pretty.block [Pretty.str (Tname_of_data data), Pretty.str ":"])) |
|
575 [Pretty.block [Pretty.brk 3, Pretty.block |
|
576 [Pretty.str "combinator:", Pretty.brk 1, Pretty.quote (Syntax.pretty_term ctxt comb)]], |
|
577 Pretty.block [Pretty.brk 3, Pretty.block |
|
578 [Pretty.str "constructors:", Pretty.brk 1, |
|
579 Pretty.list "" "" (map (Pretty.quote o Syntax.pretty_term ctxt) ctrs)]]]; |
|
580 in |
|
581 Pretty.big_list "Case translations:" (map show_case cases) |
|
582 |> Pretty.writeln |
|
583 end; |
|
584 |
|
585 val _ = |
|
586 Outer_Syntax.improper_command @{command_spec "print_case_translations"} |
|
587 "print registered case combinators and constructors" |
|
588 (Scan.succeed (Toplevel.keep (print_case_translations o Toplevel.context_of))) |
|
589 |
|
590 end; |