74 (map dest_Free (term_frees rhs) \\ lfrees); |
74 (map dest_Free (term_frees rhs) \\ lfrees); |
75 case assoc (rec_fns, fname) of |
75 case assoc (rec_fns, fname) of |
76 NONE => |
76 NONE => |
77 (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eq))]))::rec_fns |
77 (fname, (tname, rpos, [(cname, (ls, cargs, rs, rhs, eq))]))::rec_fns |
78 | SOME (_, rpos', eqns) => |
78 | SOME (_, rpos', eqns) => |
79 if is_some (assoc (eqns, cname)) then |
79 if isSome (assoc (eqns, cname)) then |
80 raise RecError "constructor already occurred as pattern" |
80 raise RecError "constructor already occurred as pattern" |
81 else if rpos <> rpos' then |
81 else if rpos <> rpos' then |
82 raise RecError "position of recursive argument inconsistent" |
82 raise RecError "position of recursive argument inconsistent" |
83 else |
83 else |
84 overwrite (rec_fns, |
84 overwrite (rec_fns, |
88 end |
88 end |
89 handle RecError s => primrec_eq_err sign s eq; |
89 handle RecError s => primrec_eq_err sign s eq; |
90 |
90 |
91 fun process_fun sign descr rec_eqns ((i, fname), (fnames, fnss)) = |
91 fun process_fun sign descr rec_eqns ((i, fname), (fnames, fnss)) = |
92 let |
92 let |
93 val (_, (tname, _, constrs)) = nth_elem (i, descr); |
93 val (_, (tname, _, constrs)) = List.nth (descr, i); |
94 |
94 |
95 (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *) |
95 (* substitute "fname ls x rs" by "y ls rs" for (x, (_, y)) in subs *) |
96 |
96 |
97 fun subst [] x = x |
97 fun subst [] x = x |
98 | subst subs (fs, Abs (a, T, t)) = |
98 | subst subs (fs, Abs (a, T, t)) = |
102 let val (f, ts) = strip_comb t; |
102 let val (f, ts) = strip_comb t; |
103 in |
103 in |
104 if is_Const f andalso (fst (dest_Const f)) mem (map fst rec_eqns) then |
104 if is_Const f andalso (fst (dest_Const f)) mem (map fst rec_eqns) then |
105 let |
105 let |
106 val (fname', _) = dest_Const f; |
106 val (fname', _) = dest_Const f; |
107 val (_, rpos, _) = the (assoc (rec_eqns, fname')); |
107 val (_, rpos, _) = valOf (assoc (rec_eqns, fname')); |
108 val ls = take (rpos, ts); |
108 val ls = Library.take (rpos, ts); |
109 val rest = drop (rpos, ts); |
109 val rest = Library.drop (rpos, ts); |
110 val (x', rs) = (hd rest, tl rest) |
110 val (x', rs) = (hd rest, tl rest) |
111 handle LIST _ => raise RecError ("not enough arguments\ |
111 handle Empty => raise RecError ("not enough arguments\ |
112 \ in recursive application\nof function " ^ quote fname' ^ " on rhs"); |
112 \ in recursive application\nof function " ^ quote fname' ^ " on rhs"); |
113 val (x, xs) = strip_comb x' |
113 val (x, xs) = strip_comb x' |
114 in |
114 in |
115 (case assoc (subs, x) of |
115 (case assoc (subs, x) of |
116 NONE => |
116 NONE => |
138 NONE => (warning ("No equation for constructor " ^ quote cname ^ |
138 NONE => (warning ("No equation for constructor " ^ quote cname ^ |
139 "\nin definition of function " ^ quote fname); |
139 "\nin definition of function " ^ quote fname); |
140 (fnames', fnss', (Const ("arbitrary", dummyT))::fns)) |
140 (fnames', fnss', (Const ("arbitrary", dummyT))::fns)) |
141 | SOME (ls, cargs', rs, rhs, eq) => |
141 | SOME (ls, cargs', rs, rhs, eq) => |
142 let |
142 let |
143 val recs = filter (is_rec_type o snd) (cargs' ~~ cargs); |
143 val recs = List.filter (is_rec_type o snd) (cargs' ~~ cargs); |
144 val rargs = map fst recs; |
144 val rargs = map fst recs; |
145 val subs = map (rpair dummyT o fst) |
145 val subs = map (rpair dummyT o fst) |
146 (rev (rename_wrt_term rhs rargs)); |
146 (rev (rename_wrt_term rhs rargs)); |
147 val ((fnames'', fnss''), rhs') = |
147 val ((fnames'', fnss''), rhs') = |
148 (subst (map (fn ((x, y), z) => |
148 (subst (map (fn ((x, y), z) => |
158 NONE => |
158 NONE => |
159 if exists (equal fname o snd) fnames then |
159 if exists (equal fname o snd) fnames then |
160 raise RecError ("inconsistent functions for datatype " ^ quote tname) |
160 raise RecError ("inconsistent functions for datatype " ^ quote tname) |
161 else |
161 else |
162 let |
162 let |
163 val (_, _, eqns) = the (assoc (rec_eqns, fname)); |
163 val (_, _, eqns) = valOf (assoc (rec_eqns, fname)); |
164 val (fnames', fnss', fns) = foldr (trans eqns) |
164 val (fnames', fnss', fns) = Library.foldr (trans eqns) |
165 (constrs, ((i, fname)::fnames, fnss, [])) |
165 (constrs, ((i, fname)::fnames, fnss, [])) |
166 in |
166 in |
167 (fnames', (i, (fname, #1 (snd (hd eqns)), fns))::fnss') |
167 (fnames', (i, (fname, #1 (snd (hd eqns)), fns))::fnss') |
168 end |
168 end |
169 | SOME fname' => |
169 | SOME fname' => |
177 fun get_fns fns (((i, (tname, _, constrs)), rec_name), (fs, defs)) = |
177 fun get_fns fns (((i, (tname, _, constrs)), rec_name), (fs, defs)) = |
178 case assoc (fns, i) of |
178 case assoc (fns, i) of |
179 NONE => |
179 NONE => |
180 let |
180 let |
181 val dummy_fns = map (fn (_, cargs) => Const ("arbitrary", |
181 val dummy_fns = map (fn (_, cargs) => Const ("arbitrary", |
182 replicate ((length cargs) + (length (filter is_rec_type cargs))) |
182 replicate ((length cargs) + (length (List.filter is_rec_type cargs))) |
183 dummyT ---> HOLogic.unitT)) constrs; |
183 dummyT ---> HOLogic.unitT)) constrs; |
184 val _ = warning ("No function definition for datatype " ^ quote tname) |
184 val _ = warning ("No function definition for datatype " ^ quote tname) |
185 in |
185 in |
186 (dummy_fns @ fs, defs) |
186 (dummy_fns @ fs, defs) |
187 end |
187 end |
190 |
190 |
191 (* make definition *) |
191 (* make definition *) |
192 |
192 |
193 fun make_def sign fs (fname, ls, rec_name, tname) = |
193 fun make_def sign fs (fname, ls, rec_name, tname) = |
194 let |
194 let |
195 val rhs = foldr (fn (T, t) => Abs ("", T, t)) |
195 val rhs = Library.foldr (fn (T, t) => Abs ("", T, t)) |
196 ((map snd ls) @ [dummyT], |
196 ((map snd ls) @ [dummyT], |
197 list_comb (Const (rec_name, dummyT), |
197 list_comb (Const (rec_name, dummyT), |
198 fs @ map Bound (0 ::(length ls downto 1)))); |
198 fs @ map Bound (0 ::(length ls downto 1)))); |
199 val defpair = (Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def", |
199 val defpair = (Sign.base_name fname ^ "_" ^ Sign.base_name tname ^ "_def", |
200 Logic.mk_equals (Const (fname, dummyT), rhs)) |
200 Logic.mk_equals (Const (fname, dummyT), rhs)) |
214 |
214 |
215 fun prepare_induct ({descr, induction, ...}: datatype_info) rec_eqns = |
215 fun prepare_induct ({descr, induction, ...}: datatype_info) rec_eqns = |
216 let |
216 let |
217 fun constrs_of (_, (_, _, cs)) = |
217 fun constrs_of (_, (_, _, cs)) = |
218 map (fn (cname:string, (_, cargs, _, _, _)) => (cname, map fst cargs)) cs; |
218 map (fn (cname:string, (_, cargs, _, _, _)) => (cname, map fst cargs)) cs; |
219 val params_of = Library.assocs (flat (map constrs_of rec_eqns)); |
219 val params_of = Library.assocs (List.concat (map constrs_of rec_eqns)); |
220 in |
220 in |
221 induction |
221 induction |
222 |> RuleCases.rename_params (map params_of (flat (map (map #1 o #3 o #2) descr))) |
222 |> RuleCases.rename_params (map params_of (List.concat (map (map #1 o #3 o #2) descr))) |
223 |> RuleCases.save induction |
223 |> RuleCases.save induction |
224 end; |
224 end; |
225 |
225 |
226 fun add_primrec_i alt_name eqns_atts thy = |
226 fun add_primrec_i alt_name eqns_atts thy = |
227 let |
227 let |
228 val (eqns, atts) = split_list eqns_atts; |
228 val (eqns, atts) = split_list eqns_atts; |
229 val sg = Theory.sign_of thy; |
229 val sg = Theory.sign_of thy; |
230 val dt_info = DatatypePackage.get_datatypes thy; |
230 val dt_info = DatatypePackage.get_datatypes thy; |
231 val rec_eqns = foldr (process_eqn sg) (map snd eqns, []); |
231 val rec_eqns = Library.foldr (process_eqn sg) (map snd eqns, []); |
232 val tnames = distinct (map (#1 o snd) rec_eqns); |
232 val tnames = distinct (map (#1 o snd) rec_eqns); |
233 val dts = find_dts dt_info tnames tnames; |
233 val dts = find_dts dt_info tnames tnames; |
234 val main_fns = |
234 val main_fns = |
235 map (fn (tname, {index, ...}) => |
235 map (fn (tname, {index, ...}) => |
236 (index, |
236 (index, |
237 fst (the (find_first (fn f => #1 (snd f) = tname) rec_eqns)))) |
237 fst (valOf (find_first (fn f => #1 (snd f) = tname) rec_eqns)))) |
238 dts; |
238 dts; |
239 val {descr, rec_names, rec_rewrites, ...} = |
239 val {descr, rec_names, rec_rewrites, ...} = |
240 if null dts then |
240 if null dts then |
241 primrec_err ("datatypes " ^ commas_quote tnames ^ |
241 primrec_err ("datatypes " ^ commas_quote tnames ^ |
242 "\nare not mutually recursive") |
242 "\nare not mutually recursive") |
243 else snd (hd dts); |
243 else snd (hd dts); |
244 val (fnames, fnss) = foldr (process_fun sg descr rec_eqns) |
244 val (fnames, fnss) = Library.foldr (process_fun sg descr rec_eqns) |
245 (main_fns, ([], [])); |
245 (main_fns, ([], [])); |
246 val (fs, defs) = foldr (get_fns fnss) (descr ~~ rec_names, ([], [])); |
246 val (fs, defs) = Library.foldr (get_fns fnss) (descr ~~ rec_names, ([], [])); |
247 val defs' = map (make_def sg fs) defs; |
247 val defs' = map (make_def sg fs) defs; |
248 val names1 = map snd fnames; |
248 val names1 = map snd fnames; |
249 val names2 = map fst rec_eqns; |
249 val names2 = map fst rec_eqns; |
250 val primrec_name = |
250 val primrec_name = |
251 if alt_name = "" then (space_implode "_" (map (Sign.base_name o #1) defs)) else alt_name; |
251 if alt_name = "" then (space_implode "_" (map (Sign.base_name o #1) defs)) else alt_name; |