|
1 (* Title: ZF/Tools/primrec_package.ML |
|
2 ID: $Id$ |
|
3 Author: Stefan Berghofer and Norbert Voelker |
|
4 Copyright 1998 TU Muenchen |
|
5 ZF version by Lawrence C Paulson (Cambridge) |
|
6 |
|
7 Package for defining functions on datatypes by primitive recursion |
|
8 *) |
|
9 |
|
10 signature PRIMREC_PACKAGE = |
|
11 sig |
|
12 val add_primrec_i : (string * term) list -> theory -> theory * thm list |
|
13 val add_primrec : (string * string) list -> theory -> theory * thm list |
|
14 end; |
|
15 |
|
16 structure PrimrecPackage : PRIMREC_PACKAGE = |
|
17 struct |
|
18 |
|
19 exception RecError of string; |
|
20 |
|
21 (* FIXME: move? *) |
|
22 |
|
23 fun dest_eq (Const ("Trueprop", _) $ (Const ("op =", _) $ lhs $ rhs)) = (lhs, rhs) |
|
24 | dest_eq t = raise TERM ("dest_eq", [t]) |
|
25 |
|
26 fun primrec_err s = error ("Primrec definition error:\n" ^ s); |
|
27 |
|
28 fun primrec_eq_err sign s eq = |
|
29 primrec_err (s ^ "\nin equation\n" ^ Sign.string_of_term sign eq); |
|
30 |
|
31 (* preprocessing of equations *) |
|
32 |
|
33 (*rec_fn_opt records equations already noted for this function*) |
|
34 fun process_eqn thy (eq, rec_fn_opt) = |
|
35 let |
|
36 val (lhs, rhs) = if null (term_vars eq) then |
|
37 dest_eq eq handle _ => raise RecError "not a proper equation" |
|
38 else raise RecError "illegal schematic variable(s)"; |
|
39 |
|
40 val (recfun, args) = strip_comb lhs; |
|
41 val (fname, ftype) = dest_Const recfun handle _ => |
|
42 raise RecError "function is not declared as constant in theory"; |
|
43 |
|
44 val (ls_frees, rest) = take_prefix is_Free args; |
|
45 val (middle, rs_frees) = take_suffix is_Free rest; |
|
46 |
|
47 val (constr, cargs_frees) = |
|
48 if null middle then raise RecError "constructor missing" |
|
49 else strip_comb (hd middle); |
|
50 val (cname, _) = dest_Const constr |
|
51 handle _ => raise RecError "ill-formed constructor"; |
|
52 val con_info = the (Symtab.lookup (ConstructorsData.get thy, cname)) |
|
53 handle _ => |
|
54 raise RecError "cannot determine datatype associated with function" |
|
55 |
|
56 val (ls, cargs, rs) = (map dest_Free ls_frees, |
|
57 map dest_Free cargs_frees, |
|
58 map dest_Free rs_frees) |
|
59 handle _ => raise RecError "illegal argument in pattern"; |
|
60 val lfrees = ls @ rs @ cargs; |
|
61 |
|
62 (*Constructor, frees to left of pattern, pattern variables, |
|
63 frees to right of pattern, rhs of equation, full original equation. *) |
|
64 val new_eqn = (cname, (rhs, cargs, eq)) |
|
65 |
|
66 in |
|
67 if not (null (duplicates lfrees)) then |
|
68 raise RecError "repeated variable name in pattern" |
|
69 else if not ((map dest_Free (term_frees rhs)) subset lfrees) then |
|
70 raise RecError "extra variables on rhs" |
|
71 else if length middle > 1 then |
|
72 raise RecError "more than one non-variable in pattern" |
|
73 else case rec_fn_opt of |
|
74 None => Some (fname, ftype, ls, rs, con_info, [new_eqn]) |
|
75 | Some (fname', _, ls', rs', con_info': constructor_info, eqns) => |
|
76 if is_some (assoc (eqns, cname)) then |
|
77 raise RecError "constructor already occurred as pattern" |
|
78 else if (ls <> ls') orelse (rs <> rs') then |
|
79 raise RecError "non-recursive arguments are inconsistent" |
|
80 else if #big_rec_name con_info <> #big_rec_name con_info' then |
|
81 raise RecError ("Mixed datatypes for function " ^ fname) |
|
82 else if fname <> fname' then |
|
83 raise RecError ("inconsistent functions for datatype " ^ |
|
84 #big_rec_name con_info) |
|
85 else Some (fname, ftype, ls, rs, con_info, new_eqn::eqns) |
|
86 end |
|
87 handle RecError s => primrec_eq_err (sign_of thy) s eq; |
|
88 |
|
89 |
|
90 (*Instantiates a recursor equation with constructor arguments*) |
|
91 fun inst_recursor ((_ $ constr, rhs), cargs') = |
|
92 subst_atomic (#2 (strip_comb constr) ~~ map Free cargs') rhs; |
|
93 |
|
94 |
|
95 (*Convert a list of recursion equations into a recursor call*) |
|
96 fun process_fun thy (fname, ftype, ls, rs, con_info: constructor_info, eqns) = |
|
97 let |
|
98 val fconst = Const(fname, ftype) |
|
99 val fabs = list_comb (fconst, map Free ls @ [Bound 0] @ map Free rs) |
|
100 and {big_rec_name, constructors, rec_rewrites, ...} = con_info |
|
101 |
|
102 (*Replace X_rec(args,t) by fname(ls,t,rs) *) |
|
103 fun use_fabs (_ $ t) = subst_bound (t, fabs) |
|
104 | use_fabs t = t |
|
105 |
|
106 val cnames = map (#1 o dest_Const) constructors |
|
107 and recursor_pairs = map (dest_eq o concl_of) rec_rewrites |
|
108 |
|
109 fun absterm (Free(a,T), body) = absfree (a,T,body) |
|
110 | absterm (t,body) = Abs("rec", iT, abstract_over (t, body)) |
|
111 |
|
112 (*Translate rec equations into function arguments suitable for recursor. |
|
113 Missing cases are replaced by 0 and all cases are put into order.*) |
|
114 fun add_case ((cname, recursor_pair), cases) = |
|
115 let val (rhs, recursor_rhs, eq) = |
|
116 case assoc (eqns, cname) of |
|
117 None => (warning ("no equation for constructor " ^ cname ^ |
|
118 "\nin definition of function " ^ fname); |
|
119 (Const ("0", iT), #2 recursor_pair, Const ("0", iT))) |
|
120 | Some (rhs, cargs', eq) => |
|
121 (rhs, inst_recursor (recursor_pair, cargs'), eq) |
|
122 val allowed_terms = map use_fabs (#2 (strip_comb recursor_rhs)) |
|
123 val abs = foldr absterm (allowed_terms, rhs) |
|
124 in |
|
125 if !Ind_Syntax.trace then |
|
126 writeln ("recursor_rhs = " ^ |
|
127 Sign.string_of_term (sign_of thy) recursor_rhs ^ |
|
128 "\nabs = " ^ Sign.string_of_term (sign_of thy) abs) |
|
129 else(); |
|
130 if Logic.occs (fconst, abs) then |
|
131 primrec_eq_err (sign_of thy) |
|
132 ("illegal recursive occurrences of " ^ fname) |
|
133 eq |
|
134 else abs :: cases |
|
135 end |
|
136 |
|
137 val recursor = head_of (#1 (hd recursor_pairs)) |
|
138 |
|
139 (** make definition **) |
|
140 |
|
141 (*the recursive argument*) |
|
142 val rec_arg = Free (variant (map #1 (ls@rs)) (Sign.base_name big_rec_name), |
|
143 iT) |
|
144 |
|
145 val def_tm = Logic.mk_equals |
|
146 (subst_bound (rec_arg, fabs), |
|
147 list_comb (recursor, |
|
148 foldr add_case (cnames ~~ recursor_pairs, [])) |
|
149 $ rec_arg) |
|
150 |
|
151 in |
|
152 writeln ("def = " ^ Sign.string_of_term (sign_of thy) def_tm); |
|
153 (Sign.base_name fname ^ "_" ^ Sign.base_name big_rec_name ^ "_def", |
|
154 def_tm) |
|
155 end; |
|
156 |
|
157 |
|
158 |
|
159 (* prepare functions needed for definitions *) |
|
160 |
|
161 (*Each equation is paired with an optional name, which is "_" (ML wildcard) |
|
162 if omitted.*) |
|
163 fun add_primrec_i recursion_eqns thy = |
|
164 let |
|
165 val Some (fname, ftype, ls, rs, con_info, eqns) = |
|
166 foldr (process_eqn thy) (map snd recursion_eqns, None); |
|
167 val def = process_fun thy (fname, ftype, ls, rs, con_info, eqns) |
|
168 val thy' = thy |> Theory.add_path (Sign.base_name (#1 def)) |
|
169 |> Theory.add_defs_i [def] |
|
170 val rewrites = get_axiom thy' (#1 def) :: |
|
171 map mk_meta_eq (#rec_rewrites con_info) |
|
172 val _ = writeln ("Proving equations for primrec function " ^ fname); |
|
173 val char_thms = |
|
174 map (fn (_,t) => |
|
175 prove_goalw_cterm rewrites |
|
176 (Ind_Syntax.traceIt "next primrec equation = " |
|
177 (cterm_of (sign_of thy') t)) |
|
178 (fn _ => [rtac refl 1])) |
|
179 recursion_eqns; |
|
180 val tsimps = Attribute.tthms_of char_thms; |
|
181 val thy'' = thy' |
|
182 |> PureThy.add_tthmss [(("simps", tsimps), [Simplifier.simp_add_global])] |
|
183 |> PureThy.add_tthms (map (rpair []) |
|
184 (filter_out (equal "_" o fst) (map fst recursion_eqns ~~ tsimps))) |
|
185 |> Theory.parent_path; |
|
186 in |
|
187 (thy'', char_thms) |
|
188 end; |
|
189 |
|
190 fun add_primrec eqns thy = |
|
191 add_primrec_i (map (apsnd (readtm (sign_of thy) propT)) eqns) thy; |
|
192 |
|
193 end; |