285 |> Syntax.check_terms lthy |
285 |> Syntax.check_terms lthy |
286 |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.conceal (Thm.def_binding b), []), t))) |
286 |> map3 (fn b => fn mx => fn t => ((b, mx), ((Binding.conceal (Thm.def_binding b), []), t))) |
287 bs mxs |
287 bs mxs |
288 end; |
288 end; |
289 |
289 |
290 fun find_rec_calls has_call (eqn_data : eqn_data) = |
290 fun massage_comp ctxt has_call bound_Ts t = |
291 let |
291 massage_nested_corec_call ctxt has_call (K (K (K I))) bound_Ts (fastype_of1 (bound_Ts, t)) t; |
292 fun find (Abs (_, _, b)) ctr_arg = find b ctr_arg |
292 |
293 | find (t as _ $ _) ctr_arg = |
293 fun find_rec_calls ctxt has_call (eqn_data : eqn_data) = |
|
294 let |
|
295 fun find bound_Ts (Abs (_, T, b)) ctr_arg = find (T :: bound_Ts) b ctr_arg |
|
296 | find bound_Ts (t as _ $ _) ctr_arg = |
294 let |
297 let |
|
298 val typof = curry fastype_of1 bound_Ts; |
295 val (f', args') = strip_comb t; |
299 val (f', args') = strip_comb t; |
296 val n = find_index (equal ctr_arg) args'; |
300 val n = find_index (equal ctr_arg o head_of) args'; |
297 in |
301 in |
298 if n < 0 then |
302 if n < 0 then |
299 find f' ctr_arg @ maps (fn x => find x ctr_arg) args' |
303 find bound_Ts f' ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args' |
300 else |
304 else |
301 let val (f, args) = chop n args' |>> curry list_comb f' in |
305 let |
|
306 val (f, args as arg :: _) = chop n args' |>> curry list_comb f' |
|
307 val (arg_head, arg_args) = Term.strip_comb arg; |
|
308 in |
302 if has_call f then |
309 if has_call f then |
303 f :: maps (fn x => find x ctr_arg) args |
310 mk_partial_compN (length arg_args) (typof f) (typof arg_head) f :: |
|
311 maps (fn x => find bound_Ts x ctr_arg) args |
304 else |
312 else |
305 find f ctr_arg @ maps (fn x => find x ctr_arg) args |
313 find bound_Ts f ctr_arg @ maps (fn x => find bound_Ts x ctr_arg) args |
306 end |
314 end |
307 end |
315 end |
308 | find _ _ = []; |
316 | find _ _ _ = []; |
309 in |
317 in |
310 map (find (#rhs_term eqn_data)) (#ctr_args eqn_data) |
318 map (find [] (#rhs_term eqn_data)) (#ctr_args eqn_data) |
311 |> (fn [] => NONE | callss => SOME (#ctr eqn_data, callss)) |
319 |> (fn [] => NONE | callss => SOME (#ctr eqn_data, callss)) |
312 end; |
320 end; |
313 |
321 |
314 fun prepare_primrec fixes specs lthy = |
322 fun prepare_primrec fixes specs lthy = |
315 let |
323 let |
325 val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =)); |
333 val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =)); |
326 val arg_Ts = map (#rec_type o hd) funs_data; |
334 val arg_Ts = map (#rec_type o hd) funs_data; |
327 val res_Ts = map (#res_type o hd) funs_data; |
335 val res_Ts = map (#res_type o hd) funs_data; |
328 val callssss = funs_data |
336 val callssss = funs_data |
329 |> map (partition_eq ((op =) o pairself #ctr)) |
337 |> map (partition_eq ((op =) o pairself #ctr)) |
330 |> map (maps (map_filter (find_rec_calls has_call))); |
338 |> map (maps (map_filter (find_rec_calls lthy has_call))); |
331 |
339 |
332 val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy') = |
340 val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy') = |
333 rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy; |
341 rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy; |
334 |
342 |
335 val actual_nn = length funs_data; |
343 val actual_nn = length funs_data; |