6 Recursor sugar ("primrec"). |
6 Recursor sugar ("primrec"). |
7 *) |
7 *) |
8 |
8 |
9 signature BNF_LFP_REC_SUGAR = |
9 signature BNF_LFP_REC_SUGAR = |
10 sig |
10 sig |
|
11 type basic_lfp_sugar = |
|
12 {T: typ, |
|
13 fp_res_index: int, |
|
14 ctor_recT: typ, |
|
15 ctr_defs: thm list, |
|
16 ctr_sugar: Ctr_Sugar.ctr_sugar, |
|
17 recx: term, |
|
18 rec_thms: thm list}; |
|
19 |
|
20 type lfp_rec_extension = |
|
21 {is_new_datatype: Proof.context -> string -> bool, |
|
22 get_basic_lfp_sugars: binding list -> typ list -> (term -> int list) -> |
|
23 (term * term list list) list list -> local_theory -> |
|
24 typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * bool * local_theory, |
|
25 massage_nested_rec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) -> |
|
26 typ list -> term -> term -> term -> term}; |
|
27 |
|
28 val register_lfp_rec_extension: lfp_rec_extension -> theory -> theory |
|
29 |
11 val add_primrec: (binding * typ option * mixfix) list -> |
30 val add_primrec: (binding * typ option * mixfix) list -> |
12 (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory |
31 (Attrib.binding * term) list -> local_theory -> (term list * thm list list) * local_theory |
13 val add_primrec_cmd: (binding * string option * mixfix) list -> |
32 val add_primrec_cmd: (binding * string option * mixfix) list -> |
14 (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory |
33 (Attrib.binding * string) list -> local_theory -> (term list * thm list list) * local_theory |
15 val add_primrec_global: (binding * typ option * mixfix) list -> |
34 val add_primrec_global: (binding * typ option * mixfix) list -> |
58 type rec_spec = |
75 type rec_spec = |
59 {recx: term, |
76 {recx: term, |
60 nested_map_idents: thm list, |
77 nested_map_idents: thm list, |
61 nested_map_comps: thm list, |
78 nested_map_comps: thm list, |
62 ctr_specs: rec_ctr_spec list}; |
79 ctr_specs: rec_ctr_spec list}; |
63 |
|
64 exception NOT_A_MAP of term; |
|
65 |
|
66 fun ill_formed_rec_call ctxt t = |
|
67 error ("Ill-formed recursive call: " ^ quote (Syntax.string_of_term ctxt t)); |
|
68 fun invalid_map ctxt t = |
|
69 error ("Invalid map function in " ^ quote (Syntax.string_of_term ctxt t)); |
|
70 fun unexpected_rec_call ctxt t = |
|
71 error ("Unexpected recursive call: " ^ quote (Syntax.string_of_term ctxt t)); |
|
72 |
|
73 fun massage_nested_rec_call ctxt has_call raw_massage_fun bound_Ts y y' = |
|
74 let |
|
75 fun check_no_call t = if has_call t then unexpected_rec_call ctxt t else (); |
|
76 |
|
77 val typof = curry fastype_of1 bound_Ts; |
|
78 val build_map_fst = build_map ctxt (fst_const o fst); |
|
79 |
|
80 val yT = typof y; |
|
81 val yU = typof y'; |
|
82 |
|
83 fun y_of_y' () = build_map_fst (yU, yT) $ y'; |
|
84 val elim_y = Term.map_aterms (fn t => if t = y then y_of_y' () else t); |
|
85 |
|
86 fun massage_mutual_fun U T t = |
|
87 (case t of |
|
88 Const (@{const_name comp}, _) $ t1 $ t2 => |
|
89 mk_comp bound_Ts (tap check_no_call t1, massage_mutual_fun U T t2) |
|
90 | _ => |
|
91 if has_call t then |
|
92 (case try HOLogic.dest_prodT U of |
|
93 SOME (U1, U2) => if U1 = T then raw_massage_fun T U2 t else invalid_map ctxt t |
|
94 | NONE => invalid_map ctxt t) |
|
95 else |
|
96 mk_comp bound_Ts (t, build_map_fst (U, T))); |
|
97 |
|
98 fun massage_map (Type (_, Us)) (Type (s, Ts)) t = |
|
99 (case try (dest_map ctxt s) t of |
|
100 SOME (map0, fs) => |
|
101 let |
|
102 val Type (_, ran_Ts) = range_type (typof t); |
|
103 val map' = mk_map (length fs) Us ran_Ts map0; |
|
104 val fs' = map_flattened_map_args ctxt s (map3 massage_map_or_map_arg Us Ts) fs; |
|
105 in |
|
106 Term.list_comb (map', fs') |
|
107 end |
|
108 | NONE => raise NOT_A_MAP t) |
|
109 | massage_map _ _ t = raise NOT_A_MAP t |
|
110 and massage_map_or_map_arg U T t = |
|
111 if T = U then |
|
112 tap check_no_call t |
|
113 else |
|
114 massage_map U T t |
|
115 handle NOT_A_MAP _ => massage_mutual_fun U T t; |
|
116 |
|
117 fun massage_call (t as t1 $ t2) = |
|
118 if has_call t then |
|
119 if t2 = y then |
|
120 massage_map yU yT (elim_y t1) $ y' |
|
121 handle NOT_A_MAP t' => invalid_map ctxt t' |
|
122 else |
|
123 let val (g, xs) = Term.strip_comb t2 in |
|
124 if g = y then |
|
125 if exists has_call xs then unexpected_rec_call ctxt t2 |
|
126 else Term.list_comb (massage_call (mk_compN (length xs) bound_Ts (t1, y)), xs) |
|
127 else |
|
128 ill_formed_rec_call ctxt t |
|
129 end |
|
130 else |
|
131 elim_y t |
|
132 | massage_call t = if t = y then y_of_y' () else ill_formed_rec_call ctxt t; |
|
133 in |
|
134 massage_call |
|
135 end; |
|
136 |
80 |
137 type basic_lfp_sugar = |
81 type basic_lfp_sugar = |
138 {T: typ, |
82 {T: typ, |
139 fp_res_index: int, |
83 fp_res_index: int, |
140 ctor_recT: typ, |
84 ctor_recT: typ, |
141 ctr_defs: thm list, |
85 ctr_defs: thm list, |
142 ctr_sugar: ctr_sugar, |
86 ctr_sugar: ctr_sugar, |
143 recx: term, |
87 recx: term, |
144 rec_thms: thm list}; |
88 rec_thms: thm list}; |
145 |
89 |
146 fun basic_lfp_sugar_of ({T, fp_res = {xtor_co_iterss = ctor_iterss, ...}, fp_res_index, ctr_defs, |
90 type lfp_rec_extension = |
147 ctr_sugar, co_iters = iters, co_iter_thmss = iter_thmss, ...} : fp_sugar) = |
91 {is_new_datatype: Proof.context -> string -> bool, |
148 {T = T, fp_res_index = fp_res_index, |
92 get_basic_lfp_sugars: binding list -> typ list -> (term -> int list) -> |
149 ctor_recT = fastype_of (co_rec_of (nth ctor_iterss fp_res_index)), ctr_defs = ctr_defs, |
93 (term * term list list) list list -> local_theory -> |
150 ctr_sugar = ctr_sugar, recx = co_rec_of iters, rec_thms = co_rec_of iter_thmss}; |
94 typ list * int list * basic_lfp_sugar list * thm list * thm list * thm * bool * local_theory, |
151 |
95 massage_nested_rec_call: Proof.context -> (term -> bool) -> (typ -> typ -> term -> term) -> |
152 fun get_basic_lfp_sugars bs arg_Ts get_indices callssss0 lthy0 = |
96 typ list -> term -> term -> term -> term}; |
153 let |
97 |
154 val ((missing_arg_Ts, perm0_kks, |
98 structure Data = Theory_Data |
155 fp_sugars as {nested_bnfs, co_inducts = [induct_thm], ...} :: _, (lfp_sugar_thms, _)), |
99 ( |
156 lthy) = |
100 type T = lfp_rec_extension option; |
157 nested_to_mutual_fps Least_FP bs arg_Ts get_indices callssss0 lthy0; |
101 val empty = NONE; |
158 val nested_map_idents = map (unfold_thms lthy @{thms id_def} o map_id0_of_bnf) nested_bnfs; |
102 val extend = I; |
159 val nested_map_comps = map map_comp_of_bnf nested_bnfs; |
103 val merge = merge_options; |
160 in |
104 ); |
161 (missing_arg_Ts, perm0_kks, map basic_lfp_sugar_of fp_sugars, nested_map_idents, |
105 |
162 nested_map_comps, induct_thm, lfp_sugar_thms, lthy) |
106 val register_lfp_rec_extension = Data.put o SOME; |
163 end; |
107 |
|
108 fun is_new_datatype ctxt = |
|
109 (case Data.get (Proof_Context.theory_of ctxt) of |
|
110 SOME {is_new_datatype, ...} => is_new_datatype ctxt |
|
111 | NONE => K false); |
|
112 |
|
113 fun get_basic_lfp_sugars bs arg_Ts get_indices callssss lthy = |
|
114 (case Data.get (Proof_Context.theory_of lthy) of |
|
115 SOME {get_basic_lfp_sugars, ...} => get_basic_lfp_sugars bs arg_Ts get_indices callssss lthy |
|
116 | NONE => error "Not implemented yet"); |
|
117 |
|
118 fun massage_nested_rec_call ctxt = |
|
119 (case Data.get (Proof_Context.theory_of ctxt) of |
|
120 SOME {massage_nested_rec_call, ...} => massage_nested_rec_call ctxt); |
164 |
121 |
165 fun rec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy0 = |
122 fun rec_specs_of bs arg_Ts res_Ts get_indices callssss0 lthy0 = |
166 let |
123 let |
167 val thy = Proof_Context.theory_of lthy0; |
124 val thy = Proof_Context.theory_of lthy0; |
168 |
125 |
169 val (missing_arg_Ts, perm0_kks, basic_lfp_sugars, nested_map_idents, nested_map_comps, |
126 val (missing_arg_Ts, perm0_kks, basic_lfp_sugars, nested_map_idents, nested_map_comps, |
170 induct_thm, lfp_sugar_thms, lthy) = |
127 induct_thm, n2m, lthy) = |
171 get_basic_lfp_sugars bs arg_Ts get_indices callssss0 lthy0; |
128 get_basic_lfp_sugars bs arg_Ts get_indices callssss0 lthy0; |
172 |
129 |
173 val perm_basic_lfp_sugars = sort (int_ord o pairself #fp_res_index) basic_lfp_sugars; |
130 val perm_basic_lfp_sugars = sort (int_ord o pairself #fp_res_index) basic_lfp_sugars; |
174 |
131 |
175 val indices = map #fp_res_index basic_lfp_sugars; |
132 val indices = map #fp_res_index basic_lfp_sugars; |
231 ({T, fp_res_index, ctr_sugar = {ctrs, ...}, recx, rec_thms, ...} : basic_lfp_sugar) = |
188 ({T, fp_res_index, ctr_sugar = {ctrs, ...}, recx, rec_thms, ...} : basic_lfp_sugar) = |
232 {recx = mk_co_iter thy Least_FP (substAT T) perm_Cs' recx, |
189 {recx = mk_co_iter thy Least_FP (substAT T) perm_Cs' recx, |
233 nested_map_idents = nested_map_idents, nested_map_comps = nested_map_comps, |
190 nested_map_idents = nested_map_idents, nested_map_comps = nested_map_comps, |
234 ctr_specs = mk_ctr_specs fp_res_index ctr_offset ctrs rec_thms}; |
191 ctr_specs = mk_ctr_specs fp_res_index ctr_offset ctrs rec_thms}; |
235 in |
192 in |
236 ((is_some lfp_sugar_thms, map2 mk_spec ctr_offsets basic_lfp_sugars, missing_arg_Ts, induct_thm, |
193 ((n2m, map2 mk_spec ctr_offsets basic_lfp_sugars, missing_arg_Ts, induct_thm, induct_thms), |
237 induct_thms), lthy) |
194 lthy) |
238 end; |
195 end; |
239 |
196 |
240 val undef_const = Const (@{const_name undefined}, dummyT); |
197 val undef_const = Const (@{const_name undefined}, dummyT); |
241 |
198 |
242 fun permute_args n t = |
199 fun permute_args n t = |
494 val eqns_data = map (dissect_eqn lthy0 fun_names) specs; |
451 val eqns_data = map (dissect_eqn lthy0 fun_names) specs; |
495 val funs_data = eqns_data |
452 val funs_data = eqns_data |
496 |> partition_eq ((op =) o pairself #fun_name) |
453 |> partition_eq ((op =) o pairself #fun_name) |
497 |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst |
454 |> finds (fn (x, y) => x = #fun_name (hd y)) fun_names |> fst |
498 |> map (fn (x, y) => the_single y |
455 |> map (fn (x, y) => the_single y |
499 handle List.Empty => |
456 handle List.Empty => primrec_error ("missing equations for function " ^ quote x)); |
500 primrec_error ("missing equations for function " ^ quote x)); |
|
501 |
457 |
502 val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =)); |
458 val has_call = exists_subterm (map (fst #>> Binding.name_of #> Free) fixes |> member (op =)); |
503 val arg_Ts = map (#rec_type o hd) funs_data; |
459 val arg_Ts = map (#rec_type o hd) funs_data; |
504 val res_Ts = map (#res_type o hd) funs_data; |
460 val res_Ts = map (#res_type o hd) funs_data; |
505 val callssss = funs_data |
461 val callssss = funs_data |
506 |> map (partition_eq ((op =) o pairself #ctr)) |
462 |> map (partition_eq ((op =) o pairself #ctr)) |
507 |> map (maps (map_filter (find_rec_calls has_call))); |
463 |> map (maps (map_filter (find_rec_calls has_call))); |
508 |
464 |
509 fun is_only_old_datatype (Type (s, _)) = |
465 fun is_only_old_datatype (Type (s, _)) = |
510 is_none (fp_sugar_of lthy0 s) andalso is_some (Datatype_Data.get_info thy s) |
466 is_some (Datatype_Data.get_info thy s) andalso not (is_new_datatype lthy0 s) |
511 | is_only_old_datatype _ = false; |
467 | is_only_old_datatype _ = false; |
512 |
468 |
513 val _ = if exists is_only_old_datatype arg_Ts then raise OLD_PRIMREC () else (); |
469 val _ = if exists is_only_old_datatype arg_Ts then raise OLD_PRIMREC () else (); |
514 val _ = (case filter_out (fn (_, T) => Sign.of_sort thy (T, HOLogic.typeS)) (bs ~~ res_Ts) of |
470 val _ = (case filter_out (fn (_, T) => Sign.of_sort thy (T, HOLogic.typeS)) (bs ~~ res_Ts) of |
515 [] => () |
471 [] => () |
516 | (b, _) :: _ => primrec_error ("type of " ^ Binding.print b ^ " contains top sort")); |
472 | (b, _) :: _ => primrec_error ("type of " ^ Binding.print b ^ " contains top sort")); |
517 |
473 |
518 val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy) = |
474 val ((n2m, rec_specs, _, induct_thm, induct_thms), lthy) = |
519 rec_specs_of bs arg_Ts res_Ts (get_indices fixes) callssss lthy0; |
475 rec_specs_of bs arg_Ts res_Ts (get_free_indices fixes) callssss lthy0; |
520 |
476 |
521 val actual_nn = length funs_data; |
477 val actual_nn = length funs_data; |
522 |
478 |
523 val ctrs = maps (map #ctr o #ctr_specs) rec_specs; |
479 val ctrs = maps (map #ctr o #ctr_specs) rec_specs; |
524 val _ = |
480 val _ = |