4 Ad-hoc overloading of constants based on their types. |
4 Ad-hoc overloading of constants based on their types. |
5 *) |
5 *) |
6 |
6 |
7 signature ADHOC_OVERLOADING = |
7 signature ADHOC_OVERLOADING = |
8 sig |
8 sig |
9 val add_overloaded: string -> theory -> theory |
9 val is_overloaded: Proof.context -> string -> bool |
10 val add_variant: string -> string -> theory -> theory |
10 val generic_add_overloaded: string -> Context.generic -> Context.generic |
11 |
11 val generic_remove_overloaded: string -> Context.generic -> Context.generic |
|
12 val generic_add_variant: string -> term -> Context.generic -> Context.generic |
|
13 (*If the list of variants is empty at the end of "generic_remove_variant", then |
|
14 "generic_remove_overloaded" is called implicitly.*) |
|
15 val generic_remove_variant: string -> term -> Context.generic -> Context.generic |
12 val show_variants: bool Config.T |
16 val show_variants: bool Config.T |
13 val setup: theory -> theory |
|
14 end |
17 end |
15 |
18 |
16 structure Adhoc_Overloading: ADHOC_OVERLOADING = |
19 structure Adhoc_Overloading: ADHOC_OVERLOADING = |
17 struct |
20 struct |
18 |
21 |
19 val show_variants = Attrib.setup_config_bool @{binding show_variants} (K false); |
22 val show_variants = Attrib.setup_config_bool @{binding show_variants} (K false); |
20 |
23 |
21 |
|
22 (* errors *) |
24 (* errors *) |
23 |
25 |
24 fun duplicate_variant_err int_name ext_name = |
26 fun duplicate_variant_error oconst = |
25 error ("Constant " ^ quote int_name ^ " is already a variant of " ^ quote ext_name); |
27 error ("Duplicate variant of " ^ quote oconst); |
26 |
28 |
27 fun not_overloaded_err name = |
29 fun not_a_variant_error oconst = |
28 error ("Constant " ^ quote name ^ " is not declared as overloaded"); |
30 error ("Not a variant of " ^ quote oconst); |
29 |
31 |
30 fun already_overloaded_err name = |
32 fun not_overloaded_error oconst = |
31 error ("Constant " ^ quote name ^ " is already declared as overloaded"); |
33 error ("Constant " ^ quote oconst ^ " is not declared as overloaded"); |
32 |
34 |
33 fun unresolved_err ctxt (c, T) t reason = |
35 fun unresolved_overloading_error ctxt (c, T) t reason = |
34 error ("Unresolved overloading of " ^ quote c ^ " :: " ^ |
36 error ("Unresolved overloading of " ^ quote c ^ " :: " ^ |
35 quote (Syntax.string_of_typ ctxt T) ^ " in " ^ |
37 quote (Syntax.string_of_typ ctxt T) ^ " in " ^ |
36 quote (Syntax.string_of_term ctxt t) ^ " (" ^ reason ^ ")"); |
38 quote (Syntax.string_of_term ctxt t) ^ " (" ^ reason ^ ")"); |
37 |
39 |
38 |
40 (* generic data *) |
39 (* theory data *) |
41 |
40 |
42 fun variants_eq ((v1, T1), (v2, T2)) = |
41 structure Overload_Data = Theory_Data |
43 Term.aconv_untyped (v1, v2) andalso T1 = T2; |
|
44 |
|
45 structure Overload_Data = Generic_Data |
42 ( |
46 ( |
43 type T = |
47 type T = |
44 { internalize : (string * typ) list Symtab.table, |
48 {variants : (term * typ) list Symtab.table, |
45 externalize : string Symtab.table }; |
49 oconsts : string Termtab.table}; |
46 val empty = {internalize=Symtab.empty, externalize=Symtab.empty}; |
50 val empty = {variants = Symtab.empty, oconsts = Termtab.empty}; |
47 val extend = I; |
51 val extend = I; |
48 |
52 |
49 fun merge_ext int_name (ext_name1, ext_name2) = |
|
50 if ext_name1 = ext_name2 then ext_name1 |
|
51 else duplicate_variant_err int_name ext_name1; |
|
52 |
|
53 fun merge |
53 fun merge |
54 ({internalize = int1, externalize = ext1}, |
54 ({variants = vtab1, oconsts = otab1}, |
55 {internalize = int2, externalize = ext2}) : T = |
55 {variants = vtab2, oconsts = otab2}) : T = |
56 {internalize = Symtab.merge_list (op =) (int1, int2), |
56 let |
57 externalize = Symtab.join merge_ext (ext1, ext2)}; |
57 fun merge_oconsts _ (oconst1, oconst2) = |
|
58 if oconst1 = oconst2 then oconst1 |
|
59 else duplicate_variant_error oconst1; |
|
60 in |
|
61 {variants = Symtab.merge_list variants_eq (vtab1, vtab2), |
|
62 oconsts = Termtab.join merge_oconsts (otab1, otab2)} |
|
63 end; |
58 ); |
64 ); |
59 |
65 |
60 fun map_tables f g = |
66 fun map_tables f g = |
61 Overload_Data.map (fn {internalize=int, externalize=ext} => |
67 Overload_Data.map (fn {variants = vtab, oconsts = otab} => |
62 {internalize=f int, externalize=g ext}); |
68 {variants = f vtab, oconsts = g otab}); |
63 |
69 |
64 val is_overloaded = Symtab.defined o #internalize o Overload_Data.get; |
70 val is_overloaded = Symtab.defined o #variants o Overload_Data.get o Context.Proof; |
65 val get_variants = Symtab.lookup o #internalize o Overload_Data.get; |
71 val get_variants = Symtab.lookup o #variants o Overload_Data.get o Context.Proof; |
66 val get_external = Symtab.lookup o #externalize o Overload_Data.get; |
72 val get_overloaded = Termtab.lookup o #oconsts o Overload_Data.get o Context.Proof; |
67 |
73 |
68 fun add_overloaded ext_name thy = |
74 fun generic_add_overloaded oconst context = |
69 let val _ = not (is_overloaded thy ext_name) orelse already_overloaded_err ext_name; |
75 if is_overloaded (Context.proof_of context) oconst then context |
70 in map_tables (Symtab.update (ext_name, [])) I thy end; |
76 else map_tables (Symtab.update (oconst, [])) I context; |
71 |
77 |
72 fun add_variant ext_name name thy = |
78 fun generic_remove_overloaded oconst context = |
73 let |
79 let |
74 val _ = is_overloaded thy ext_name orelse not_overloaded_err ext_name; |
80 fun remove_oconst_and_variants context oconst = |
75 val _ = |
81 let |
76 (case get_external thy name of |
82 val remove_variants = |
77 NONE => () |
83 (case get_variants (Context.proof_of context) oconst of |
78 | SOME gen' => duplicate_variant_err name gen'); |
84 NONE => I |
79 val T = Sign.the_const_type thy name; |
85 | SOME vs => fold (Termtab.remove (op =) o rpair oconst o fst) vs); |
80 in |
86 in map_tables (Symtab.delete_safe oconst) remove_variants context end; |
81 map_tables (Symtab.cons_list (ext_name, (name, T))) |
87 in |
82 (Symtab.update (name, ext_name)) thy |
88 if is_overloaded (Context.proof_of context) oconst then remove_oconst_and_variants context oconst |
83 end |
89 else not_overloaded_error oconst |
84 |
90 end; |
|
91 |
|
92 local |
|
93 fun generic_variant add oconst t context = |
|
94 let |
|
95 val ctxt = Context.proof_of context; |
|
96 val _ = if is_overloaded ctxt oconst then () else not_overloaded_error oconst; |
|
97 val T = t |> singleton (Variable.polymorphic ctxt) |> fastype_of; |
|
98 val t' = Term.map_types (K dummyT) t; |
|
99 in |
|
100 if add then |
|
101 let |
|
102 val _ = |
|
103 (case get_overloaded ctxt t' of |
|
104 NONE => () |
|
105 | SOME oconst' => duplicate_variant_error oconst'); |
|
106 in |
|
107 map_tables (Symtab.cons_list (oconst, (t', T))) (Termtab.update (t', oconst)) context |
|
108 end |
|
109 else |
|
110 let |
|
111 val _ = |
|
112 if member variants_eq (the (get_variants ctxt oconst)) (t', T) then () |
|
113 else not_a_variant_error oconst; |
|
114 in |
|
115 map_tables (Symtab.map_entry oconst (remove1 variants_eq (t', T))) |
|
116 (Termtab.delete_safe t') context |
|
117 |> (fn context => |
|
118 (case get_variants (Context.proof_of context) oconst of |
|
119 SOME [] => generic_remove_overloaded oconst context |
|
120 | _ => context)) |
|
121 end |
|
122 end; |
|
123 in |
|
124 val generic_add_variant = generic_variant true; |
|
125 val generic_remove_variant = generic_variant false; |
|
126 end; |
85 |
127 |
86 (* check / uncheck *) |
128 (* check / uncheck *) |
87 |
129 |
88 fun unifiable_with ctxt T1 (c, T2) = |
130 fun unifiable_with thy T1 (t, T2) = |
89 let |
131 let |
90 val thy = Proof_Context.theory_of ctxt; |
|
91 val maxidx1 = Term.maxidx_of_typ T1; |
132 val maxidx1 = Term.maxidx_of_typ T1; |
92 val T2' = Logic.incr_tvar (maxidx1 + 1) T2; |
133 val T2' = Logic.incr_tvar (maxidx1 + 1) T2; |
93 val maxidx2 = Int.max (maxidx1, Term.maxidx_of_typ T2'); |
134 val maxidx2 = Term.maxidx_typ T2' maxidx1; |
94 in |
135 in |
95 (Sign.typ_unify thy (T1, T2') (Vartab.empty, maxidx2); SOME c) |
136 (Sign.typ_unify thy (T1, T2') (Vartab.empty, maxidx2); SOME t) |
96 handle Type.TUNIFY => NONE |
137 handle Type.TUNIFY => NONE |
97 end; |
138 end; |
98 |
139 |
99 fun insert_internal_same ctxt t (Const (c, T)) = |
140 fun insert_variants_same ctxt t (Const (c, T)) = |
100 (case map_filter (unifiable_with ctxt T) |
141 (case map_filter (unifiable_with (Proof_Context.theory_of ctxt) T) |
101 (Same.function (get_variants (Proof_Context.theory_of ctxt)) c) of |
142 (Same.function (get_variants ctxt) c) of |
102 [] => unresolved_err ctxt (c, T) t "no instances" |
143 [] => unresolved_overloading_error ctxt (c, T) t "no instances" |
103 | [c'] => Const (c', dummyT) |
144 | [variant] => variant |
104 | _ => raise Same.SAME) |
145 | _ => raise Same.SAME) |
105 | insert_internal_same _ _ _ = raise Same.SAME; |
146 | insert_variants_same _ _ _ = raise Same.SAME; |
106 |
147 |
107 fun insert_external_same ctxt _ (Const (c, T)) = |
148 fun insert_overloaded_same ctxt variant = |
108 Const (Same.function (get_external (Proof_Context.theory_of ctxt)) c, T) |
149 let |
109 | insert_external_same _ _ _ = raise Same.SAME; |
150 val thy = Proof_Context.theory_of ctxt; |
|
151 val t = Pattern.rewrite_term thy [] [fn t => |
|
152 Term.map_types (K dummyT) t |
|
153 |> get_overloaded ctxt |
|
154 |> Option.map (Const o rpair (fastype_of variant))] variant; |
|
155 in |
|
156 if Term.aconv_untyped (variant, t) then raise Same.SAME |
|
157 else t |
|
158 end; |
110 |
159 |
111 fun gen_check_uncheck replace ts ctxt = |
160 fun gen_check_uncheck replace ts ctxt = |
112 Same.capture (Same.map (fn t => Term_Subst.map_aterms_same (replace ctxt t) t)) ts |
161 Same.capture (Same.map replace) ts |
113 |> Option.map (rpair ctxt); |
162 |> Option.map (rpair ctxt); |
114 |
163 |
115 val check = gen_check_uncheck insert_internal_same; |
164 fun check ts ctxt = gen_check_uncheck (fn t => |
|
165 Term_Subst.map_aterms_same (insert_variants_same ctxt t) t) ts ctxt; |
116 |
166 |
117 fun uncheck ts ctxt = |
167 fun uncheck ts ctxt = |
118 if Config.get ctxt show_variants then NONE |
168 if Config.get ctxt show_variants then NONE |
119 else gen_check_uncheck insert_external_same ts ctxt; |
169 else gen_check_uncheck (insert_overloaded_same ctxt) ts ctxt; |
120 |
170 |
121 fun reject_unresolved ts ctxt = |
171 fun reject_unresolved ts ctxt = |
122 let |
172 let |
123 val thy = Proof_Context.theory_of ctxt; |
|
124 fun check_unresolved t = |
173 fun check_unresolved t = |
125 (case filter (is_overloaded thy o fst) (Term.add_consts t []) of |
174 (case filter (is_overloaded ctxt o fst) (Term.add_consts t []) of |
126 [] => () |
175 [] => () |
127 | ((c, T) :: _) => unresolved_err ctxt (c, T) t "multiple instances"); |
176 | ((c, T) :: _) => unresolved_overloading_error ctxt (c, T) t "multiple instances"); |
128 val _ = map check_unresolved ts; |
177 val _ = map check_unresolved ts; |
129 in NONE end; |
178 in NONE end; |
130 |
179 |
131 |
|
132 (* setup *) |
180 (* setup *) |
133 |
181 |
134 val setup = Context.theory_map |
182 val _ = Context.>> |
135 (Syntax_Phases.term_check' 0 "adhoc_overloading" check |
183 (Syntax_Phases.term_check' 0 "adhoc_overloading" check |
136 #> Syntax_Phases.term_check' 1 "adhoc_overloading_unresolved_check" reject_unresolved |
184 #> Syntax_Phases.term_check' 1 "adhoc_overloading_unresolved_check" reject_unresolved |
137 #> Syntax_Phases.term_uncheck' 0 "adhoc_overloading" uncheck); |
185 #> Syntax_Phases.term_uncheck' 0 "adhoc_overloading" uncheck); |
138 |
186 |
|
187 (* commands *) |
|
188 |
|
189 fun generic_adhoc_overloading_cmd add = |
|
190 if add then |
|
191 fold (fn (oconst, ts) => |
|
192 generic_add_overloaded oconst |
|
193 #> fold (generic_add_variant oconst) ts) |
|
194 else |
|
195 fold (fn (oconst, ts) => |
|
196 fold (generic_remove_variant oconst) ts); |
|
197 |
|
198 fun adhoc_overloading_cmd' add args phi = |
|
199 let val args' = args |
|
200 |> map (apsnd (map_filter (fn t => |
|
201 let val t' = Morphism.term phi t; |
|
202 in if Term.aconv_untyped (t, t') then SOME t' else NONE end))); |
|
203 in generic_adhoc_overloading_cmd add args' end; |
|
204 |
|
205 fun adhoc_overloading_cmd add raw_args lthy = |
|
206 let |
|
207 fun const_name ctxt = fst o dest_Const o Proof_Context.read_const ctxt false dummyT; |
|
208 val args = |
|
209 raw_args |
|
210 |> map (apfst (const_name lthy)) |
|
211 |> map (apsnd (map (Syntax.read_term lthy))); |
|
212 in |
|
213 Local_Theory.declaration {syntax = true, pervasive = false} |
|
214 (adhoc_overloading_cmd' add args) lthy |
|
215 end; |
|
216 |
|
217 val _ = |
|
218 Outer_Syntax.local_theory @{command_spec "adhoc_overloading"} |
|
219 "add ad-hoc overloading for constants / fixed variables" |
|
220 (Parse.and_list1 (Parse.const -- Scan.repeat Parse.term) >> adhoc_overloading_cmd true); |
|
221 |
|
222 val _ = |
|
223 Outer_Syntax.local_theory @{command_spec "no_adhoc_overloading"} |
|
224 "add ad-hoc overloading for constants / fixed variables" |
|
225 (Parse.and_list1 (Parse.const -- Scan.repeat Parse.term) >> adhoc_overloading_cmd false); |
|
226 |
139 end; |
227 end; |
|
228 |