|
1 (* Title: HOL/Tools/Old_Datatype/old_primrec.ML |
|
2 Author: Norbert Voelker, FernUni Hagen |
|
3 Author: Stefan Berghofer, TU Muenchen |
|
4 Author: Florian Haftmann, TU Muenchen |
|
5 |
|
6 Primitive recursive functions on datatypes. |
|
7 *) |
|
8 |
|
9 signature OLD_PRIMREC = |
|
10 sig |
|
11 val add_primrec: (binding * typ option * mixfix) list -> |
|
12 (Attrib.binding * term) list -> local_theory -> (term list * thm list) * local_theory |
|
13 val add_primrec_cmd: (binding * string option * mixfix) list -> |
|
14 (Attrib.binding * string) list -> local_theory -> (term list * thm list) * local_theory |
|
15 val add_primrec_global: (binding * typ option * mixfix) list -> |
|
16 (Attrib.binding * term) list -> theory -> (term list * thm list) * theory |
|
17 val add_primrec_overloaded: (string * (string * typ) * bool) list -> |
|
18 (binding * typ option * mixfix) list -> |
|
19 (Attrib.binding * term) list -> theory -> (term list * thm list) * theory |
|
20 val add_primrec_simple: ((binding * typ) * mixfix) list -> term list -> |
|
21 local_theory -> (string * (term list * thm list)) * local_theory |
|
22 end; |
|
23 |
|
24 structure Old_Primrec : OLD_PRIMREC = |
|
25 struct |
|
26 |
|
27 exception PrimrecError of string * term option; |
|
28 |
|
29 fun primrec_error msg = raise PrimrecError (msg, NONE); |
|
30 fun primrec_error_eqn msg eqn = raise PrimrecError (msg, SOME eqn); |
|
31 |
|
32 |
|
33 (* preprocessing of equations *) |
|
34 |
|
35 fun process_eqn is_fixed spec rec_fns = |
|
36 let |
|
37 val (vs, Ts) = split_list (strip_qnt_vars @{const_name Pure.all} spec); |
|
38 val body = strip_qnt_body @{const_name Pure.all} spec; |
|
39 val (vs', _) = fold_map Name.variant vs (Name.make_context (fold_aterms |
|
40 (fn Free (v, _) => insert (op =) v | _ => I) body [])); |
|
41 val eqn = curry subst_bounds (map2 (curry Free) vs' Ts |> rev) body; |
|
42 val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eqn) |
|
43 handle TERM _ => primrec_error "not a proper equation"; |
|
44 val (recfun, args) = strip_comb lhs; |
|
45 val fname = |
|
46 (case recfun of |
|
47 Free (v, _) => |
|
48 if is_fixed v then v |
|
49 else primrec_error "illegal head of function equation" |
|
50 | _ => primrec_error "illegal head of function equation"); |
|
51 |
|
52 val (ls', rest) = take_prefix is_Free args; |
|
53 val (middle, rs') = take_suffix is_Free rest; |
|
54 val rpos = length ls'; |
|
55 |
|
56 val (constr, cargs') = |
|
57 if null middle then primrec_error "constructor missing" |
|
58 else strip_comb (hd middle); |
|
59 val (cname, T) = dest_Const constr |
|
60 handle TERM _ => primrec_error "ill-formed constructor"; |
|
61 val (tname, _) = dest_Type (body_type T) handle TYPE _ => |
|
62 primrec_error "cannot determine datatype associated with function" |
|
63 |
|
64 val (ls, cargs, rs) = |
|
65 (map dest_Free ls', map dest_Free cargs', map dest_Free rs') |
|
66 handle TERM _ => primrec_error "illegal argument in pattern"; |
|
67 val lfrees = ls @ rs @ cargs; |
|
68 |
|
69 fun check_vars _ [] = () |
|
70 | check_vars s vars = primrec_error (s ^ commas_quote (map fst vars)) eqn; |
|
71 in |
|
72 if length middle > 1 then |
|
73 primrec_error "more than one non-variable in pattern" |
|
74 else |
|
75 (check_vars "repeated variable names in pattern: " (duplicates (op =) lfrees); |
|
76 check_vars "extra variables on rhs: " |
|
77 (Term.add_frees rhs [] |> subtract (op =) lfrees |
|
78 |> filter_out (is_fixed o fst)); |
|
79 (case AList.lookup (op =) rec_fns fname of |
|
80 NONE => |
|
81 (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eqn))])) :: rec_fns |
|
82 | SOME (_, rpos', eqns) => |
|
83 if AList.defined (op =) eqns cname then |
|
84 primrec_error "constructor already occurred as pattern" |
|
85 else if rpos <> rpos' then |
|
86 primrec_error "position of recursive argument inconsistent" |
|
87 else |
|
88 AList.update (op =) |
|
89 (fname, (tname, rpos, (cname, (ls, cargs, rs, rhs, eqn)) :: eqns)) |
|
90 rec_fns)) |
|
91 end handle PrimrecError (msg, NONE) => primrec_error_eqn msg spec; |
|
92 |
|
93 fun process_fun descr eqns (i, fname) (fnames, fnss) = |
|
94 let |
|
95 val (_, (tname, _, constrs)) = nth descr i; |
|
96 |
|
97 (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *) |
|
98 |
|
99 fun subst [] t fs = (t, fs) |
|
100 | subst subs (Abs (a, T, t)) fs = |
|
101 fs |
|
102 |> subst subs t |
|
103 |-> (fn t' => pair (Abs (a, T, t'))) |
|
104 | subst subs (t as (_ $ _)) fs = |
|
105 let |
|
106 val (f, ts) = strip_comb t; |
|
107 in |
|
108 if is_Free f |
|
109 andalso member (fn ((v, _), (w, _)) => v = w) eqns (dest_Free f) then |
|
110 let |
|
111 val (fname', _) = dest_Free f; |
|
112 val (_, rpos, _) = the (AList.lookup (op =) eqns fname'); |
|
113 val (ls, rs) = chop rpos ts |
|
114 val (x', rs') = |
|
115 (case rs of |
|
116 x' :: rs => (x', rs) |
|
117 | [] => primrec_error ("not enough arguments in recursive application\n" ^ |
|
118 "of function " ^ quote fname' ^ " on rhs")); |
|
119 val (x, xs) = strip_comb x'; |
|
120 in |
|
121 (case AList.lookup (op =) subs x of |
|
122 NONE => |
|
123 fs |
|
124 |> fold_map (subst subs) ts |
|
125 |-> (fn ts' => pair (list_comb (f, ts'))) |
|
126 | SOME (i', y) => |
|
127 fs |
|
128 |> fold_map (subst subs) (xs @ ls @ rs') |
|
129 ||> process_fun descr eqns (i', fname') |
|
130 |-> (fn ts' => pair (list_comb (y, ts')))) |
|
131 end |
|
132 else |
|
133 fs |
|
134 |> fold_map (subst subs) (f :: ts) |
|
135 |-> (fn f' :: ts' => pair (list_comb (f', ts'))) |
|
136 end |
|
137 | subst _ t fs = (t, fs); |
|
138 |
|
139 (* translate rec equations into function arguments suitable for rec comb *) |
|
140 |
|
141 fun trans eqns (cname, cargs) (fnames', fnss', fns) = |
|
142 (case AList.lookup (op =) eqns cname of |
|
143 NONE => (warning ("No equation for constructor " ^ quote cname ^ |
|
144 "\nin definition of function " ^ quote fname); |
|
145 (fnames', fnss', (Const (@{const_name undefined}, dummyT)) :: fns)) |
|
146 | SOME (ls, cargs', rs, rhs, eq) => |
|
147 let |
|
148 val recs = filter (Old_Datatype_Aux.is_rec_type o snd) (cargs' ~~ cargs); |
|
149 val rargs = map fst recs; |
|
150 val subs = map (rpair dummyT o fst) |
|
151 (rev (Term.rename_wrt_term rhs rargs)); |
|
152 val (rhs', (fnames'', fnss'')) = subst (map2 (fn (x, y) => fn z => |
|
153 (Free x, (Old_Datatype_Aux.body_index y, Free z))) recs subs) rhs (fnames', fnss') |
|
154 handle PrimrecError (s, NONE) => primrec_error_eqn s eq |
|
155 in |
|
156 (fnames'', fnss'', fold_rev absfree (cargs' @ subs @ ls @ rs) rhs' :: fns) |
|
157 end) |
|
158 |
|
159 in |
|
160 (case AList.lookup (op =) fnames i of |
|
161 NONE => |
|
162 if exists (fn (_, v) => fname = v) fnames then |
|
163 primrec_error ("inconsistent functions for datatype " ^ quote tname) |
|
164 else |
|
165 let |
|
166 val (_, _, eqns) = the (AList.lookup (op =) eqns fname); |
|
167 val (fnames', fnss', fns) = fold_rev (trans eqns) constrs |
|
168 ((i, fname) :: fnames, fnss, []) |
|
169 in |
|
170 (fnames', (i, (fname, #1 (snd (hd eqns)), fns)) :: fnss') |
|
171 end |
|
172 | SOME fname' => |
|
173 if fname = fname' then (fnames, fnss) |
|
174 else primrec_error ("inconsistent functions for datatype " ^ quote tname)) |
|
175 end; |
|
176 |
|
177 |
|
178 (* prepare functions needed for definitions *) |
|
179 |
|
180 fun get_fns fns ((i : int, (tname, _, constrs)), rec_name) (fs, defs) = |
|
181 (case AList.lookup (op =) fns i of |
|
182 NONE => |
|
183 let |
|
184 val dummy_fns = map (fn (_, cargs) => Const (@{const_name undefined}, |
|
185 replicate (length cargs + length (filter Old_Datatype_Aux.is_rec_type cargs)) |
|
186 dummyT ---> HOLogic.unitT)) constrs; |
|
187 val _ = warning ("No function definition for datatype " ^ quote tname) |
|
188 in |
|
189 (dummy_fns @ fs, defs) |
|
190 end |
|
191 | SOME (fname, ls, fs') => (fs' @ fs, (fname, ls, rec_name, tname) :: defs)); |
|
192 |
|
193 |
|
194 (* make definition *) |
|
195 |
|
196 fun make_def ctxt fixes fs (fname, ls, rec_name, tname) = |
|
197 let |
|
198 val SOME (var, varT) = get_first (fn ((b, T), mx) => |
|
199 if Binding.name_of b = fname then SOME ((b, mx), T) else NONE) fixes; |
|
200 val def_name = Thm.def_name (Long_Name.base_name fname); |
|
201 val raw_rhs = fold_rev (fn T => fn t => Abs ("", T, t)) (map snd ls @ [dummyT]) |
|
202 (list_comb (Const (rec_name, dummyT), fs @ map Bound (0 :: (length ls downto 1)))) |
|
203 val rhs = singleton (Syntax.check_terms ctxt) (Type.constraint varT raw_rhs); |
|
204 in (var, ((Binding.conceal (Binding.name def_name), []), rhs)) end; |
|
205 |
|
206 |
|
207 (* find datatypes which contain all datatypes in tnames' *) |
|
208 |
|
209 fun find_dts _ _ [] = [] |
|
210 | find_dts dt_info tnames' (tname :: tnames) = |
|
211 (case Symtab.lookup dt_info tname of |
|
212 NONE => primrec_error (quote tname ^ " is not a datatype") |
|
213 | SOME (dt : Old_Datatype_Aux.info) => |
|
214 if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then |
|
215 (tname, dt) :: (find_dts dt_info tnames' tnames) |
|
216 else find_dts dt_info tnames' tnames); |
|
217 |
|
218 |
|
219 (* distill primitive definition(s) from primrec specification *) |
|
220 |
|
221 fun distill ctxt fixes eqs = |
|
222 let |
|
223 val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed ctxt v |
|
224 orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs []; |
|
225 val tnames = distinct (op =) (map (#1 o snd) eqns); |
|
226 val dts = find_dts (Old_Datatype_Data.get_all (Proof_Context.theory_of ctxt)) tnames tnames; |
|
227 val main_fns = map (fn (tname, {index, ...}) => |
|
228 (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts; |
|
229 val {descr, rec_names, rec_rewrites, ...} = |
|
230 if null dts then primrec_error |
|
231 ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive") |
|
232 else snd (hd dts); |
|
233 val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []); |
|
234 val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []); |
|
235 val defs = map (make_def ctxt fixes fs) raw_defs; |
|
236 val names = map snd fnames; |
|
237 val names_eqns = map fst eqns; |
|
238 val _ = |
|
239 if eq_set (op =) (names, names_eqns) then () |
|
240 else primrec_error ("functions " ^ commas_quote names_eqns ^ |
|
241 "\nare not mutually recursive"); |
|
242 val rec_rewrites' = map mk_meta_eq rec_rewrites; |
|
243 val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs); |
|
244 fun prove ctxt defs = |
|
245 let |
|
246 val frees = fold (Variable.add_free_names ctxt) eqs []; |
|
247 val rewrites = rec_rewrites' @ map (snd o snd) defs; |
|
248 in |
|
249 map (fn eq => Goal.prove ctxt frees [] eq |
|
250 (fn {context = ctxt', ...} => EVERY [rewrite_goals_tac ctxt' rewrites, rtac refl 1])) eqs |
|
251 end; |
|
252 in ((prefix, (fs, defs)), prove) end |
|
253 handle PrimrecError (msg, some_eqn) => |
|
254 error ("Primrec definition error:\n" ^ msg ^ |
|
255 (case some_eqn of |
|
256 SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term ctxt eqn) |
|
257 | NONE => "")); |
|
258 |
|
259 |
|
260 (* primrec definition *) |
|
261 |
|
262 fun add_primrec_simple fixes ts lthy = |
|
263 let |
|
264 val ((prefix, (_, defs)), prove) = distill lthy fixes ts; |
|
265 in |
|
266 lthy |
|
267 |> fold_map Local_Theory.define defs |
|
268 |-> (fn defs => `(fn lthy => (prefix, (map fst defs, prove lthy defs)))) |
|
269 end; |
|
270 |
|
271 local |
|
272 |
|
273 fun gen_primrec prep_spec raw_fixes raw_spec lthy = |
|
274 let |
|
275 val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy); |
|
276 fun attr_bindings prefix = map (fn ((b, attrs), _) => |
|
277 (Binding.qualify false prefix b, Code.add_default_eqn_attrib :: attrs)) spec; |
|
278 fun simp_attr_binding prefix = |
|
279 (Binding.qualify true prefix (Binding.name "simps"), @{attributes [simp, nitpick_simp]}); |
|
280 in |
|
281 lthy |
|
282 |> add_primrec_simple fixes (map snd spec) |
|
283 |-> (fn (prefix, (ts, simps)) => |
|
284 Spec_Rules.add Spec_Rules.Equational (ts, simps) |
|
285 #> fold_map Local_Theory.note (attr_bindings prefix ~~ map single simps) |
|
286 #-> (fn simps' => Local_Theory.note (simp_attr_binding prefix, maps snd simps') |
|
287 #>> (fn (_, simps'') => (ts, simps'')))) |
|
288 end; |
|
289 |
|
290 in |
|
291 |
|
292 val add_primrec = gen_primrec Specification.check_spec; |
|
293 val add_primrec_cmd = gen_primrec Specification.read_spec; |
|
294 |
|
295 end; |
|
296 |
|
297 fun add_primrec_global fixes specs thy = |
|
298 let |
|
299 val lthy = Named_Target.theory_init thy; |
|
300 val ((ts, simps), lthy') = add_primrec fixes specs lthy; |
|
301 val simps' = Proof_Context.export lthy' lthy simps; |
|
302 in ((ts, simps'), Local_Theory.exit_global lthy') end; |
|
303 |
|
304 fun add_primrec_overloaded ops fixes specs thy = |
|
305 let |
|
306 val lthy = Overloading.overloading ops thy; |
|
307 val ((ts, simps), lthy') = add_primrec fixes specs lthy; |
|
308 val simps' = Proof_Context.export lthy' lthy simps; |
|
309 in ((ts, simps'), Local_Theory.exit_global lthy') end; |
|
310 |
|
311 end; |