|
1 (* Title: HOLCF/Tools/fixrec.ML |
|
2 Author: Amber Telfer and Brian Huffman |
|
3 |
|
4 Recursive function definition package for HOLCF. |
|
5 *) |
|
6 |
|
7 signature FIXREC = |
|
8 sig |
|
9 val add_fixrec: bool -> (binding * typ option * mixfix) list |
|
10 -> (Attrib.binding * term) list -> local_theory -> local_theory |
|
11 val add_fixrec_cmd: bool -> (binding * string option * mixfix) list |
|
12 -> (Attrib.binding * string) list -> local_theory -> local_theory |
|
13 val add_fixpat: Thm.binding * term list -> theory -> theory |
|
14 val add_fixpat_cmd: Attrib.binding * string list -> theory -> theory |
|
15 val add_matchers: (string * string) list -> theory -> theory |
|
16 val setup: theory -> theory |
|
17 end; |
|
18 |
|
19 structure Fixrec :> FIXREC = |
|
20 struct |
|
21 |
|
22 val def_cont_fix_eq = @{thm def_cont_fix_eq}; |
|
23 val def_cont_fix_ind = @{thm def_cont_fix_ind}; |
|
24 |
|
25 |
|
26 fun fixrec_err s = error ("fixrec definition error:\n" ^ s); |
|
27 fun fixrec_eq_err thy s eq = |
|
28 fixrec_err (s ^ "\nin\n" ^ quote (Syntax.string_of_term_global thy eq)); |
|
29 |
|
30 (*************************************************************************) |
|
31 (***************************** building types ****************************) |
|
32 (*************************************************************************) |
|
33 |
|
34 (* ->> is taken from holcf_logic.ML *) |
|
35 fun cfunT (T, U) = Type(@{type_name "->"}, [T, U]); |
|
36 |
|
37 infixr 6 ->>; val (op ->>) = cfunT; |
|
38 |
|
39 fun cfunsT (Ts, U) = foldr cfunT U Ts; |
|
40 |
|
41 fun dest_cfunT (Type(@{type_name "->"}, [T, U])) = (T, U) |
|
42 | dest_cfunT T = raise TYPE ("dest_cfunT", [T], []); |
|
43 |
|
44 fun binder_cfun (Type(@{type_name "->"},[T, U])) = T :: binder_cfun U |
|
45 | binder_cfun _ = []; |
|
46 |
|
47 fun body_cfun (Type(@{type_name "->"},[T, U])) = body_cfun U |
|
48 | body_cfun T = T; |
|
49 |
|
50 fun strip_cfun T : typ list * typ = |
|
51 (binder_cfun T, body_cfun T); |
|
52 |
|
53 fun maybeT T = Type(@{type_name "maybe"}, [T]); |
|
54 |
|
55 fun dest_maybeT (Type(@{type_name "maybe"}, [T])) = T |
|
56 | dest_maybeT T = raise TYPE ("dest_maybeT", [T], []); |
|
57 |
|
58 fun tupleT [] = HOLogic.unitT |
|
59 | tupleT [T] = T |
|
60 | tupleT (T :: Ts) = HOLogic.mk_prodT (T, tupleT Ts); |
|
61 |
|
62 fun matchT (T, U) = |
|
63 body_cfun T ->> cfunsT (binder_cfun T, U) ->> U; |
|
64 |
|
65 |
|
66 (*************************************************************************) |
|
67 (***************************** building terms ****************************) |
|
68 (*************************************************************************) |
|
69 |
|
70 val mk_trp = HOLogic.mk_Trueprop; |
|
71 |
|
72 (* splits a cterm into the right and lefthand sides of equality *) |
|
73 fun dest_eqs t = HOLogic.dest_eq (HOLogic.dest_Trueprop t); |
|
74 |
|
75 (* similar to Thm.head_of, but for continuous application *) |
|
76 fun chead_of (Const(@{const_name Rep_CFun},_)$f$t) = chead_of f |
|
77 | chead_of u = u; |
|
78 |
|
79 fun capply_const (S, T) = |
|
80 Const(@{const_name Rep_CFun}, (S ->> T) --> (S --> T)); |
|
81 |
|
82 fun cabs_const (S, T) = |
|
83 Const(@{const_name Abs_CFun}, (S --> T) --> (S ->> T)); |
|
84 |
|
85 fun mk_cabs t = |
|
86 let val T = Term.fastype_of t |
|
87 in cabs_const (Term.domain_type T, Term.range_type T) $ t end |
|
88 |
|
89 fun mk_capply (t, u) = |
|
90 let val (S, T) = |
|
91 case Term.fastype_of t of |
|
92 Type(@{type_name "->"}, [S, T]) => (S, T) |
|
93 | _ => raise TERM ("mk_capply " ^ ML_Syntax.print_list ML_Syntax.print_term [t, u], [t, u]); |
|
94 in capply_const (S, T) $ t $ u end; |
|
95 |
|
96 infix 0 ==; val (op ==) = Logic.mk_equals; |
|
97 infix 1 ===; val (op ===) = HOLogic.mk_eq; |
|
98 infix 9 ` ; val (op `) = mk_capply; |
|
99 |
|
100 (* builds the expression (LAM v. rhs) *) |
|
101 fun big_lambda v rhs = |
|
102 cabs_const (Term.fastype_of v, Term.fastype_of rhs) $ Term.lambda v rhs; |
|
103 |
|
104 (* builds the expression (LAM v1 v2 .. vn. rhs) *) |
|
105 fun big_lambdas [] rhs = rhs |
|
106 | big_lambdas (v::vs) rhs = big_lambda v (big_lambdas vs rhs); |
|
107 |
|
108 fun mk_return t = |
|
109 let val T = Term.fastype_of t |
|
110 in Const(@{const_name Fixrec.return}, T ->> maybeT T) ` t end; |
|
111 |
|
112 fun mk_bind (t, u) = |
|
113 let val (T, mU) = dest_cfunT (Term.fastype_of u); |
|
114 val bindT = maybeT T ->> (T ->> mU) ->> mU; |
|
115 in Const(@{const_name Fixrec.bind}, bindT) ` t ` u end; |
|
116 |
|
117 fun mk_mplus (t, u) = |
|
118 let val mT = Term.fastype_of t |
|
119 in Const(@{const_name Fixrec.mplus}, mT ->> mT ->> mT) ` t ` u end; |
|
120 |
|
121 fun mk_run t = |
|
122 let val mT = Term.fastype_of t |
|
123 val T = dest_maybeT mT |
|
124 in Const(@{const_name Fixrec.run}, mT ->> T) ` t end; |
|
125 |
|
126 fun mk_fix t = |
|
127 let val (T, _) = dest_cfunT (Term.fastype_of t) |
|
128 in Const(@{const_name fix}, (T ->> T) ->> T) ` t end; |
|
129 |
|
130 fun mk_cont t = |
|
131 let val T = Term.fastype_of t |
|
132 in Const(@{const_name cont}, T --> HOLogic.boolT) $ t end; |
|
133 |
|
134 val mk_fst = HOLogic.mk_fst |
|
135 val mk_snd = HOLogic.mk_snd |
|
136 |
|
137 (* builds the expression (v1,v2,..,vn) *) |
|
138 fun mk_tuple [] = HOLogic.unit |
|
139 | mk_tuple (t::[]) = t |
|
140 | mk_tuple (t::ts) = HOLogic.mk_prod (t, mk_tuple ts); |
|
141 |
|
142 (* builds the expression (%(v1,v2,..,vn). rhs) *) |
|
143 fun lambda_tuple [] rhs = Term.lambda (Free("unit", HOLogic.unitT)) rhs |
|
144 | lambda_tuple (v::[]) rhs = Term.lambda v rhs |
|
145 | lambda_tuple (v::vs) rhs = |
|
146 HOLogic.mk_split (Term.lambda v (lambda_tuple vs rhs)); |
|
147 |
|
148 |
|
149 (*************************************************************************) |
|
150 (************* fixed-point definitions and unfolding theorems ************) |
|
151 (*************************************************************************) |
|
152 |
|
153 fun add_fixdefs |
|
154 (fixes : ((binding * typ) * mixfix) list) |
|
155 (spec : (Attrib.binding * term) list) |
|
156 (lthy : local_theory) = |
|
157 let |
|
158 val thy = ProofContext.theory_of lthy; |
|
159 val names = map (Binding.name_of o fst o fst) fixes; |
|
160 val all_names = space_implode "_" names; |
|
161 val (lhss,rhss) = ListPair.unzip (map (dest_eqs o snd) spec); |
|
162 val functional = lambda_tuple lhss (mk_tuple rhss); |
|
163 val fixpoint = mk_fix (mk_cabs functional); |
|
164 |
|
165 val cont_thm = |
|
166 Goal.prove lthy [] [] (mk_trp (mk_cont functional)) |
|
167 (K (simp_tac (local_simpset_of lthy) 1)); |
|
168 |
|
169 fun one_def (l as Free(n,_)) r = |
|
170 let val b = Long_Name.base_name n |
|
171 in ((Binding.name (b^"_def"), []), r) end |
|
172 | one_def _ _ = fixrec_err "fixdefs: lhs not of correct form"; |
|
173 fun defs [] _ = [] |
|
174 | defs (l::[]) r = [one_def l r] |
|
175 | defs (l::ls) r = one_def l (mk_fst r) :: defs ls (mk_snd r); |
|
176 val fixdefs = defs lhss fixpoint; |
|
177 val define_all = fold_map (LocalTheory.define Thm.definitionK); |
|
178 val (fixdef_thms : (term * (string * thm)) list, lthy') = lthy |
|
179 |> define_all (map (apfst fst) fixes ~~ fixdefs); |
|
180 fun pair_equalI (thm1, thm2) = @{thm Pair_equalI} OF [thm1, thm2]; |
|
181 val tuple_fixdef_thm = foldr1 pair_equalI (map (snd o snd) fixdef_thms); |
|
182 val P = Var (("P", 0), map Term.fastype_of lhss ---> HOLogic.boolT); |
|
183 val predicate = lambda_tuple lhss (list_comb (P, lhss)); |
|
184 val tuple_induct_thm = (def_cont_fix_ind OF [tuple_fixdef_thm, cont_thm]) |
|
185 |> Drule.instantiate' [] [SOME (Thm.cterm_of thy predicate)] |
|
186 |> LocalDefs.unfold lthy @{thms split_paired_all split_conv split_strict}; |
|
187 val tuple_unfold_thm = (def_cont_fix_eq OF [tuple_fixdef_thm, cont_thm]) |
|
188 |> LocalDefs.unfold lthy' @{thms split_conv}; |
|
189 fun unfolds [] thm = [] |
|
190 | unfolds (n::[]) thm = [(n^"_unfold", thm)] |
|
191 | unfolds (n::ns) thm = let |
|
192 val thmL = thm RS @{thm Pair_eqD1}; |
|
193 val thmR = thm RS @{thm Pair_eqD2}; |
|
194 in (n^"_unfold", thmL) :: unfolds ns thmR end; |
|
195 val unfold_thms = unfolds names tuple_unfold_thm; |
|
196 fun mk_note (n, thm) = ((Binding.name n, []), [thm]); |
|
197 val (thmss, lthy'') = lthy' |
|
198 |> fold_map (LocalTheory.note Thm.generatedK o mk_note) |
|
199 ((all_names ^ "_induct", tuple_induct_thm) :: unfold_thms); |
|
200 in |
|
201 (lthy'', names, fixdef_thms, map snd unfold_thms) |
|
202 end; |
|
203 |
|
204 (*************************************************************************) |
|
205 (*********** monadic notation and pattern matching compilation ***********) |
|
206 (*************************************************************************) |
|
207 |
|
208 structure FixrecMatchData = TheoryDataFun ( |
|
209 type T = string Symtab.table; |
|
210 val empty = Symtab.empty; |
|
211 val copy = I; |
|
212 val extend = I; |
|
213 fun merge _ tabs : T = Symtab.merge (K true) tabs; |
|
214 ); |
|
215 |
|
216 (* associate match functions with pattern constants *) |
|
217 fun add_matchers ms = FixrecMatchData.map (fold Symtab.update ms); |
|
218 |
|
219 fun taken_names (t : term) : bstring list = |
|
220 let |
|
221 fun taken (Const(a,_), bs) = insert (op =) (Long_Name.base_name a) bs |
|
222 | taken (Free(a,_) , bs) = insert (op =) a bs |
|
223 | taken (f $ u , bs) = taken (f, taken (u, bs)) |
|
224 | taken (Abs(a,_,t), bs) = taken (t, insert (op =) a bs) |
|
225 | taken (_ , bs) = bs; |
|
226 in |
|
227 taken (t, []) |
|
228 end; |
|
229 |
|
230 (* builds a monadic term for matching a constructor pattern *) |
|
231 fun pre_build match_name pat rhs vs taken = |
|
232 case pat of |
|
233 Const(@{const_name Rep_CFun},_)$f$(v as Free(n,T)) => |
|
234 pre_build match_name f rhs (v::vs) taken |
|
235 | Const(@{const_name Rep_CFun},_)$f$x => |
|
236 let val (rhs', v, taken') = pre_build match_name x rhs [] taken; |
|
237 in pre_build match_name f rhs' (v::vs) taken' end |
|
238 | Const(c,T) => |
|
239 let |
|
240 val n = Name.variant taken "v"; |
|
241 fun result_type (Type(@{type_name "->"},[_,T])) (x::xs) = result_type T xs |
|
242 | result_type T _ = T; |
|
243 val v = Free(n, result_type T vs); |
|
244 val m = Const(match_name c, matchT (T, fastype_of rhs)); |
|
245 val k = big_lambdas vs rhs; |
|
246 in |
|
247 (m`v`k, v, n::taken) |
|
248 end |
|
249 | Free(n,_) => fixrec_err ("expected constructor, found free variable " ^ quote n) |
|
250 | _ => fixrec_err "pre_build: invalid pattern"; |
|
251 |
|
252 (* builds a monadic term for matching a function definition pattern *) |
|
253 (* returns (name, arity, matcher) *) |
|
254 fun building match_name pat rhs vs taken = |
|
255 case pat of |
|
256 Const(@{const_name Rep_CFun}, _)$f$(v as Free(n,T)) => |
|
257 building match_name f rhs (v::vs) taken |
|
258 | Const(@{const_name Rep_CFun}, _)$f$x => |
|
259 let val (rhs', v, taken') = pre_build match_name x rhs [] taken; |
|
260 in building match_name f rhs' (v::vs) taken' end |
|
261 | Free(_,_) => ((pat, length vs), big_lambdas vs rhs) |
|
262 | Const(_,_) => ((pat, length vs), big_lambdas vs rhs) |
|
263 | _ => fixrec_err ("function is not declared as constant in theory: " |
|
264 ^ ML_Syntax.print_term pat); |
|
265 |
|
266 fun strip_alls t = |
|
267 if Logic.is_all t then strip_alls (snd (Logic.dest_all t)) else t; |
|
268 |
|
269 fun match_eq match_name eq = |
|
270 let |
|
271 val (lhs,rhs) = dest_eqs (Logic.strip_imp_concl (strip_alls eq)); |
|
272 in |
|
273 building match_name lhs (mk_return rhs) [] (taken_names eq) |
|
274 end; |
|
275 |
|
276 (* returns the sum (using +++) of the terms in ms *) |
|
277 (* also applies "run" to the result! *) |
|
278 fun fatbar arity ms = |
|
279 let |
|
280 fun LAM_Ts 0 t = ([], Term.fastype_of t) |
|
281 | LAM_Ts n (_ $ Abs(_,T,t)) = |
|
282 let val (Ts, U) = LAM_Ts (n-1) t in (T::Ts, U) end |
|
283 | LAM_Ts _ _ = fixrec_err "fatbar: internal error, not enough LAMs"; |
|
284 fun unLAM 0 t = t |
|
285 | unLAM n (_$Abs(_,_,t)) = unLAM (n-1) t |
|
286 | unLAM _ _ = fixrec_err "fatbar: internal error, not enough LAMs"; |
|
287 fun reLAM ([], U) t = t |
|
288 | reLAM (T::Ts, U) t = reLAM (Ts, T ->> U) (cabs_const(T,U)$Abs("",T,t)); |
|
289 val msum = foldr1 mk_mplus (map (unLAM arity) ms); |
|
290 val (Ts, U) = LAM_Ts arity (hd ms) |
|
291 in |
|
292 reLAM (rev Ts, dest_maybeT U) (mk_run msum) |
|
293 end; |
|
294 |
|
295 (* this is the pattern-matching compiler function *) |
|
296 fun compile_pats match_name eqs = |
|
297 let |
|
298 val (((n::names),(a::arities)),mats) = |
|
299 apfst ListPair.unzip (ListPair.unzip (map (match_eq match_name) eqs)); |
|
300 val cname = if forall (fn x => n=x) names then n |
|
301 else fixrec_err "all equations in block must define the same function"; |
|
302 val arity = if forall (fn x => a=x) arities then a |
|
303 else fixrec_err "all equations in block must have the same arity"; |
|
304 val rhs = fatbar arity mats; |
|
305 in |
|
306 mk_trp (cname === rhs) |
|
307 end; |
|
308 |
|
309 (*************************************************************************) |
|
310 (********************** Proving associated theorems **********************) |
|
311 (*************************************************************************) |
|
312 |
|
313 (* proves a block of pattern matching equations as theorems, using unfold *) |
|
314 fun make_simps lthy (unfold_thm, eqns : (Attrib.binding * term) list) = |
|
315 let |
|
316 val tacs = |
|
317 [rtac (unfold_thm RS @{thm ssubst_lhs}) 1, |
|
318 asm_simp_tac (local_simpset_of lthy) 1]; |
|
319 fun prove_term t = Goal.prove lthy [] [] t (K (EVERY tacs)); |
|
320 fun prove_eqn (bind, eqn_t) = (bind, prove_term eqn_t); |
|
321 in |
|
322 map prove_eqn eqns |
|
323 end; |
|
324 |
|
325 (*************************************************************************) |
|
326 (************************* Main fixrec function **************************) |
|
327 (*************************************************************************) |
|
328 |
|
329 local |
|
330 (* code adapted from HOL/Tools/primrec.ML *) |
|
331 |
|
332 fun gen_fixrec |
|
333 (set_group : bool) |
|
334 prep_spec |
|
335 (strict : bool) |
|
336 raw_fixes |
|
337 raw_spec |
|
338 (lthy : local_theory) = |
|
339 let |
|
340 val (fixes : ((binding * typ) * mixfix) list, |
|
341 spec : (Attrib.binding * term) list) = |
|
342 fst (prep_spec raw_fixes raw_spec lthy); |
|
343 val chead_of_spec = |
|
344 chead_of o fst o dest_eqs o Logic.strip_imp_concl o strip_alls o snd; |
|
345 fun name_of (Free (n, _)) = n |
|
346 | name_of t = fixrec_err ("unknown term"); |
|
347 val all_names = map (name_of o chead_of_spec) spec; |
|
348 val names = distinct (op =) all_names; |
|
349 fun block_of_name n = |
|
350 map_filter |
|
351 (fn (m,eq) => if m = n then SOME eq else NONE) |
|
352 (all_names ~~ spec); |
|
353 val blocks = map block_of_name names; |
|
354 |
|
355 val matcher_tab = FixrecMatchData.get (ProofContext.theory_of lthy); |
|
356 fun match_name c = |
|
357 case Symtab.lookup matcher_tab c of SOME m => m |
|
358 | NONE => fixrec_err ("unknown pattern constructor: " ^ c); |
|
359 |
|
360 val matches = map (compile_pats match_name) (map (map snd) blocks); |
|
361 val spec' = map (pair Attrib.empty_binding) matches; |
|
362 val (lthy', cnames, fixdef_thms, unfold_thms) = |
|
363 add_fixdefs fixes spec' lthy; |
|
364 in |
|
365 if strict then let (* only prove simp rules if strict = true *) |
|
366 val simps : (Attrib.binding * thm) list list = |
|
367 map (make_simps lthy') (unfold_thms ~~ blocks); |
|
368 fun mk_bind n : Attrib.binding = |
|
369 (Binding.name (n ^ "_simps"), |
|
370 [Attrib.internal (K Simplifier.simp_add)]); |
|
371 val simps1 : (Attrib.binding * thm list) list = |
|
372 map (fn (n,xs) => (mk_bind n, map snd xs)) (names ~~ simps); |
|
373 val simps2 : (Attrib.binding * thm list) list = |
|
374 map (apsnd (fn thm => [thm])) (List.concat simps); |
|
375 val (_, lthy'') = lthy' |
|
376 |> fold_map (LocalTheory.note Thm.generatedK) (simps1 @ simps2); |
|
377 in |
|
378 lthy'' |
|
379 end |
|
380 else lthy' |
|
381 end; |
|
382 |
|
383 in |
|
384 |
|
385 val add_fixrec = gen_fixrec false Specification.check_spec; |
|
386 val add_fixrec_cmd = gen_fixrec true Specification.read_spec; |
|
387 |
|
388 end; (* local *) |
|
389 |
|
390 (*************************************************************************) |
|
391 (******************************** Fixpat *********************************) |
|
392 (*************************************************************************) |
|
393 |
|
394 fun fix_pat thy t = |
|
395 let |
|
396 val T = fastype_of t; |
|
397 val eq = mk_trp (HOLogic.eq_const T $ t $ Var (("x",0),T)); |
|
398 val cname = case chead_of t of Const(c,_) => c | _ => |
|
399 fixrec_err "function is not declared as constant in theory"; |
|
400 val unfold_thm = PureThy.get_thm thy (cname^"_unfold"); |
|
401 val simp = Goal.prove_global thy [] [] eq |
|
402 (fn _ => EVERY [stac unfold_thm 1, simp_tac (simpset_of thy) 1]); |
|
403 in simp end; |
|
404 |
|
405 fun gen_add_fixpat prep_term prep_attrib ((name, srcs), strings) thy = |
|
406 let |
|
407 val atts = map (prep_attrib thy) srcs; |
|
408 val ts = map (prep_term thy) strings; |
|
409 val simps = map (fix_pat thy) ts; |
|
410 in |
|
411 (snd o PureThy.add_thmss [((name, simps), atts)]) thy |
|
412 end; |
|
413 |
|
414 val add_fixpat = gen_add_fixpat Sign.cert_term (K I); |
|
415 val add_fixpat_cmd = gen_add_fixpat Syntax.read_term_global Attrib.attribute; |
|
416 |
|
417 |
|
418 (*************************************************************************) |
|
419 (******************************** Parsers ********************************) |
|
420 (*************************************************************************) |
|
421 |
|
422 local structure P = OuterParse and K = OuterKeyword in |
|
423 |
|
424 val _ = OuterSyntax.local_theory "fixrec" "define recursive functions (HOLCF)" K.thy_decl |
|
425 ((P.opt_keyword "permissive" >> not) -- P.fixes -- SpecParse.where_alt_specs |
|
426 >> (fn ((strict, fixes), specs) => add_fixrec_cmd strict fixes specs)); |
|
427 |
|
428 val _ = OuterSyntax.command "fixpat" "define rewrites for fixrec functions" K.thy_decl |
|
429 (SpecParse.specs >> (Toplevel.theory o add_fixpat_cmd)); |
|
430 |
|
431 end; |
|
432 |
|
433 val setup = FixrecMatchData.init; |
|
434 |
|
435 end; |