209 if tnames' subset (map (#1 o snd) (#descr dt)) then |
211 if tnames' subset (map (#1 o snd) (#descr dt)) then |
210 (tname, dt)::(find_dts dt_info tnames' tnames) |
212 (tname, dt)::(find_dts dt_info tnames' tnames) |
211 else find_dts dt_info tnames' tnames); |
213 else find_dts dt_info tnames' tnames); |
212 |
214 |
213 |
215 |
214 (* primrec definition *) |
216 (* distill primitive definition(s) from primrec specification *) |
215 |
217 |
216 local |
218 fun distill lthy fixes eqs = |
217 |
219 let |
218 fun prove_spec ctxt names rec_rewrites defs eqs = |
|
219 let |
|
220 val rewrites = map mk_meta_eq rec_rewrites @ map (snd o snd) defs; |
|
221 fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1]; |
|
222 val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names); |
|
223 in map (fn (a, t) => (a, [Goal.prove ctxt [] [] t tac])) eqs end; |
|
224 |
|
225 fun gen_primrec set_group prep_spec raw_fixes raw_spec lthy = |
|
226 let |
|
227 val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy); |
|
228 val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v |
220 val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v |
229 orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes) o snd) spec []; |
221 orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs []; |
230 val tnames = distinct (op =) (map (#1 o snd) eqns); |
222 val tnames = distinct (op =) (map (#1 o snd) eqns); |
231 val dts = find_dts (DatatypePackage.get_datatypes (ProofContext.theory_of lthy)) tnames tnames; |
223 val dts = find_dts (DatatypePackage.get_datatypes (ProofContext.theory_of lthy)) tnames tnames; |
232 val main_fns = map (fn (tname, {index, ...}) => |
224 val main_fns = map (fn (tname, {index, ...}) => |
233 (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts; |
225 (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts; |
234 val {descr, rec_names, rec_rewrites, ...} = |
226 val {descr, rec_names, rec_rewrites, ...} = |
235 if null dts then primrec_error |
227 if null dts then primrec_error |
236 ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive") |
228 ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive") |
237 else snd (hd dts); |
229 else snd (hd dts); |
238 val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []); |
230 val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []); |
239 val (fs, defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []); |
231 val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []); |
240 val names1 = map snd fnames; |
232 val defs = map (make_def lthy fixes fs) raw_defs; |
241 val names2 = map fst eqns; |
233 val names = map snd fnames; |
242 val _ = if gen_eq_set (op =) (names1, names2) then () |
234 val names_eqns = map fst eqns; |
243 else primrec_error ("functions " ^ commas_quote names2 ^ |
235 val _ = if gen_eq_set (op =) (names, names_eqns) then () |
|
236 else primrec_error ("functions " ^ commas_quote names_eqns ^ |
244 "\nare not mutually recursive"); |
237 "\nare not mutually recursive"); |
245 val prefix = space_implode "_" (map (Long_Name.base_name o #1) defs); |
238 val rec_rewrites' = map mk_meta_eq rec_rewrites; |
246 val qualify = Binding.qualify false prefix; |
239 val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs); |
247 val spec' = (map o apfst) |
240 fun prove lthy defs = |
248 (fn (b, attrs) => (qualify b, Code.add_default_eqn_attrib :: attrs)) spec; |
241 let |
249 val simp_atts = map (Attrib.internal o K) |
242 val rewrites = rec_rewrites' @ map (snd o snd) defs; |
250 [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add, Quickcheck_RecFun_Simp_Thms.add]; |
243 fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1]; |
|
244 val _ = message ("Proving equations for primrec function(s) " ^ commas_quote names); |
|
245 in map (fn eq => [Goal.prove lthy [] [] eq tac]) eqs end; |
|
246 in ((prefix, (fs, defs)), prove) end |
|
247 handle PrimrecError (msg, some_eqn) => |
|
248 error ("Primrec definition error:\n" ^ msg ^ (case some_eqn |
|
249 of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn) |
|
250 | NONE => "")); |
|
251 |
|
252 |
|
253 (* primrec definition *) |
|
254 |
|
255 fun add_primrec_simple fixes spec lthy = |
|
256 let |
|
257 val ((prefix, (fs, defs)), prove) = distill lthy fixes (map snd spec); |
|
258 in |
|
259 lthy |
|
260 |> fold_map (LocalTheory.define Thm.definitionK) defs |
|
261 |-> (fn defs => `(fn lthy => (prefix, prove lthy defs))) |
|
262 end; |
|
263 |
|
264 local |
|
265 |
|
266 fun gen_primrec set_group prep_spec raw_fixes raw_spec lthy = |
|
267 let |
|
268 val (fixes, spec) = fst (prep_spec raw_fixes raw_spec lthy); |
|
269 fun attr_bindings prefix = map (fn ((b, attrs), _) => |
|
270 (Binding.qualify false prefix b, Code.add_default_eqn_attrib :: attrs)) spec; |
|
271 fun simp_attr_binding prefix = (Binding.qualify false prefix (Binding.name "simps"), |
|
272 map (Attrib.internal o K) |
|
273 [Simplifier.simp_add, Nitpick_Const_Simp_Thms.add, Quickcheck_RecFun_Simp_Thms.add]); |
251 in |
274 in |
252 lthy |
275 lthy |
253 |> set_group ? LocalTheory.set_group (serial_string ()) |
276 |> set_group ? LocalTheory.set_group (serial_string ()) |
254 |> fold_map (LocalTheory.define Thm.definitionK o make_def lthy fixes fs) defs |
277 |> add_primrec_simple fixes spec |
255 |-> (fn defs => `(fn ctxt => prove_spec ctxt names1 rec_rewrites defs spec')) |
278 |-> (fn (prefix, simps) => fold_map (LocalTheory.note Thm.generatedK) |
256 |-> (fn simps => fold_map (LocalTheory.note Thm.generatedK) simps) |
279 (attr_bindings prefix ~~ simps) |
257 |-> (fn simps' => LocalTheory.note Thm.generatedK |
280 #-> (fn simps' => LocalTheory.note Thm.generatedK |
258 ((qualify (Binding.qualified_name "simps"), simp_atts), maps snd simps')) |
281 (simp_attr_binding prefix, maps snd simps'))) |
259 |>> snd |
282 |>> snd |
260 end handle PrimrecError (msg, some_eqn) => |
283 end; |
261 error ("Primrec definition error:\n" ^ msg ^ (case some_eqn |
|
262 of SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn) |
|
263 | NONE => "")); |
|
264 |
284 |
265 in |
285 in |
266 |
286 |
267 val add_primrec = gen_primrec false Specification.check_spec; |
287 val add_primrec = gen_primrec false Specification.check_spec; |
268 val add_primrec_cmd = gen_primrec true Specification.read_spec; |
288 val add_primrec_cmd = gen_primrec true Specification.read_spec; |