9 type 'a converter = { |
9 type 'a converter = { |
10 constc : 'a wctxt -> term -> 'a, |
10 constc : 'a wctxt -> term -> 'a, |
11 funcc : 'a wctxt -> term -> term list -> 'a, |
11 funcc : 'a wctxt -> term -> term list -> 'a, |
12 ifc : 'a wctxt -> typ -> term -> term -> term -> 'a, |
12 ifc : 'a wctxt -> typ -> term -> term -> term -> 'a, |
13 casec : 'a wctxt -> term -> term list -> 'a, |
13 casec : 'a wctxt -> term -> term list -> 'a, |
14 letc : 'a wctxt -> typ -> term -> string list -> typ list -> term -> 'a |
14 letc : 'a wctxt -> typ -> term -> (string * typ) list -> term -> 'a |
15 } |
15 } |
16 val walk : local_theory -> term list -> 'a converter -> term -> 'a |
16 val walk : local_theory -> term list -> 'a converter -> term -> 'a |
|
17 val Iconst : term wctxt -> term -> term |
|
18 val Ifunc : term wctxt -> term -> term list -> term |
|
19 val Iif : term wctxt -> typ -> term -> term -> term -> term |
|
20 val Icase : term wctxt -> term -> term list -> term |
|
21 val Ilet : term wctxt -> typ -> term -> (string * typ) list -> term -> term |
17 |
22 |
18 type pfunc = { names : string list, terms : term list, typs : typ list } |
23 type pfunc = { names : string list, terms : term list, typs : typ list } |
19 val fun_pretty': Proof.context -> pfunc -> Pretty.T |
24 val fun_pretty': Proof.context -> pfunc -> Pretty.T |
20 val fun_pretty: Proof.context -> Function.info -> Pretty.T |
25 val fun_pretty: Proof.context -> Function.info -> Pretty.T |
21 val print_timing': Proof.context -> pfunc -> pfunc -> unit |
26 val print_timing': Proof.context -> pfunc -> pfunc -> unit |
22 val print_timing: Proof.context -> Function.info -> Function.info -> unit |
27 val print_timing: Proof.context -> Function.info -> Function.info -> unit |
23 |
28 |
24 val reg_and_proove_time_func: local_theory -> term list -> term list |
29 val reg_and_proove_time_func: local_theory -> term list -> term list |
25 -> bool -> Function.info * local_theory |
30 -> bool -> bool -> Function.info * local_theory |
26 val reg_time_func: local_theory -> term list -> term list |
31 val reg_time_func: local_theory -> term list -> term list |
27 -> bool -> Function.info * local_theory |
32 -> bool -> bool -> Function.info * local_theory |
28 |
33 |
29 val time_dom_tac: Proof.context -> thm -> thm list -> int -> tactic |
34 val time_dom_tac: Proof.context -> thm -> thm list -> int -> tactic |
30 |
35 |
31 end |
36 end |
32 |
37 |
33 structure Timing_Functions : TIMING_FUNCTIONS = |
38 structure Timing_Functions : TIMING_FUNCTIONS = |
34 struct |
39 struct |
35 (* Configure config variable to adjust the prefix *) |
40 (* Configure config variable to adjust the prefix *) |
36 val bprefix = Attrib.setup_config_string @{binding "time_prefix"} (K "T_") |
41 val bprefix = Attrib.setup_config_string @{binding "time_prefix"} (K "T_") |
|
42 val bprefix_snd = Attrib.setup_config_string @{binding "time_prefix_snd"} (K "T2_") |
37 (* Configure config variable to adjust the suffix *) |
43 (* Configure config variable to adjust the suffix *) |
38 val bsuffix = Attrib.setup_config_string @{binding "time_suffix"} (K "") |
44 val bsuffix = Attrib.setup_config_string @{binding "time_suffix"} (K "") |
39 |
45 |
40 (* some default values to build terms easier *) |
46 (* some default values to build terms easier *) |
41 val zero = Const (@{const_name "Groups.zero"}, HOLogic.natT) |
47 val zero = Const (@{const_name "Groups.zero"}, HOLogic.natT) |
109 and replace all function arguments f with (t*T_f) if used *) |
136 and replace all function arguments f with (t*T_f) if used *) |
110 fun change_typ' used (Type ("fun", [T1, T2])) = |
137 fun change_typ' used (Type ("fun", [T1, T2])) = |
111 Type ("fun", [check_for_fun' (used 0) T1, change_typ' (fn i => used (i+1)) T2]) |
138 Type ("fun", [check_for_fun' (used 0) T1, change_typ' (fn i => used (i+1)) T2]) |
112 | change_typ' _ _ = HOLogic.natT |
139 | change_typ' _ _ = HOLogic.natT |
113 and check_for_fun' true (f as Type ("fun", [_,_])) = HOLogic.mk_prodT (f, change_typ' (K false) f) |
140 and check_for_fun' true (f as Type ("fun", [_,_])) = HOLogic.mk_prodT (f, change_typ' (K false) f) |
114 | check_for_fun' false (f as Type ("fun", [_,_])) = change_typ' (K false) f |
141 | check_for_fun' false (f as Type ("fun", [_,_])) = change_typ' (K true) f |
115 | check_for_fun' _ t = t |
142 | check_for_fun' _ t = t |
116 val change_typ = change_typ' (K false) |
143 val change_typ = change_typ' (K true) |
117 (* Convert string name of function to its timing equivalent *) |
144 (* Convert string name of function to its timing equivalent *) |
118 fun fun_name_to_time ctxt s name = |
145 fun fun_name_to_time' ctxt s second name = |
119 let |
146 let |
120 val prefix = Config.get ctxt bprefix |
147 val prefix = Config.get ctxt (if second then bprefix_snd else bprefix) |
121 val suffix = (if s then Config.get ctxt bsuffix else "") |
148 val suffix = (if s then Config.get ctxt bsuffix else "") |
122 fun replace_last_name [n] = [prefix ^ n ^ suffix] |
149 fun replace_last_name [n] = [prefix ^ n ^ suffix] |
123 | replace_last_name (n::ns) = n :: (replace_last_name ns) |
150 | replace_last_name (n::ns) = n :: (replace_last_name ns) |
124 | replace_last_name _ = error "Internal error: Invalid function name to convert" |
151 | replace_last_name _ = error "Internal error: Invalid function name to convert" |
125 val parts = split_name name |
152 val parts = split_name name |
126 in |
153 in |
127 String.concatWith "." (replace_last_name parts) |
154 String.concatWith "." (replace_last_name parts) |
128 end |
155 end |
|
156 fun fun_name_to_time ctxt s name = fun_name_to_time' ctxt s false name |
129 (* Count number of arguments of a function *) |
157 (* Count number of arguments of a function *) |
130 fun count_args (Type (n, [_,res])) = (if n = "fun" then 1 + count_args res else 0) |
158 fun count_args (Type (n, [_,res])) = (if n = "fun" then 1 + count_args res else 0) |
131 | count_args _ = 0 |
159 | count_args _ = 0 |
132 (* Check if number of arguments matches function *) |
160 (* Check if number of arguments matches function *) |
133 val _ = dest_Const |
|
134 fun check_args s (t, args) = |
161 fun check_args s (t, args) = |
135 (if length args = count_args (type_of t) then () |
162 (if length args = count_args (type_of t) then () |
136 else error ("Partial applications/Lambdas not allowed (" ^ s ^ ")")) |
163 else error ("Partial applications/Lambdas not allowed (" ^ s ^ ")")) |
137 (* Removes Abs *) |
164 (* Removes Abs *) |
138 fun rem_abs f (Abs (_,_,t)) = rem_abs f t |
165 fun rem_abs f (Abs (_,_,t)) = rem_abs f t |
189 type 'a converter = { |
216 type 'a converter = { |
190 constc : 'a wctxt -> term -> 'a, |
217 constc : 'a wctxt -> term -> 'a, |
191 funcc : 'a wctxt -> term -> term list -> 'a, |
218 funcc : 'a wctxt -> term -> term list -> 'a, |
192 ifc : 'a wctxt -> typ -> term -> term -> term -> 'a, |
219 ifc : 'a wctxt -> typ -> term -> term -> term -> 'a, |
193 casec : 'a wctxt -> term -> term list -> 'a, |
220 casec : 'a wctxt -> term -> term list -> 'a, |
194 letc : 'a wctxt -> typ -> term -> string list -> typ list -> term -> 'a |
221 letc : 'a wctxt -> typ -> term -> (string * typ) list -> term -> 'a |
195 } |
222 } |
196 |
223 |
197 (* Walks over term and calls given converter *) |
224 (* Walks over term and calls given converter *) |
198 fun walk_func (t1 $ t2) ts = walk_func t1 (t2::ts) |
225 (* get rid and use Term.strip_abs.eta especially for lambdas *) |
199 | walk_func t ts = (t, ts) |
226 fun build_abs t ((nm,T)::abs) = build_abs (Abs (nm,T,t)) abs |
200 fun walk_func' t = walk_func t [] |
227 | build_abs t [] = t |
201 fun build_func (f, []) = f |
|
202 | build_func (f, (t::ts)) = build_func (f$t, ts) |
|
203 fun walk_abs (Abs (nm,T,t)) nms Ts = walk_abs t (nm::nms) (T::Ts) |
|
204 | walk_abs t nms Ts = (t, nms, Ts) |
|
205 fun build_abs t (nm::nms) (T::Ts) = build_abs (Abs (nm,T,t)) nms Ts |
|
206 | build_abs t [] [] = t |
|
207 | build_abs _ _ _ = error "Internal error: Invalid terms to build abs" |
|
208 fun walk ctxt (origin: term list) (conv as {ifc, casec, funcc, letc, ...} : 'a converter) (t as _ $ _) = |
228 fun walk ctxt (origin: term list) (conv as {ifc, casec, funcc, letc, ...} : 'a converter) (t as _ $ _) = |
209 let |
229 let |
210 val (f, args) = walk_func t [] |
230 val (f, args) = strip_comb t |
211 val this = (walk ctxt origin conv) |
231 val this = (walk ctxt origin conv) |
212 val _ = (case f of Abs _ => error "Lambdas not supported" | _ => ()) |
232 val _ = (case f of Abs _ => error "Lambdas not supported" | _ => ()) |
213 val wctxt = {ctxt = ctxt, origins = origin, f = this} |
233 val wctxt = {ctxt = ctxt, origins = origin, f = this} |
214 in |
234 in |
215 (if is_if f then |
235 (if is_if f then |
218 | _ => error "Partial applications not supported (if)") |
238 | _ => error "Partial applications not supported (if)") |
219 | _ => error "Internal error: invalid if term") |
239 | _ => error "Internal error: invalid if term") |
220 else if is_case f then casec wctxt f args |
240 else if is_case f then casec wctxt f args |
221 else if is_let f then |
241 else if is_let f then |
222 (case f of (Const (_,lT)) => |
242 (case f of (Const (_,lT)) => |
223 (case args of [exp, t] => |
243 (case args of [exp, t] => |
224 let val (t,nms,Ts) = walk_abs t [] [] in letc wctxt lT exp nms Ts t end |
244 let val (abs,t) = strip_abs t in letc wctxt lT exp abs t end |
225 | _ => error "Partial applications not allowed (let)") |
245 | _ => error "Partial applications not allowed (let)") |
226 | _ => error "Internal error: invalid let term") |
246 | _ => error "Internal error: invalid let term") |
227 else funcc wctxt f args) |
247 else funcc wctxt f args) |
228 end |
248 end |
229 | walk ctxt origin (conv as {constc, ...}) c = |
249 | walk ctxt origin (conv as {constc, ...}) c = |
230 constc {ctxt = ctxt, origins = origin, f = walk ctxt origin conv} c |
250 constc {ctxt = ctxt, origins = origin, f = walk ctxt origin conv} c |
|
251 fun Ifunc (wctxt: term wctxt) t args = list_comb (#f wctxt t,map (#f wctxt) args) |
|
252 val Iconst = K I |
|
253 fun Iif (wctxt: term wctxt) T cond tt tf = |
|
254 Const (@{const_name "HOL.If"}, T) $ (#f wctxt cond) $ (#f wctxt tt) $ (#f wctxt tf) |
|
255 fun Icase (wctxt: term wctxt) t cs = list_comb (#f wctxt t,map (#f wctxt) cs) |
|
256 fun Ilet (wctxt: term wctxt) lT exp abs t = |
|
257 Const (@{const_name "HOL.Let"},lT) $ (#f wctxt exp) $ build_abs (#f wctxt t) abs |
231 |
258 |
232 (* 1. Fix all terms *) |
259 (* 1. Fix all terms *) |
233 (* Exchange Var in types and terms to Free *) |
260 (* Exchange Var in types and terms to Free *) |
234 fun fixTerms (Var(ixn,T)) = Free (fst ixn, T) |
261 fun freeTerms (Var(ixn,T)) = Free (fst ixn, T) |
235 | fixTerms t = t |
262 | freeTerms t = t |
236 fun fixTypes (TVar ((t, _), T)) = TFree (t, T) |
263 fun freeTypes (TVar ((t, _), T)) = TFree (t, T) |
237 | fixTypes t = t |
264 | freeTypes t = t |
238 |
265 |
239 fun noFun (Type ("fun", _)) = error "Functions in datatypes are not allowed in case constructions" |
266 fun noFun (Type ("fun", _)) = error "Functions in datatypes are not allowed in case constructions" |
240 | noFun T = T |
267 | noFun T = T |
241 fun casecBuildBounds n t = if n > 0 then casecBuildBounds (n-1) (t $ (Bound (n-1))) else t |
268 fun casecBuildBounds n t = if n > 0 then casecBuildBounds (n-1) (t $ (Bound (n-1))) else t |
242 fun casecAbs wctxt n (Type ("fun",[T,Tr])) (Abs (v,Ta,t)) = (map_atyps noFun T; Abs (v,Ta,casecAbs wctxt n Tr t)) |
269 fun casecAbs wctxt n (Type ("fun",[T,Tr])) (Abs (v,Ta,t)) = (map_atyps noFun T; Abs (v,Ta,casecAbs wctxt n Tr t)) |
245 | casecAbs wctxt n _ t = (#f wctxt) (casecBuildBounds n (Term.incr_bv n 0 t)) |
272 | casecAbs wctxt n _ t = (#f wctxt) (casecBuildBounds n (Term.incr_bv n 0 t)) |
246 fun fixCasecCases _ _ [t] = [t] |
273 fun fixCasecCases _ _ [t] = [t] |
247 | fixCasecCases wctxt (Type (_,[T,Tr])) (t::ts) = casecAbs wctxt 0 T t :: fixCasecCases wctxt Tr ts |
274 | fixCasecCases wctxt (Type (_,[T,Tr])) (t::ts) = casecAbs wctxt 0 T t :: fixCasecCases wctxt Tr ts |
248 | fixCasecCases _ _ _ = error "Internal error: invalid case types/terms" |
275 | fixCasecCases _ _ _ = error "Internal error: invalid case types/terms" |
249 fun fixCasec wctxt (t as Const (_,T)) args = |
276 fun fixCasec wctxt (t as Const (_,T)) args = |
250 (check_args "cases" (t,args); build_func (t,fixCasecCases wctxt T args)) |
277 (check_args "cases" (t,args); list_comb (t,fixCasecCases wctxt T args)) |
251 | fixCasec _ _ _ = error "Internal error: invalid case term" |
278 | fixCasec _ _ _ = error "Internal error: invalid case term" |
252 |
279 |
253 fun fixPartTerms ctxt (term: term list) t = |
280 fun shortFunc fixedNum (Const (nm,T)) = |
|
281 Const (nm,T |> strip_type |>> drop fixedNum |> (op --->)) |
|
282 | shortFunc _ _ = error "Internal error: Invalid term" |
|
283 fun shortApp fixedNum (c, args) = |
|
284 (shortFunc fixedNum c, drop fixedNum args) |
|
285 fun shortOriginFunc (term: term list) fixedNum (f as (c as Const (_,_), _)) = |
|
286 if contains' const_comp term c then shortApp fixedNum f else f |
|
287 | shortOriginFunc _ _ t = t |
|
288 fun fixTerms ctxt (term: term list) (fixedNum: int) (t as pT $ (eq $ l $ r)) = |
254 let |
289 let |
255 val _ = check_args "args" (walk_func (get_l t) []) |
290 val _ = check_args "args" (strip_comb (get_l t)) |
256 in |
291 val l' = shortApp fixedNum (strip_comb l) |> list_comb |
257 map_r (walk ctxt term { |
292 val shortOriginFunc' = shortOriginFunc (term |> map (fst o strip_comb)) fixedNum |
|
293 val r' = walk ctxt term { |
258 funcc = (fn wctxt => fn t => fn args => |
294 funcc = (fn wctxt => fn t => fn args => |
259 (check_args "func" (t,args); build_func (t, map (#f wctxt) args))), |
295 (check_args "func" (t,args); (t, map (#f wctxt) args) |> shortOriginFunc' |> list_comb)), |
260 constc = (fn _ => fn c => (case c of Abs _ => error "Lambdas not supported" | _ => c)), |
296 constc = (fn _ => fn c => (case c of Abs _ => error "Lambdas not supported" | _ => c)), |
261 ifc = (fn wctxt => fn T => fn cond => fn tt => fn tf => |
297 ifc = Iif, |
262 ((Const (@{const_name "HOL.If"}, T)) $ (#f wctxt) cond $ ((#f wctxt) tt) $ ((#f wctxt) tf))), |
|
263 casec = fixCasec, |
298 casec = fixCasec, |
264 letc = (fn wctxt => fn expT => fn exp => fn nms => fn Ts => fn t => |
299 letc = (fn wctxt => fn expT => fn exp => fn abs => fn t => |
265 let |
300 let |
266 val f' = if length nms = 0 then |
301 val f' = if length abs = 0 then |
267 (case (#f wctxt) (t$exp) of t$_ => t | _ => error "Internal error: case could not be fixed (let)") |
302 (case (#f wctxt) (t$exp) of t$_ => t | _ => error "Internal error: case could not be fixed (let)") |
268 else (#f wctxt) t |
303 else (#f wctxt) t |
269 in (Const (@{const_name "HOL.Let"},expT) $ ((#f wctxt) exp) $ build_abs f' nms Ts) end) |
304 in (Const (@{const_name "HOL.Let"},expT) $ ((#f wctxt) exp) $ build_abs f' abs) end) |
270 }) t |
305 } r |
|
306 in |
|
307 pT $ (eq $ l' $ r') |
271 end |
308 end |
|
309 | fixTerms _ _ _ _ = error "Internal error: invalid term" |
272 |
310 |
273 (* 2. Check for properties about the function *) |
311 (* 2. Check for properties about the function *) |
274 (* 2.1 Check if function is recursive *) |
312 (* 2.1 Check if function is recursive *) |
275 fun or f (a,b) = f a orelse b |
313 fun or f (a,b) = f a orelse b |
276 fun find_rec ctxt term = (walk ctxt term { |
314 fun find_rec ctxt term = (walk ctxt term { |
280 constc = (K o K) false, |
318 constc = (K o K) false, |
281 ifc = (fn wctxt => fn _ => fn cond => fn tt => fn tf => |
319 ifc = (fn wctxt => fn _ => fn cond => fn tt => fn tf => |
282 (#f wctxt) cond orelse (#f wctxt) tt orelse (#f wctxt) tf), |
320 (#f wctxt) cond orelse (#f wctxt) tt orelse (#f wctxt) tf), |
283 casec = (fn wctxt => fn t => fn cs => |
321 casec = (fn wctxt => fn t => fn cs => |
284 (#f wctxt) t orelse List.foldr (or (rem_abs (#f wctxt))) false cs), |
322 (#f wctxt) t orelse List.foldr (or (rem_abs (#f wctxt))) false cs), |
285 letc = (fn wctxt => fn _ => fn exp => fn _ => fn _ => fn t => |
323 letc = (fn wctxt => fn _ => fn exp => fn _ => fn t => |
286 (#f wctxt) exp orelse (#f wctxt) t) |
324 (#f wctxt) exp orelse (#f wctxt) t) |
287 }) o get_r |
325 }) o get_r |
288 fun is_rec ctxt (term: term list) = List.foldr (or (find_rec ctxt term)) false |
326 fun is_rec ctxt (term: term list) = List.foldr (or (find_rec ctxt term)) false |
289 |
|
290 (* 2.2 Check for higher-order function if original function is used *) |
|
291 fun find_used' ctxt term t T_t = |
|
292 let |
|
293 val (ident, _) = walk_func (get_l t) [] |
|
294 val (T_ident, T_args) = walk_func (get_l T_t) [] |
|
295 |
|
296 fun filter_passed [] = [] |
|
297 | filter_passed ((f as Free (_, Type ("Product_Type.prod",[Type ("fun",_), Type ("fun", _)])))::args) = |
|
298 f :: filter_passed args |
|
299 | filter_passed (_::args) = filter_passed args |
|
300 val frees' = (walk ctxt term { |
|
301 funcc = (fn wctxt => fn t => fn args => |
|
302 (case t of (Const ("Product_Type.prod.snd", _)) => [] |
|
303 | _ => (if t = T_ident then [] else filter_passed args) |
|
304 @ List.foldr (fn (l,r) => (#f wctxt) l @ r) [] args)), |
|
305 constc = (K o K) [], |
|
306 ifc = (fn wctxt => fn _ => fn cond => fn tt => fn tf => (#f wctxt) cond @ (#f wctxt) tt @ (#f wctxt) tf), |
|
307 casec = (fn wctxt => fn _ => fn cs => List.foldr (fn (l,r) => (#f wctxt) l @ r) [] cs), |
|
308 letc = (fn wctxt => fn _ => fn exp => fn _ => fn _ => fn t => (#f wctxt) exp @ (#f wctxt) t) |
|
309 }) (get_r T_t) |
|
310 fun build _ [] _ = false |
|
311 | build i (a::args) item = |
|
312 (if item = (ident,i) then contains frees' a else build (i+1) args item) |
|
313 in |
|
314 build 0 T_args |
|
315 end |
|
316 fun find_used ctxt term terms T_terms = |
|
317 ListPair.zip (terms, T_terms) |
|
318 |> List.map (fn (t, T_t) => find_used' ctxt term t T_t) |
|
319 |> List.foldr (fn (f,g) => fn item => f item orelse g item) (K false) |
|
320 |
|
321 |
327 |
322 (* 3. Convert equations *) |
328 (* 3. Convert equations *) |
323 (* Some Helper *) |
329 (* Some Helper *) |
324 val plusTyp = @{typ "nat => nat => nat"} |
330 val plusTyp = @{typ "nat => nat => nat"} |
325 fun plus (SOME a) (SOME b) = SOME (Const (@{const_name "Groups.plus"}, plusTyp) $ a $ b) |
331 fun plus (SOME a) (SOME b) = SOME (Const (@{const_name "Groups.plus"}, plusTyp) $ a $ b) |
330 | opt_term (SOME t) = t |
336 | opt_term (SOME t) = t |
331 fun use_origin (Free (nm, T as Type ("fun",_))) = HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T))) |
337 fun use_origin (Free (nm, T as Type ("fun",_))) = HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T))) |
332 | use_origin t = t |
338 | use_origin t = t |
333 |
339 |
334 (* Conversion of function term *) |
340 (* Conversion of function term *) |
335 fun fun_to_time ctxt orig_used _ (origin: term list) (func as Const (nm,T)) = |
341 fun fun_to_time' ctxt (origin: term list) second (func as Const (nm,T)) = |
336 let |
342 let |
337 val used' = used_for_const orig_used func |
343 val origin' = map (fst o strip_comb) origin |
338 in |
344 in |
339 if contains' const_comp origin func then SOME (Free (func |> Term.term_name |> fun_name_to_time ctxt true, change_typ' used' T)) else |
345 if contains' const_comp origin' func then SOME (Free (func |> Term.term_name |> fun_name_to_time' ctxt true second, change_typ T)) else |
340 if Zero_Funcs.is_zero (Proof_Context.theory_of ctxt) (nm,T) then NONE else |
346 if Zero_Funcs.is_zero (Proof_Context.theory_of ctxt) (nm,T) then NONE else |
341 time_term ctxt false func |
347 time_term ctxt false func |
342 end |
348 end |
343 | fun_to_time ctxt _ used _ (f as Free (nm,T)) = SOME ( |
349 | fun_to_time' _ _ _ (Free (nm,T)) = |
344 if used f then HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T))) |
350 SOME (HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T)))) |
345 else Free (fun_name_to_time ctxt false nm, change_typ T) |
351 | fun_to_time' _ _ _ _ = error "Internal error: invalid function to convert" |
346 ) |
352 fun fun_to_time context origin func = fun_to_time' context origin false func |
347 | fun_to_time _ _ _ _ _ = error "Internal error: invalid function to convert" |
|
348 |
353 |
349 (* Convert arguments of left side of a term *) |
354 (* Convert arguments of left side of a term *) |
350 fun conv_arg ctxt used _ (f as Free (nm,T as Type("fun",_))) = |
355 fun conv_arg _ (Free (nm,T as Type("fun",_))) = |
351 if used f then Free (nm, HOLogic.mk_prodT (T, change_typ' (K false) T)) |
356 Free (nm, HOLogic.mk_prodT (T, change_typ' (K false) T)) |
352 else Free (fun_name_to_time ctxt false nm, change_typ' (K false) T) |
357 | conv_arg _ x = x |
353 | conv_arg _ _ _ x = x |
358 fun conv_args ctxt = map (conv_arg ctxt) |
354 fun conv_args ctxt used origin = map (conv_arg ctxt used origin) |
|
355 |
359 |
356 (* Handle function calls *) |
360 (* Handle function calls *) |
357 fun build_zero (Type ("fun", [T, R])) = Abs ("uu", T, build_zero R) |
361 fun build_zero (Type ("fun", [T, R])) = Abs ("uu", T, build_zero R) |
358 | build_zero _ = zero |
362 | build_zero _ = zero |
359 fun funcc_use_origin used (f as Free (nm, T as Type ("fun",_))) = |
363 fun funcc_use_origin (Free (nm, T as Type ("fun",_))) = |
360 if used f then HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T))) |
364 HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T))) |
361 else error "Internal error: Error in used detection" |
365 | funcc_use_origin t = t |
362 | funcc_use_origin _ t = t |
366 fun funcc_conv_arg _ _ (t as (_ $ _)) = map_aterms funcc_use_origin t |
363 fun funcc_conv_arg _ used _ (t as (_ $ _)) = map_aterms (funcc_use_origin used) t |
367 | funcc_conv_arg _ u (Free (nm, T as Type ("fun",_))) = |
364 | funcc_conv_arg wctxt used u (f as Free (nm, T as Type ("fun",_))) = |
368 if u then Free (nm, HOLogic.mk_prodT (T, change_typ T)) |
365 if used f then |
369 else HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T))) |
366 if u then Free (nm, HOLogic.mk_prodT (T, change_typ T)) |
370 | funcc_conv_arg wctxt true (f as Const (_,T as Type ("fun",_))) = |
367 else HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T))) |
|
368 else Free (fun_name_to_time (#ctxt wctxt) false nm, change_typ T) |
|
369 | funcc_conv_arg wctxt _ true (f as Const (_,T as Type ("fun",_))) = |
|
370 (Const (@{const_name "Product_Type.Pair"}, |
371 (Const (@{const_name "Product_Type.Pair"}, |
371 Type ("fun", [T,Type ("fun", [change_typ T, HOLogic.mk_prodT (T,change_typ T)])])) |
372 Type ("fun", [T,Type ("fun", [change_typ T, HOLogic.mk_prodT (T,change_typ T)])])) |
372 $ f $ (Option.getOpt (fun_to_time (#ctxt wctxt) (K false) (K false) (#origins wctxt) f, build_zero T))) |
373 $ f $ (Option.getOpt (fun_to_time (#ctxt wctxt) (#origins wctxt) f, build_zero T))) |
373 | funcc_conv_arg wctxt _ false (f as Const (_,T as Type ("fun",_))) = |
374 | funcc_conv_arg wctxt false (f as Const (_,T as Type ("fun",_))) = |
374 Option.getOpt (fun_to_time (#ctxt wctxt) (K false) (K false) (#origins wctxt) f, build_zero T) |
375 Option.getOpt (fun_to_time (#ctxt wctxt) (#origins wctxt) f, build_zero T) |
375 | funcc_conv_arg _ _ _ t = t |
376 | funcc_conv_arg _ _ t = t |
376 |
377 |
377 fun funcc_conv_args _ _ _ [] = [] |
378 fun funcc_conv_args _ _ [] = [] |
378 | funcc_conv_args wctxt used (Type ("fun", [t, ts])) (a::args) = |
379 | funcc_conv_args wctxt (Type ("fun", [t, ts])) (a::args) = |
379 funcc_conv_arg wctxt used (is_Used t) a :: funcc_conv_args wctxt used ts args |
380 funcc_conv_arg wctxt (is_Used t) a :: funcc_conv_args wctxt ts args |
380 | funcc_conv_args _ _ _ _ = error "Internal error: Non matching type" |
381 | funcc_conv_args _ _ _ = error "Internal error: Non matching type" |
381 fun funcc orig_used used wctxt func args = |
382 fun funcc wctxt func args = |
382 let |
383 let |
383 fun get_T (Free (_,T)) = T |
384 fun get_T (Free (_,T)) = T |
384 | get_T (Const (_,T)) = T |
385 | get_T (Const (_,T)) = T |
385 | get_T (_ $ (Free (_,Type (_, [_, T])))) = T (* Case of snd was constructed *) |
386 | get_T (_ $ (Free (_,Type (_, [_, T])))) = T (* Case of snd was constructed *) |
386 | get_T _ = error "Internal error: Forgotten type" |
387 | get_T _ = error "Internal error: Forgotten type" |
387 in |
388 in |
388 List.foldr (I #-> plus) |
389 List.foldr (I #-> plus) |
389 (case fun_to_time (#ctxt wctxt) orig_used used (#origins wctxt) func |
390 (case fun_to_time (#ctxt wctxt) (#origins wctxt) func |
390 of SOME t => SOME (build_func (t,funcc_conv_args wctxt used (get_T t) args)) |
391 of SOME t => SOME (list_comb (t,funcc_conv_args wctxt (get_T t) args)) |
391 | NONE => NONE) |
392 | NONE => NONE) |
392 (map (#f wctxt) args) |
393 (map (#f wctxt) args) |
393 end |
394 end |
394 |
395 |
395 (* Handle case terms *) |
396 (* Handle case terms *) |
431 (SOME ((Const (@{const_name "HOL.If"}, @{typ "bool \<Rightarrow> nat \<Rightarrow> nat \<Rightarrow> nat"})) $ rcond $ (opt_term tt) $ (opt_term ft)))) |
432 (SOME ((Const (@{const_name "HOL.If"}, @{typ "bool \<Rightarrow> nat \<Rightarrow> nat \<Rightarrow> nat"})) $ rcond $ (opt_term tt) $ (opt_term ft)))) |
432 end |
433 end |
433 |
434 |
434 fun letc_change_typ (Type ("fun", [T1, Type ("fun", [T2, _])])) = (Type ("fun", [T1, Type ("fun", [change_typ T2, HOLogic.natT])])) |
435 fun letc_change_typ (Type ("fun", [T1, Type ("fun", [T2, _])])) = (Type ("fun", [T1, Type ("fun", [change_typ T2, HOLogic.natT])])) |
435 | letc_change_typ _ = error "Internal error: invalid let type" |
436 | letc_change_typ _ = error "Internal error: invalid let type" |
436 fun letc wctxt expT exp nms Ts t = |
437 fun letc wctxt expT exp abs t = |
437 plus (#f wctxt exp) |
438 plus (#f wctxt exp) |
438 (if length nms = 0 (* In case of "length nms = 0" the expression got reducted |
439 (if length abs = 0 (* In case of "length nms = 0" the expression got reducted |
439 Here we need Bound 0 to gain non-partial application *) |
440 Here we need Bound 0 to gain non-partial application *) |
440 then (case #f wctxt (t $ Bound 0) of SOME (t' $ Bound 0) => |
441 then (case #f wctxt (t $ Bound 0) of SOME (t' $ Bound 0) => |
441 (SOME (Const (@{const_name "HOL.Let"}, letc_change_typ expT) $ (map_aterms use_origin exp) $ t')) |
442 (SOME (Const (@{const_name "HOL.Let"}, letc_change_typ expT) $ (map_aterms use_origin exp) $ t')) |
442 (* Expression is not used and can therefore let be dropped *) |
443 (* Expression is not used and can therefore let be dropped *) |
443 | SOME t' => SOME t' |
444 | SOME t' => SOME t' |
444 | NONE => NONE) |
445 | NONE => NONE) |
445 else (case #f wctxt t of SOME t' => |
446 else (case #f wctxt t of SOME t' => |
446 SOME (if Term.is_dependent t' then Const (@{const_name "HOL.Let"}, letc_change_typ expT) $ (map_aterms use_origin exp) $ build_abs t' nms Ts |
447 SOME (if Term.is_dependent t' then Const (@{const_name "HOL.Let"}, letc_change_typ expT) $ (map_aterms use_origin exp) $ build_abs t' abs |
447 else Term.subst_bounds([exp],t')) |
448 else Term.subst_bounds([exp],t')) |
448 | NONE => NONE)) |
449 | NONE => NONE)) |
449 |
450 |
450 (* The converter for timing functions given to the walker *) |
451 (* The converter for timing functions given to the walker *) |
451 fun converter orig_used used : term option converter = { |
452 val converter : term option converter = { |
452 constc = fn _ => fn t => |
453 constc = fn _ => fn t => |
453 (case t of Const ("HOL.undefined", _) => SOME (Const ("HOL.undefined", @{typ "nat"})) |
454 (case t of Const ("HOL.undefined", _) => SOME (Const ("HOL.undefined", @{typ "nat"})) |
454 | _ => NONE), |
455 | _ => NONE), |
455 funcc = (funcc orig_used used), |
456 funcc = funcc, |
456 ifc = ifc, |
457 ifc = ifc, |
457 casec = casec, |
458 casec = casec, |
458 letc = letc |
459 letc = letc |
459 } |
460 } |
460 fun top_converter is_rec _ _ = opt_term o (fn exp => plus exp (if is_rec then SOME one else NONE)) |
461 fun top_converter is_rec _ _ = opt_term o (fn exp => plus exp (if is_rec then SOME one else NONE)) |
461 |
462 |
462 (* Use converter to convert right side of a term *) |
463 (* Use converter to convert right side of a term *) |
463 fun to_time ctxt origin is_rec orig_used used term = |
464 fun to_time ctxt origin is_rec term = |
464 top_converter is_rec ctxt origin (walk ctxt origin (converter orig_used used) term) |
465 top_converter is_rec ctxt origin (walk ctxt origin converter term) |
465 |
466 |
466 (* Converts a term to its running time version *) |
467 (* Converts a term to its running time version *) |
467 fun convert_term ctxt (origin: term list) is_rec orig_used (pT $ (Const (eqN, _) $ l $ r)) = |
468 fun convert_term ctxt (origin: term list) is_rec (pT $ (Const (eqN, _) $ l $ r)) = |
468 let |
469 let |
469 val (l' as (l_const, l_params)) = walk_func l [] |
470 val (l_const, l_params) = strip_comb l |
470 val used = |
471 in |
471 l_const |
472 pT |
472 |> used_for_const orig_used |
473 $ (Const (eqN, @{typ "nat \<Rightarrow> nat \<Rightarrow> bool"}) |
473 |> (fn f => fn n => f (index l_params n)) |
474 $ (list_comb (l_const |> fun_to_time ctxt origin |> Option.valOf, l_params |> conv_args ctxt)) |
474 in |
475 $ (to_time ctxt origin is_rec r)) |
475 pT |
476 end |
476 $ (Const (eqN, @{typ "nat \<Rightarrow> nat \<Rightarrow> bool"}) |
477 | convert_term _ _ _ _ = error "Internal error: invalid term to convert" |
477 $ (build_func (l' |>> (fun_to_time ctxt orig_used used origin) |>> Option.valOf ||> conv_args ctxt used origin)) |
478 |
478 $ (to_time ctxt origin is_rec orig_used used r)) |
479 (* 3.5 Support for locales *) |
479 end |
480 fun replaceFstSndFree ctxt (origin: term list) (rfst: term -> term) (rsnd: term -> term) = |
480 | convert_term _ _ _ _ _ = error "Internal error: invalid term to convert" |
481 (walk ctxt origin { |
|
482 funcc = fn wctxt => fn t => fn args => |
|
483 case args of |
|
484 (f as Free _)::args => |
|
485 (case t of |
|
486 Const ("Product_Type.prod.fst", _) => |
|
487 list_comb (rfst (t $ f), map (#f wctxt) args) |
|
488 | Const ("Product_Type.prod.snd", _) => |
|
489 list_comb (rsnd (t $ f), map (#f wctxt) args) |
|
490 | t => list_comb (t, map (#f wctxt) (f :: args))) |
|
491 | args => list_comb (t, map (#f wctxt) args), |
|
492 constc = Iconst, |
|
493 ifc = Iif, |
|
494 casec = Icase, |
|
495 letc = Ilet |
|
496 }) |
481 |
497 |
482 (* 4. Tactic to prove "f_dom n" *) |
498 (* 4. Tactic to prove "f_dom n" *) |
483 fun time_dom_tac ctxt induct_rule domintros = |
499 fun time_dom_tac ctxt induct_rule domintros = |
484 (Induction.induction_tac ctxt true [] [[]] [] (SOME [induct_rule]) [] |
500 (Induction.induction_tac ctxt true [] [[]] [] (SOME [induct_rule]) [] |
485 THEN_ALL_NEW ((K (auto_tac ctxt)) THEN' (fn i => FIRST' ( |
501 THEN_ALL_NEW ((K (auto_tac ctxt)) THEN' (fn i => FIRST' ( |
493 |> map #rules |
509 |> map #rules |
494 |> map (map Thm.prop_of) |
510 |> map (map Thm.prop_of) |
495 handle Empty => error "Function or terms of function not found" |
511 handle Empty => error "Function or terms of function not found" |
496 in |
512 in |
497 equations |
513 equations |
498 |> filter (fn ts => typ_comp (ts |> hd |> get_l |> walk_func' |> fst |> dest_Const |> snd) (term |> dest_Const |> snd)) |
514 |> filter (List.exists |
|
515 (fn t => typ_comp (t |> get_l |> strip_comb |> fst |> dest_Const |> snd) (term |> strip_comb |> fst |> dest_Const |> snd))) |
499 |> hd |
516 |> hd |
500 end |
517 end |
501 |
518 |
|
519 (* 5. Check for higher-order function if original function is used \<rightarrow> find simplifications *) |
|
520 fun find_used' T_t = |
|
521 let |
|
522 val (T_ident, T_args) = strip_comb (get_l T_t) |
|
523 |
|
524 fun filter_passed [] = [] |
|
525 | filter_passed ((f as Free (_, Type ("Product_Type.prod",[Type ("fun",_), Type ("fun", _)])))::args) = |
|
526 f :: filter_passed args |
|
527 | filter_passed (_::args) = filter_passed args |
|
528 val frees = (walk @{context} [] { |
|
529 funcc = (fn wctxt => fn t => fn args => |
|
530 (case t of (Const ("Product_Type.prod.snd", _)) => [] |
|
531 | _ => (if t = T_ident then [] else filter_passed args) |
|
532 @ List.foldr (fn (l,r) => (#f wctxt) l @ r) [] args)), |
|
533 constc = (K o K) [], |
|
534 ifc = (fn wctxt => fn _ => fn cond => fn tt => fn tf => (#f wctxt) cond @ (#f wctxt) tt @ (#f wctxt) tf), |
|
535 casec = (fn wctxt => fn _ => fn cs => List.foldr (fn (l,r) => (#f wctxt) l @ r) [] cs), |
|
536 letc = (fn wctxt => fn _ => fn exp => fn _ => fn t => (#f wctxt) exp @ (#f wctxt) t) |
|
537 }) (get_r T_t) |
|
538 fun build _ [] = [] |
|
539 | build i (a::args) = |
|
540 (if contains frees a then [(T_ident,i)] else []) @ build (i+1) args |
|
541 in |
|
542 build 0 T_args |
|
543 end |
|
544 fun find_simplifyble ctxt term terms = |
|
545 let |
|
546 val used = |
|
547 terms |
|
548 |> List.map find_used' |
|
549 |> List.foldr (op @) [] |
|
550 val change = |
|
551 Option.valOf o fun_to_time ctxt term |
|
552 fun detect t i (Type ("fun",_)::args) = |
|
553 (if contains used (change t,i) then [] else [i]) @ detect t (i+1) args |
|
554 | detect t i (_::args) = detect t (i+1) args |
|
555 | detect _ _ [] = [] |
|
556 in |
|
557 map (fn t => t |> type_of |> strip_type |> fst |> detect t 0) term |
|
558 end |
|
559 |
|
560 fun define_simp' term simplifyable ctxt = |
|
561 let |
|
562 val base_name = case Named_Target.locale_of ctxt of |
|
563 NONE => ctxt |> Proof_Context.theory_of |> Context.theory_base_name |
|
564 | SOME nm => nm |
|
565 |
|
566 val orig_name = term |> dest_Const_name |> split_name |> List.last |
|
567 val red_name = fun_name_to_time ctxt false orig_name |
|
568 val name = fun_name_to_time' ctxt true true orig_name |
|
569 val full_name = base_name ^ "." ^ name |
|
570 val def_name = red_name ^ "_def" |
|
571 val def = Binding.name def_name |
|
572 |
|
573 val canon = Syntax.read_term (Local_Theory.exit ctxt) name |> strip_comb |
|
574 val canonFrees = canon |> snd |
|
575 val canonType = canon |> fst |> dest_Const_type |> strip_type |> fst |> take (length canonFrees) |
|
576 |
|
577 val types = term |> dest_Const_type |> strip_type |> fst |
|
578 val vars = Variable.variant_fixes (map (K "") types) ctxt |> fst |
|
579 fun l_typs' i ((T as (Type ("fun",_)))::types) = |
|
580 (if contains simplifyable i |
|
581 then change_typ T |
|
582 else HOLogic.mk_prodT (T,change_typ T)) |
|
583 :: l_typs' (i+1) types |
|
584 | l_typs' i (T::types) = T :: l_typs' (i+1) types |
|
585 | l_typs' _ [] = [] |
|
586 val l_typs = l_typs' 0 types |
|
587 val lhs = |
|
588 List.foldl (fn ((v,T),t) => t $ Free (v,T)) (Free (red_name,l_typs ---> HOLogic.natT)) (ListPair.zip (vars,l_typs)) |
|
589 fun fixType (TFree _) = HOLogic.natT |
|
590 | fixType T = T |
|
591 fun fixUnspecified T = T |> strip_type ||> fixType |> (op --->) |
|
592 fun r_terms' i (v::vars) ((T as (Type ("fun",_)))::types) = |
|
593 (if contains simplifyable i |
|
594 then HOLogic.mk_prod (Const ("HOL.undefined", fixUnspecified T), Free (v,change_typ T)) |
|
595 else Free (v,HOLogic.mk_prodT (T,change_typ T))) |
|
596 :: r_terms' (i+1) vars types |
|
597 | r_terms' i (v::vars) (T::types) = Free (v,T) :: r_terms' (i+1) vars types |
|
598 | r_terms' _ _ _ = [] |
|
599 val r_terms = r_terms' 0 vars types |
|
600 val full_type = (r_terms |> map (type_of) ---> HOLogic.natT) |
|
601 val full = list_comb (Const (full_name,canonType ---> full_type), canonFrees) |
|
602 val rhs = list_comb (full, r_terms) |
|
603 val eq = (lhs, rhs) |> HOLogic.mk_eq |> HOLogic.mk_Trueprop |
|
604 val _ = Pretty.writeln (Pretty.block [Pretty.str "Defining simplified version:\n", |
|
605 Syntax.pretty_term ctxt eq]) |
|
606 |
|
607 val (_, ctxt') = Specification.definition NONE [] [] ((def, []), eq) ctxt |
|
608 |
|
609 in |
|
610 ((def_name, orig_name), ctxt') |
|
611 end |
|
612 fun define_simp simpables ctxt = |
|
613 let |
|
614 fun cond ((term,simplifyable),(defs,ctxt)) = |
|
615 define_simp' term simplifyable ctxt |>> (fn def => def :: defs) |
|
616 in |
|
617 List.foldr cond ([], ctxt) simpables |
|
618 end |
|
619 |
|
620 |
|
621 fun replace from to = |
|
622 map (map_aterms (fn t => if t = from then to else t)) |
|
623 fun replaceAll [] = I |
|
624 | replaceAll ((from,to)::xs) = replaceAll xs o replace from to |
|
625 fun calculateSimplifications ctxt T_terms term simpables = |
|
626 let |
|
627 (* Show where a simplification can take place *) |
|
628 fun reportReductions (t,(i::is)) = |
|
629 (Pretty.writeln (Pretty.str |
|
630 ((Term.term_name t |> fun_name_to_time ctxt true) |
|
631 ^ " can be simplified because only the time-function component of parameter " |
|
632 ^ (Int.toString (i + 1)) ^ " is used. ")); |
|
633 reportReductions (t,is)) |
|
634 | reportReductions (_,[]) = () |
|
635 val _ = simpables |
|
636 |> map reportReductions |
|
637 |
|
638 (* Register definitions for simplified function *) |
|
639 val (reds, ctxt) = define_simp simpables ctxt |
|
640 |
|
641 fun genRetype (Const (nm,T),is) = |
|
642 let |
|
643 val T_name = fun_name_to_time ctxt true nm |> split_name |> List.last |
|
644 val from = Free (T_name,change_typ T) |
|
645 val to = Free (T_name,change_typ' (not o contains is) T) |
|
646 in |
|
647 (from,to) |
|
648 end |
|
649 | genRetype _ = error "Internal error: invalid term" |
|
650 val retyping = map genRetype simpables |
|
651 |
|
652 fun replaceArgs (pT $ (eq $ l $ r)) = |
|
653 let |
|
654 val (t,params) = strip_comb l |
|
655 fun match (Const (f_nm,_),_) = |
|
656 (fun_name_to_time ctxt true f_nm |> Long_Name.base_name) = (dest_Free t |> fst) |
|
657 | match _ = false |
|
658 val simps = List.find match simpables |> Option.valOf |> snd |
|
659 |
|
660 fun dest_Prod_snd (Free (nm, Type (_, [_, T2]))) = |
|
661 Free (fun_name_to_time ctxt false nm, T2) |
|
662 | dest_Prod_snd _ = error "Internal error: Argument is not a pair" |
|
663 fun rep _ [] = ([],[]) |
|
664 | rep i (x::xs) = |
|
665 let |
|
666 val (rs,args) = rep (i+1) xs |
|
667 in |
|
668 if contains simps i |
|
669 then (x::rs,dest_Prod_snd x::args) |
|
670 else (rs,x::args) |
|
671 end |
|
672 val (rs,params) = rep 0 params |
|
673 fun fFst _ = error "Internal error: Invalid term to simplify" |
|
674 fun fSnd (t as (Const _ $ f)) = |
|
675 (if contains rs f |
|
676 then dest_Prod_snd f |
|
677 else t) |
|
678 | fSnd t = t |
|
679 in |
|
680 (pT $ (eq |
|
681 $ (list_comb (t,params)) |
|
682 $ (replaceFstSndFree ctxt term fFst fSnd r |
|
683 |> (fn t => replaceAll (map (fn t => (t,dest_Prod_snd t)) rs) [t]) |
|
684 |> hd |
|
685 ) |
|
686 )) |
|
687 end |
|
688 | replaceArgs _ = error "Internal error: Invalid term" |
|
689 |
|
690 (* Calculate reduced terms *) |
|
691 val T_terms_red = T_terms |
|
692 |> replaceAll retyping |
|
693 |> map replaceArgs |
|
694 |
|
695 val _ = print_lemma ctxt reds T_terms_red |
|
696 val _ = |
|
697 Pretty.writeln (Pretty.str "If you do not want the simplified T function, use \"time_fun [no_simp]\"") |
|
698 in |
|
699 ctxt |
|
700 end |
|
701 |
502 (* Register timing function of a given function *) |
702 (* Register timing function of a given function *) |
503 fun reg_time_func (lthy: local_theory) (term: term list) (terms: term list) print = |
703 fun reg_time_func (lthy: local_theory) (term: term list) (terms: term list) print simp = |
504 let |
704 let |
505 val _ = |
705 val _ = |
506 case time_term lthy true (hd term) |
706 case time_term lthy true (hd term) |
507 handle (ERROR _) => NONE |
707 handle (ERROR _) => NONE |
508 of SOME _ => error ("Timing function already declared: " ^ (Term.term_name (hd term))) |
708 of SOME _ => error ("Timing function already declared: " ^ (Term.term_name (hd term))) |
509 | NONE => () |
709 | NONE => () |
510 |
710 |
|
711 (* Number of terms fixed by locale *) |
|
712 val fixedNum = term |
|
713 |> hd |
|
714 |> strip_comb |> snd |
|
715 |> length |
|
716 |
511 (* 1. Fix all terms *) |
717 (* 1. Fix all terms *) |
512 (* Exchange Var in types and terms to Free and check constraints *) |
718 (* Exchange Var in types and terms to Free and check constraints *) |
513 val terms = map |
719 val terms = map |
514 (map_aterms fixTerms |
720 (map_aterms freeTerms |
515 #> map_types (map_atyps fixTypes) |
721 #> map_types (map_atyps freeTypes) |
516 #> fixPartTerms lthy term) |
722 #> fixTerms lthy term fixedNum) |
517 terms |
723 terms |
|
724 val fixedFrees = (hd term) |> strip_comb |> snd |> take fixedNum |
|
725 val fixedFreesNames = map (fst o dest_Free) fixedFrees |
|
726 val term = map (shortFunc fixedNum o fst o strip_comb) term |
|
727 |
518 |
728 |
519 (* 2. Find properties about the function *) |
729 (* 2. Find properties about the function *) |
520 (* 2.1 Check if function is recursive *) |
730 (* 2.1 Check if function is recursive *) |
521 val is_rec = is_rec lthy term terms |
731 val is_rec = is_rec lthy term terms |
522 |
732 |
523 (* 3. Convert every equation |
733 (* 3. Convert every equation |
524 - Change type of toplevel equation from _ \<Rightarrow> _ \<Rightarrow> bool to nat \<Rightarrow> nat \<Rightarrow> bool |
734 - Change type of toplevel equation from _ \<Rightarrow> _ \<Rightarrow> bool to nat \<Rightarrow> nat \<Rightarrow> bool |
525 - On left side change name of function to timing function |
735 - On left side change name of function to timing function |
526 - Convert right side of equation with conversion schema |
736 - Convert right side of equation with conversion schema |
527 *) |
737 *) |
528 fun convert used = map (convert_term lthy term is_rec used) |
738 fun fFst (t as (Const (_,T) $ Free (nm,_))) = |
529 fun repeat T_terms = |
739 (if contains fixedFreesNames nm |
|
740 then Free (nm,strip_type T |>> tl |> (op --->)) |
|
741 else t) |
|
742 | fFst t = t |
|
743 fun fSnd (t as (Const (_,T) $ Free (nm,_))) = |
|
744 (if contains fixedFreesNames nm |
|
745 then Free (fun_name_to_time lthy false nm,strip_type T |>> tl |> (op --->)) |
|
746 else t) |
|
747 | fSnd t = t |
|
748 val T_terms = map (convert_term lthy term is_rec) terms |
|
749 |> map (map_r (replaceFstSndFree lthy term fFst fSnd)) |
|
750 |
|
751 val simpables = (if simp |
|
752 then find_simplifyble lthy term T_terms |
|
753 else map (K []) term) |
|
754 |> (fn s => ListPair.zip (term,s)) |
|
755 (* Determine if something is simpable, if so rename everything *) |
|
756 val simpable = simpables |> map snd |> exists (not o null) |
|
757 (* Rename to secondary if simpable *) |
|
758 fun genRename (t,_) = |
530 let |
759 let |
531 val orig_used = find_used lthy term terms T_terms |
760 val old = fun_to_time lthy term t |> Option.valOf |
532 val T_terms' = convert orig_used terms |
761 val new = fun_to_time' lthy term true t |> Option.valOf |
533 in |
762 in |
534 if T_terms' <> T_terms then repeat T_terms' else T_terms' |
763 (old,new) |
535 end |
764 end |
536 val T_terms = repeat (convert (K true) terms) |
765 val can_T_terms = if simpable |
537 val orig_used = find_used lthy term terms T_terms |
766 then replaceAll (map genRename simpables) T_terms |
538 |
767 else T_terms |
539 (* 4. Register function and prove termination *) |
768 |
|
769 (* 4. Register function and prove completeness *) |
540 val names = map Term.term_name term |
770 val names = map Term.term_name term |
541 val timing_names = map (fun_name_to_time lthy true) names |
771 val timing_names = map (fun_name_to_time' lthy true simpable) names |
542 val bindings = map (fn nm => (Binding.name nm, NONE, NoSyn)) timing_names |
772 val bindings = map (fn nm => (Binding.name nm, NONE, NoSyn)) timing_names |
543 fun pat_completeness_auto ctxt = |
773 fun pat_completeness_auto ctxt = |
544 Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt |
774 Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt |
545 val specs = map (fn eq => (((Binding.empty, []), eq), [], [])) T_terms |
775 val specs = map (fn eq => (((Binding.empty, []), eq), [], [])) can_T_terms |
546 |
776 |
|
777 (* Context for printing without showing question marks *) |
|
778 val print_ctxt = lthy |
|
779 |> Config.put show_question_marks false |
|
780 |> Config.put show_sorts false (* Change it for debugging *) |
|
781 val print_ctxt = List.foldl (fn (term, ctxt) => Variable.add_fixes_implicit term ctxt) print_ctxt (term @ T_terms) |
|
782 (* Print result if print *) |
|
783 val _ = if not print then () else |
|
784 let |
|
785 val nms = map (dest_Const_name) term |
|
786 val typs = map (dest_Const_type) term |
|
787 in |
|
788 print_timing' print_ctxt { names=nms, terms=terms, typs=typs } |
|
789 { names=timing_names, terms=can_T_terms, typs=map change_typ typs } |
|
790 end |
|
791 |
547 (* For partial functions sequential=true is needed in order to support them |
792 (* For partial functions sequential=true is needed in order to support them |
548 We need sequential=false to support the automatic proof of termination over dom |
793 We need sequential=false to support the automatic proof of termination over dom |
549 *) |
794 *) |
550 fun register seq = |
795 fun register seq = |
551 let |
796 let |
674 val suffix = (if isTypeClass then detect_typ ctxt (hd fterms) else NONE) |
894 val suffix = (if isTypeClass then detect_typ ctxt (hd fterms) else NONE) |
675 in |
895 in |
676 (case suffix of NONE => I | SOME s => Config.put bsuffix ("_" ^ s)) ctxt |
896 (case suffix of NONE => I | SOME s => Config.put bsuffix ("_" ^ s)) ctxt |
677 end |
897 end |
678 |
898 |
|
899 fun check_opts [] = false |
|
900 | check_opts ["no_simp"] = true |
|
901 | check_opts (a::_) = error ("Option " ^ a ^ " is not defined") |
|
902 |
679 (* Convert function into its timing function (called by command) *) |
903 (* Convert function into its timing function (called by command) *) |
680 fun reg_time_fun_cmd (funcs, thms) (ctxt: local_theory) = |
904 fun reg_time_fun_cmd ((opts, funcs), thms) (ctxt: local_theory) = |
681 let |
905 let |
|
906 val no_simp = check_opts opts |
682 val fterms = map (Syntax.read_term ctxt) funcs |
907 val fterms = map (Syntax.read_term ctxt) funcs |
683 val ctxt = set_suffix fterms ctxt |
908 val ctxt = set_suffix fterms ctxt |
684 val (_, ctxt') = reg_and_proove_time_func ctxt fterms |
909 val (_, ctxt') = reg_and_proove_time_func ctxt fterms |
685 (case thms of NONE => get_terms ctxt (hd fterms) |
910 (case thms of NONE => get_terms ctxt (hd fterms) |
686 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) |
911 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) |
687 true |
912 true (not no_simp) |
688 in ctxt' |
913 in ctxt' |
689 end |
914 end |
690 |
915 |
691 (* Convert function into its timing function (called by command) with termination proof provided by user*) |
916 (* Convert function into its timing function (called by command) with termination proof provided by user*) |
692 fun reg_time_function_cmd (funcs, thms) (ctxt: local_theory) = |
917 fun reg_time_function_cmd ((opts, funcs), thms) (ctxt: local_theory) = |
693 let |
918 let |
|
919 val no_simp = check_opts opts |
694 val fterms = map (Syntax.read_term ctxt) funcs |
920 val fterms = map (Syntax.read_term ctxt) funcs |
695 val ctxt = set_suffix fterms ctxt |
921 val ctxt = set_suffix fterms ctxt |
696 val ctxt' = reg_time_func ctxt fterms |
922 val ctxt' = reg_time_func ctxt fterms |
697 (case thms of NONE => get_terms ctxt (hd fterms) |
923 (case thms of NONE => get_terms ctxt (hd fterms) |
698 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) |
924 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) |
699 true |
925 true (not no_simp) |
700 |> snd |
926 |> snd |
701 in ctxt' |
927 in ctxt' |
702 end |
928 end |
703 |
929 |
704 (* Convert function into its timing function (called by command) *) |
930 (* Convert function into its timing function (called by command) *) |
705 fun reg_time_definition_cmd (funcs, thms) (ctxt: local_theory) = |
931 fun reg_time_definition_cmd ((opts, funcs), thms) (ctxt: local_theory) = |
706 let |
932 let |
|
933 val no_simp = check_opts opts |
707 val fterms = map (Syntax.read_term ctxt) funcs |
934 val fterms = map (Syntax.read_term ctxt) funcs |
708 val ctxt = set_suffix fterms ctxt |
935 val ctxt = set_suffix fterms ctxt |
709 val (_, ctxt') = reg_and_proove_time_func ctxt fterms |
936 val (_, ctxt') = reg_and_proove_time_func ctxt fterms |
710 (case thms of NONE => get_terms ctxt (hd fterms) |> check_definition |> map fix_definition |
937 (case thms of NONE => get_terms ctxt (hd fterms) |> check_definition |> map fix_definition |
711 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) |
938 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) |
712 true |
939 true (not no_simp) |
713 in ctxt' |
940 in ctxt' |
714 end |
941 end |
715 |
942 |
716 val parser = (Scan.repeat1 Parse.prop) -- (Scan.option (Parse.keyword_improper "equations" -- Parse.thms1 >> snd)) |
943 val parser = (Parse.opt_attribs >> map (fst o Token.name_of_src)) |
717 |
944 -- Scan.repeat1 Parse.prop |
|
945 -- Scan.option (Parse.keyword_improper "equations" -- Parse.thms1 >> snd) |
|
946 val _ = Toplevel.local_theory |
718 val _ = Outer_Syntax.local_theory @{command_keyword "time_fun"} |
947 val _ = Outer_Syntax.local_theory @{command_keyword "time_fun"} |
719 "Defines runtime function of a function" |
948 "Defines runtime function of a function" |
720 (parser >> reg_time_fun_cmd) |
949 (parser >> reg_time_fun_cmd) |
721 |
950 |
722 val _ = Outer_Syntax.local_theory @{command_keyword "time_function"} |
951 val _ = Outer_Syntax.local_theory @{command_keyword "time_function"} |