30 |
30 |
31 exception CASE_ERROR of string * int; |
31 exception CASE_ERROR of string * int; |
32 |
32 |
33 fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty; |
33 fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty; |
34 |
34 |
35 (*--------------------------------------------------------------------------- |
35 (* Get information about datatypes *) |
36 * Get information about datatypes |
|
37 *---------------------------------------------------------------------------*) |
|
38 |
36 |
39 fun ty_info tab sT = |
37 fun ty_info tab sT = |
40 (case tab sT of |
38 (case tab sT of |
41 SOME ({descr, case_name, index, sorts, ...} : info) => |
39 SOME ({descr, case_name, index, sorts, ...} : info) => |
42 let |
40 let |
49 Const (cname, Logic.varifyT_global (map mk_ty dts' ---> T))) constrs} |
47 Const (cname, Logic.varifyT_global (map mk_ty dts' ---> T))) constrs} |
50 end |
48 end |
51 | NONE => NONE); |
49 | NONE => NONE); |
52 |
50 |
53 |
51 |
54 (*--------------------------------------------------------------------------- |
52 (*Each pattern carries with it a tag i, which denotes the clause it |
55 * Each pattern carries with it a tag i, which denotes |
53 came from. i = ~1 indicates that the clause was added by pattern |
56 * the clause it came from. i = ~1 indicates that |
54 completion.*) |
57 * the clause was added by pattern completion. |
|
58 *---------------------------------------------------------------------------*) |
|
59 |
|
60 fun pattern_subst theta (tm, x) = (subst_free theta tm, x); |
|
61 |
55 |
62 fun add_row_used ((prfx, pats), (tm, tag)) = |
56 fun add_row_used ((prfx, pats), (tm, tag)) = |
63 fold Term.add_free_names (tm :: pats @ map Free prfx); |
57 fold Term.add_free_names (tm :: pats @ map Free prfx); |
64 |
58 |
65 (* try to preserve names given by user *) |
59 (*try to preserve names given by user*) |
66 fun default_names names ts = |
60 fun default_names names ts = |
67 map (fn ("", Free (name', _)) => name' | (name, _) => name) (names ~~ ts); |
61 map (fn ("", Free (name', _)) => name' | (name, _) => name) (names ~~ ts); |
68 |
62 |
69 fun strip_constraints (Const (@{syntax_const "_constrain"}, _) $ t $ tT) = |
63 fun strip_constraints (Const (@{syntax_const "_constrain"}, _) $ t $ tT) = |
70 strip_constraints t ||> cons tT |
64 strip_constraints t ||> cons tT |
73 fun mk_fun_constrain tT t = |
67 fun mk_fun_constrain tT t = |
74 Syntax.const @{syntax_const "_constrain"} $ t $ |
68 Syntax.const @{syntax_const "_constrain"} $ t $ |
75 (Syntax.const @{type_syntax fun} $ tT $ Syntax.const @{type_syntax dummy}); |
69 (Syntax.const @{type_syntax fun} $ tT $ Syntax.const @{type_syntax dummy}); |
76 |
70 |
77 |
71 |
78 (*--------------------------------------------------------------------------- |
72 (*Produce an instance of a constructor, plus fresh variables for its arguments.*) |
79 * Produce an instance of a constructor, plus genvars for its arguments. |
|
80 *---------------------------------------------------------------------------*) |
|
81 fun fresh_constr ty_match ty_inst colty used c = |
73 fun fresh_constr ty_match ty_inst colty used c = |
82 let |
74 let |
83 val (_, Ty) = dest_Const c |
75 val (_, Ty) = dest_Const c |
84 val Ts = binder_types Ty; |
76 val Ts = binder_types Ty; |
85 val names = Name.variant_list used |
77 val names = Name.variant_list used |
90 val c' = ty_inst ty_theta c |
82 val c' = ty_inst ty_theta c |
91 val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts) |
83 val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts) |
92 in (c', gvars) end; |
84 in (c', gvars) end; |
93 |
85 |
94 |
86 |
95 (*--------------------------------------------------------------------------- |
87 (*Goes through a list of rows and picks out the ones beginning with a |
96 * Goes through a list of rows and picks out the ones beginning with a |
88 pattern with constructor = name.*) |
97 * pattern with constructor = name. |
|
98 *---------------------------------------------------------------------------*) |
|
99 fun mk_group (name, T) rows = |
89 fun mk_group (name, T) rows = |
100 let val k = length (binder_types T) in |
90 let val k = length (binder_types T) in |
101 fold (fn (row as ((prfx, p :: rst), rhs as (_, i))) => |
91 fold (fn (row as ((prfx, p :: rst), rhs as (_, i))) => |
102 fn ((in_group, not_in_group), (names, cnstrts)) => |
92 fn ((in_group, not_in_group), (names, cnstrts)) => |
103 (case strip_comb p of |
93 (case strip_comb p of |
114 else ((in_group, row :: not_in_group), (names, cnstrts)) |
104 else ((in_group, row :: not_in_group), (names, cnstrts)) |
115 | _ => raise CASE_ERROR ("Not a constructor pattern", i))) |
105 | _ => raise CASE_ERROR ("Not a constructor pattern", i))) |
116 rows (([], []), (replicate k "", replicate k [])) |>> pairself rev |
106 rows (([], []), (replicate k "", replicate k [])) |>> pairself rev |
117 end; |
107 end; |
118 |
108 |
119 (*--------------------------------------------------------------------------- |
109 |
120 * Partition the rows. Not efficient: we should use hashing. |
110 (* Partitioning *) |
121 *---------------------------------------------------------------------------*) |
111 |
122 fun partition _ _ _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1) |
112 fun partition _ _ _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1) |
123 | partition ty_match ty_inst type_of used constructors colty res_ty |
113 | partition ty_match ty_inst type_of used constructors colty res_ty |
124 (rows as (((prfx, _ :: rstp), _) :: _)) = |
114 (rows as (((prfx, _ :: rstp), _) :: _)) = |
125 let |
115 let |
126 fun part {constrs = [], rows = [], A} = rev A |
116 fun part {constrs = [], rows = [], A} = rev A |
154 constraints = cnstrts, |
144 constraints = cnstrts, |
155 group = in_group'} :: A} |
145 group = in_group'} :: A} |
156 end |
146 end |
157 in part {constrs = constructors, rows = rows, A = []} end; |
147 in part {constrs = constructors, rows = rows, A = []} end; |
158 |
148 |
159 (*--------------------------------------------------------------------------- |
|
160 * Misc. routines used in mk_case |
|
161 *---------------------------------------------------------------------------*) |
|
162 |
|
163 fun v_to_prfx (prfx, Free v::pats) = (v::prfx,pats) |
149 fun v_to_prfx (prfx, Free v::pats) = (v::prfx,pats) |
164 | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1); |
150 | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1); |
165 |
151 |
166 |
152 |
167 (*---------------------------------------------------------------------------- |
153 (* Translation of pattern terms into nested case expressions. *) |
168 * Translation of pattern terms into nested case expressions. |
154 |
169 * |
|
170 * This performs the translation and also builds the full set of patterns. |
|
171 * Thus it supports the construction of induction theorems even when an |
|
172 * incomplete set of patterns is given. |
|
173 *---------------------------------------------------------------------------*) |
|
174 |
|
175 fun mk_case tab ctxt ty_match ty_inst type_of used range_ty = |
155 fun mk_case tab ctxt ty_match ty_inst type_of used range_ty = |
176 let |
156 let |
177 val name = Name.variant used "a"; |
157 val name = Name.variant used "a"; |
178 fun expand constructors used ty ((_, []), _) = |
158 fun expand constructors used ty ((_, []), _) = |
179 raise CASE_ERROR ("mk_case: expand_var_row", ~1) |
159 raise CASE_ERROR ("mk_case: expand_var_row", ~1) |
180 | expand constructors used ty (row as ((prfx, p :: rst), rhs)) = |
160 | expand constructors used ty (row as ((prfx, p :: rst), (rhs, tag))) = |
181 if is_Free p then |
161 if is_Free p then |
182 let |
162 let |
183 val used' = add_row_used row used; |
163 val used' = add_row_used row used; |
184 fun expnd c = |
164 fun expnd c = |
185 let val capp = |
165 let val capp = |
186 list_comb (fresh_constr ty_match ty_inst ty used' c) |
166 list_comb (fresh_constr ty_match ty_inst ty used' c) |
187 in ((prfx, capp :: rst), pattern_subst [(p, capp)] rhs) |
167 in ((prfx, capp :: rst), (subst_free [(p, capp)] rhs, tag)) |
188 end |
168 end |
189 in map expnd constructors end |
169 in map expnd constructors end |
190 else [row] |
170 else [row] |
191 fun mk {rows = [], ...} = raise CASE_ERROR ("no rows", ~1) |
171 fun mk {rows = [], ...} = raise CASE_ERROR ("no rows", ~1) |
192 | mk {path = [], rows = ((prfx, []), (tm, tag)) :: _} = (* Done *) |
172 | mk {path = [], rows = ((prfx, []), (tm, tag)) :: _} = (* Done *) |
197 let val col0 = map (fn ((_, p :: _), (_, i)) => (p, i)) rows in |
177 let val col0 = map (fn ((_, p :: _), (_, i)) => (p, i)) rows in |
198 (case Option.map (apfst head_of) (find_first (not o is_Free o fst) col0) of |
178 (case Option.map (apfst head_of) (find_first (not o is_Free o fst) col0) of |
199 NONE => |
179 NONE => |
200 let |
180 let |
201 val rows' = map (fn ((v, _), row) => row ||> |
181 val rows' = map (fn ((v, _), row) => row ||> |
202 pattern_subst [(v, u)] |>> v_to_prfx) (col0 ~~ rows); |
182 apfst (subst_free [(v, u)]) |>> v_to_prfx) (col0 ~~ rows); |
203 in mk {path = rstp, rows = rows'} end |
183 in mk {path = rstp, rows = rows'} end |
204 | SOME (Const (cname, cT), i) => |
184 | SOME (Const (cname, cT), i) => |
205 (case ty_info tab (cname, cT) of |
185 (case ty_info tab (cname, cT) of |
206 NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i) |
186 NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i) |
207 | SOME {case_name, constructors} => |
187 | SOME {case_name, constructors} => |
232 | mk _ = raise CASE_ERROR ("Malformed row matrix", ~1) |
212 | mk _ = raise CASE_ERROR ("Malformed row matrix", ~1) |
233 in mk end; |
213 in mk end; |
234 |
214 |
235 fun case_error s = error ("Error in case expression:\n" ^ s); |
215 fun case_error s = error ("Error in case expression:\n" ^ s); |
236 |
216 |
237 (* Repeated variable occurrences in a pattern are not allowed. *) |
217 (*Repeated variable occurrences in a pattern are not allowed.*) |
238 fun no_repeat_vars ctxt pat = fold_aterms |
218 fun no_repeat_vars ctxt pat = fold_aterms |
239 (fn x as Free (s, _) => (fn xs => |
219 (fn x as Free (s, _) => (fn xs => |
240 if member op aconv xs x then |
220 if member op aconv xs x then |
241 case_error (quote s ^ " occurs repeatedly in the pattern " ^ |
221 case_error (quote s ^ " occurs repeatedly in the pattern " ^ |
242 quote (Syntax.string_of_term ctxt pat)) |
222 quote (Syntax.string_of_term ctxt pat)) |
322 (flat cnstrts) t) cases; |
302 (flat cnstrts) t) cases; |
323 in case_tm end |
303 in case_tm end |
324 | case_tr _ _ _ ts = case_error "case_tr"; |
304 | case_tr _ _ _ ts = case_error "case_tr"; |
325 |
305 |
326 |
306 |
327 (*--------------------------------------------------------------------------- |
307 (* Pretty printing of nested case expressions *) |
328 * Pretty printing of nested case expressions |
|
329 *---------------------------------------------------------------------------*) |
|
330 |
308 |
331 (* destruct one level of pattern matching *) |
309 (* destruct one level of pattern matching *) |
332 |
310 |
333 fun gen_dest_case name_of type_of tab d used t = |
311 fun gen_dest_case name_of type_of tab d used t = |
334 (case apfst name_of (strip_comb t) of |
312 (case apfst name_of (strip_comb t) of |