5 *) |
5 *) |
6 |
6 |
7 signature PREDICATE_COMPILE = |
7 signature PREDICATE_COMPILE = |
8 sig |
8 sig |
9 type mode = int list option list * int list |
9 type mode = int list option list * int list |
10 val prove_equation: string -> mode option -> theory -> theory |
10 val add_equations_of: string list -> theory -> theory |
11 val intro_rule: theory -> string -> mode -> thm |
11 val register_predicate : (thm list * thm * int) -> theory -> theory |
12 val elim_rule: theory -> string -> mode -> thm |
12 val predfun_intro_of: theory -> string -> mode -> thm |
13 val strip_intro_concl: term -> int -> term * (term list * term list) |
13 val predfun_elim_of: theory -> string -> mode -> thm |
14 val modename_of: theory -> string -> mode -> string |
14 val strip_intro_concl: int -> term -> term * (term list * term list) |
|
15 val predfun_name_of: theory -> string -> mode -> string |
|
16 val all_preds_of : theory -> string list |
15 val modes_of: theory -> string -> mode list |
17 val modes_of: theory -> string -> mode list |
16 val pred_intros: theory -> string -> thm list |
18 val intros_of: theory -> string -> thm list |
17 val get_nparams: theory -> string -> int |
19 val nparams_of: theory -> string -> int |
18 val setup: theory -> theory |
20 val setup: theory -> theory |
19 val code_pred: string -> Proof.context -> Proof.state |
21 val code_pred: string -> Proof.context -> Proof.state |
20 val code_pred_cmd: string -> Proof.context -> Proof.state |
22 val code_pred_cmd: string -> Proof.context -> Proof.state |
21 val print_alternative_rules: theory -> theory (*FIXME diagnostic command?*) |
23 (* val print_alternative_rules: theory -> theory (*FIXME diagnostic command?*) *) |
22 val do_proofs: bool ref |
24 val do_proofs: bool ref |
23 val analyze_compr: theory -> term -> term |
25 val analyze_compr: theory -> term -> term |
24 val eval_ref: (unit -> term Predicate.pred) option ref |
26 val eval_ref: (unit -> term Predicate.pred) option ref |
|
27 (* val extend : (key -> 'a * key list) -> key list -> 'a Graph.T -> 'a Graph.T *) |
|
28 val add_equations : string -> theory -> theory |
25 end; |
29 end; |
26 |
30 |
27 structure Predicate_Compile : PREDICATE_COMPILE = |
31 structure Predicate_Compile : PREDICATE_COMPILE = |
28 struct |
32 struct |
29 |
33 |
30 (** auxiliary **) |
34 (** auxiliary **) |
31 |
35 |
32 (* debug stuff *) |
36 (* debug stuff *) |
33 |
37 (* |
34 fun makestring _ = "?"; (* FIXME dummy *) |
38 fun makestring _ = "?"; |
|
39 *) (* FIXME dummy *) |
35 |
40 |
36 fun tracing s = (if ! Toplevel.debug then Output.tracing s else ()); |
41 fun tracing s = (if ! Toplevel.debug then Output.tracing s else ()); |
37 |
42 |
38 fun print_tac s = (if ! Toplevel.debug then Tactical.print_tac s else Seq.single); |
43 fun print_tac s = (if ! Toplevel.debug then Tactical.print_tac s else Seq.single); |
39 fun debug_tac msg = (fn st => (tracing msg; Seq.single st)); |
44 fun debug_tac msg = (fn st => (tracing msg; Seq.single st)); |
40 |
45 |
41 val do_proofs = ref true; |
46 val do_proofs = ref true; |
42 |
47 |
|
48 fun mycheat_tac thy i st = |
|
49 (Tactic.rtac (SkipProof.make_thm thy (Var (("A", 0), propT))) i) st |
|
50 |
|
51 (* reference to preprocessing of InductiveSet package *) |
|
52 |
|
53 val ind_set_codegen_preproc = InductiveSetPackage.codegen_preproc; |
43 |
54 |
44 (** fundamentals **) |
55 (** fundamentals **) |
45 |
56 |
46 (* syntactic operations *) |
57 (* syntactic operations *) |
47 |
58 |
96 HOLogic.boolT --> mk_pred_enumT HOLogic.unitT) $ cond; |
111 HOLogic.boolT --> mk_pred_enumT HOLogic.unitT) $ cond; |
97 |
112 |
98 fun mk_not_pred t = let val T = mk_pred_enumT HOLogic.unitT |
113 fun mk_not_pred t = let val T = mk_pred_enumT HOLogic.unitT |
99 in Const (@{const_name Predicate.not_pred}, T --> T) $ t end |
114 in Const (@{const_name Predicate.not_pred}, T --> T) $ t end |
100 |
115 |
|
116 (* destruction of intro rules *) |
|
117 |
|
118 (* FIXME: look for other place where this functionality was used before *) |
|
119 fun strip_intro_concl nparams intro = let |
|
120 val _ $ u = Logic.strip_imp_concl intro |
|
121 val (pred, all_args) = strip_comb u |
|
122 val (params, args) = chop nparams all_args |
|
123 in (pred, (params, args)) end |
101 |
124 |
102 (* data structures *) |
125 (* data structures *) |
103 |
126 |
104 type mode = int list option list * int list; |
127 type mode = int list option list * int list; (*pmode FIMXE*) |
105 |
128 |
106 val mode_ord = prod_ord (list_ord (option_ord (list_ord int_ord))) (list_ord int_ord); |
129 fun string_of_mode (iss, is) = space_implode " -> " (map |
107 |
130 (fn NONE => "X" |
108 structure PredModetab = TableFun( |
131 | SOME js => enclose "[" "]" (commas (map string_of_int js))) |
109 type key = string * mode |
132 (iss @ [SOME is])); |
110 val ord = prod_ord fast_string_ord mode_ord |
133 |
111 ); |
134 fun print_modes modes = Output.tracing ("Inferred modes:\n" ^ |
112 |
135 cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map |
113 |
136 string_of_mode ms)) modes)); |
114 (*FIXME scrap boilerplate*) |
137 |
115 |
138 datatype predfun_data = PredfunData of { |
116 structure IndCodegenData = TheoryDataFun |
139 name : string, |
|
140 definition : thm, |
|
141 intro : thm, |
|
142 elim : thm |
|
143 }; |
|
144 |
|
145 fun rep_predfun_data (PredfunData data) = data; |
|
146 fun mk_predfun_data (name, definition, intro, elim) = |
|
147 PredfunData {name = name, definition = definition, intro = intro, elim = elim} |
|
148 |
|
149 datatype pred_data = PredData of { |
|
150 intros : thm list, |
|
151 elim : thm option, |
|
152 nparams : int, |
|
153 functions : (mode * predfun_data) list |
|
154 }; |
|
155 |
|
156 fun rep_pred_data (PredData data) = data; |
|
157 fun mk_pred_data ((intros, elim, nparams), functions) = |
|
158 PredData {intros = intros, elim = elim, nparams = nparams, functions = functions} |
|
159 fun map_pred_data f (PredData {intros, elim, nparams, functions}) = |
|
160 mk_pred_data (f ((intros, elim, nparams), functions)) |
|
161 |
|
162 fun eq_option eq (NONE, NONE) = true |
|
163 | eq_option eq (SOME x, SOME y) = eq (x, y) |
|
164 | eq_option eq _ = false |
|
165 |
|
166 fun eq_pred_data (PredData d1, PredData d2) = |
|
167 eq_list (Thm.eq_thm) (#intros d1, #intros d2) andalso |
|
168 eq_option (Thm.eq_thm) (#elim d1, #elim d2) andalso |
|
169 #nparams d1 = #nparams d2 |
|
170 |
|
171 structure PredData = TheoryDataFun |
117 ( |
172 ( |
118 type T = {names : string PredModetab.table, |
173 type T = pred_data Graph.T; |
119 modes : mode list Symtab.table, |
174 val empty = Graph.empty; |
120 function_defs : Thm.thm Symtab.table, |
|
121 function_intros : Thm.thm Symtab.table, |
|
122 function_elims : Thm.thm Symtab.table, |
|
123 intro_rules : Thm.thm list Symtab.table, |
|
124 elim_rules : Thm.thm Symtab.table, |
|
125 nparams : int Symtab.table |
|
126 }; (*FIXME: better group tables according to key*) |
|
127 (* names: map from inductive predicate and mode to function name (string). |
|
128 modes: map from inductive predicates to modes |
|
129 function_defs: map from function name to definition |
|
130 function_intros: map from function name to intro rule |
|
131 function_elims: map from function name to elim rule |
|
132 intro_rules: map from inductive predicate to alternative intro rules |
|
133 elim_rules: map from inductive predicate to alternative elimination rule |
|
134 nparams: map from const name to number of parameters (* assuming there exist intro and elimination rules *) |
|
135 *) |
|
136 val empty = {names = PredModetab.empty, |
|
137 modes = Symtab.empty, |
|
138 function_defs = Symtab.empty, |
|
139 function_intros = Symtab.empty, |
|
140 function_elims = Symtab.empty, |
|
141 intro_rules = Symtab.empty, |
|
142 elim_rules = Symtab.empty, |
|
143 nparams = Symtab.empty}; |
|
144 val copy = I; |
175 val copy = I; |
145 val extend = I; |
176 val extend = I; |
146 fun merge _ (r : T * T) = {names = PredModetab.merge (op =) (pairself #names r), |
177 fun merge _ = Graph.merge eq_pred_data; |
147 modes = Symtab.merge (op =) (pairself #modes r), |
|
148 function_defs = Symtab.merge Thm.eq_thm (pairself #function_defs r), |
|
149 function_intros = Symtab.merge Thm.eq_thm (pairself #function_intros r), |
|
150 function_elims = Symtab.merge Thm.eq_thm (pairself #function_elims r), |
|
151 intro_rules = Symtab.merge ((forall Thm.eq_thm) o (op ~~)) (pairself #intro_rules r), |
|
152 elim_rules = Symtab.merge Thm.eq_thm (pairself #elim_rules r), |
|
153 nparams = Symtab.merge (op =) (pairself #nparams r)}; |
|
154 ); |
178 ); |
155 |
179 |
156 fun map_names f thy = IndCodegenData.map |
180 (* queries *) |
157 (fn x => {names = f (#names x), modes = #modes x, function_defs = #function_defs x, |
181 |
158 function_intros = #function_intros x, function_elims = #function_elims x, |
182 val lookup_pred_data = try rep_pred_data oo Graph.get_node o PredData.get; |
159 intro_rules = #intro_rules x, elim_rules = #elim_rules x, |
183 |
160 nparams = #nparams x}) thy |
184 fun the_pred_data thy name = case lookup_pred_data thy name |
161 |
185 of NONE => error ("No such predicate " ^ quote name) |
162 fun map_modes f thy = IndCodegenData.map |
186 | SOME data => data; |
163 (fn x => {names = #names x, modes = f (#modes x), function_defs = #function_defs x, |
187 |
164 function_intros = #function_intros x, function_elims = #function_elims x, |
188 val is_pred = is_some oo lookup_pred_data |
165 intro_rules = #intro_rules x, elim_rules = #elim_rules x, |
189 |
166 nparams = #nparams x}) thy |
190 val all_preds_of = Graph.keys o PredData.get |
167 |
191 |
168 fun map_function_defs f thy = IndCodegenData.map |
192 val intros_of = #intros oo the_pred_data |
169 (fn x => {names = #names x, modes = #modes x, function_defs = f (#function_defs x), |
193 |
170 function_intros = #function_intros x, function_elims = #function_elims x, |
194 fun the_elim_of thy name = case #elim (the_pred_data thy name) |
171 intro_rules = #intro_rules x, elim_rules = #elim_rules x, |
195 of NONE => error ("No elimination rule for predicate " ^ quote name) |
172 nparams = #nparams x}) thy |
196 | SOME thm => thm |
173 |
197 |
174 fun map_function_elims f thy = IndCodegenData.map |
198 val has_elim = is_some o #elim oo the_pred_data; |
175 (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x, |
199 |
176 function_intros = #function_intros x, function_elims = f (#function_elims x), |
200 val nparams_of = #nparams oo the_pred_data |
177 intro_rules = #intro_rules x, elim_rules = #elim_rules x, |
201 |
178 nparams = #nparams x}) thy |
202 val modes_of = (map fst) o #functions oo the_pred_data |
179 |
203 |
180 fun map_function_intros f thy = IndCodegenData.map |
204 fun all_modes_of thy = map (fn name => (name, modes_of thy name)) (all_preds_of thy) |
181 (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x, |
205 |
182 function_intros = f (#function_intros x), function_elims = #function_elims x, |
206 val is_compiled = not o null o #functions oo the_pred_data |
183 intro_rules = #intro_rules x, elim_rules = #elim_rules x, |
207 |
184 nparams = #nparams x}) thy |
208 fun lookup_predfun_data thy name mode = |
185 |
209 Option.map rep_predfun_data (AList.lookup (op =) |
186 fun map_intro_rules f thy = IndCodegenData.map |
210 (#functions (the_pred_data thy name)) mode) |
187 (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x, |
211 |
188 function_intros = #function_intros x, function_elims = #function_elims x, |
212 fun the_predfun_data thy name mode = case lookup_predfun_data thy name mode |
189 intro_rules = f (#intro_rules x), elim_rules = #elim_rules x, |
213 of NONE => error ("No such mode" ^ string_of_mode mode) |
190 nparams = #nparams x}) thy |
214 | SOME data => data; |
|
215 |
|
216 val predfun_name_of = #name ooo the_predfun_data |
|
217 |
|
218 val predfun_definition_of = #definition ooo the_predfun_data |
|
219 |
|
220 val predfun_intro_of = #intro ooo the_predfun_data |
|
221 |
|
222 val predfun_elim_of = #elim ooo the_predfun_data |
|
223 |
|
224 |
|
225 (* replaces print_alternative_rules *) |
|
226 (* TODO: |
|
227 fun print_alternative_rules thy = let |
|
228 val d = IndCodegenData.get thy |
|
229 val preds = (Symtab.keys (#intro_rules d)) union (Symtab.keys (#elim_rules d)) |
|
230 val _ = tracing ("preds: " ^ (makestring preds)) |
|
231 fun print pred = let |
|
232 val _ = tracing ("predicate: " ^ pred) |
|
233 val _ = tracing ("introrules: ") |
|
234 val _ = fold (fn thm => fn u => tracing (makestring thm)) |
|
235 (rev (Symtab.lookup_list (#intro_rules d) pred)) () |
|
236 val _ = tracing ("casesrule: ") |
|
237 val _ = tracing (makestring (Symtab.lookup (#elim_rules d) pred)) |
|
238 in () end |
|
239 val _ = map print preds |
|
240 in thy end; |
|
241 *) |
|
242 |
|
243 (* updaters *) |
|
244 |
|
245 fun add_predfun name mode data = let |
|
246 val add = apsnd (cons (mode, mk_predfun_data data)) |
|
247 in PredData.map (Graph.map_node name (map_pred_data add)) end |
|
248 |
|
249 fun add_intro thm = let |
|
250 val (name, _) = dest_Const (fst (strip_intro_concl 0 (prop_of thm))) |
|
251 fun set (intros, elim, nparams) = (thm::intros, elim, nparams) |
|
252 in PredData.map (Graph.map_node name (map_pred_data (apfst set))) end |
|
253 |
|
254 fun set_elim thm = let |
|
255 val (name, _) = dest_Const (fst |
|
256 (strip_comb (HOLogic.dest_Trueprop (hd (prems_of thm))))) |
|
257 fun set (intros, _, nparams) = (intros, SOME thm, nparams) |
|
258 in PredData.map (Graph.map_node name (map_pred_data (apfst set))) end |
|
259 |
|
260 fun set_nparams name nparams = let |
|
261 fun set (intros, elim, _ ) = (intros, elim, nparams) |
|
262 in PredData.map (Graph.map_node name (map_pred_data (apfst set))) end |
|
263 |
|
264 fun register_predicate (intros, elim, nparams) = let |
|
265 val (name, _) = dest_Const (fst (strip_intro_concl nparams (prop_of (hd intros)))) |
|
266 fun set _ = (intros, SOME elim, nparams) |
|
267 in PredData.map (Graph.new_node (name, mk_pred_data ((intros, SOME elim, nparams), []))) end |
|
268 |
191 |
269 |
192 fun map_elim_rules f thy = IndCodegenData.map |
270 (* Mode analysis *) |
193 (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x, |
271 |
194 function_intros = #function_intros x, function_elims = #function_elims x, |
272 (*** check if a term contains only constructor functions ***) |
195 intro_rules = #intro_rules x, elim_rules = f (#elim_rules x), |
|
196 nparams = #nparams x}) thy |
|
197 |
|
198 fun map_nparams f thy = IndCodegenData.map |
|
199 (fn x => {names = #names x, modes = #modes x, function_defs = #function_defs x, |
|
200 function_intros = #function_intros x, function_elims = #function_elims x, |
|
201 intro_rules = #intro_rules x, elim_rules = #elim_rules x, |
|
202 nparams = f (#nparams x)}) thy |
|
203 |
|
204 (* removes first subgoal *) |
|
205 fun mycheat_tac thy i st = |
|
206 (Tactic.rtac (SkipProof.make_thm thy (Var (("A", 0), propT))) i) st |
|
207 |
|
208 (* Lightweight mode analysis **********************************************) |
|
209 |
|
210 (**************************************************************************) |
|
211 (* source code from old code generator ************************************) |
|
212 |
|
213 (**** check if a term contains only constructor functions ****) |
|
214 |
|
215 fun is_constrt thy = |
273 fun is_constrt thy = |
216 let |
274 let |
217 val cnstrs = flat (maps |
275 val cnstrs = flat (maps |
218 (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd) |
276 (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd) |
219 (Symtab.dest (DatatypePackage.get_datatypes thy))); |
277 (Symtab.dest (DatatypePackage.get_datatypes thy))); |
263 |
309 |
264 fun subsets i j = if i <= j then |
310 fun subsets i j = if i <= j then |
265 let val is = subsets (i+1) j |
311 let val is = subsets (i+1) j |
266 in merge (map (fn ks => i::ks) is) is end |
312 in merge (map (fn ks => i::ks) is) is end |
267 else [[]]; |
313 else [[]]; |
268 |
314 |
|
315 (* FIXME: should be in library - map_prod *) |
269 fun cprod ([], ys) = [] |
316 fun cprod ([], ys) = [] |
270 | cprod (x :: xs, ys) = map (pair x) ys @ cprod (xs, ys); |
317 | cprod (x :: xs, ys) = map (pair x) ys @ cprod (xs, ys); |
271 |
318 |
272 fun cprods xss = foldr (map op :: o cprod) [[]] xss; |
319 fun cprods xss = foldr (map op :: o cprod) [[]] xss; |
273 |
320 |
274 datatype hmode = Mode of mode * int list * hmode option list; (*FIXME don't understand |
321 datatype hmode = Mode of mode * int list * hmode option list; (*FIXME don't understand |
275 why there is another mode type!?*) |
322 why there is another mode type tmode !?*) |
276 |
323 |
277 fun modes_of_term modes t = |
324 fun string_of_term (t : term) = makestring t |
|
325 fun string_of_terms (ts : term list) = commas (map string_of_term ts) |
|
326 |
|
327 (*TODO: cleanup function and put together with modes_of_term *) |
|
328 fun modes_of_param default modes t = let |
|
329 val (vs, t') = strip_abs t |
|
330 val b = length vs |
|
331 fun mk_modes name args = Option.map (maps (fn (m as (iss, is)) => |
|
332 let |
|
333 val (args1, args2) = |
|
334 if length args < length iss then |
|
335 error ("Too few arguments for inductive predicate " ^ name) |
|
336 else chop (length iss) args; |
|
337 val k = length args2; |
|
338 val _ = Output.tracing ("args2:" ^ string_of_terms args2) |
|
339 val perm = map (fn i => (find_index_eq (Bound (b - i)) args2) + 1) |
|
340 (1 upto b) |
|
341 val _ = Output.tracing ("perm: " ^ (makestring perm)) |
|
342 val partial_mode = (1 upto k) \\ perm |
|
343 in |
|
344 if not (partial_mode subset is) then [] else |
|
345 let |
|
346 val is' = |
|
347 (fold_index (fn (i, j) => if j mem is then cons (i + 1) else I) perm []) |
|
348 |> fold (fn i => if i > k then cons (i - k + b) else I) is |
|
349 |
|
350 val res = map (fn x => Mode (m, is', x)) (cprods (map |
|
351 (fn (NONE, _) => [NONE] |
|
352 | (SOME js, arg) => map SOME (filter |
|
353 (fn Mode (_, js', _) => js=js') (modes_of_term modes arg))) |
|
354 (iss ~~ args1))) |
|
355 val _ = Output.tracing ("is = " ^ (makestring is)) |
|
356 val _ = Output.tracing ("is' = " ^ (makestring is')) |
|
357 in res end |
|
358 end)) (AList.lookup op = modes name) |
|
359 in case strip_comb t' of |
|
360 (Const (name, _), args) => the_default default (mk_modes name args) |
|
361 | (Var ((name, _), _), args) => the (mk_modes name args) |
|
362 | (Free (name, _), args) => the (mk_modes name args) |
|
363 | _ => default end |
|
364 |
|
365 and modes_of_term modes t = |
278 let |
366 let |
279 val ks = 1 upto length (binder_types (fastype_of t)); |
367 val ks = 1 upto length (binder_types (fastype_of t)); |
280 val default = [Mode (([], ks), ks, [])]; |
368 val default = [Mode (([], ks), ks, [])]; |
281 fun mk_modes name args = Option.map (maps (fn (m as (iss, is)) => |
369 fun mk_modes name args = Option.map (maps (fn (m as (iss, is)) => |
282 let |
370 let |
457 (Name.variant_list names (replicate (length Ts) "x") ~~ Ts) |
534 (Name.variant_list names (replicate (length Ts) "x") ~~ Ts) |
458 in |
535 in |
459 fold_rev lambda vs (f (list_comb (t, vs))) |
536 fold_rev lambda vs (f (list_comb (t, vs))) |
460 end; |
537 end; |
461 |
538 |
462 fun compile_param thy modes (NONE, t) = t |
539 |
463 | compile_param thy modes (m as SOME (Mode ((iss, is'), is, ms)), t) = let |
540 |
464 val (f, args) = strip_comb t |
541 fun compile_param_ext thy modes (NONE, t) = t |
465 val (params, args') = chop (length ms) args |
542 | compile_param_ext thy modes (m as SOME (Mode ((iss, is'), is, ms)), t) = |
466 val params' = map (compile_param thy modes) (ms ~~ params) |
543 let |
467 val f' = case f of |
544 val (vs, u) = strip_abs t |
|
545 val (ivs, ovs) = get_args is vs |
|
546 val _ = Output.tracing ("ivs = " ^ (makestring ivs)) |
|
547 val _ = Output.tracing ("ovs = " ^ (makestring ovs)) |
|
548 val (f, args) = strip_comb u |
|
549 val (params, args') = chop (length ms) args |
|
550 val (inargs, outargs) = get_args is' args' |
|
551 val b = length vs |
|
552 val perm = map (fn i => (find_index_eq (Bound (b - i)) args') + 1) (1 upto b) |
|
553 val _ = Output.tracing ("perm (compile) = " ^ (makestring perm)) |
|
554 val outp_perm = |
|
555 snd (get_args is perm) |
|
556 |> map (fn i => i - length (filter (fn x => x < i) is')) |
|
557 val _ = Output.tracing ("outp_perm = " ^ (makestring outp_perm)) |
|
558 val names = [] (* TODO *) |
|
559 val out_names = Name.variant_list names (replicate (length outargs) "x") |
|
560 val f' = case f of |
|
561 Const (name, T) => |
|
562 if AList.defined op = modes name then |
|
563 Const (predfun_name_of thy name (iss, is'), funT'_of (iss, is') T) |
|
564 else error "compile param: Not an inductive predicate with correct mode" |
|
565 | Free (name, T) => Free (name, funT_of T (SOME is')) |
|
566 val outTs = dest_tupleT (dest_pred_enumT (body_type (fastype_of f'))) |
|
567 val _ = Output.tracing ("outTs = " ^ (makestring outTs)) |
|
568 val out_vs = map Free (out_names ~~ outTs) |
|
569 val _ = Output.tracing "dipp" |
|
570 val params' = map (compile_param thy modes) (ms ~~ params) |
|
571 val f_app = list_comb (f', params' @ inargs) |
|
572 val single_t = (mk_single (mk_tuple (map (fn i => nth out_vs (i - 1)) outp_perm))) |
|
573 val match_t = compile_match thy [] [] out_vs single_t |
|
574 val _ = Output.tracing "dupp" |
|
575 in list_abs (ivs, |
|
576 mk_bind (f_app, match_t)) |
|
577 |> tap (fn r => Output.tracing ("compile_param: " ^ (Syntax.string_of_term_global thy r))) |
|
578 end |
|
579 | compile_param_ext _ _ _ = error "compile params" |
|
580 |
|
581 and compile_param thy modes (NONE, t) = t |
|
582 | compile_param thy modes (m as SOME (Mode ((iss, is'), is, ms)), t) = |
|
583 (case t of |
|
584 Abs _ => compile_param_ext thy modes (m, t) |
|
585 | _ => let |
|
586 val (f, args) = strip_comb t |
|
587 val (params, args') = chop (length ms) args |
|
588 val params' = map (compile_param thy modes) (ms ~~ params) |
|
589 val f' = case f of |
468 Const (name, T) => |
590 Const (name, T) => |
469 if AList.defined op = modes name then |
591 if AList.defined op = modes name then |
470 Const (modename_of thy name (iss, is'), funT'_of (iss, is') T) |
592 Const (predfun_name_of thy name (iss, is'), funT'_of (iss, is') T) |
471 else error "compile param: Not an inductive predicate with correct mode" |
593 else error "compile param: Not an inductive predicate with correct mode" |
472 | Free (name, T) => Free (name, funT_of T (SOME is')) |
594 | Free (name, T) => Free (name, funT_of T (SOME is')) |
473 in list_comb (f', params' @ args') end |
595 in list_comb (f', params' @ args') end) |
474 | compile_param _ _ _ = error "compile params" |
596 | compile_param _ _ _ = error "compile params" |
475 |
597 |
|
598 |
476 fun compile_expr thy modes (SOME (Mode (mode, is, ms)), t) = |
599 fun compile_expr thy modes (SOME (Mode (mode, is, ms)), t) = |
477 (case strip_comb t of |
600 (case strip_comb t of |
478 (Const (name, T), params) => |
601 (Const (name, T), params) => |
479 if AList.defined op = modes name then |
602 if AList.defined op = modes name then |
480 let |
603 let |
481 val (Ts, Us) = get_args is |
604 val (Ts, Us) = get_args is |
482 (curry Library.drop (length ms) (fst (strip_type T))) |
605 (curry Library.drop (length ms) (fst (strip_type T))) |
483 val params' = map (compile_param thy modes) (ms ~~ params) |
606 val params' = map (compile_param thy modes) (ms ~~ params) |
484 val mode_id = modename_of thy name mode |
607 in list_comb (Const (predfun_name_of thy name mode, ((map fastype_of params') @ Ts) ---> |
485 in list_comb (Const (mode_id, ((map fastype_of params') @ Ts) ---> |
|
486 mk_pred_enumT (mk_tupleT Us)), params') |
608 mk_pred_enumT (mk_tupleT Us)), params') |
487 end |
609 end |
488 else error "not a valid inductive expression" |
610 else error "not a valid inductive expression" |
489 | (Free (name, T), args) => |
611 | (Free (name, T), args) => |
490 (*if name mem param_vs then *) |
612 (*if name mem param_vs then *) |
497 fun compile_clause thy all_vs param_vs modes (iss, is) (ts, ps) inp = |
619 fun compile_clause thy all_vs param_vs modes (iss, is) (ts, ps) inp = |
498 let |
620 let |
499 val modes' = modes @ List.mapPartial |
621 val modes' = modes @ List.mapPartial |
500 (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)])) |
622 (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)])) |
501 (param_vs ~~ iss); |
623 (param_vs ~~ iss); |
502 fun check_constrt ((names, eqs), t) = |
624 fun check_constrt t (names, eqs) = |
503 if is_constrt thy t then ((names, eqs), t) else |
625 if is_constrt thy t then (t, (names, eqs)) else |
504 let |
626 let |
505 val s = Name.variant names "x"; |
627 val s = Name.variant names "x"; |
506 val v = Free (s, fastype_of t) |
628 val v = Free (s, fastype_of t) |
507 in ((s::names, HOLogic.mk_eq (v, t)::eqs), v) end; |
629 in (v, (s::names, HOLogic.mk_eq (v, t)::eqs)) end; |
508 |
630 |
509 val (in_ts, out_ts) = get_args is ts; |
631 val (in_ts, out_ts) = get_args is ts; |
510 val ((all_vs', eqs), in_ts') = |
632 val (in_ts', (all_vs', eqs)) = |
511 (*FIXME*) Library.foldl_map check_constrt ((all_vs, []), in_ts); |
633 fold_map check_constrt in_ts (all_vs, []); |
512 |
634 |
513 fun compile_prems out_ts' vs names [] = |
635 fun compile_prems out_ts' vs names [] = |
514 let |
636 let |
515 val ((names', eqs'), out_ts'') = |
637 val (out_ts'', (names', eqs')) = |
516 (*FIXME*) Library.foldl_map check_constrt ((names, []), out_ts'); |
638 fold_map check_constrt out_ts' (names, []); |
517 val (nvs, out_ts''') = (*FIXME*) Library.foldl_map distinct_v |
639 val (out_ts''', (names'', constr_vs)) = fold_map distinct_v |
518 ((names', map (rpair []) vs), out_ts''); |
640 out_ts'' (names', map (rpair []) vs); |
519 in |
641 in |
520 compile_match thy (snd nvs) (eqs @ eqs') out_ts''' |
642 compile_match thy constr_vs (eqs @ eqs') out_ts''' |
521 (mk_single (mk_tuple out_ts)) |
643 (mk_single (mk_tuple out_ts)) |
522 end |
644 end |
523 | compile_prems out_ts vs names ps = |
645 | compile_prems out_ts vs names ps = |
524 let |
646 let |
525 val vs' = distinct (op =) (flat (vs :: map term_vs out_ts)); |
647 val vs' = distinct (op =) (flat (vs :: map term_vs out_ts)); |
526 val SOME (p, mode as SOME (Mode (_, js, _))) = |
648 val SOME (p, mode as SOME (Mode (_, js, _))) = |
527 select_mode_prem thy modes' vs' ps |
649 select_mode_prem thy modes' vs' ps |
528 val ps' = filter_out (equal p) ps |
650 val ps' = filter_out (equal p) ps |
529 val ((names', eqs), out_ts') = |
651 val (out_ts', (names', eqs)) = |
530 (*FIXME*) Library.foldl_map check_constrt ((names, []), out_ts) |
652 fold_map check_constrt out_ts (names, []) |
531 val (nvs, out_ts'') = (*FIXME*) Library.foldl_map distinct_v |
653 val (out_ts'', (names'', constr_vs')) = fold_map distinct_v |
532 ((names', map (rpair []) vs), out_ts') |
654 out_ts' ((names', map (rpair []) vs)) |
533 val (compiled_clause, rest) = case p of |
655 val (compiled_clause, rest) = case p of |
534 Prem (us, t) => |
656 Prem (us, t) => |
535 let |
657 let |
536 val (in_ts, out_ts''') = get_args js us; |
658 val (in_ts, out_ts''') = get_args js us; |
537 val u = list_comb (compile_expr thy modes (mode, t), in_ts) |
659 val u = list_comb (compile_expr thy modes (mode, t), in_ts) |
538 val rest = compile_prems out_ts''' vs' (fst nvs) ps' |
660 val rest = compile_prems out_ts''' vs' names'' ps' |
539 in |
661 in |
540 (u, rest) |
662 (u, rest) |
541 end |
663 end |
542 | Negprem (us, t) => |
664 | Negprem (us, t) => |
543 let |
665 let |
544 val (in_ts, out_ts''') = get_args js us |
666 val (in_ts, out_ts''') = get_args js us |
545 val u = list_comb (compile_expr thy modes (mode, t), in_ts) |
667 val u = list_comb (compile_expr thy modes (mode, t), in_ts) |
546 val rest = compile_prems out_ts''' vs' (fst nvs) ps' |
668 val rest = compile_prems out_ts''' vs' names'' ps' |
547 in |
669 in |
548 (mk_not_pred u, rest) |
670 (mk_not_pred u, rest) |
549 end |
671 end |
550 | Sidecond t => |
672 | Sidecond t => |
551 let |
673 let |
552 val rest = compile_prems [] vs' (fst nvs) ps'; |
674 val rest = compile_prems [] vs' names'' ps'; |
553 in |
675 in |
554 (mk_if_predenum t, rest) |
676 (mk_if_predenum t, rest) |
555 end |
677 end |
556 in |
678 in |
557 compile_match thy (snd nvs) eqs out_ts'' |
679 compile_match thy constr_vs' eqs out_ts'' |
558 (mk_bind (compiled_clause, rest)) |
680 (mk_bind (compiled_clause, rest)) |
559 end |
681 end |
560 val prem_t = compile_prems in_ts' param_vs all_vs' ps; |
682 val prem_t = compile_prems in_ts' param_vs all_vs' ps; |
561 in |
683 in |
562 mk_bind (mk_single inp, prem_t) |
684 mk_bind (mk_single inp, prem_t) |
692 val predterm = mk_Enum (mk_split_lambda xouts (list_comb (Const (name, T), xparams' @ xargs))) |
798 val predterm = mk_Enum (mk_split_lambda xouts (list_comb (Const (name, T), xparams' @ xargs))) |
693 val funT = (Ts1' @ Us1) ---> (mk_pred_enumT (mk_tupleT Us2)) |
799 val funT = (Ts1' @ Us1) ---> (mk_pred_enumT (mk_tupleT Us2)) |
694 val mode_id = Sign.full_bname thy (Long_Name.base_name mode_id) |
800 val mode_id = Sign.full_bname thy (Long_Name.base_name mode_id) |
695 val lhs = list_comb (Const (mode_id, funT), xparams @ xins) |
801 val lhs = list_comb (Const (mode_id, funT), xparams @ xins) |
696 val def = Logic.mk_equals (lhs, predterm) |
802 val def = Logic.mk_equals (lhs, predterm) |
697 val ([defthm], thy') = thy |> |
803 val ([definition], thy') = thy |> |
698 Sign.add_consts_i [(Binding.name (Long_Name.base_name mode_id), funT, NoSyn)] |> |
804 Sign.add_consts_i [(Binding.name (Long_Name.base_name mode_id), funT, NoSyn)] |> |
699 PureThy.add_defs false [((Binding.name (Long_Name.base_name mode_id ^ "_def"), def), [])] |
805 PureThy.add_defs false [((Binding.name (Long_Name.base_name mode_id ^ "_def"), def), [])] |
700 in thy' |> map_names (PredModetab.update_new ((name, mode), mode_id)) |
806 val (intro, elim) = create_intro_elim_rule nparams mode definition mode_id funT (Const (name, T)) thy' |
701 |> map_function_defs (Symtab.update_new (mode_id, defthm)) |
807 in thy' |> add_predfun name mode (mode_id, definition, intro, elim) |
702 |> create_intro_rule nparams mode defthm mode_id funT (Const (name, T)) |
808 |> PureThy.store_thm (Binding.name (Long_Name.base_name mode_id ^ "I"), intro) |> snd |
|
809 |> PureThy.store_thm (Binding.name (Long_Name.base_name mode_id ^ "E"), elim) |> snd |
703 end; |
810 end; |
704 in |
811 in |
705 fold create_definition modes thy |
812 fold create_definition modes thy |
706 end; |
813 end; |
707 |
814 |
708 (**************************************************************************************) |
815 (**************************************************************************************) |
709 (* Proving equivalence of term *) |
816 (* Proving equivalence of term *) |
710 |
|
711 |
|
712 fun intro_rule thy pred mode = modename_of thy pred mode |
|
713 |> Symtab.lookup (#function_intros (IndCodegenData.get thy)) |> the |
|
714 |
|
715 fun elim_rule thy pred mode = modename_of thy pred mode |
|
716 |> Symtab.lookup (#function_elims (IndCodegenData.get thy)) |> the |
|
717 |
|
718 fun pred_intros thy predname = let |
|
719 fun is_intro_of pred intro = let |
|
720 val const = fst (strip_comb (HOLogic.dest_Trueprop (concl_of intro))) |
|
721 in (fst (dest_Const const) = pred) end; |
|
722 val d = IndCodegenData.get thy |
|
723 in |
|
724 if (Symtab.defined (#intro_rules d) predname) then |
|
725 rev (Symtab.lookup_list (#intro_rules d) predname) |
|
726 else |
|
727 InductivePackage.the_inductive (ProofContext.init thy) predname |
|
728 |> snd |> #intrs |> filter (is_intro_of predname) |
|
729 end |
|
730 |
|
731 fun function_definition thy pred mode = |
|
732 modename_of thy pred mode |> Symtab.lookup (#function_defs (IndCodegenData.get thy)) |> the |
|
733 |
817 |
734 fun is_Type (Type _) = true |
818 fun is_Type (Type _) = true |
735 | is_Type _ = false |
819 | is_Type _ = false |
736 |
820 |
737 fun imp_prems_conv cv ct = |
821 fun imp_prems_conv cv ct = |
965 |
1051 |
966 fun select_sup 1 1 = [] |
1052 fun select_sup 1 1 = [] |
967 | select_sup _ 1 = [rtac @{thm supI1}] |
1053 | select_sup _ 1 = [rtac @{thm supI1}] |
968 | select_sup n i = (rtac @{thm supI2})::(select_sup (n - 1) (i - 1)); |
1054 | select_sup n i = (rtac @{thm supI2})::(select_sup (n - 1) (i - 1)); |
969 |
1055 |
970 (* FIXME: This function relies on the derivation of an induction rule *) |
|
971 fun get_nparams thy s = let |
|
972 val _ = tracing ("get_nparams: " ^ s) |
|
973 in |
|
974 if Symtab.defined (#nparams (IndCodegenData.get thy)) s then |
|
975 the (Symtab.lookup (#nparams (IndCodegenData.get thy)) s) |
|
976 else |
|
977 case try (InductivePackage.the_inductive (ProofContext.init thy)) s of |
|
978 SOME info => info |> snd |> #raw_induct |> Thm.unvarify |
|
979 |> InductivePackage.params_of |> length |
|
980 | NONE => 0 (* default value *) |
|
981 end |
|
982 |
|
983 val ind_set_codegen_preproc = InductiveSetPackage.codegen_preproc; |
|
984 |
|
985 fun pred_elim thy predname = |
|
986 if (Symtab.defined (#elim_rules (IndCodegenData.get thy)) predname) then |
|
987 the (Symtab.lookup (#elim_rules (IndCodegenData.get thy)) predname) |
|
988 else |
|
989 (let |
|
990 val ind_result = InductivePackage.the_inductive (ProofContext.init thy) predname |
|
991 val index = find_index (fn s => s = predname) (#names (fst ind_result)) |
|
992 in nth (#elims (snd ind_result)) index end) |
|
993 |
|
994 fun prove_one_direction thy all_vs param_vs modes clauses ((pred, T), mode) = let |
1056 fun prove_one_direction thy all_vs param_vs modes clauses ((pred, T), mode) = let |
995 val elim_rule = the (Symtab.lookup (#function_elims (IndCodegenData.get thy)) (modename_of thy pred mode)) |
|
996 (* val ind_result = InductivePackage.the_inductive (ProofContext.init thy) pred |
1057 (* val ind_result = InductivePackage.the_inductive (ProofContext.init thy) pred |
997 val index = find_index (fn s => s = pred) (#names (fst ind_result)) |
1058 val index = find_index (fn s => s = pred) (#names (fst ind_result)) |
998 val (_, T) = dest_Const (nth (#preds (snd ind_result)) index) *) |
1059 val (_, T) = dest_Const (nth (#preds (snd ind_result)) index) *) |
999 val nargs = length (binder_types T) - get_nparams thy pred |
1060 val nargs = length (binder_types T) - nparams_of thy pred |
1000 val pred_case_rule = singleton (ind_set_codegen_preproc thy) |
1061 val pred_case_rule = singleton (ind_set_codegen_preproc thy) |
1001 (preprocess_elim thy nargs (pred_elim thy pred)) |
1062 (preprocess_elim thy nargs (the_elim_of thy pred)) |
1002 (* FIXME preprocessor |> Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}])*) |
1063 (* FIXME preprocessor |> Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}])*) |
1003 val _ = tracing ("pred_case_rule " ^ (makestring pred_case_rule)) |
1064 val _ = tracing ("pred_case_rule " ^ (makestring pred_case_rule)) |
1004 in |
1065 in |
1005 REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"})) |
1066 REPEAT_DETERM (CHANGED (rewtac @{thm "split_paired_all"})) |
1006 THEN etac elim_rule 1 |
1067 THEN etac (predfun_elim_of thy pred mode) 1 |
1007 THEN etac pred_case_rule 1 |
1068 THEN etac pred_case_rule 1 |
1008 THEN (EVERY (map |
1069 THEN (EVERY (map |
1009 (fn i => EVERY' (select_sup (length clauses) i) i) |
1070 (fn i => EVERY' (select_sup (length clauses) i) i) |
1010 (1 upto (length clauses)))) |
1071 (1 upto (length clauses)))) |
1011 THEN (EVERY (map (prove_clause thy nargs all_vs param_vs modes mode) clauses)) |
1072 THEN (EVERY (map (prove_clause thy nargs all_vs param_vs modes mode) clauses)) |
1200 end; |
1262 end; |
1201 |
1263 |
1202 fun prove_preds thy all_vs param_vs modes clauses pmts = |
1264 fun prove_preds thy all_vs param_vs modes clauses pmts = |
1203 map (prove_pred thy all_vs param_vs modes clauses) pmts |
1265 map (prove_pred thy all_vs param_vs modes clauses) pmts |
1204 |
1266 |
1205 (* look for other place where this functionality was used before *) |
|
1206 fun strip_intro_concl intro nparams = let |
|
1207 val _ $ u = Logic.strip_imp_concl intro |
|
1208 val (pred, all_args) = strip_comb u |
|
1209 val (params, args) = chop nparams all_args |
|
1210 in (pred, (params, args)) end |
|
1211 |
|
1212 (* setup for alternative introduction and elimination rules *) |
|
1213 |
|
1214 fun add_intro_thm thm thy = let |
|
1215 val (pred, _) = dest_Const (fst (strip_intro_concl (prop_of thm) 0)) |
|
1216 in map_intro_rules (Symtab.insert_list Thm.eq_thm (pred, thm)) thy end |
|
1217 |
|
1218 fun add_elim_thm thm thy = let |
|
1219 val (pred, _) = dest_Const (fst |
|
1220 (strip_comb (HOLogic.dest_Trueprop (hd (prems_of thm))))) |
|
1221 in map_elim_rules (Symtab.update (pred, thm)) thy end |
|
1222 |
|
1223 |
|
1224 (* special case: inductive predicate with no clauses *) |
1267 (* special case: inductive predicate with no clauses *) |
1225 fun noclause (predname, T) thy = let |
1268 fun noclause (predname, T) thy = let |
1226 val Ts = binder_types T |
1269 val Ts = binder_types T |
1227 val names = Name.variant_list [] |
1270 val names = Name.variant_list [] |
1228 (map (fn i => "x" ^ (string_of_int i)) (1 upto (length Ts))) |
1271 (map (fn i => "x" ^ (string_of_int i)) (1 upto (length Ts))) |
1229 val vs = map2 (curry Free) names Ts |
1272 val vs = map2 (curry Free) names Ts |
1230 val clausehd = HOLogic.mk_Trueprop (list_comb(Const (predname, T), vs)) |
1273 val clausehd = HOLogic.mk_Trueprop (list_comb(Const (predname, T), vs)) |
1231 val intro_t = Logic.mk_implies (@{prop False}, clausehd) |
1274 val intro_t = Logic.mk_implies (@{prop False}, clausehd) |
1232 val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT)) |
1275 val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT)) |
1233 val elim_t = Logic.list_implies ([clausehd, Logic.mk_implies (@{prop False}, P)], P) |
1276 val elim_t = Logic.list_implies ([clausehd, Logic.mk_implies (@{prop False}, P)], P) |
1234 val intro_thm = Goal.prove (ProofContext.init thy) names [] intro_t |
1277 val intro = Goal.prove (ProofContext.init thy) names [] intro_t |
1235 (fn {...} => etac @{thm FalseE} 1) |
1278 (fn {...} => etac @{thm FalseE} 1) |
1236 val elim_thm = Goal.prove (ProofContext.init thy) ("P" :: names) [] elim_t |
1279 val elim = Goal.prove (ProofContext.init thy) ("P" :: names) [] elim_t |
1237 (fn {...} => etac (pred_elim thy predname) 1) |
1280 (fn {...} => etac (the_elim_of thy predname) 1) |
1238 in |
1281 in |
1239 add_intro_thm intro_thm thy |
1282 add_intro intro thy |
1240 |> add_elim_thm elim_thm |
1283 |> set_elim elim |
1241 end |
1284 end |
1242 |
1285 |
1243 (*************************************************************************************) |
1286 fun prepare_intrs thy prednames = |
1244 (* main function *********************************************************************) |
1287 let |
1245 (*************************************************************************************) |
1288 val intrs = map (preprocess_intro thy) (maps (intros_of thy) prednames) |
1246 |
1289 |> ind_set_codegen_preproc thy (*FIXME preprocessor |
1247 fun prove_equation ind_name mode thy = |
1290 |> map (Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}]))*) |
1248 let |
1291 |> map (Logic.unvarify o prop_of) |
1249 val _ = tracing ("starting prove_equation' with " ^ ind_name) |
1292 val nparams = nparams_of thy (hd prednames) |
1250 val (prednames, preds) = |
1293 val preds = distinct (op =) (map (dest_Const o fst o (strip_intro_concl nparams)) intrs) |
1251 case (try (InductivePackage.the_inductive (ProofContext.init thy)) ind_name) of |
1294 val extra_modes = all_modes_of thy |> filter_out (fn (name, _) => member (op =) prednames name) |
1252 SOME info => let val preds = info |> snd |> #preds |
1295 val _ $ u = Logic.strip_imp_concl (hd intrs); |
1253 in (map (fst o dest_Const) preds, map ((apsnd Logic.unvarifyT) o dest_Const) preds) end |
1296 val params = List.take (snd (strip_comb u), nparams); |
1254 | NONE => let |
1297 val param_vs = maps term_vs params |
1255 val pred = Symtab.lookup (#intro_rules (IndCodegenData.get thy)) ind_name |
1298 val all_vs = terms_vs intrs |
1256 |> the |> hd |> prop_of |
1299 fun dest_prem t = |
1257 |> Logic.strip_imp_concl |> HOLogic.dest_Trueprop |> strip_comb |
|
1258 |> fst |> dest_Const |> apsnd Logic.unvarifyT |
|
1259 in ([ind_name], [pred]) end |
|
1260 val thy' = fold (fn pred as (predname, T) => fn thy => |
|
1261 if null (pred_intros thy predname) then noclause pred thy else thy) preds thy |
|
1262 val intrs = map (preprocess_intro thy') (maps (pred_intros thy') prednames) |
|
1263 |> ind_set_codegen_preproc thy' (*FIXME preprocessor |
|
1264 |> map (Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}]))*) |
|
1265 |> map (Logic.unvarify o prop_of) |
|
1266 val _ = tracing ("preprocessed intro rules:" ^ (makestring (map (cterm_of thy') intrs))) |
|
1267 val name_of_calls = get_name_of_ind_calls_of_clauses thy' prednames intrs |
|
1268 val _ = tracing ("calling preds: " ^ makestring name_of_calls) |
|
1269 val _ = tracing "starting recursive compilations" |
|
1270 fun rec_call name thy = |
|
1271 (*FIXME use member instead of infix mem*) |
|
1272 if not (name mem (Symtab.keys (#modes (IndCodegenData.get thy)))) then |
|
1273 prove_equation name NONE thy else thy |
|
1274 val thy'' = fold rec_call name_of_calls thy' |
|
1275 val _ = tracing "returning from recursive calls" |
|
1276 val _ = tracing "starting mode inference" |
|
1277 val extra_modes = Symtab.dest (#modes (IndCodegenData.get thy'')) |
|
1278 val nparams = get_nparams thy'' ind_name |
|
1279 val _ $ u = Logic.strip_imp_concl (hd intrs); |
|
1280 val params = List.take (snd (strip_comb u), nparams); |
|
1281 val param_vs = maps term_vs params |
|
1282 val all_vs = terms_vs intrs |
|
1283 fun dest_prem t = |
|
1284 (case strip_comb t of |
1300 (case strip_comb t of |
1285 (v as Free _, ts) => if v mem params then Prem (ts, v) else Sidecond t |
1301 (v as Free _, ts) => if v mem params then Prem (ts, v) else Sidecond t |
1286 | (c as Const (@{const_name Not}, _), [t]) => (case dest_prem t of |
1302 | (c as Const (@{const_name Not}, _), [t]) => (case dest_prem t of |
1287 Prem (ts, t) => Negprem (ts, t) |
1303 Prem (ts, t) => Negprem (ts, t) |
1288 | Negprem _ => error ("Double negation not allowed in premise: " ^ (makestring (c $ t))) |
1304 | Negprem _ => error ("Double negation not allowed in premise: " ^ (makestring (c $ t))) |
1289 | Sidecond t => Sidecond (c $ t)) |
1305 | Sidecond t => Sidecond (c $ t)) |
1290 | (c as Const (s, _), ts) => |
1306 | (c as Const (s, _), ts) => |
1291 if is_ind_pred thy'' s then |
1307 if is_pred thy s then |
1292 let val (ts1, ts2) = chop (get_nparams thy'' s) ts |
1308 let val (ts1, ts2) = chop (nparams_of thy s) ts |
1293 in Prem (ts2, list_comb (c, ts1)) end |
1309 in Prem (ts2, list_comb (c, ts1)) end |
1294 else Sidecond t |
1310 else Sidecond t |
1295 | _ => Sidecond t) |
1311 | _ => Sidecond t) |
1296 fun add_clause intr (clauses, arities) = |
1312 fun add_clause intr (clauses, arities) = |
1297 let |
1313 let |
1298 val _ $ t = Logic.strip_imp_concl intr; |
1314 val _ $ t = Logic.strip_imp_concl intr; |
1299 val (Const (name, T), ts) = strip_comb t; |
1315 val (Const (name, T), ts) = strip_comb t; |
1300 val (ts1, ts2) = chop nparams ts; |
1316 val (ts1, ts2) = chop nparams ts; |
1301 val prems = map (dest_prem o HOLogic.dest_Trueprop) (Logic.strip_imp_prems intr); |
1317 val prems = map (dest_prem o HOLogic.dest_Trueprop) (Logic.strip_imp_prems intr); |
1302 val (Ts, Us) = chop nparams (binder_types T) |
1318 val (Ts, Us) = chop nparams (binder_types T) |
1303 in |
1319 in |
1304 (AList.update op = (name, these (AList.lookup op = clauses name) @ |
1320 (AList.update op = (name, these (AList.lookup op = clauses name) @ |
1305 [(ts2, prems)]) clauses, |
1321 [(ts2, prems)]) clauses, |
1306 AList.update op = (name, (map (fn U => (case strip_type U of |
1322 AList.update op = (name, (map (fn U => (case strip_type U of |
1307 (Rs as _ :: _, Type ("bool", [])) => SOME (length Rs) |
1323 (Rs as _ :: _, Type ("bool", [])) => SOME (length Rs) |
1308 | _ => NONE)) Ts, |
1324 | _ => NONE)) Ts, |
1309 length Us)) arities) |
1325 length Us)) arities) |
1310 end; |
1326 end; |
1311 val (clauses, arities) = fold add_clause intrs ([], []); |
1327 val (clauses, arities) = fold add_clause intrs ([], []); |
1312 val modes = infer_modes thy'' extra_modes arities param_vs clauses |
1328 in (preds, nparams, all_vs, param_vs, extra_modes, clauses, arities) end; |
1313 val _ = print_arities arities; |
1329 |
1314 val _ = print_modes modes; |
1330 fun arrange kvs = |
1315 val modes = if (is_some mode) then AList.update (op =) (ind_name, [the mode]) modes else modes |
1331 let |
1316 val _ = print_modes modes |
1332 fun add (key, value) table = |
1317 val thy''' = fold (create_definitions preds nparams) modes thy'' |
1333 AList.update op = (key, these (AList.lookup op = table key) @ [value]) table |
1318 |> map_modes (fold Symtab.update_new modes) |
1334 in fold add kvs [] end; |
|
1335 |
|
1336 (* main function *) |
|
1337 |
|
1338 fun add_equations_of prednames thy = |
|
1339 let |
|
1340 val _ = tracing ("starting add_equations with " ^ commas prednames ^ "...") |
|
1341 (* null clause handling *) |
|
1342 (* |
|
1343 val thy' = fold (fn pred as (predname, T) => fn thy => |
|
1344 if null (intros_of thy predname) then noclause pred thy else thy) preds thy |
|
1345 *) |
|
1346 val (preds, nparams, all_vs, param_vs, extra_modes, clauses, arities) = |
|
1347 prepare_intrs thy prednames |
|
1348 val _ = tracing "Infering modes..." |
|
1349 val modes = infer_modes thy extra_modes arities param_vs clauses |
|
1350 val _ = tracing "Defining executable functions..." |
|
1351 val thy' = fold (create_definitions preds nparams) modes thy |
1319 val clauses' = map (fn (s, cls) => (s, (the (AList.lookup (op =) preds s), cls))) clauses |
1352 val clauses' = map (fn (s, cls) => (s, (the (AList.lookup (op =) preds s), cls))) clauses |
1320 val _ = tracing "compiling predicates..." |
1353 val _ = tracing "Compiling equations..." |
1321 val ts = compile_preds thy''' all_vs param_vs (extra_modes @ modes) clauses' |
1354 val ts = compile_preds thy' all_vs param_vs (extra_modes @ modes) clauses' |
1322 val _ = tracing "returned term from compile_preds" |
1355 val pred_mode = |
1323 val pred_mode = maps (fn (s, (T, _)) => map (pair (s, T)) ((the o AList.lookup (op =) modes) s)) clauses' |
1356 maps (fn (s, (T, _)) => map (pair (s, T)) ((the o AList.lookup (op =) modes) s)) clauses' |
1324 val _ = tracing "starting proof" |
1357 val _ = tracing "Proving equations..." |
1325 val result_thms = prove_preds thy''' all_vs param_vs (extra_modes @ modes) clauses (pred_mode ~~ (flat ts)) |
1358 val result_thms = |
1326 val (_, thy'''') = yield_singleton PureThy.add_thmss |
1359 prove_preds thy' all_vs param_vs (extra_modes @ modes) clauses (pred_mode ~~ (flat ts)) |
1327 ((Binding.qualify true (Long_Name.base_name ind_name) (Binding.name "equation"), result_thms), |
1360 val thy'' = fold (fn (name, result_thms) => fn thy => snd (PureThy.add_thmss |
1328 [Attrib.attribute_i thy''' Code.add_default_eqn_attrib]) thy''' |
1361 [((Binding.qualify true (Long_Name.base_name name) (Binding.name "equation"), result_thms), |
|
1362 [Attrib.attribute_i thy Code.add_default_eqn_attrib])] thy)) |
|
1363 (arrange ((map (fn ((name, _), _) => name) pred_mode) ~~ result_thms)) thy' |
1329 in |
1364 in |
1330 thy'''' |
1365 thy'' |
1331 end |
1366 end |
1332 |
|
1333 fun set_nparams (pred, nparams) thy = map_nparams (Symtab.update (pred, nparams)) thy |
|
1334 |
|
1335 fun print_alternative_rules thy = let |
|
1336 val d = IndCodegenData.get thy |
|
1337 val preds = (Symtab.keys (#intro_rules d)) union (Symtab.keys (#elim_rules d)) |
|
1338 val _ = tracing ("preds: " ^ (makestring preds)) |
|
1339 fun print pred = let |
|
1340 val _ = tracing ("predicate: " ^ pred) |
|
1341 val _ = tracing ("introrules: ") |
|
1342 val _ = fold (fn thm => fn u => tracing (makestring thm)) |
|
1343 (rev (Symtab.lookup_list (#intro_rules d) pred)) () |
|
1344 val _ = tracing ("casesrule: ") |
|
1345 val _ = tracing (makestring (Symtab.lookup (#elim_rules d) pred)) |
|
1346 in () end |
|
1347 val _ = map print preds |
|
1348 in thy end; |
|
1349 |
|
1350 |
1367 |
1351 (* generation of case rules from user-given introduction rules *) |
1368 (* generation of case rules from user-given introduction rules *) |
1352 |
1369 |
1353 fun mk_casesrule introrules nparams ctxt = |
1370 fun mk_casesrule introrules nparams ctxt = |
1354 let |
1371 let |
1355 val intros = map prop_of introrules |
1372 val intros = map prop_of introrules |
1356 val (pred, (params, args)) = strip_intro_concl (hd intros) nparams |
1373 val (pred, (params, args)) = strip_intro_concl nparams (hd intros) |
1357 val ([propname], ctxt1) = Variable.variant_fixes ["thesis"] ctxt |
1374 val ([propname], ctxt1) = Variable.variant_fixes ["thesis"] ctxt |
1358 val prop = HOLogic.mk_Trueprop (Free (propname, HOLogic.boolT)) |
1375 val prop = HOLogic.mk_Trueprop (Free (propname, HOLogic.boolT)) |
1359 val (argnames, ctxt2) = Variable.variant_fixes |
1376 val (argnames, ctxt2) = Variable.variant_fixes |
1360 (map (fn i => "a" ^ string_of_int i) (1 upto (length args))) ctxt1 |
1377 (map (fn i => "a" ^ string_of_int i) (1 upto (length args))) ctxt1 |
1361 val argvs = map Free (argnames ~~ (map fastype_of args)) |
1378 val argvs = map2 (curry Free) argnames (map fastype_of args) |
1362 (*FIXME map2*) |
|
1363 fun mk_case intro = let |
1379 fun mk_case intro = let |
1364 val (_, (_, args)) = strip_intro_concl intro nparams |
1380 val (_, (_, args)) = strip_intro_concl nparams intro |
1365 val prems = Logic.strip_imp_prems intro |
1381 val prems = Logic.strip_imp_prems intro |
1366 val eqprems = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (argvs ~~ args) |
1382 val eqprems = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (argvs ~~ args) |
1367 val frees = (fold o fold_aterms) |
1383 val frees = (fold o fold_aterms) |
1368 (fn t as Free _ => |
1384 (fn t as Free _ => |
1369 if member (op aconv) params t then I else insert (op aconv) t |
1385 if member (op aconv) params t then I else insert (op aconv) t |
1374 val (_, ctxt3) = ProofContext.add_assms_i Assumption.assume_export |
1390 val (_, ctxt3) = ProofContext.add_assms_i Assumption.assume_export |
1375 [((Binding.name AutoBind.assmsN, []), map (fn t => (t, [])) (assm :: cases))] |
1391 [((Binding.name AutoBind.assmsN, []), map (fn t => (t, [])) (assm :: cases))] |
1376 ctxt2 |
1392 ctxt2 |
1377 in (pred, prop, ctxt3) end; |
1393 in (pred, prop, ctxt3) end; |
1378 |
1394 |
|
1395 (* code dependency graph *) |
|
1396 |
|
1397 fun fetch_pred_data thy name = |
|
1398 case try (InductivePackage.the_inductive (ProofContext.init thy)) name of |
|
1399 SOME (info as (_, result)) => |
|
1400 let |
|
1401 fun is_intro_of intro = |
|
1402 let |
|
1403 val (const, _) = strip_comb (HOLogic.dest_Trueprop (concl_of intro)) |
|
1404 in (fst (dest_Const const) = name) end; |
|
1405 val intros = filter is_intro_of (#intrs result) |
|
1406 val elim = nth (#elims result) (find_index (fn s => s = name) (#names (fst info))) |
|
1407 val nparams = length (InductivePackage.params_of (#raw_induct result)) |
|
1408 in mk_pred_data ((intros, SOME elim, nparams), []) end |
|
1409 | NONE => error ("No such predicate: " ^ quote name) |
|
1410 |
|
1411 fun dependencies_of (thy : theory) name = |
|
1412 let |
|
1413 fun is_inductive_predicate thy name = |
|
1414 is_some (try (InductivePackage.the_inductive (ProofContext.init thy)) name) |
|
1415 val data = fetch_pred_data thy name |
|
1416 val intros = map Thm.prop_of (#intros (rep_pred_data data)) |
|
1417 val keys = fold Term.add_consts intros [] |> map fst |
|
1418 |> filter (is_inductive_predicate thy) |
|
1419 in |
|
1420 (data, keys) |
|
1421 end; |
|
1422 |
|
1423 fun extend explore keys gr = |
|
1424 let |
|
1425 fun contains_node gr key = member (op =) (Graph.keys gr) key |
|
1426 fun extend' key gr = |
|
1427 let |
|
1428 val (node, preds) = explore key |
|
1429 in |
|
1430 gr |> (not (contains_node gr key)) ? |
|
1431 (Graph.new_node (key, node) |
|
1432 #> fold extend' preds |
|
1433 #> fold (Graph.add_edge o (pair key)) preds) |
|
1434 end |
|
1435 in fold extend' keys gr end |
|
1436 |
|
1437 fun mk_graph explore keys = extend explore keys Graph.empty |
|
1438 |
|
1439 fun add_equations name thy = |
|
1440 let |
|
1441 val thy' = PredData.map (extend (dependencies_of thy) [name]) thy; |
|
1442 (*val preds = Graph.all_preds (PredData.get thy') [name] |> filter_out (has_elim thy') *) |
|
1443 fun strong_conn_of gr keys = |
|
1444 Graph.strong_conn (Graph.subgraph (member (op =) (Graph.all_succs gr keys)) gr) |
|
1445 val scc = strong_conn_of (PredData.get thy') [name] |
|
1446 val thy'' = fold_rev add_equations_of scc thy' |
|
1447 in thy'' end |
1379 |
1448 |
1380 (** user interface **) |
1449 (** user interface **) |
1381 |
1450 |
1382 local |
1451 local |
1383 |
1452 |
1384 fun attrib f = Thm.declaration_attribute (fn thm => Context.mapping (f thm) I); |
1453 fun attrib f = Thm.declaration_attribute (fn thm => Context.mapping (f thm) I); |
1385 |
1454 |
1386 val add_elim_attrib = attrib add_elim_thm; |
1455 (* |
1387 |
1456 val add_elim_attrib = attrib set_elim; |
|
1457 *) |
|
1458 |
|
1459 |
|
1460 (* TODO: make TheoryDataFun to GenericDataFun & remove duplication of local theory and theory *) |
|
1461 (* TODO: must create state to prove multiple cases *) |
1388 fun generic_code_pred prep_const raw_const lthy = |
1462 fun generic_code_pred prep_const raw_const lthy = |
1389 let |
1463 let |
1390 val thy = ProofContext.theory_of lthy |
1464 val thy = (ProofContext.theory_of lthy) |
1391 val const = prep_const thy raw_const |
1465 val const = prep_const thy raw_const |
1392 val nparams = get_nparams thy const |
1466 val lthy' = lthy |
1393 val intro_rules = pred_intros thy const |
1467 val thy' = PredData.map (extend (dependencies_of thy) [const]) thy |
1394 val (((tfrees, frees), fact), lthy') = |
1468 val preds = Graph.all_preds (PredData.get thy') [const] |> filter_out (has_elim thy') |
1395 Variable.import_thms true intro_rules lthy; |
1469 val _ = Output.tracing ("preds: " ^ commas preds) |
1396 val (pred, prop, lthy'') = mk_casesrule fact nparams lthy' |
1470 (* |
1397 val (predname, _) = dest_Const pred |
1471 fun mk_elim pred = |
1398 fun after_qed [[th]] lthy'' = |
1472 let |
1399 lthy'' |
1473 val nparams = nparams_of thy pred |
|
1474 val intros = intros_of thy pred |
|
1475 val (((tfrees, frees), fact), lthy'') = |
|
1476 Variable.import_thms true intros lthy'; |
|
1477 val (pred, prop, lthy''') = mk_casesrule fact nparams lthy'' |
|
1478 in (pred, prop, lthy''') end; |
|
1479 |
|
1480 val (predname, _) = dest_Const pred |
|
1481 *) |
|
1482 val nparams = nparams_of thy' const |
|
1483 val intros = intros_of thy' const |
|
1484 val (((tfrees, frees), fact), lthy'') = |
|
1485 Variable.import_thms true intros lthy'; |
|
1486 val (pred, prop, lthy''') = mk_casesrule fact nparams lthy'' |
|
1487 val (predname, _) = dest_Const pred |
|
1488 fun after_qed [[th]] lthy''' = |
|
1489 lthy''' |
1400 |> LocalTheory.note Thm.generatedK |
1490 |> LocalTheory.note Thm.generatedK |
1401 ((Binding.empty, [Attrib.internal (K add_elim_attrib)]), [th]) |
1491 ((Binding.empty, []), [th]) |
1402 |> snd |
1492 |> snd |
1403 |> LocalTheory.theory (prove_equation predname NONE) |
1493 |> LocalTheory.theory (add_equations_of [predname]) |
1404 in |
1494 in |
1405 Proof.theorem_i NONE after_qed [[(prop, [])]] lthy'' |
1495 Proof.theorem_i NONE after_qed [[(prop, [])]] lthy''' |
1406 end; |
1496 end; |
1407 |
1497 |
1408 structure P = OuterParse |
1498 structure P = OuterParse |
1409 |
1499 |
1410 in |
1500 in |
1411 |
1501 |
1412 val code_pred = generic_code_pred (K I); |
1502 val code_pred = generic_code_pred (K I); |
1413 val code_pred_cmd = generic_code_pred Code.read_const |
1503 val code_pred_cmd = generic_code_pred Code.read_const |
1414 |
1504 |
1415 val setup = |
1505 val setup = PredData.put (Graph.empty) #> |
1416 Attrib.setup @{binding code_ind_intros} (Scan.succeed (attrib add_intro_thm)) |
1506 Attrib.setup @{binding code_pred_intros} (Scan.succeed (attrib add_intro)) |
1417 "adding alternative introduction rules for code generation of inductive predicates" #> |
1507 "adding alternative introduction rules for code generation of inductive predicates" |
1418 Attrib.setup @{binding code_ind_cases} (Scan.succeed add_elim_attrib) |
1508 (* Attrib.setup @{binding code_ind_cases} (Scan.succeed add_elim_attrib) |
1419 "adding alternative elimination rules for code generation of inductive predicates"; |
1509 "adding alternative elimination rules for code generation of inductive predicates"; |
|
1510 *) |
1420 (*FIXME name discrepancy in attribs and ML code*) |
1511 (*FIXME name discrepancy in attribs and ML code*) |
1421 (*FIXME intros should be better named intro*) |
1512 (*FIXME intros should be better named intro*) |
1422 (*FIXME why distinguished atribute for cases?*) |
1513 (*FIXME why distinguished attribute for cases?*) |
1423 |
1514 |
1424 val _ = OuterSyntax.local_theory_to_proof "code_pred" |
1515 val _ = OuterSyntax.local_theory_to_proof "code_pred" |
1425 "prove equations for predicate specified by intro/elim rules" |
1516 "prove equations for predicate specified by intro/elim rules" |
1426 OuterKeyword.thy_goal (P.term_group >> code_pred_cmd) |
1517 OuterKeyword.thy_goal (P.term_group >> code_pred_cmd) |
1427 |
1518 |