464 HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN |
464 HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN |
465 unfold_thms_tac ctxt (@{thms id_def split comp_def fst_conv snd_conv} @ map_comps @ |
465 unfold_thms_tac ctxt (@{thms id_def split comp_def fst_conv snd_conv} @ map_comps @ |
466 map_idents) THEN |
466 map_idents) THEN |
467 HEADGOAL (rtac refl); |
467 HEADGOAL (rtac refl); |
468 |
468 |
469 fun prepare_primrec fixes specs lthy = |
469 fun prepare_primrec fixes specs lthy0 = |
470 let |
470 let |
471 val thy = Proof_Context.theory_of lthy; |
471 val thy = Proof_Context.theory_of lthy0; |
472 |
472 |
473 val (bs, mxs) = map_split (apfst fst) fixes; |
473 val (bs, mxs) = map_split (apfst fst) fixes; |
474 val fun_names = map Binding.name_of bs; |
474 val fun_names = map Binding.name_of bs; |
475 val eqns_data = map (dissect_eqn lthy fun_names) specs; |
475 val eqns_data = map (dissect_eqn lthy0 fun_names) specs; |
476 val funs_data = eqns_data |
476 val funs_data = eqns_data |
477 |> partition_eq ((op =) o pairself #fun_name) |
477 |> partition_eq ((op =) o pairself #fun_name) |
478 |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst |
478 |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst |
479 |> map (fn (x, y) => the_single y |
479 |> map (fn (x, y) => the_single y |
480 handle List.Empty => |
480 handle List.Empty => |
486 val callssss = funs_data |
486 val callssss = funs_data |
487 |> map (partition_eq ((op =) o pairself #ctr)) |
487 |> map (partition_eq ((op =) o pairself #ctr)) |
488 |> map (maps (map_filter (find_rec_calls has_call))); |
488 |> map (maps (map_filter (find_rec_calls has_call))); |
489 |
489 |
490 fun is_only_old_datatype (Type (s, _)) = |
490 fun is_only_old_datatype (Type (s, _)) = |
491 is_none (fp_sugar_of lthy s) andalso is_some (Datatype_Data.get_info thy s) |
491 is_none (fp_sugar_of lthy0 s) andalso is_some (Datatype_Data.get_info thy s) |
492 | is_only_old_datatype _ = false; |
492 | is_only_old_datatype _ = false; |
493 |
493 |
494 val _ = if exists is_only_old_datatype arg_Ts then raise OLD_PRIMREC () else (); |
494 val _ = if exists is_only_old_datatype arg_Ts then raise OLD_PRIMREC () else (); |
495 val _ = (case filter_out (fn (_, T) => Sign.of_sort thy (T, HOLogic.typeS)) (bs ~~ res_Ts) of |
495 val _ = (case filter_out (fn (_, T) => Sign.of_sort thy (T, HOLogic.typeS)) (bs ~~ res_Ts) of |
496 [] => () |
496 [] => () |
497 | (b, _) :: _ => primrec_error ("type of " ^ Binding.print b ^ " contains top sort")); |
497 | (b, _) :: _ => primrec_error ("type of " ^ Binding.print b ^ " contains top sort")); |
498 |
498 |
499 val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy') = |
499 val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy) = |
500 rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy; |
500 rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy0; |
501 |
501 |
502 val actual_nn = length funs_data; |
502 val actual_nn = length funs_data; |
503 |
503 |
504 val _ = let val ctrs = (maps (map #ctr o #ctr_specs) rec_specs) in |
504 val _ = let val ctrs = (maps (map #ctr o #ctr_specs) rec_specs) in |
505 map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse |
505 map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse |
506 primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy' ctr) ^ |
506 primrec_error_eqn ("argument " ^ quote (Syntax.string_of_term lthy ctr) ^ |
507 " is not a constructor in left-hand side") user_eqn) eqns_data end; |
507 " is not a constructor in left-hand side") user_eqn) eqns_data end; |
508 |
508 |
509 val defs = build_defs lthy' bs mxs funs_data rec_specs has_call; |
509 val defs = build_defs lthy bs mxs funs_data rec_specs has_call; |
510 |
510 |
511 fun prove lthy def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec) |
511 fun prove lthy' def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec) |
512 (fun_data : eqn_data list) = |
512 (fun_data : eqn_data list) = |
513 let |
513 let |
|
514 val js = |
|
515 find_indices (op = o pairself (fn {fun_name, ctr, ...} => (fun_name, ctr))) |
|
516 fun_data eqns_data; |
|
517 |
514 val def_thms = map (snd o snd) def_thms'; |
518 val def_thms = map (snd o snd) def_thms'; |
515 val simp_thmss = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs |
519 val simp_thms = finds (fn (x, y) => #ctr x = #ctr y) fun_data ctr_specs |
516 |> fst |
520 |> fst |
517 |> map_filter (try (fn (x, [y]) => |
521 |> map_filter (try (fn (x, [y]) => |
518 (#user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y))) |
522 (#fun_name x, #user_eqn x, length (#left_args x) + length (#right_args x), #rec_thm y))) |
519 |> map (fn (user_eqn, num_extra_args, rec_thm) => |
523 |> map2 (fn j => fn (fun_name, user_eqn, num_extra_args, rec_thm) => |
520 mk_primrec_tac lthy num_extra_args nested_map_idents nested_map_comps def_thms rec_thm |
524 mk_primrec_tac lthy' num_extra_args nested_map_idents nested_map_comps def_thms rec_thm |
521 |> K |> Goal.prove_sorry lthy [] [] user_eqn |
525 |> K |> Goal.prove_sorry lthy' [] [] user_eqn |
522 |> Thm.close_derivation); |
526 (* for code extraction from proof terms: *) |
523 val poss = |
527 |> singleton (Proof_Context.export lthy' lthy) |
524 find_indices (op = o pairself (fn {fun_name, ctr, ...} => (fun_name, ctr))) |
528 |> Thm.name_derivation (Sign.full_name thy (Binding.name fun_name) ^ |
525 fun_data eqns_data; |
529 Long_Name.separator ^ simpsN ^ |
|
530 (if js = [0] then "" else "_" ^ string_of_int (j + 1)))) |
|
531 js; |
526 in |
532 in |
527 (poss, simp_thmss) |
533 (js, simp_thms) |
528 end; |
534 end; |
529 |
535 |
530 val notes = |
536 val notes = |
531 (if n2m then |
537 (if n2m then |
532 map2 (fn name => fn thm => |
538 map2 (fn name => fn thm => |
544 ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])])); |
550 ((Binding.qualify true common_name (Binding.name thmN), attrs), [(thms, [])])); |
545 in |
551 in |
546 (((fun_names, defs), |
552 (((fun_names, defs), |
547 fn lthy => fn defs => |
553 fn lthy => fn defs => |
548 split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)), |
554 split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)), |
549 lthy' |> Local_Theory.notes (notes @ common_notes) |> snd) |
555 lthy |> Local_Theory.notes (notes @ common_notes) |> snd) |
550 end; |
556 end; |
551 |
557 |
552 fun add_primrec_simple fixes ts lthy = |
558 fun add_primrec_simple fixes ts lthy = |
553 let |
559 let |
554 val (((names, defs), prove), lthy') = prepare_primrec fixes ts lthy |
560 val (((names, defs), prove), lthy') = prepare_primrec fixes ts lthy |
586 ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes |
592 ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes |
587 end); |
593 end); |
588 in |
594 in |
589 lthy |
595 lthy |
590 |> add_primrec_simple fixes (map snd specs) |
596 |> add_primrec_simple fixes (map snd specs) |
591 |-> (fn (names, (ts, (posss, simpss))) => |
597 |-> (fn (names, (ts, (jss, simpss))) => |
592 Spec_Rules.add Spec_Rules.Equational (ts, flat simpss) |
598 Spec_Rules.add Spec_Rules.Equational (ts, flat simpss) |
593 #> Local_Theory.notes (mk_notes posss names simpss) |
599 #> Local_Theory.notes (mk_notes jss names simpss) |
594 #>> pair ts o map snd) |
600 #>> pair ts o map snd) |
595 end |
601 end |
596 handle OLD_PRIMREC () => old_primrec raw_fixes raw_spec lthy |>> apsnd single; |
602 handle OLD_PRIMREC () => old_primrec raw_fixes raw_spec lthy |>> apsnd single; |
597 |
603 |
598 val add_primrec = gen_primrec Primrec.add_primrec Specification.check_spec; |
604 val add_primrec = gen_primrec Primrec.add_primrec Specification.check_spec; |
599 val add_primrec_cmd = gen_primrec Primrec.add_primrec_cmd Specification.read_spec; |
605 val add_primrec_cmd = gen_primrec Primrec.add_primrec_cmd Specification.read_spec; |
600 |
606 |
601 fun add_primrec_global fixes specs thy = |
607 fun add_primrec_global fixes specs = |
602 let |
608 Named_Target.theory_init |
603 val lthy = Named_Target.theory_init thy; |
609 #> add_primrec fixes specs |
604 val ((ts, simpss), lthy') = add_primrec fixes specs lthy; |
610 ##> Local_Theory.exit_global; |
605 val simpss' = burrow (Proof_Context.export lthy' lthy) simpss; |
611 |
606 in ((ts, simpss'), Local_Theory.exit_global lthy') end; |
612 fun add_primrec_overloaded ops fixes specs = |
607 |
613 Overloading.overloading ops |
608 fun add_primrec_overloaded ops fixes specs thy = |
614 #> add_primrec fixes specs |
609 let |
615 ##> Local_Theory.exit_global; |
610 val lthy = Overloading.overloading ops thy; |
|
611 val ((ts, simpss), lthy') = add_primrec fixes specs lthy; |
|
612 val simpss' = burrow (Proof_Context.export lthy' lthy) simpss; |
|
613 in ((ts, simpss'), Local_Theory.exit_global lthy') end; |
|
614 |
616 |
615 val _ = Outer_Syntax.local_theory @{command_spec "primrec"} |
617 val _ = Outer_Syntax.local_theory @{command_spec "primrec"} |
616 "define primitive recursive functions" |
618 "define primitive recursive functions" |
617 (Parse.fixes -- Parse_Spec.where_alt_specs >> (snd oo uncurry add_primrec_cmd)); |
619 (Parse.fixes -- Parse_Spec.where_alt_specs >> (snd oo uncurry add_primrec_cmd)); |
618 |
620 |