204 in (var, ((Binding.conceal (Binding.name def_name), []), rhs)) end; |
204 in (var, ((Binding.conceal (Binding.name def_name), []), rhs)) end; |
205 |
205 |
206 |
206 |
207 (* find datatypes which contain all datatypes in tnames' *) |
207 (* find datatypes which contain all datatypes in tnames' *) |
208 |
208 |
209 fun find_dts (dt_info : Datatype_Aux.info Symtab.table) _ [] = [] |
209 fun find_dts _ _ [] = [] |
210 | find_dts dt_info tnames' (tname :: tnames) = |
210 | find_dts dt_info tnames' (tname :: tnames) = |
211 (case Symtab.lookup dt_info tname of |
211 (case Symtab.lookup dt_info tname of |
212 NONE => primrec_error (quote tname ^ " is not a datatype") |
212 NONE => primrec_error (quote tname ^ " is not a datatype") |
213 | SOME dt => |
213 | SOME (dt : Datatype_Aux.info) => |
214 if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then |
214 if subset (op =) (tnames', map (#1 o snd) (#descr dt)) then |
215 (tname, dt) :: (find_dts dt_info tnames' tnames) |
215 (tname, dt) :: (find_dts dt_info tnames' tnames) |
216 else find_dts dt_info tnames' tnames); |
216 else find_dts dt_info tnames' tnames); |
217 |
217 |
218 |
218 |
219 (* distill primitive definition(s) from primrec specification *) |
219 (* distill primitive definition(s) from primrec specification *) |
220 |
220 |
221 fun distill lthy fixes eqs = |
221 fun distill ctxt fixes eqs = |
222 let |
222 let |
223 val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed lthy v |
223 val eqns = fold_rev (process_eqn (fn v => Variable.is_fixed ctxt v |
224 orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs []; |
224 orelse exists (fn ((w, _), _) => v = Binding.name_of w) fixes)) eqs []; |
225 val tnames = distinct (op =) (map (#1 o snd) eqns); |
225 val tnames = distinct (op =) (map (#1 o snd) eqns); |
226 val dts = find_dts (Datatype_Data.get_all (Proof_Context.theory_of lthy)) tnames tnames; |
226 val dts = find_dts (Datatype_Data.get_all (Proof_Context.theory_of ctxt)) tnames tnames; |
227 val main_fns = map (fn (tname, {index, ...}) => |
227 val main_fns = map (fn (tname, {index, ...}) => |
228 (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts; |
228 (index, (fst o the o find_first (fn (_, x) => #1 x = tname)) eqns)) dts; |
229 val {descr, rec_names, rec_rewrites, ...} = |
229 val {descr, rec_names, rec_rewrites, ...} = |
230 if null dts then primrec_error |
230 if null dts then primrec_error |
231 ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive") |
231 ("datatypes " ^ commas_quote tnames ^ "\nare not mutually recursive") |
232 else snd (hd dts); |
232 else snd (hd dts); |
233 val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []); |
233 val (fnames, fnss) = fold_rev (process_fun descr eqns) main_fns ([], []); |
234 val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []); |
234 val (fs, raw_defs) = fold_rev (get_fns fnss) (descr ~~ rec_names) ([], []); |
235 val defs = map (make_def lthy fixes fs) raw_defs; |
235 val defs = map (make_def ctxt fixes fs) raw_defs; |
236 val names = map snd fnames; |
236 val names = map snd fnames; |
237 val names_eqns = map fst eqns; |
237 val names_eqns = map fst eqns; |
238 val _ = |
238 val _ = |
239 if eq_set (op =) (names, names_eqns) then () |
239 if eq_set (op =) (names, names_eqns) then () |
240 else primrec_error ("functions " ^ commas_quote names_eqns ^ |
240 else primrec_error ("functions " ^ commas_quote names_eqns ^ |
241 "\nare not mutually recursive"); |
241 "\nare not mutually recursive"); |
242 val rec_rewrites' = map mk_meta_eq rec_rewrites; |
242 val rec_rewrites' = map mk_meta_eq rec_rewrites; |
243 val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs); |
243 val prefix = space_implode "_" (map (Long_Name.base_name o #1) raw_defs); |
244 fun prove lthy defs = |
244 fun prove ctxt defs = |
245 let |
245 let |
246 val frees = fold (Variable.add_free_names lthy) eqs []; |
246 val frees = fold (Variable.add_free_names ctxt) eqs []; |
247 val rewrites = rec_rewrites' @ map (snd o snd) defs; |
247 val rewrites = rec_rewrites' @ map (snd o snd) defs; |
248 fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1]; |
248 fun tac _ = EVERY [rewrite_goals_tac rewrites, rtac refl 1]; |
249 in map (fn eq => Goal.prove lthy frees [] eq tac) eqs end; |
249 in map (fn eq => Goal.prove ctxt frees [] eq tac) eqs end; |
250 in ((prefix, (fs, defs)), prove) end |
250 in ((prefix, (fs, defs)), prove) end |
251 handle PrimrecError (msg, some_eqn) => |
251 handle PrimrecError (msg, some_eqn) => |
252 error ("Primrec definition error:\n" ^ msg ^ |
252 error ("Primrec definition error:\n" ^ msg ^ |
253 (case some_eqn of |
253 (case some_eqn of |
254 SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term lthy eqn) |
254 SOME eqn => "\nin\n" ^ quote (Syntax.string_of_term ctxt eqn) |
255 | NONE => "")); |
255 | NONE => "")); |
256 |
256 |
257 |
257 |
258 (* primrec definition *) |
258 (* primrec definition *) |
259 |
259 |
260 fun add_primrec_simple fixes ts lthy = |
260 fun add_primrec_simple fixes ts lthy = |
261 let |
261 let |
262 val ((prefix, (fs, defs)), prove) = distill lthy fixes ts; |
262 val ((prefix, (_, defs)), prove) = distill lthy fixes ts; |
263 in |
263 in |
264 lthy |
264 lthy |
265 |> fold_map Local_Theory.define defs |
265 |> fold_map Local_Theory.define defs |
266 |-> (fn defs => `(fn lthy => (prefix, (map fst defs, prove lthy defs)))) |
266 |-> (fn defs => `(fn lthy => (prefix, (map fst defs, prove lthy defs)))) |
267 end; |
267 end; |