53 (term * term list list) list list -> local_theory -> |
53 (term * term list list) list list -> local_theory -> |
54 typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * Token.src list * bool |
54 typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * Token.src list * bool |
55 * local_theory |
55 * local_theory |
56 val rec_specs_of: binding list -> typ list -> typ list -> term list -> |
56 val rec_specs_of: binding list -> typ list -> typ list -> term list -> |
57 (term * term list list) list list -> local_theory -> |
57 (term * term list list) list list -> local_theory -> |
58 (bool * rec_spec list * typ list * thm * thm list * Token.src list) * local_theory |
58 (bool * rec_spec list * typ list * thm * thm list * Token.src list * typ list) * local_theory |
|
59 |
|
60 val primrec_interpretation: |
|
61 string -> (BNF_FP_Rec_Sugar_Util.fp_rec_sugar -> local_theory -> local_theory) -> |
|
62 theory -> theory |
59 |
63 |
60 val add_primrec: (binding * typ option * mixfix) list -> |
64 val add_primrec: (binding * typ option * mixfix) list -> |
61 (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory |
65 (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory |
62 val add_primrec_cmd: primrec_option list -> (binding * string option * mixfix) list -> |
66 val add_primrec_cmd: primrec_option list -> (binding * string option * mixfix) list -> |
63 (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory |
67 (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory |
86 val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs; |
90 val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs; |
87 |
91 |
88 exception OLD_PRIMREC of unit; |
92 exception OLD_PRIMREC of unit; |
89 exception PRIMREC of string * term list; |
93 exception PRIMREC of string * term list; |
90 |
94 |
91 datatype primrec_option = Nonexhaustive_Option; |
95 datatype primrec_option = Nonexhaustive_Option | Transfer_Option; |
92 |
96 |
93 datatype rec_call = |
97 datatype rec_call = |
94 No_Rec of int * typ | |
98 No_Rec of int * typ | |
95 Mutual_Rec of (int * typ) * (int * typ) | |
99 Mutual_Rec of (int * typ) * (int * typ) | |
96 Nested_Rec of int * typ; |
100 Nested_Rec of int * typ; |
172 fun rewrite_nested_rec_call ctxt = |
176 fun rewrite_nested_rec_call ctxt = |
173 (case Data.get (Proof_Context.theory_of ctxt) of |
177 (case Data.get (Proof_Context.theory_of ctxt) of |
174 SOME {rewrite_nested_rec_call = SOME f, ...} => f ctxt |
178 SOME {rewrite_nested_rec_call = SOME f, ...} => f ctxt |
175 | _ => error "Unsupported nested recursion"); |
179 | _ => error "Unsupported nested recursion"); |
176 |
180 |
|
181 val transfer_primrec = morph_fp_rec_sugar o Morphism.transfer_morphism; |
|
182 |
|
183 structure Primrec_Plugin = Plugin(type T = fp_rec_sugar); |
|
184 |
|
185 fun primrec_interpretation name f = |
|
186 Primrec_Plugin.interpretation name (fn fp_rec_sugar => fn lthy => |
|
187 f (transfer_primrec (Proof_Context.theory_of lthy) fp_rec_sugar) lthy); |
|
188 |
|
189 val interpret_primrec = Primrec_Plugin.data_default; |
|
190 |
177 fun rec_specs_of bs arg_Ts res_Ts callers callssss0 lthy0 = |
191 fun rec_specs_of bs arg_Ts res_Ts callers callssss0 lthy0 = |
178 let |
192 let |
179 val thy = Proof_Context.theory_of lthy0; |
193 val thy = Proof_Context.theory_of lthy0; |
180 |
194 |
181 val (missing_arg_Ts, perm0_kks, basic_lfp_sugars, fp_nesting_map_ident0s, fp_nesting_map_comps, |
195 val (missing_arg_Ts, perm0_kks, basic_lfp_sugars, fp_nesting_map_ident0s, fp_nesting_map_comps, |
240 {recx = mk_co_rec thy Least_FP perm_Cs' (substAT T) recx, |
254 {recx = mk_co_rec thy Least_FP perm_Cs' (substAT T) recx, |
241 fp_nesting_map_ident0s = fp_nesting_map_ident0s, fp_nesting_map_comps = fp_nesting_map_comps, |
255 fp_nesting_map_ident0s = fp_nesting_map_ident0s, fp_nesting_map_comps = fp_nesting_map_comps, |
242 ctr_specs = mk_ctr_specs fp_res_index ctr_offset ctrs rec_thms}; |
256 ctr_specs = mk_ctr_specs fp_res_index ctr_offset ctrs rec_thms}; |
243 in |
257 in |
244 ((n2m, map2 mk_spec ctr_offsets basic_lfp_sugars, missing_arg_Ts, common_induct, inducts, |
258 ((n2m, map2 mk_spec ctr_offsets basic_lfp_sugars, missing_arg_Ts, common_induct, inducts, |
245 induct_attrs), lthy) |
259 induct_attrs, map #T basic_lfp_sugars), lthy) |
246 end; |
260 end; |
247 |
261 |
248 val undef_const = Const (@{const_name undefined}, dummyT); |
262 val undef_const = Const (@{const_name undefined}, dummyT); |
249 |
263 |
250 type eqn_data = { |
264 type eqn_data = { |
399 t |
413 t |
400 |> subst_rec_calls ctxt get_ctr_pos has_call ctr_args mutual_calls nested_calls |
414 |> subst_rec_calls ctxt get_ctr_pos has_call ctr_args mutual_calls nested_calls |
401 |> fold_rev lambda (args @ left_args @ right_args) |
415 |> fold_rev lambda (args @ left_args @ right_args) |
402 end); |
416 end); |
403 |
417 |
404 fun build_defs ctxt nonexhaustive bs mxs (funs_data : eqn_data list list) |
418 fun build_defs ctxt nonexhaustives bs mxs (funs_data : eqn_data list list) |
405 (rec_specs : rec_spec list) has_call = |
419 (rec_specs : rec_spec list) has_call = |
406 let |
420 let |
407 val n_funs = length funs_data; |
421 val n_funs = length funs_data; |
408 |
422 |
409 val ctr_spec_eqn_data_list' = |
423 val ctr_spec_eqn_data_list' = |
410 map #ctr_specs (take n_funs rec_specs) ~~ funs_data |
424 maps (fn ((xs, ys), z) => |
411 |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y)) |
425 let |
412 ##> (fn x => null x orelse |
426 val zs = replicate (length xs) z |
413 raise PRIMREC ("excess equations in definition", map #rhs_term x)) #> fst); |
427 val (b, c) = finds (fn ((x,_), y) => #ctr x = #ctr y) (xs ~~ zs) ys |
414 val _ = ctr_spec_eqn_data_list' |> map (fn ({ctr, ...}, x) => |
428 val (_ : bool ) = (fn x => null x orelse |
|
429 raise PRIMREC ("excess equations in definition", map #rhs_term x)) c |
|
430 in b end) (map #ctr_specs (take n_funs rec_specs) ~~ funs_data ~~ nonexhaustives); |
|
431 |
|
432 val (_ : unit list) = ctr_spec_eqn_data_list' |> map (fn (({ctr, ...}, nonexhaustive), x) => |
415 if length x > 1 then raise PRIMREC ("multiple equations for constructor", map #user_eqn x) |
433 if length x > 1 then raise PRIMREC ("multiple equations for constructor", map #user_eqn x) |
416 else if length x = 1 orelse nonexhaustive then () |
434 else if length x = 1 orelse nonexhaustive then () |
417 else warning ("no equation for constructor " ^ Syntax.string_of_term ctxt ctr)); |
435 else warning ("no equation for constructor " ^ Syntax.string_of_term ctxt ctr)); |
418 |
436 |
419 val ctr_spec_eqn_data_list = |
437 val ctr_spec_eqn_data_list = |
420 ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair [])); |
438 map (fn ((x, y), z) => (x, z)) ctr_spec_eqn_data_list' @ |
|
439 (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair [])); |
421 |
440 |
422 val recs = take n_funs rec_specs |> map #recx; |
441 val recs = take n_funs rec_specs |> map #recx; |
423 val rec_args = ctr_spec_eqn_data_list |
442 val rec_args = ctr_spec_eqn_data_list |
424 |> sort (op < o apply2 (#offset o fst) |> make_ord) |
443 |> sort (op < o apply2 (#offset o fst) |> make_ord) |
425 |> map (uncurry (build_rec_arg ctxt funs_data has_call) o apsnd (try the_single)); |
444 |> map (uncurry (build_rec_arg ctxt funs_data has_call) o apsnd (try the_single)); |
470 unfold_thms_tac ctxt fun_defs THEN |
489 unfold_thms_tac ctxt fun_defs THEN |
471 HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN |
490 HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN |
472 unfold_thms_tac ctxt (nested_simps ctxt @ map_ident0s @ map_comps) THEN |
491 unfold_thms_tac ctxt (nested_simps ctxt @ map_ident0s @ map_comps) THEN |
473 HEADGOAL (rtac refl); |
492 HEADGOAL (rtac refl); |
474 |
493 |
475 fun prepare_primrec nonexhaustive fixes specs lthy0 = |
494 fun prepare_primrec nonexhaustives transfers fixes specs lthy0 = |
476 let |
495 let |
477 val thy = Proof_Context.theory_of lthy0; |
496 val thy = Proof_Context.theory_of lthy0; |
478 |
497 |
479 val (bs, mxs) = map_split (apfst fst) fixes; |
498 val (bs, mxs) = map_split (apfst fst) fixes; |
480 val fun_names = map Binding.name_of bs; |
499 val fun_names = map Binding.name_of bs; |
500 val _ = if exists is_only_old_datatype arg_Ts then raise OLD_PRIMREC () else (); |
519 val _ = if exists is_only_old_datatype arg_Ts then raise OLD_PRIMREC () else (); |
501 val _ = (case filter_out (fn (_, T) => Sign.of_sort thy (T, @{sort type})) (bs ~~ res_Ts) of |
520 val _ = (case filter_out (fn (_, T) => Sign.of_sort thy (T, @{sort type})) (bs ~~ res_Ts) of |
502 [] => () |
521 [] => () |
503 | (b, _) :: _ => raise PRIMREC ("type of " ^ Binding.print b ^ " contains top sort", [])); |
522 | (b, _) :: _ => raise PRIMREC ("type of " ^ Binding.print b ^ " contains top sort", [])); |
504 |
523 |
505 val ((n2m, rec_specs, _, common_induct, inducts, induct_attrs), lthy) = |
524 val ((n2m, rec_specs, _, common_induct, inducts, induct_attrs, Ts), lthy) = |
506 rec_specs_of bs arg_Ts res_Ts frees callssss lthy0; |
525 rec_specs_of bs arg_Ts res_Ts frees callssss lthy0; |
507 |
526 |
508 val actual_nn = length funs_data; |
527 val actual_nn = length funs_data; |
509 |
528 |
510 val ctrs = maps (map #ctr o #ctr_specs) rec_specs; |
529 val ctrs = maps (map #ctr o #ctr_specs) rec_specs; |
511 val _ = |
530 val _ = |
512 map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse |
531 map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse |
513 raise PRIMREC ("argument " ^ quote (Syntax.string_of_term lthy ctr) ^ |
532 raise PRIMREC ("argument " ^ quote (Syntax.string_of_term lthy ctr) ^ |
514 " is not a constructor in left-hand side", [user_eqn])) eqns_data; |
533 " is not a constructor in left-hand side", [user_eqn])) eqns_data; |
515 |
534 |
516 val defs = build_defs lthy nonexhaustive bs mxs funs_data rec_specs has_call; |
535 val defs = build_defs lthy nonexhaustives bs mxs funs_data rec_specs has_call; |
517 |
536 |
518 fun prove lthy' def_thms' ({ctr_specs, fp_nesting_map_ident0s, fp_nesting_map_comps, ...} |
537 fun prove def_thms' ({ctr_specs, fp_nesting_map_ident0s, fp_nesting_map_comps, ...} |
519 : rec_spec) (fun_data : eqn_data list) = |
538 : rec_spec) (fun_data : eqn_data list) lthy' = |
520 let |
539 let |
521 val js = |
540 val js = |
522 find_indices (op = o apply2 (fn {fun_name, ctr, ...} => (fun_name, ctr))) |
541 find_indices (op = o apply2 (fn {fun_name, ctr, ...} => (fun_name, ctr))) |
523 fun_data eqns_data; |
542 fun_data eqns_data; |
524 |
543 |
553 |> map (fn (thmN, thms, attrs) => |
572 |> map (fn (thmN, thms, attrs) => |
554 ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])])); |
573 ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])])); |
555 in |
574 in |
556 (((fun_names, defs), |
575 (((fun_names, defs), |
557 fn lthy => fn defs => |
576 fn lthy => fn defs => |
558 split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)), |
577 let |
|
578 val def_thms = map (snd o snd) defs; |
|
579 val ts = map fst defs; |
|
580 val phi = Local_Theory.target_morphism lthy; |
|
581 in |
|
582 map_prod split_list |
|
583 (interpret_primrec {transfers = transfers, fun_names = fun_names, |
|
584 funs = (map (Morphism.term phi) ts), fun_defs = (Morphism.fact phi def_thms), |
|
585 fpTs = (take actual_nn Ts)}) |
|
586 (@{fold_map 2} (prove defs) (take actual_nn rec_specs) funs_data lthy) |
|
587 end), |
559 lthy |> Local_Theory.notes (notes @ common_notes) |> snd) |
588 lthy |> Local_Theory.notes (notes @ common_notes) |> snd) |
560 end; |
589 end; |
561 |
590 |
562 fun add_primrec_simple' opts fixes ts lthy = |
591 fun add_primrec_simple' opts fixes ts lthy = |
563 let |
592 let |
564 val nonexhaustive = member (op =) opts Nonexhaustive_Option; |
593 val actual_nn = length fixes; |
565 val (((names, defs), prove), lthy') = prepare_primrec nonexhaustive fixes ts lthy |
594 val nonexhaustives = replicate actual_nn (member (op =) opts Nonexhaustive_Option); |
|
595 val transfers = replicate actual_nn (member (op =) opts Transfer_Option); |
|
596 val (((names, defs), prove), lthy') = prepare_primrec nonexhaustives transfers fixes ts lthy |
566 handle ERROR str => raise PRIMREC (str, []); |
597 handle ERROR str => raise PRIMREC (str, []); |
567 in |
598 in |
568 lthy' |
599 lthy' |
569 |> fold_map Local_Theory.define defs |
600 |> fold_map Local_Theory.define defs |
570 |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs)))) |
601 |-> (fn defs => fn lthy => |
|
602 let val (thms, lthy) = prove lthy defs; |
|
603 in ((names, (map fst defs, thms)), lthy) end) |
571 end |
604 end |
572 handle PRIMREC (str, eqns) => |
605 handle PRIMREC (str, eqns) => |
573 if null eqns then |
606 if null eqns then |
574 error ("primrec error:\n " ^ str) |
607 error ("primrec error:\n " ^ str) |
575 else |
608 else |
624 Overloading.overloading ops |
657 Overloading.overloading ops |
625 #> add_primrec fixes specs |
658 #> add_primrec fixes specs |
626 ##> Local_Theory.exit_global; |
659 ##> Local_Theory.exit_global; |
627 |
660 |
628 val primrec_option_parser = Parse.group (fn () => "option") |
661 val primrec_option_parser = Parse.group (fn () => "option") |
629 (Parse.reserved "nonexhaustive" >> K Nonexhaustive_Option) |
662 (Parse.reserved "nonexhaustive" >> K Nonexhaustive_Option |
|
663 || Parse.reserved "transfer" >> K Transfer_Option) |
630 |
664 |
631 val _ = Outer_Syntax.local_theory @{command_spec "primrec"} |
665 val _ = Outer_Syntax.local_theory @{command_spec "primrec"} |
632 "define primitive recursive functions" |
666 "define primitive recursive functions" |
633 ((Scan.optional (@{keyword "("} |-- |
667 ((Scan.optional (@{keyword "("} |-- |
634 Parse.!!! (Parse.list1 primrec_option_parser) --| @{keyword ")"}) []) -- |
668 Parse.!!! (Parse.list1 primrec_option_parser) --| @{keyword ")"}) []) -- |