6 Recursor sugar ("primrec"). |
6 Recursor sugar ("primrec"). |
7 *) |
7 *) |
8 |
8 |
9 signature BNF_LFP_REC_SUGAR = |
9 signature BNF_LFP_REC_SUGAR = |
10 sig |
10 sig |
|
11 datatype primrec_option = Nonexhaustive_Option |
|
12 |
11 type basic_lfp_sugar = |
13 type basic_lfp_sugar = |
12 {T: typ, |
14 {T: typ, |
13 fp_res_index: int, |
15 fp_res_index: int, |
14 C: typ, |
16 C: typ, |
15 fun_arg_Tsss : typ list list list, |
17 fun_arg_Tsss : typ list list list, |
31 |
33 |
32 val register_lfp_rec_extension: lfp_rec_extension -> theory -> theory |
34 val register_lfp_rec_extension: lfp_rec_extension -> theory -> theory |
33 |
35 |
34 val add_primrec: (binding * typ option * mixfix) list -> |
36 val add_primrec: (binding * typ option * mixfix) list -> |
35 (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory |
37 (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory |
36 val add_primrec_cmd: (binding * string option * mixfix) list -> |
38 val add_primrec_cmd: primrec_option list -> (binding * string option * mixfix) list -> |
37 (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory |
39 (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory |
38 val add_primrec_global: (binding * typ option * mixfix) list -> |
40 val add_primrec_global: (binding * typ option * mixfix) list -> |
39 (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory |
41 (Attrib.binding * term) list -> theory -> (term list * thm list list) * theory |
40 val add_primrec_overloaded: (string * (string * typ) * bool) list -> |
42 val add_primrec_overloaded: (string * (string * typ) * bool) list -> |
41 (binding * typ option * mixfix) list -> |
43 (binding * typ option * mixfix) list -> |
59 val simp_attrs = @{attributes [simp]}; |
61 val simp_attrs = @{attributes [simp]}; |
60 val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs; |
62 val code_nitpicksimp_simp_attrs = Code.add_default_eqn_attrib :: nitpicksimp_attrs @ simp_attrs; |
61 |
63 |
62 exception OLD_PRIMREC of unit; |
64 exception OLD_PRIMREC of unit; |
63 exception PRIMREC of string * term list; |
65 exception PRIMREC of string * term list; |
|
66 |
|
67 datatype primrec_option = Nonexhaustive_Option; |
64 |
68 |
65 datatype rec_call = |
69 datatype rec_call = |
66 No_Rec of int * typ | |
70 No_Rec of int * typ | |
67 Mutual_Rec of (int * typ) * (int * typ) | |
71 Mutual_Rec of (int * typ) * (int * typ) | |
68 Nested_Rec of int * typ; |
72 Nested_Rec of int * typ; |
344 t |
348 t |
345 |> subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls |
349 |> subst_rec_calls lthy get_ctr_pos has_call ctr_args mutual_calls nested_calls |
346 |> fold_rev lambda (args @ left_args @ right_args) |
350 |> fold_rev lambda (args @ left_args @ right_args) |
347 end); |
351 end); |
348 |
352 |
349 fun build_defs lthy bs mxs (funs_data : eqn_data list list) (rec_specs : rec_spec list) has_call = |
353 fun build_defs lthy nonexhaustive bs mxs (funs_data : eqn_data list list) |
|
354 (rec_specs : rec_spec list) has_call = |
350 let |
355 let |
351 val n_funs = length funs_data; |
356 val n_funs = length funs_data; |
352 |
357 |
353 val ctr_spec_eqn_data_list' = |
358 val ctr_spec_eqn_data_list' = |
354 (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data |
359 (take n_funs rec_specs |> map #ctr_specs) ~~ funs_data |
355 |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y)) |
360 |> maps (uncurry (finds (fn (x, y) => #ctr x = #ctr y)) |
356 ##> (fn x => null x orelse |
361 ##> (fn x => null x orelse |
357 raise PRIMREC ("excess equations in definition", map #rhs_term x)) #> fst); |
362 raise PRIMREC ("excess equations in definition", map #rhs_term x)) #> fst); |
358 val _ = ctr_spec_eqn_data_list' |> map (fn (_, x) => length x <= 1 orelse |
363 val _ = ctr_spec_eqn_data_list' |> map (fn ({ctr, ...}, x) => |
359 raise PRIMREC ("multiple equations for constructor", map #user_eqn x)); |
364 if length x > 1 then raise PRIMREC ("multiple equations for constructor", map #user_eqn x) |
|
365 else if length x = 1 orelse nonexhaustive then () |
|
366 else warning ("no equation for constructor " ^ Syntax.string_of_term lthy ctr)); |
360 |
367 |
361 val ctr_spec_eqn_data_list = |
368 val ctr_spec_eqn_data_list = |
362 ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair [])); |
369 ctr_spec_eqn_data_list' @ (drop n_funs rec_specs |> maps #ctr_specs |> map (rpair [])); |
363 |
370 |
364 val recs = take n_funs rec_specs |> map #recx; |
371 val recs = take n_funs rec_specs |> map #recx; |
412 unfold_thms_tac ctxt fun_defs THEN |
419 unfold_thms_tac ctxt fun_defs THEN |
413 HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN |
420 HEADGOAL (rtac (funpow num_extra_args (fn thm => thm RS fun_cong) recx RS trans)) THEN |
414 unfold_thms_tac ctxt (nested_simps ctxt @ map_comps @ map_idents) THEN |
421 unfold_thms_tac ctxt (nested_simps ctxt @ map_comps @ map_idents) THEN |
415 HEADGOAL (rtac refl); |
422 HEADGOAL (rtac refl); |
416 |
423 |
417 fun prepare_primrec fixes specs lthy0 = |
424 fun prepare_primrec nonexhaustive fixes specs lthy0 = |
418 let |
425 let |
419 val thy = Proof_Context.theory_of lthy0; |
426 val thy = Proof_Context.theory_of lthy0; |
420 |
427 |
421 val (bs, mxs) = map_split (apfst fst) fixes; |
428 val (bs, mxs) = map_split (apfst fst) fixes; |
422 val fun_names = map Binding.name_of bs; |
429 val fun_names = map Binding.name_of bs; |
453 val _ = |
460 val _ = |
454 map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse |
461 map (fn {ctr, user_eqn, ...} => member (op =) ctrs ctr orelse |
455 raise PRIMREC ("argument " ^ quote (Syntax.string_of_term lthy ctr) ^ |
462 raise PRIMREC ("argument " ^ quote (Syntax.string_of_term lthy ctr) ^ |
456 " is not a constructor in left-hand side", [user_eqn])) eqns_data; |
463 " is not a constructor in left-hand side", [user_eqn])) eqns_data; |
457 |
464 |
458 val defs = build_defs lthy bs mxs funs_data rec_specs has_call; |
465 val defs = build_defs lthy nonexhaustive bs mxs funs_data rec_specs has_call; |
459 |
466 |
460 fun prove lthy' def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec) |
467 fun prove lthy' def_thms' ({ctr_specs, nested_map_idents, nested_map_comps, ...} : rec_spec) |
461 (fun_data : eqn_data list) = |
468 (fun_data : eqn_data list) = |
462 let |
469 let |
463 val js = |
470 val js = |
503 fn lthy => fn defs => |
510 fn lthy => fn defs => |
504 split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)), |
511 split_list (map2 (prove lthy defs) (take actual_nn rec_specs) funs_data)), |
505 lthy |> Local_Theory.notes (notes @ common_notes) |> snd) |
512 lthy |> Local_Theory.notes (notes @ common_notes) |> snd) |
506 end; |
513 end; |
507 |
514 |
508 fun add_primrec_simple fixes ts lthy = |
515 fun add_primrec_simple' opts fixes ts lthy = |
509 let |
516 let |
510 val (((names, defs), prove), lthy') = prepare_primrec fixes ts lthy |
517 val nonexhaustive = member (op =) opts Nonexhaustive_Option; |
|
518 val (((names, defs), prove), lthy') = prepare_primrec nonexhaustive fixes ts lthy |
511 handle ERROR str => raise PRIMREC (str, []); |
519 handle ERROR str => raise PRIMREC (str, []); |
512 in |
520 in |
513 lthy' |
521 lthy' |
514 |> fold_map Local_Theory.define defs |
522 |> fold_map Local_Theory.define defs |
515 |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs)))) |
523 |-> (fn defs => `(fn lthy => (names, (map fst defs, prove lthy defs)))) |
519 error ("primrec error:\n " ^ str) |
527 error ("primrec error:\n " ^ str) |
520 else |
528 else |
521 error ("primrec error:\n " ^ str ^ "\nin\n " ^ |
529 error ("primrec error:\n " ^ str ^ "\nin\n " ^ |
522 space_implode "\n " (map (quote o Syntax.string_of_term lthy) eqns)); |
530 space_implode "\n " (map (quote o Syntax.string_of_term lthy) eqns)); |
523 |
531 |
524 fun gen_primrec old_primrec prep_spec (raw_fixes : (binding * 'a option * mixfix) list) raw_spec |
532 val add_primrec_simple = add_primrec_simple' []; |
525 lthy = |
533 |
|
534 fun gen_primrec old_primrec prep_spec opts |
|
535 (raw_fixes : (binding * 'a option * mixfix) list) raw_spec lthy = |
526 let |
536 let |
527 val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes) |
537 val d = duplicates (op =) (map (Binding.name_of o #1) raw_fixes) |
528 val _ = null d orelse raise PRIMREC ("duplicate function name(s): " ^ commas d, []); |
538 val _ = null d orelse raise PRIMREC ("duplicate function name(s): " ^ commas d, []); |
529 |
539 |
530 val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy); |
540 val (fixes, specs) = fst (prep_spec raw_fixes raw_spec lthy); |
541 in |
551 in |
542 ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes |
552 ((Binding.qualify true prefix (Binding.name simpsN), []), [(thms, [])]) :: notes |
543 end); |
553 end); |
544 in |
554 in |
545 lthy |
555 lthy |
546 |> add_primrec_simple fixes (map snd specs) |
556 |> add_primrec_simple' opts fixes (map snd specs) |
547 |-> (fn (names, (ts, (jss, simpss))) => |
557 |-> (fn (names, (ts, (jss, simpss))) => |
548 Spec_Rules.add Spec_Rules.Equational (ts, flat simpss) |
558 Spec_Rules.add Spec_Rules.Equational (ts, flat simpss) |
549 #> Local_Theory.notes (mk_notes jss names simpss) |
559 #> Local_Theory.notes (mk_notes jss names simpss) |
550 #>> pair ts o map snd) |
560 #>> pair ts o map snd) |
551 end |
561 end |
552 handle OLD_PRIMREC () => old_primrec raw_fixes raw_spec lthy |>> apsnd single; |
562 handle OLD_PRIMREC () => old_primrec raw_fixes raw_spec lthy |>> apsnd single; |
553 |
563 |
554 val add_primrec = gen_primrec Primrec.add_primrec Specification.check_spec; |
564 val add_primrec = gen_primrec Primrec.add_primrec Specification.check_spec []; |
555 val add_primrec_cmd = gen_primrec Primrec.add_primrec_cmd Specification.read_spec; |
565 val add_primrec_cmd = gen_primrec Primrec.add_primrec_cmd Specification.read_spec; |
556 |
566 |
557 fun add_primrec_global fixes specs = |
567 fun add_primrec_global fixes specs = |
558 Named_Target.theory_init |
568 Named_Target.theory_init |
559 #> add_primrec fixes specs |
569 #> add_primrec fixes specs |
562 fun add_primrec_overloaded ops fixes specs = |
572 fun add_primrec_overloaded ops fixes specs = |
563 Overloading.overloading ops |
573 Overloading.overloading ops |
564 #> add_primrec fixes specs |
574 #> add_primrec fixes specs |
565 ##> Local_Theory.exit_global; |
575 ##> Local_Theory.exit_global; |
566 |
576 |
|
577 val primrec_option_parser = Parse.group (fn () => "option") |
|
578 (Parse.reserved "nonexhaustive" >> K Nonexhaustive_Option) |
|
579 |
567 val _ = Outer_Syntax.local_theory @{command_spec "primrec"} |
580 val _ = Outer_Syntax.local_theory @{command_spec "primrec"} |
568 "define primitive recursive functions" |
581 "define primitive recursive functions" |
569 (Parse.fixes -- Parse_Spec.where_alt_specs >> (snd oo uncurry add_primrec_cmd)); |
582 ((Scan.optional (@{keyword "("} |-- |
|
583 Parse.!!! (Parse.list1 primrec_option_parser) --| @{keyword ")"}) []) -- |
|
584 (Parse.fixes -- Parse_Spec.where_alt_specs) |
|
585 >> (fn (opts, (fixes, spec)) => snd o add_primrec_cmd opts fixes spec)); |
570 |
586 |
571 end; |
587 end; |