4 Deriving specialised predicates and their intro rules |
4 Deriving specialised predicates and their intro rules |
5 *) |
5 *) |
6 |
6 |
7 signature PREDICATE_COMPILE_SPECIALISATION = |
7 signature PREDICATE_COMPILE_SPECIALISATION = |
8 sig |
8 sig |
9 val find_specialisations : string list -> (string * thm list) list -> theory -> (string * thm list) list * theory |
9 val find_specialisations : string list -> (string * thm list) list -> |
|
10 theory -> (string * thm list) list * theory |
10 end; |
11 end; |
11 |
12 |
12 structure Predicate_Compile_Specialisation : PREDICATE_COMPILE_SPECIALISATION = |
13 structure Predicate_Compile_Specialisation : PREDICATE_COMPILE_SPECIALISATION = |
13 struct |
14 struct |
14 |
15 |
15 open Predicate_Compile_Aux; |
16 open Predicate_Compile_Aux; |
16 |
17 |
17 (* table of specialisations *) |
18 (* table of specialisations *) |
18 structure Specialisations = Theory_Data |
19 structure Specialisations = Theory_Data |
19 ( |
20 ( |
20 type T = (term * term) Item_Net.T; |
21 type T = (term * term) Item_Net.T |
21 val empty : T = Item_Net.init (op aconv o pairself fst) (single o fst); |
22 val empty : T = Item_Net.init (op aconv o pairself fst) (single o fst) |
22 val extend = I; |
23 val extend = I |
23 val merge = Item_Net.merge; |
24 val merge = Item_Net.merge |
24 ) |
25 ) |
25 |
26 |
26 fun specialisation_of thy atom = |
27 fun specialisation_of thy atom = |
27 Item_Net.retrieve (Specialisations.get thy) atom |
28 Item_Net.retrieve (Specialisations.get thy) atom |
28 |
29 |
29 fun import (_, intros) args ctxt = |
30 fun import (_, intros) args ctxt = |
30 let |
31 let |
31 val ((_, intros'), ctxt') = Variable.importT intros ctxt |
32 val ((_, intros'), ctxt') = Variable.importT intros ctxt |
32 val pred' = fst (strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl (prop_of (hd intros'))))) |
33 val pred' = |
|
34 fst (strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl (prop_of (hd intros'))))) |
33 val Ts = binder_types (fastype_of pred') |
35 val Ts = binder_types (fastype_of pred') |
34 val argTs = map fastype_of args |
36 val argTs = map fastype_of args |
35 val Tsubst = Type.raw_matches (argTs, Ts) Vartab.empty |
37 val Tsubst = Type.raw_matches (argTs, Ts) Vartab.empty |
36 val args' = map (Envir.subst_term_types Tsubst) args |
38 val args' = map (Envir.subst_term_types Tsubst) args |
37 in |
39 in |
40 |
42 |
41 (* patterns only constructed of variables and pairs/tuples are trivial constructor terms*) |
43 (* patterns only constructed of variables and pairs/tuples are trivial constructor terms*) |
42 fun is_nontrivial_constrt thy t = |
44 fun is_nontrivial_constrt thy t = |
43 let |
45 let |
44 val cnstrs = get_constrs thy |
46 val cnstrs = get_constrs thy |
45 fun check t = (case strip_comb t of |
47 fun check t = |
|
48 (case strip_comb t of |
46 (Var _, []) => (true, true) |
49 (Var _, []) => (true, true) |
47 | (Free _, []) => (true, true) |
50 | (Free _, []) => (true, true) |
48 | (Const (@{const_name Pair}, _), ts) => |
51 | (Const (@{const_name Pair}, _), ts) => |
49 pairself (forall I) (split_list (map check ts)) |
52 pairself (forall I) (split_list (map check ts)) |
50 | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of |
53 | (Const (s, T), ts) => |
|
54 (case (AList.lookup (op =) cnstrs s, body_type T) of |
51 (SOME (i, Tname), Type (Tname', _)) => (false, |
55 (SOME (i, Tname), Type (Tname', _)) => (false, |
52 length ts = i andalso Tname = Tname' andalso forall (snd o check) ts) |
56 length ts = i andalso Tname = Tname' andalso forall (snd o check) ts) |
53 | _ => (false, false)) |
57 | _ => (false, false)) |
54 | _ => (false, false)) |
58 | _ => (false, false)) |
55 in check t = (false, true) end; |
59 in check t = (false, true) end |
56 |
60 |
57 fun specialise_intros black_list (pred, intros) pats thy = |
61 fun specialise_intros black_list (pred, intros) pats thy = |
58 let |
62 let |
59 val ctxt = Proof_Context.init_global thy |
63 val ctxt = Proof_Context.init_global thy |
60 val maxidx = fold (Term.maxidx_term o prop_of) intros ~1 |
64 val maxidx = fold (Term.maxidx_term o prop_of) intros ~1 |
87 val intro = Logic.list_implies (prems, concl) |
91 val intro = Logic.list_implies (prems, concl) |
88 in |
92 in |
89 SOME intro |
93 SOME intro |
90 end handle Pattern.Unif => NONE) |
94 end handle Pattern.Unif => NONE) |
91 val specialised_intros_t = map_filter I (map specialise_intro intros) |
95 val specialised_intros_t = map_filter I (map specialise_intro intros) |
92 val thy' = Sign.add_consts_i [(Binding.name (Long_Name.base_name constname), constT, NoSyn)] thy |
96 val thy' = |
|
97 Sign.add_consts_i [(Binding.name (Long_Name.base_name constname), constT, NoSyn)] thy |
93 val specialised_intros = map (Skip_Proof.make_thm thy') specialised_intros_t |
98 val specialised_intros = map (Skip_Proof.make_thm thy') specialised_intros_t |
94 val exported_intros = Variable.exportT ctxt' ctxt specialised_intros |
99 val exported_intros = Variable.exportT ctxt' ctxt specialised_intros |
95 val [t, specialised_t] = Variable.exportT_terms ctxt' ctxt |
100 val [t, specialised_t] = Variable.exportT_terms ctxt' ctxt |
96 [list_comb (pred, pats), list_comb (specialised_const, result_pats)] |
101 [list_comb (pred, pats), list_comb (specialised_const, result_pats)] |
97 val thy'' = Specialisations.map (Item_Net.update (t, specialised_t)) thy' |
102 val thy'' = Specialisations.map (Item_Net.update (t, specialised_t)) thy' |
121 in |
126 in |
122 (free :: ts', free_names'') |
127 (free :: ts', free_names'') |
123 end |
128 end |
124 and restrict_pattern' thy [] free_names = ([], free_names) |
129 and restrict_pattern' thy [] free_names = ([], free_names) |
125 | restrict_pattern' thy ((T, Free (x, _)) :: Tts) free_names = |
130 | restrict_pattern' thy ((T, Free (x, _)) :: Tts) free_names = |
126 let |
131 let |
127 val (ts', free_names') = restrict_pattern' thy Tts free_names |
132 val (ts', free_names') = restrict_pattern' thy Tts free_names |
128 in |
133 in |
129 (Free (x, T) :: ts', free_names') |
134 (Free (x, T) :: ts', free_names') |
130 end |
135 end |
131 | restrict_pattern' thy ((T as TFree _, t) :: Tts) free_names = |
136 | restrict_pattern' thy ((T as TFree _, t) :: Tts) free_names = |
132 replace_term_and_restrict thy T t Tts free_names |
137 replace_term_and_restrict thy T t Tts free_names |
133 | restrict_pattern' thy ((T as Type (Tcon, _), t) :: Tts) free_names = |
138 | restrict_pattern' thy ((T as Type (Tcon, _), t) :: Tts) free_names = |
134 case Ctr_Sugar.ctr_sugar_of ctxt Tcon of |
139 case Ctr_Sugar.ctr_sugar_of ctxt Tcon of |
135 NONE => replace_term_and_restrict thy T t Tts free_names |
140 NONE => replace_term_and_restrict thy T t Tts free_names |
136 | SOME {ctrs, ...} => (case strip_comb t of |
141 | SOME {ctrs, ...} => |
137 (Const (s, _), ats) => |
142 (case strip_comb t of |
138 (case AList.lookup (op =) (map_filter (try dest_Const) ctrs) s of |
143 (Const (s, _), ats) => |
139 SOME constr_T => |
144 (case AList.lookup (op =) (map_filter (try dest_Const) ctrs) s of |
140 let |
145 SOME constr_T => |
141 val (Ts', T') = strip_type constr_T |
146 let |
142 val Tsubst = Type.raw_match (T', T) Vartab.empty |
147 val (Ts', T') = strip_type constr_T |
143 val Ts = map (Envir.subst_type Tsubst) Ts' |
148 val Tsubst = Type.raw_match (T', T) Vartab.empty |
144 val (bts', free_names') = restrict_pattern' thy ((Ts ~~ ats) @ Tts) free_names |
149 val Ts = map (Envir.subst_type Tsubst) Ts' |
145 val (ats', ts') = chop (length ats) bts' |
150 val (bts', free_names') = restrict_pattern' thy ((Ts ~~ ats) @ Tts) free_names |
146 in |
151 val (ats', ts') = chop (length ats) bts' |
147 (list_comb (Const (s, map fastype_of ats' ---> T), ats') :: ts', free_names') |
152 in |
148 end |
153 (list_comb (Const (s, map fastype_of ats' ---> T), ats') :: ts', free_names') |
149 | NONE => replace_term_and_restrict thy T t Tts free_names)) |
154 end |
|
155 | NONE => replace_term_and_restrict thy T t Tts free_names)) |
150 fun restrict_pattern thy Ts args = |
156 fun restrict_pattern thy Ts args = |
151 let |
157 let |
152 val args = map Logic.unvarify_global args |
158 val args = map Logic.unvarify_global args |
153 val Ts = map Logic.unvarifyT_global Ts |
159 val Ts = map Logic.unvarifyT_global Ts |
154 val free_names = fold Term.add_free_names args [] |
160 val free_names = fold Term.add_free_names args [] |
155 val (pat, _) = restrict_pattern' thy (Ts ~~ args) free_names |
161 val (pat, _) = restrict_pattern' thy (Ts ~~ args) free_names |
156 in map Logic.varify_global pat end |
162 in map Logic.varify_global pat end |
157 fun detect' atom thy = |
163 fun detect' atom thy = |
158 case strip_comb atom of |
164 (case strip_comb atom of |
159 (pred as Const (pred_name, _), args) => |
165 (pred as Const (pred_name, _), args) => |
160 let |
166 let |
161 val Ts = binder_types (Sign.the_const_type thy pred_name) |
167 val Ts = binder_types (Sign.the_const_type thy pred_name) |
162 val pats = restrict_pattern thy Ts args |
168 val pats = restrict_pattern thy Ts args |
163 in |
169 in |
164 if (exists (is_nontrivial_constrt thy) pats) |
170 if (exists (is_nontrivial_constrt thy) pats) |
165 orelse (has_duplicates (op =) (fold add_vars pats [])) then |
171 orelse (has_duplicates (op =) (fold add_vars pats [])) then |
166 let |
172 let |
167 val thy' = |
173 val thy' = |
168 case specialisation_of thy atom of |
174 (case specialisation_of thy atom of |
169 [] => |
175 [] => |
170 if member (op =) ((map fst specs) @ black_list) pred_name then |
176 if member (op =) ((map fst specs) @ black_list) pred_name then |
171 thy |
177 thy |
172 else |
178 else |
173 (case try (Core_Data.intros_of (Proof_Context.init_global thy)) pred_name of |
179 (case try (Core_Data.intros_of (Proof_Context.init_global thy)) pred_name of |
174 NONE => thy |
180 NONE => thy |
175 | SOME [] => thy |
181 | SOME [] => thy |
176 | SOME intros => |
182 | SOME intros => |
177 specialise_intros ((map fst specs) @ (pred_name :: black_list)) |
183 specialise_intros ((map fst specs) @ (pred_name :: black_list)) |
178 (pred, intros) pats thy) |
184 (pred, intros) pats thy) |
179 | _ :: _ => thy |
185 | _ :: _ => thy) |
180 val atom' = |
186 val atom' = |
181 case specialisation_of thy' atom of |
187 (case specialisation_of thy' atom of |
182 [] => atom |
188 [] => atom |
183 | (t, specialised_t) :: _ => |
189 | (t, specialised_t) :: _ => |
184 let |
190 let |
185 val subst = Pattern.match thy' (t, atom) (Vartab.empty, Vartab.empty) |
191 val subst = Pattern.match thy' (t, atom) (Vartab.empty, Vartab.empty) |
186 in Envir.subst_term subst specialised_t end handle Pattern.MATCH => atom |
192 in Envir.subst_term subst specialised_t end handle Pattern.MATCH => atom) |
187 (*FIXME: this exception could be caught earlier in specialisation_of *) |
193 (*FIXME: this exception could be handled earlier in specialisation_of *) |
188 in |
194 in |
189 (atom', thy') |
195 (atom', thy') |
190 end |
196 end |
191 else (atom, thy) |
197 else (atom, thy) |
192 end |
198 end |
193 | _ => (atom, thy) |
199 | _ => (atom, thy)) |
194 fun specialise' (constname, intros) thy = |
200 fun specialise' (constname, intros) thy = |
195 let |
201 let |
196 (* FIXME: only necessary because of sloppy Logic.unvarify in restrict_pattern *) |
202 (* FIXME: only necessary because of sloppy Logic.unvarify in restrict_pattern *) |
197 val intros = Drule.zero_var_indexes_list intros |
203 val intros = Drule.zero_var_indexes_list intros |
198 val (intros_t', thy') = (fold_map o fold_map_atoms) detect' (map prop_of intros) thy |
204 val (intros_t', thy') = (fold_map o fold_map_atoms) detect' (map prop_of intros) thy |