# HG changeset patch # User nipkow # Date 1724265676 -7200 # Node ID 0c406b9469ab15678d44e8648a40a9c838b9bb0f # Parent 17d8b3f6d744330eabb0af85b306bff59c6d5005# Parent 7054a1bc8347085be0db11fa8287a75a70121051 merged diff -r 17d8b3f6d744 -r 0c406b9469ab src/HOL/Data_Structures/Define_Time_Function.ML --- a/src/HOL/Data_Structures/Define_Time_Function.ML Wed Aug 21 14:09:44 2024 +0100 +++ b/src/HOL/Data_Structures/Define_Time_Function.ML Wed Aug 21 20:41:16 2024 +0200 @@ -1,13 +1,18 @@ signature TIMING_FUNCTIONS = sig +type 'a wctxt = { + ctxt: local_theory, + origins: term list, + f: term -> 'a +} type 'a converter = { - constc : local_theory -> term list -> (term -> 'a) -> term -> 'a, - funcc : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a, - ifc : local_theory -> term list -> (term -> 'a) -> typ -> term -> term -> term -> 'a, - casec : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a, - letc : local_theory -> term list -> (term -> 'a) -> typ -> term -> string list -> typ list -> term -> 'a -}; + constc : 'a wctxt -> term -> 'a, + funcc : 'a wctxt -> term -> term list -> 'a, + ifc : 'a wctxt -> typ -> term -> term -> term -> 'a, + casec : 'a wctxt -> term -> term list -> 'a, + letc : 'a wctxt -> typ -> term -> string list -> typ list -> term -> 'a +} val walk : local_theory -> term list -> 'a converter -> term -> 'a type pfunc = { names : string list, terms : term list, typs : typ list } @@ -16,10 +21,10 @@ val print_timing': Proof.context -> pfunc -> pfunc -> unit val print_timing: Proof.context -> Function.info -> Function.info -> unit -val reg_and_proove_time_func: theory -> term list -> term list - -> bool -> Function.info * theory -val reg_time_func: theory -> term list -> term list - -> bool -> theory +val reg_and_proove_time_func: local_theory -> term list -> term list + -> bool -> Function.info * local_theory +val reg_time_func: local_theory -> term list -> term list + -> bool -> Function.info * local_theory val time_dom_tac: Proof.context -> thm -> thm list -> int -> tactic @@ -29,17 +34,16 @@ struct (* Configure config variable to adjust the prefix *) val bprefix = Attrib.setup_config_string @{binding "time_prefix"} (K "T_") +(* Configure config variable to adjust the suffix *) +val bsuffix = Attrib.setup_config_string @{binding "time_suffix"} (K "") (* some default values to build terms easier *) val zero = Const (@{const_name "Groups.zero"}, HOLogic.natT) val one = Const (@{const_name "Groups.one"}, HOLogic.natT) (* Extracts terms from function info *) fun terms_of_info (info: Function.info) = -let - val {simps, ...} = info -in - map Thm.prop_of (case simps of SOME s => s | NONE => error "No terms of function found in info") -end; + map Thm.prop_of (case #simps info of SOME s => s + | NONE => error "No terms of function found in info") type pfunc = { names : string list, @@ -49,7 +53,9 @@ fun info_pfunc (info: Function.info): pfunc = let val {defname, fs, ...} = info; - val T = case hd fs of (Const (_,T)) => T | _ => error "Internal error: Invalid info to print" + val T = case hd fs of (Const (_,T)) => T + | (Free (_,T)) => T + | _ => error "Internal error: Invalid info to print" in { names=[Binding.name_of defname], terms=terms_of_info info, typs=[T] } end @@ -82,43 +88,41 @@ fun print_timing ctxt (oinfo: Function.info) (tinfo: Function.info) = print_timing' ctxt (info_pfunc oinfo) (info_pfunc tinfo) -val If_name = @{const_name "HOL.If"} -val Let_name = @{const_name "HOL.Let"} - fun contains l e = exists (fn e' => e' = e) l -fun zip [] [] = [] - | zip (x::xs) (y::ys) = (x, y) :: zip xs ys - | zip _ _ = error "Internal error: Cannot zip lists with differing size" +fun contains' comp l e = exists (comp e) l fun index [] _ = 0 | index (x::xs) el = (if x = el then 0 else 1 + index xs el) fun used_for_const orig_used t i = orig_used (t,i) +(* Split name by . *) +val split_name = String.fields (fn s => s = #".") (* returns true if it's an if term *) -fun is_if (Const (n,_)) = (n = If_name) +fun is_if (Const (@{const_name "HOL.If"},_)) = true | is_if _ = false (* returns true if it's a case term *) -fun is_case (Const (n,_)) = String.isPrefix "case_" (List.last (String.fields (fn s => s = #".") n)) +fun is_case (Const (n,_)) = n |> split_name |> List.last |> String.isPrefix "case_" | is_case _ = false (* returns true if it's a let term *) -fun is_let (Const (n,_)) = (n = Let_name) +fun is_let (Const (@{const_name "HOL.Let"},_)) = true | is_let _ = false (* change type of original function to new type (_ \ ... \ _ to _ \ ... \ nat) - and replace all function arguments f with (t*T_f) *) -fun change_typ' used i (Type ("fun", [T1, T2])) = - Type ("fun", [check_for_fun' (used i) T1, change_typ' used (i+1) T2]) - | change_typ' _ _ _ = HOLogic.natT -and check_for_fun' true (f as Type ("fun", [_,_])) = HOLogic.mk_prodT (f, change_typ' (K false) 0 f) - | check_for_fun' false (f as Type ("fun", [_,_])) = change_typ' (K false) 0 f + and replace all function arguments f with (t*T_f) if used *) +fun change_typ' used (Type ("fun", [T1, T2])) = + Type ("fun", [check_for_fun' (used 0) T1, change_typ' (fn i => used (i+1)) T2]) + | change_typ' _ _ = HOLogic.natT +and check_for_fun' true (f as Type ("fun", [_,_])) = HOLogic.mk_prodT (f, change_typ' (K false) f) + | check_for_fun' false (f as Type ("fun", [_,_])) = change_typ' (K false) f | check_for_fun' _ t = t -val change_typ = change_typ' (K false) 0 +val change_typ = change_typ' (K false) (* Convert string name of function to its timing equivalent *) -fun fun_name_to_time ctxt name = +fun fun_name_to_time ctxt s name = let val prefix = Config.get ctxt bprefix - fun replace_last_name [n] = [prefix ^ n] + val suffix = (if s then Config.get ctxt bsuffix else "") + fun replace_last_name [n] = [prefix ^ n ^ suffix] | replace_last_name (n::ns) = n :: (replace_last_name ns) | replace_last_name _ = error "Internal error: Invalid function name to convert" - val parts = String.fields (fn s => s = #".") name + val parts = split_name name in String.concatWith "." (replace_last_name parts) end @@ -126,11 +130,10 @@ fun count_args (Type (n, [_,res])) = (if n = "fun" then 1 + count_args res else 0) | count_args _ = 0 (* Check if number of arguments matches function *) -fun check_args s (Const (_,T), args) = - (if length args = count_args T then () else error ("Partial applications/Lambdas not allowed (" ^ s ^ ")")) - | check_args s (Free (_,T), args) = - (if length args = count_args T then () else error ("Partial applications/Lambdas not allowed (" ^ s ^ ")")) - | check_args s _ = error ("Partial applications/Lambdas not allowed (" ^ s ^ ")") +val _ = dest_Const +fun check_args s (t, args) = + (if length args = count_args (type_of t) then () + else error ("Partial applications/Lambdas not allowed (" ^ s ^ ")")) (* Removes Abs *) fun rem_abs f (Abs (_,_,t)) = rem_abs f t | rem_abs f t = f t @@ -148,39 +151,53 @@ | Const_name _ = NONE fun is_Used (Type ("Product_Type.prod", _)) = true | is_Used _ = false +(* Custom compare function for types ignoring variable names *) +fun typ_comp (Type (A,a)) (Type (B,b)) = (A = B) andalso List.foldl (fn ((c,i),s) => typ_comp c i andalso s) true (ListPair.zip (a, b)) + | typ_comp (Type _) _ = false + | typ_comp _ (Type _) = false + | typ_comp _ _ = true +fun const_comp (Const (nm,T)) (Const (nm',T')) = nm = nm' andalso typ_comp T T' + | const_comp _ _ = false -fun time_term ctxt (Const (nm,T)) = +fun time_term ctxt s (Const (nm,T)) = let - val T_nm = fun_name_to_time ctxt nm + val T_nm = fun_name_to_time ctxt s nm val T_T = change_typ T in (SOME (Syntax.check_term ctxt (Const (T_nm,T_T)))) handle (ERROR _) => case Syntax.read_term ctxt (Long_Name.base_name T_nm) - of (Const (nm,T_T)) => + of (Const (T_nm,T_T)) => let fun col_Used i (Type ("fun", [Type ("fun", _), Ts])) (Type ("fun", [T', Ts'])) = (if is_Used T' then [i] else []) @ col_Used (i+1) Ts Ts' | col_Used i (Type ("fun", [_, Ts])) (Type ("fun", [_, Ts'])) = col_Used (i+1) Ts Ts' | col_Used _ _ _ = [] in - SOME (Const (nm,change_typ' (contains (col_Used 0 T T_T)) 0 T)) + SOME (Const (T_nm,change_typ' (contains (col_Used 0 T T_T)) T)) end | _ => error ("Timing function of " ^ nm ^ " is not defined") end - | time_term _ _ = error "Internal error: No valid function given" + | time_term _ _ _ = error "Internal error: No valid function given" + +type 'a wctxt = { + ctxt: local_theory, + origins: term list, + f: term -> 'a +} type 'a converter = { - constc : local_theory -> term list -> (term -> 'a) -> term -> 'a, - funcc : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a, - ifc : local_theory -> term list -> (term -> 'a) -> typ -> term -> term -> term -> 'a, - casec : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a, - letc : local_theory -> term list -> (term -> 'a) -> typ -> term -> string list -> typ list -> term -> 'a -}; + constc : 'a wctxt -> term -> 'a, + funcc : 'a wctxt -> term -> term list -> 'a, + ifc : 'a wctxt -> typ -> term -> term -> term -> 'a, + casec : 'a wctxt -> term -> term list -> 'a, + letc : 'a wctxt -> typ -> term -> string list -> typ list -> term -> 'a +} (* Walks over term and calls given converter *) fun walk_func (t1 $ t2) ts = walk_func t1 (t2::ts) | walk_func t ts = (t, ts) +fun walk_func' t = walk_func t [] fun build_func (f, []) = f | build_func (f, (t::ts)) = build_func (f$t, ts) fun walk_abs (Abs (nm,T,t)) nms Ts = walk_abs t (nm::nms) (T::Ts) @@ -193,23 +210,24 @@ val (f, args) = walk_func t [] val this = (walk ctxt origin conv) val _ = (case f of Abs _ => error "Lambdas not supported" | _ => ()) + val wctxt = {ctxt = ctxt, origins = origin, f = this} in (if is_if f then (case f of (Const (_,T)) => - (case args of [cond, t, f] => ifc ctxt origin this T cond t f + (case args of [cond, t, f] => ifc wctxt T cond t f | _ => error "Partial applications not supported (if)") | _ => error "Internal error: invalid if term") - else if is_case f then casec ctxt origin this f args + else if is_case f then casec wctxt f args else if is_let f then (case f of (Const (_,lT)) => (case args of [exp, t] => - let val (t,nms,Ts) = walk_abs t [] [] in letc ctxt origin this lT exp nms Ts t end + let val (t,nms,Ts) = walk_abs t [] [] in letc wctxt lT exp nms Ts t end | _ => error "Partial applications not allowed (let)") | _ => error "Internal error: invalid let term") - else funcc ctxt origin this f args) + else funcc wctxt f args) end | walk ctxt origin (conv as {constc, ...}) c = - constc ctxt origin (walk ctxt origin conv) c + constc {ctxt = ctxt, origins = origin, f = walk ctxt origin conv} c (* 1. Fix all terms *) (* Exchange Var in types and terms to Free *) @@ -221,36 +239,34 @@ fun noFun (Type ("fun", _)) = error "Functions in datatypes are not allowed in case constructions" | noFun T = T fun casecBuildBounds n t = if n > 0 then casecBuildBounds (n-1) (t $ (Bound (n-1))) else t -fun casecAbs ctxt f n (Type ("fun",[T,Tr])) (Abs (v,Ta,t)) = ( map_atyps noFun T; Abs (v,Ta,casecAbs ctxt f n Tr t)) - | casecAbs ctxt f n (Type ("fun",[T,Tr])) t = - (map_atyps noFun T; case Variable.variant_fixes ["x"] ctxt of ([v],ctxt) => - (Abs (v,T,casecAbs ctxt f (n + 1) Tr t)) - | _ => error "Internal error: could not fix variable") - | casecAbs _ f n _ t = f (casecBuildBounds n (Term.incr_bv n 0 t)) -fun fixCasecCases _ _ _ [t] = [t] - | fixCasecCases ctxt f (Type (_,[T,Tr])) (t::ts) = casecAbs ctxt f 0 T t :: fixCasecCases ctxt f Tr ts - | fixCasecCases _ _ _ _ = error "Internal error: invalid case types/terms" -fun fixCasec ctxt _ f (t as Const (_,T)) args = - (check_args "cases" (t,args); build_func (t,fixCasecCases ctxt f T args)) - | fixCasec _ _ _ _ _ = error "Internal error: invalid case term" +fun casecAbs wctxt n (Type ("fun",[T,Tr])) (Abs (v,Ta,t)) = (map_atyps noFun T; Abs (v,Ta,casecAbs wctxt n Tr t)) + | casecAbs wctxt n (Type ("fun",[T,Tr])) t = + (map_atyps noFun T; Abs ("uu",T,casecAbs wctxt (n + 1) Tr t)) + | casecAbs wctxt n _ t = (#f wctxt) (casecBuildBounds n (Term.incr_bv n 0 t)) +fun fixCasecCases _ _ [t] = [t] + | fixCasecCases wctxt (Type (_,[T,Tr])) (t::ts) = casecAbs wctxt 0 T t :: fixCasecCases wctxt Tr ts + | fixCasecCases _ _ _ = error "Internal error: invalid case types/terms" +fun fixCasec wctxt (t as Const (_,T)) args = + (check_args "cases" (t,args); build_func (t,fixCasecCases wctxt T args)) + | fixCasec _ _ _ = error "Internal error: invalid case term" fun fixPartTerms ctxt (term: term list) t = let val _ = check_args "args" (walk_func (get_l t) []) in map_r (walk ctxt term { - funcc = (fn _ => fn _ => fn f => fn t => fn args => - (check_args "func" (t,args); build_func (t, map f args))), - constc = (fn _ => fn _ => fn _ => fn c => (case c of Abs _ => error "Lambdas not supported" | _ => c)), - ifc = (fn _ => fn _ => fn f => fn T => fn cond => fn tt => fn tf => - ((Const (If_name, T)) $ f cond $ (f tt) $ (f tf))), + funcc = (fn wctxt => fn t => fn args => + (check_args "func" (t,args); build_func (t, map (#f wctxt) args))), + constc = (fn _ => fn c => (case c of Abs _ => error "Lambdas not supported" | _ => c)), + ifc = (fn wctxt => fn T => fn cond => fn tt => fn tf => + ((Const (@{const_name "HOL.If"}, T)) $ (#f wctxt) cond $ ((#f wctxt) tt) $ ((#f wctxt) tf))), casec = fixCasec, - letc = (fn _ => fn _ => fn f => fn expT => fn exp => fn nms => fn Ts => fn t => + letc = (fn wctxt => fn expT => fn exp => fn nms => fn Ts => fn t => let val f' = if length nms = 0 then - (case f (t$exp) of t$_ => t | _ => error "Internal error: case could not be fixed (let)") - else f t - in (Const (Let_name,expT) $ (f exp) $ build_abs f' nms Ts) end) + (case (#f wctxt) (t$exp) of t$_ => t | _ => error "Internal error: case could not be fixed (let)") + else (#f wctxt) t + in (Const (@{const_name "HOL.Let"},expT) $ ((#f wctxt) exp) $ build_abs f' nms Ts) end) }) t end @@ -258,11 +274,16 @@ (* 2.1 Check if function is recursive *) fun or f (a,b) = f a orelse b fun find_rec ctxt term = (walk ctxt term { - funcc = (fn _ => fn _ => fn f => fn t => fn args => List.exists (fn term => Const_name t = Const_name term) term orelse List.foldr (or f) false args), - constc = (K o K o K o K) false, - ifc = (fn _ => fn _ => fn f => fn _ => fn cond => fn tt => fn tf => f cond orelse f tt orelse f tf), - casec = (fn _ => fn _ => fn f => fn t => fn cs => f t orelse List.foldr (or (rem_abs f)) false cs), - letc = (fn _ => fn _ => fn f => fn _ => fn exp => fn _ => fn _ => fn t => f exp orelse f t) + funcc = (fn wctxt => fn t => fn args => + List.exists (fn term => (Const_name t) = (Const_name term)) term + orelse List.foldr (or (#f wctxt)) false args), + constc = (K o K) false, + ifc = (fn wctxt => fn _ => fn cond => fn tt => fn tf => + (#f wctxt) cond orelse (#f wctxt) tt orelse (#f wctxt) tf), + casec = (fn wctxt => fn t => fn cs => + (#f wctxt) t orelse List.foldr (or (rem_abs (#f wctxt))) false cs), + letc = (fn wctxt => fn _ => fn exp => fn _ => fn _ => fn t => + (#f wctxt) exp orelse (#f wctxt) t) }) o get_r fun is_rec ctxt (term: term list) = List.foldr (or (find_rec ctxt term)) false @@ -277,14 +298,14 @@ f :: filter_passed args | filter_passed (_::args) = filter_passed args val frees' = (walk ctxt term { - funcc = (fn _ => fn _ => fn f => fn t => fn args => + funcc = (fn wctxt => fn t => fn args => (case t of (Const ("Product_Type.prod.snd", _)) => [] | _ => (if t = T_ident then [] else filter_passed args) - @ List.foldr (fn (l,r) => f l @ r) [] args)), - constc = (K o K o K o K) [], - ifc = (fn _ => fn _ => fn f => fn _ => fn cond => fn tt => fn tf => f cond @ f tt @ f tf), - casec = (fn _ => fn _ => fn f => fn _ => fn cs => List.foldr (fn (l,r) => f l @ r) [] cs), - letc = (fn _ => fn _ => fn f => fn _ => fn exp => fn _ => fn _ => fn t => f exp @ f t) + @ List.foldr (fn (l,r) => (#f wctxt) l @ r) [] args)), + constc = (K o K) [], + ifc = (fn wctxt => fn _ => fn cond => fn tt => fn tf => (#f wctxt) cond @ (#f wctxt) tt @ (#f wctxt) tf), + casec = (fn wctxt => fn _ => fn cs => List.foldr (fn (l,r) => (#f wctxt) l @ r) [] cs), + letc = (fn wctxt => fn _ => fn exp => fn _ => fn _ => fn t => (#f wctxt) exp @ (#f wctxt) t) }) (get_r T_t) fun build _ [] _ = false | build i (a::args) item = @@ -293,7 +314,7 @@ build 0 T_args end fun find_used ctxt term terms T_terms = - zip terms T_terms + ListPair.zip (terms, T_terms) |> List.map (fn (t, T_t) => find_used' ctxt term t T_t) |> List.foldr (fn (f,g) => fn item => f item orelse g item) (K false) @@ -310,57 +331,54 @@ fun use_origin (Free (nm, T as Type ("fun",_))) = HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T))) | use_origin t = t -(* Converting of function term *) +(* Conversion of function term *) fun fun_to_time ctxt orig_used _ (origin: term list) (func as Const (nm,T)) = let - val prefix = Config.get ctxt bprefix val used' = used_for_const orig_used func in - if contains origin func then SOME (Free (prefix ^ Term.term_name func, change_typ' used' 0 T)) else + if contains' const_comp origin func then SOME (Free (func |> Term.term_name |> fun_name_to_time ctxt true, change_typ' used' T)) else if Zero_Funcs.is_zero (Proof_Context.theory_of ctxt) (nm,T) then NONE else - time_term ctxt func + time_term ctxt false func end | fun_to_time ctxt _ used _ (f as Free (nm,T)) = SOME ( if used f then HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T))) - else Free (Config.get ctxt bprefix ^ nm, change_typ T) + else Free (fun_name_to_time ctxt false nm, change_typ T) ) | fun_to_time _ _ _ _ _ = error "Internal error: invalid function to convert" (* Convert arguments of left side of a term *) fun conv_arg ctxt used _ (f as Free (nm,T as Type("fun",_))) = - if used f then Free (nm, HOLogic.mk_prodT (T, change_typ' (K false) 0 T)) - else Free (Config.get ctxt bprefix ^ nm, change_typ' (K false) 0 T) - | conv_arg ctxt _ origin (f as Const (_, Type("fun",_))) = - (error "weird case i don't understand TODO"; HOLogic.mk_prod (f, fun_to_time ctxt (K false) (K false) origin f |> Option.valOf)) + if used f then Free (nm, HOLogic.mk_prodT (T, change_typ' (K false) T)) + else Free (fun_name_to_time ctxt false nm, change_typ' (K false) T) | conv_arg _ _ _ x = x fun conv_args ctxt used origin = map (conv_arg ctxt used origin) (* Handle function calls *) -fun build_zero (Type ("fun", [T, R])) = Abs ("x", T, build_zero R) +fun build_zero (Type ("fun", [T, R])) = Abs ("uu", T, build_zero R) | build_zero _ = zero -fun funcc_use_origin _ _ used (f as Free (nm, T as Type ("fun",_))) = +fun funcc_use_origin used (f as Free (nm, T as Type ("fun",_))) = if used f then HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T))) else error "Internal error: Error in used detection" - | funcc_use_origin _ _ _ t = t -fun funcc_conv_arg ctxt origin used _ (t as (_ $ _)) = map_aterms (funcc_use_origin ctxt origin used) t - | funcc_conv_arg ctxt _ used u (f as Free (nm, T as Type ("fun",_))) = + | funcc_use_origin _ t = t +fun funcc_conv_arg _ used _ (t as (_ $ _)) = map_aterms (funcc_use_origin used) t + | funcc_conv_arg wctxt used u (f as Free (nm, T as Type ("fun",_))) = if used f then if u then Free (nm, HOLogic.mk_prodT (T, change_typ T)) else HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T))) - else Free (Config.get ctxt bprefix ^ nm, change_typ T) - | funcc_conv_arg ctxt origin _ true (f as Const (_,T as Type ("fun",_))) = + else Free (fun_name_to_time (#ctxt wctxt) false nm, change_typ T) + | funcc_conv_arg wctxt _ true (f as Const (_,T as Type ("fun",_))) = (Const (@{const_name "Product_Type.Pair"}, Type ("fun", [T,Type ("fun", [change_typ T, HOLogic.mk_prodT (T,change_typ T)])])) - $ f $ (Option.getOpt (fun_to_time ctxt (K false) (K false) origin f, build_zero T))) - | funcc_conv_arg ctxt origin _ false (f as Const (_,T as Type ("fun",_))) = - Option.getOpt (fun_to_time ctxt (K false) (K false) origin f, build_zero T) - | funcc_conv_arg _ _ _ _ t = t + $ f $ (Option.getOpt (fun_to_time (#ctxt wctxt) (K false) (K false) (#origins wctxt) f, build_zero T))) + | funcc_conv_arg wctxt _ false (f as Const (_,T as Type ("fun",_))) = + Option.getOpt (fun_to_time (#ctxt wctxt) (K false) (K false) (#origins wctxt) f, build_zero T) + | funcc_conv_arg _ _ _ t = t -fun funcc_conv_args _ _ _ _ [] = [] - | funcc_conv_args ctxt origin used (Type ("fun", [t, ts])) (a::args) = - funcc_conv_arg ctxt origin used (is_Used t) a :: funcc_conv_args ctxt origin used ts args - | funcc_conv_args _ _ _ _ _ = error "Internal error: Non matching type" -fun funcc orig_used used ctxt (origin: term list) f func args = +fun funcc_conv_args _ _ _ [] = [] + | funcc_conv_args wctxt used (Type ("fun", [t, ts])) (a::args) = + funcc_conv_arg wctxt used (is_Used t) a :: funcc_conv_args wctxt used ts args + | funcc_conv_args _ _ _ _ = error "Internal error: Non matching type" +fun funcc orig_used used wctxt func args = let fun get_T (Free (_,T)) = T | get_T (Const (_,T)) = T @@ -368,10 +386,10 @@ | get_T _ = error "Internal error: Forgotten type" in List.foldr (I #-> plus) - (case fun_to_time ctxt orig_used used origin func - of SOME t => SOME (build_func (t,funcc_conv_args ctxt origin used (get_T t) args)) + (case fun_to_time (#ctxt wctxt) orig_used used (#origins wctxt) func + of SOME t => SOME (build_func (t,funcc_conv_args wctxt used (get_T t) args)) | NONE => NONE) - (map f args) + (map (#f wctxt) args) end (* Handle case terms *) @@ -389,48 +407,49 @@ (case casecAbs f t of (nconst, tt) => casecArgs f ar ||> (fn ar => tt :: ar) |>> (if nconst then K true else I)) | casecArgs _ _ = error "Internal error: Invalid case term" -fun casec _ _ f (Const (t,T)) args = +fun casec wctxt (Const (t,T)) args = if not (casecIsCase T) then error "Internal error: Invalid case type" else - let val (nconst, args') = casecArgs f args in + let val (nconst, args') = casecArgs (#f wctxt) args in plus - (f (List.last args)) + ((#f wctxt) (List.last args)) (if nconst then SOME (build_func (Const (t,casecTyp T), args')) else NONE) end - | casec _ _ _ _ _ = error "Internal error: Invalid case term" + | casec _ _ _ = error "Internal error: Invalid case term" (* Handle if terms -> drop the term if true and false terms are zero *) -fun ifc _ _ f _ cond tt ft = +fun ifc wctxt _ cond tt ft = let + val f = #f wctxt val rcond = map_aterms use_origin cond val tt = f tt val ft = f ft in plus (f cond) (case (tt,ft) of (NONE, NONE) => NONE | _ => if tt = ft then tt else - (SOME ((Const (If_name, @{typ "bool \ nat \ nat \ nat"})) $ rcond $ (opt_term tt) $ (opt_term ft)))) + (SOME ((Const (@{const_name "HOL.If"}, @{typ "bool \ nat \ nat \ nat"})) $ rcond $ (opt_term tt) $ (opt_term ft)))) end fun letc_change_typ (Type ("fun", [T1, Type ("fun", [T2, _])])) = (Type ("fun", [T1, Type ("fun", [change_typ T2, HOLogic.natT])])) | letc_change_typ _ = error "Internal error: invalid let type" -fun letc _ _ f expT exp nms Ts t = - plus (f exp) +fun letc wctxt expT exp nms Ts t = + plus (#f wctxt exp) (if length nms = 0 (* In case of "length nms = 0" the expression got reducted Here we need Bound 0 to gain non-partial application *) - then (case f (t $ Bound 0) of SOME (t' $ Bound 0) => - (SOME (Const (Let_name, letc_change_typ expT) $ (map_aterms use_origin exp) $ t')) + then (case #f wctxt (t $ Bound 0) of SOME (t' $ Bound 0) => + (SOME (Const (@{const_name "HOL.Let"}, letc_change_typ expT) $ (map_aterms use_origin exp) $ t')) (* Expression is not used and can therefore let be dropped *) | SOME t' => SOME t' | NONE => NONE) - else (case f t of SOME t' => - SOME (if Term.is_dependent t' then Const (Let_name, letc_change_typ expT) $ (map_aterms use_origin exp) $ build_abs t' nms Ts + else (case #f wctxt t of SOME t' => + SOME (if Term.is_dependent t' then Const (@{const_name "HOL.Let"}, letc_change_typ expT) $ (map_aterms use_origin exp) $ build_abs t' nms Ts else Term.subst_bounds([exp],t')) | NONE => NONE)) (* The converter for timing functions given to the walker *) fun converter orig_used used : term option converter = { - constc = fn _ => fn _ => fn _ => fn t => + constc = fn _ => fn t => (case t of Const ("HOL.undefined", _) => SOME (Const ("HOL.undefined", @{typ "nat"})) | _ => NONE), funcc = (funcc orig_used used), @@ -469,20 +488,22 @@ fun get_terms theory (term: term) = - Spec_Rules.retrieve_global theory term - |> hd |> #rules - |> map Thm.prop_of +let + val equations = Spec_Rules.retrieve theory term + |> map #rules + |> map (map Thm.prop_of) handle Empty => error "Function or terms of function not found" +in + equations + |> filter (fn ts => typ_comp (ts |> hd |> get_l |> walk_func' |> fst |> dest_Const |> snd) (term |> dest_Const |> snd)) + |> hd +end (* Register timing function of a given function *) -fun reg_and_proove_time_func (theory: theory) (term: term list) (terms: term list) print = - reg_time_func theory term terms false - |> proove_termination term terms print -and reg_time_func (theory: theory) (term: term list) (terms: term list) print = +fun reg_time_func (lthy: local_theory) (term: term list) (terms: term list) print = let - val lthy = Named_Target.theory_init theory val _ = - case time_term lthy (hd term) + case time_term lthy true (hd term) handle (ERROR _) => NONE of SOME _ => error ("Timing function already declared: " ^ (Term.term_name (hd term))) | NONE => () @@ -517,7 +538,7 @@ (* 4. Register function and prove termination *) val names = map Term.term_name term - val timing_names = map (fun_name_to_time lthy) names + val timing_names = map (fun_name_to_time lthy true) names val bindings = map (fn nm => (Binding.name nm, NONE, NoSyn)) timing_names fun pat_completeness_auto ctxt = Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt @@ -539,34 +560,30 @@ val print_ctxt = lthy |> Config.put show_question_marks false |> Config.put show_sorts false (* Change it for debugging *) - val print_ctxt = List.foldl (fn (term, ctxt) => Variable.add_fixes_implicit term ctxt) print_ctxt (term @ T_terms) + val print_ctxt = List.foldl (fn (term, ctxt) => Variable.add_fixes_implicit term ctxt) print_ctxt (term @ T_terms) (* Print result if print *) val _ = if not print then () else let - val nms = map dest_Const_name term + val nms = map (fst o dest_Const) term val used = map (used_for_const orig_used) term - val typs = map dest_Const_type term + val typs = map (snd o dest_Const) term in - print_timing' print_ctxt { names=nms, terms=terms, typs=typs } { names=timing_names, terms=T_terms, typs=map (fn (used, typ) => change_typ' used 0 typ) (zip used typs) } + print_timing' print_ctxt { names=nms, terms=terms, typs=typs } + { names=timing_names, terms=T_terms, typs=map (fn (used, typ) => change_typ' used typ) (ListPair.zip (used, typs)) } end - (* Register function *) - val (_, lthy) = - register false + in + register false handle (ERROR _) => register true | Match => register true - in - Local_Theory.exit_global lthy end -and proove_termination (term: term list) terms print (theory: theory) = +fun proove_termination (term: term list) terms print (T_info: Function.info, lthy: local_theory) = let - val lthy = Named_Target.theory_init theory - (* Start proving the termination *) val infos = SOME (map (Function.get_info lthy) term) handle Empty => NONE - val timing_names = map (fun_name_to_time lthy o Term.term_name) term + val timing_names = map (fun_name_to_time lthy true o Term.term_name) term (* Proof by lexicographic_order_tac *) val (time_info, lthy') = @@ -578,22 +595,22 @@ val _ = (if length term > 1 then error "Proof over dom not supported for mutual recursive functions" else ()) fun args (a$(Var ((nm,_),T))) = args a |> (fn ar => (nm,T)::ar) - | args (a$(Const (_,T))) = args a |> (fn ar => ("x",T)::ar) + | args (a$(Const (_,T))) = args a |> (fn ar => ("uu",T)::ar) | args _ = [] - val dom_args = - terms |> hd |> get_l |> args - |> Variable.variant_frees lthy [] - |> map fst + val dom_vars = + terms |> hd |> get_l |> map_types (map_atyps fixTypes) + |> args |> Variable.variant_frees lthy [] + val dom_args = + List.foldl (fn (t,p) => HOLogic.mk_prod ((Free t),p)) (Free (hd dom_vars)) (tl dom_vars) val {inducts, ...} = case infos of SOME [i] => i | _ => error "Proof over dom failed as no induct rule was found" val induct = (Option.valOf inducts |> hd) val domintros = Proof_Context.get_fact lthy (Facts.named (hd timing_names ^ ".domintros")) - val prop = (hd timing_names ^ "_dom (" ^ (String.concatWith "," dom_args) ^ ")") - |> Syntax.read_prop lthy + val prop = HOLogic.mk_Trueprop (#dom T_info $ dom_args) (* Prove a helper lemma *) - val dom_lemma = Goal.prove lthy dom_args [] prop + val dom_lemma = Goal.prove lthy (map fst dom_vars) [] prop (fn {context, ...} => HEADGOAL (time_dom_tac context induct domintros)) (* Add dom_lemma to simplification set *) val simp_lthy = Simplifier.add_simp dom_lemma lthy @@ -602,7 +619,7 @@ Function.prove_termination NONE (auto_tac simp_lthy) lthy end - + (* Context for printing without showing question marks *) val print_ctxt = lthy' |> Config.put show_question_marks false @@ -616,63 +633,98 @@ print_timing' print_ctxt { names=nms, terms=terms, typs=typs } (info_pfunc time_info) end in - (time_info, Local_Theory.exit_global lthy') + (time_info, lthy') end +fun reg_and_proove_time_func (lthy: local_theory) (term: term list) (terms: term list) print = + reg_time_func lthy term terms false + |> proove_termination term terms print fun fix_definition (Const ("Pure.eq", _) $ l $ r) = Const ("HOL.Trueprop", @{typ "bool \ prop"}) $ (Const ("HOL.eq", @{typ "bool \ bool \ bool"}) $ l $ r) | fix_definition t = t fun check_definition [t] = [t] - | check_definition _ = error "Only a single defnition is allowed" + | check_definition _ = error "Only a single definition is allowed" + +fun isTypeClass' (Const (nm,_)) = + (case split_name nm |> rev + of (_::nm::_) => String.isSuffix "_class" nm + | _ => false) + | isTypeClass' _ = false +val isTypeClass = + (List.foldr (fn (a,b) => a orelse b) false) o (map isTypeClass') + +fun detect_typ (ctxt: local_theory) (term: term) = +let + val class_term = (case term of Const (nm,_) => Syntax.read_term ctxt nm + | _ => error "Could not find term of class") + fun find_free (Type (_,class)) (Type (_,inst)) = + List.foldl (fn ((c,i),s) => (case s of NONE => find_free c i | t => t)) (NONE) (ListPair.zip (class, inst)) + | find_free (TFree _) (TFree _) = NONE + | find_free (TFree _) (Type (nm,_)) = SOME nm + | find_free _ _ = error "Unhandled case in detecting type" +in + find_free (type_of class_term) (type_of term) + |> Option.map (hd o rev o split_name) +end + +fun set_suffix (fterms: term list) ctxt = +let + val isTypeClass = isTypeClass fterms + val _ = (if length fterms > 1 andalso isTypeClass then error "No mutual recursion inside instantiation allowed" else ()) + val suffix = (if isTypeClass then detect_typ ctxt (hd fterms) else NONE) +in + (case suffix of NONE => I | SOME s => Config.put bsuffix ("_" ^ s)) ctxt +end (* Convert function into its timing function (called by command) *) -fun reg_time_fun_cmd (funcs, thms) (theory: theory) = +fun reg_time_fun_cmd (funcs, thms) (ctxt: local_theory) = let - val ctxt = Proof_Context.init_global theory val fterms = map (Syntax.read_term ctxt) funcs - val (_, lthy') = reg_and_proove_time_func theory fterms - (case thms of NONE => get_terms theory (hd fterms) + val ctxt = set_suffix fterms ctxt + val (_, ctxt') = reg_and_proove_time_func ctxt fterms + (case thms of NONE => get_terms ctxt (hd fterms) | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) true -in lthy' +in ctxt' end (* Convert function into its timing function (called by command) with termination proof provided by user*) -fun reg_time_function_cmd (funcs, thms) (theory: theory) = +fun reg_time_function_cmd (funcs, thms) (ctxt: local_theory) = let - val ctxt = Proof_Context.init_global theory val fterms = map (Syntax.read_term ctxt) funcs - val theory = reg_time_func theory fterms - (case thms of NONE => get_terms theory (hd fterms) + val ctxt = set_suffix fterms ctxt + val ctxt' = reg_time_func ctxt fterms + (case thms of NONE => get_terms ctxt (hd fterms) | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) true -in theory + |> snd +in ctxt' end (* Convert function into its timing function (called by command) *) -fun reg_time_definition_cmd (funcs, thms) (theory: theory) = +fun reg_time_definition_cmd (funcs, thms) (ctxt: local_theory) = let - val ctxt = Proof_Context.init_global theory val fterms = map (Syntax.read_term ctxt) funcs - val (_, lthy') = reg_and_proove_time_func theory fterms - (case thms of NONE => get_terms theory (hd fterms) |> check_definition |> map fix_definition + val ctxt = set_suffix fterms ctxt + val (_, ctxt') = reg_and_proove_time_func ctxt fterms + (case thms of NONE => get_terms ctxt (hd fterms) |> check_definition |> map fix_definition | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) true -in lthy' +in ctxt' end val parser = (Scan.repeat1 Parse.prop) -- (Scan.option (Parse.keyword_improper "equations" -- Parse.thms1 >> snd)) -val _ = Outer_Syntax.command @{command_keyword "time_fun"} +val _ = Outer_Syntax.local_theory @{command_keyword "time_fun"} "Defines runtime function of a function" - (parser >> (fn p => Toplevel.theory (reg_time_fun_cmd p))) + (parser >> reg_time_fun_cmd) -val _ = Outer_Syntax.command @{command_keyword "time_function"} +val _ = Outer_Syntax.local_theory @{command_keyword "time_function"} "Defines runtime function of a function" - (parser >> (fn p => Toplevel.theory (reg_time_function_cmd p))) + (parser >> reg_time_function_cmd) -val _ = Outer_Syntax.command @{command_keyword "time_definition"} +val _ = Outer_Syntax.local_theory @{command_keyword "time_definition"} "Defines runtime function of a definition" - (parser >> (fn p => Toplevel.theory (reg_time_definition_cmd p))) + (parser >> reg_time_definition_cmd) end diff -r 17d8b3f6d744 -r 0c406b9469ab src/HOL/Data_Structures/Selection.thy --- a/src/HOL/Data_Structures/Selection.thy Wed Aug 21 14:09:44 2024 +0100 +++ b/src/HOL/Data_Structures/Selection.thy Wed Aug 21 20:41:16 2024 +0200 @@ -645,9 +645,7 @@ lemmas T_slow_select_def [simp del] = T_slow_select.simps - -definition T_slow_median :: "'a :: linorder list \ nat" where - "T_slow_median xs = T_length xs + T_slow_select ((length xs - 1) div 2) xs" +time_fun slow_median lemma T_slow_select_le: assumes "k < length xs" @@ -671,7 +669,7 @@ shows "T_slow_median xs \ length xs ^ 2 + 4 * length xs + 2" proof - have "T_slow_median xs = length xs + T_slow_select ((length xs - 1) div 2) xs + 1" - by (simp add: T_slow_median_def T_length_eq) + by (simp add: T_length_eq) also from assms have "length xs > 0" by simp hence "(length xs - 1) div 2 < length xs" diff -r 17d8b3f6d744 -r 0c406b9469ab src/HOL/Data_Structures/Time_Funs.thy --- a/src/HOL/Data_Structures/Time_Funs.thy Wed Aug 21 14:09:44 2024 +0100 +++ b/src/HOL/Data_Structures/Time_Funs.thy Wed Aug 21 20:41:16 2024 +0200 @@ -12,16 +12,25 @@ lemma T_append: "T_append xs ys = length xs + 1" by(induction xs) auto -text \Automatic definition of \T_length\ is cumbersome because of the type class for \size\.\ +class T_size = + fixes T_size :: "'a \ nat" + +instantiation list :: (_) T_size +begin -fun T_length :: "'a list \ nat" where - "T_length [] = 1" -| "T_length (x # xs) = T_length xs + 1" +time_fun length + +instance .. + +end + +abbreviation T_length :: "'a list \ nat" where +"T_length \ T_size" lemma T_length_eq: "T_length xs = length xs + 1" by (induction xs) auto -lemmas [simp del] = T_length.simps +lemmas [simp del] = T_size_list.simps time_fun map