# HG changeset patch # User nipkow # Date 1728704729 -32400 # Node ID 503e5280ba726401ebfacaa6b6f460594c45ab45 # Parent 87f173836d5667137b5fe2b1f4ff878449c93704 new HO time functions diff -r 87f173836d56 -r 503e5280ba72 src/HOL/Data_Structures/Define_Time_Function.ML --- a/src/HOL/Data_Structures/Define_Time_Function.ML Thu Oct 10 14:13:18 2024 +0200 +++ b/src/HOL/Data_Structures/Define_Time_Function.ML Sat Oct 12 12:45:29 2024 +0900 @@ -11,9 +11,14 @@ funcc : 'a wctxt -> term -> term list -> 'a, ifc : 'a wctxt -> typ -> term -> term -> term -> 'a, casec : 'a wctxt -> term -> term list -> 'a, - letc : 'a wctxt -> typ -> term -> string list -> typ list -> term -> 'a + letc : 'a wctxt -> typ -> term -> (string * typ) list -> term -> 'a } val walk : local_theory -> term list -> 'a converter -> term -> 'a +val Iconst : term wctxt -> term -> term +val Ifunc : term wctxt -> term -> term list -> term +val Iif : term wctxt -> typ -> term -> term -> term -> term +val Icase : term wctxt -> term -> term list -> term +val Ilet : term wctxt -> typ -> term -> (string * typ) list -> term -> term type pfunc = { names : string list, terms : term list, typs : typ list } val fun_pretty': Proof.context -> pfunc -> Pretty.T @@ -22,9 +27,9 @@ val print_timing: Proof.context -> Function.info -> Function.info -> unit val reg_and_proove_time_func: local_theory -> term list -> term list - -> bool -> Function.info * local_theory + -> bool -> bool -> Function.info * local_theory val reg_time_func: local_theory -> term list -> term list - -> bool -> Function.info * local_theory + -> bool -> bool -> Function.info * local_theory val time_dom_tac: Proof.context -> thm -> thm list -> int -> tactic @@ -34,6 +39,7 @@ struct (* Configure config variable to adjust the prefix *) val bprefix = Attrib.setup_config_string @{binding "time_prefix"} (K "T_") +val bprefix_snd = Attrib.setup_config_string @{binding "time_prefix_snd"} (K "T2_") (* Configure config variable to adjust the suffix *) val bsuffix = Attrib.setup_config_string @{binding "time_suffix"} (K "") @@ -83,16 +89,37 @@ val poriginal = Pretty.item [Pretty.str "Original function:\n", fun_pretty' ctxt opfunc] val ptiming = Pretty.item [Pretty.str ("Running time function:\n"), fun_pretty' ctxt tpfunc] in - Pretty.writeln (Pretty.text_fold [Pretty.str ("Converting " ^ (hd names) ^ (String.concat (map (fn nm => ", " ^ nm) (tl names))) ^ "\n"), poriginal, Pretty.str "\n", ptiming]) + Pretty.writeln (Pretty.text_fold [ + Pretty.str ("Converting " ^ (hd names) ^ (String.concat (map (fn nm => ", " ^ nm) (tl names))) ^ "\n"), + poriginal, Pretty.str "\n", ptiming]) end fun print_timing ctxt (oinfo: Function.info) (tinfo: Function.info) = print_timing' ctxt (info_pfunc oinfo) (info_pfunc tinfo) +fun print_lemma ctxt defs (T_terms: term list) = +let + val names = + defs + |> map snd + |> map (fn s => "_" ^ s) + |> List.foldr (op ^) "" + val begin = "lemma T" ^ names ^ "_simps [simp,code]:\n" + fun convLine T_term = + " \"" ^ Syntax.string_of_term ctxt T_term ^ "\"\n" + val lines = map convLine T_terms + fun convDefs def = " " ^ (fst def) + val proof = " by (simp_all add:" :: (map convDefs defs) @ [")"] + val _ = Pretty.writeln (Pretty.str "Characteristic recursion equations can be derived:") +in + (begin :: lines @ proof) + |> String.concat + (* |> Active.sendback_markup_properties [Markup.padding_fun] *) + |> Pretty.str + |> Pretty.writeln +end + fun contains l e = exists (fn e' => e' = e) l fun contains' comp l e = exists (comp e) l -fun index [] _ = 0 - | index (x::xs) el = (if x = el then 0 else 1 + index xs el) -fun used_for_const orig_used t i = orig_used (t,i) (* Split name by . *) val split_name = String.fields (fn s => s = #".") @@ -111,13 +138,13 @@ Type ("fun", [check_for_fun' (used 0) T1, change_typ' (fn i => used (i+1)) T2]) | change_typ' _ _ = HOLogic.natT and check_for_fun' true (f as Type ("fun", [_,_])) = HOLogic.mk_prodT (f, change_typ' (K false) f) - | check_for_fun' false (f as Type ("fun", [_,_])) = change_typ' (K false) f + | check_for_fun' false (f as Type ("fun", [_,_])) = change_typ' (K true) f | check_for_fun' _ t = t -val change_typ = change_typ' (K false) +val change_typ = change_typ' (K true) (* Convert string name of function to its timing equivalent *) -fun fun_name_to_time ctxt s name = +fun fun_name_to_time' ctxt s second name = let - val prefix = Config.get ctxt bprefix + val prefix = Config.get ctxt (if second then bprefix_snd else bprefix) val suffix = (if s then Config.get ctxt bsuffix else "") fun replace_last_name [n] = [prefix ^ n ^ suffix] | replace_last_name (n::ns) = n :: (replace_last_name ns) @@ -126,11 +153,11 @@ in String.concatWith "." (replace_last_name parts) end +fun fun_name_to_time ctxt s name = fun_name_to_time' ctxt s false name (* Count number of arguments of a function *) fun count_args (Type (n, [_,res])) = (if n = "fun" then 1 + count_args res else 0) | count_args _ = 0 (* Check if number of arguments matches function *) -val _ = dest_Const fun check_args s (t, args) = (if length args = count_args (type_of t) then () else error ("Partial applications/Lambdas not allowed (" ^ s ^ ")")) @@ -191,23 +218,16 @@ funcc : 'a wctxt -> term -> term list -> 'a, ifc : 'a wctxt -> typ -> term -> term -> term -> 'a, casec : 'a wctxt -> term -> term list -> 'a, - letc : 'a wctxt -> typ -> term -> string list -> typ list -> term -> 'a + letc : 'a wctxt -> typ -> term -> (string * typ) list -> term -> 'a } (* Walks over term and calls given converter *) -fun walk_func (t1 $ t2) ts = walk_func t1 (t2::ts) - | walk_func t ts = (t, ts) -fun walk_func' t = walk_func t [] -fun build_func (f, []) = f - | build_func (f, (t::ts)) = build_func (f$t, ts) -fun walk_abs (Abs (nm,T,t)) nms Ts = walk_abs t (nm::nms) (T::Ts) - | walk_abs t nms Ts = (t, nms, Ts) -fun build_abs t (nm::nms) (T::Ts) = build_abs (Abs (nm,T,t)) nms Ts - | build_abs t [] [] = t - | build_abs _ _ _ = error "Internal error: Invalid terms to build abs" +(* get rid and use Term.strip_abs.eta especially for lambdas *) +fun build_abs t ((nm,T)::abs) = build_abs (Abs (nm,T,t)) abs + | build_abs t [] = t fun walk ctxt (origin: term list) (conv as {ifc, casec, funcc, letc, ...} : 'a converter) (t as _ $ _) = let - val (f, args) = walk_func t [] + val (f, args) = strip_comb t val this = (walk ctxt origin conv) val _ = (case f of Abs _ => error "Lambdas not supported" | _ => ()) val wctxt = {ctxt = ctxt, origins = origin, f = this} @@ -220,21 +240,28 @@ else if is_case f then casec wctxt f args else if is_let f then (case f of (Const (_,lT)) => - (case args of [exp, t] => - let val (t,nms,Ts) = walk_abs t [] [] in letc wctxt lT exp nms Ts t end + (case args of [exp, t] => + let val (abs,t) = strip_abs t in letc wctxt lT exp abs t end | _ => error "Partial applications not allowed (let)") | _ => error "Internal error: invalid let term") else funcc wctxt f args) end | walk ctxt origin (conv as {constc, ...}) c = constc {ctxt = ctxt, origins = origin, f = walk ctxt origin conv} c +fun Ifunc (wctxt: term wctxt) t args = list_comb (#f wctxt t,map (#f wctxt) args) +val Iconst = K I +fun Iif (wctxt: term wctxt) T cond tt tf = + Const (@{const_name "HOL.If"}, T) $ (#f wctxt cond) $ (#f wctxt tt) $ (#f wctxt tf) +fun Icase (wctxt: term wctxt) t cs = list_comb (#f wctxt t,map (#f wctxt) cs) +fun Ilet (wctxt: term wctxt) lT exp abs t = + Const (@{const_name "HOL.Let"},lT) $ (#f wctxt exp) $ build_abs (#f wctxt t) abs (* 1. Fix all terms *) (* Exchange Var in types and terms to Free *) -fun fixTerms (Var(ixn,T)) = Free (fst ixn, T) - | fixTerms t = t -fun fixTypes (TVar ((t, _), T)) = TFree (t, T) - | fixTypes t = t +fun freeTerms (Var(ixn,T)) = Free (fst ixn, T) + | freeTerms t = t +fun freeTypes (TVar ((t, _), T)) = TFree (t, T) + | freeTypes t = t fun noFun (Type ("fun", _)) = error "Functions in datatypes are not allowed in case constructions" | noFun T = T @@ -247,28 +274,39 @@ | fixCasecCases wctxt (Type (_,[T,Tr])) (t::ts) = casecAbs wctxt 0 T t :: fixCasecCases wctxt Tr ts | fixCasecCases _ _ _ = error "Internal error: invalid case types/terms" fun fixCasec wctxt (t as Const (_,T)) args = - (check_args "cases" (t,args); build_func (t,fixCasecCases wctxt T args)) + (check_args "cases" (t,args); list_comb (t,fixCasecCases wctxt T args)) | fixCasec _ _ _ = error "Internal error: invalid case term" -fun fixPartTerms ctxt (term: term list) t = +fun shortFunc fixedNum (Const (nm,T)) = + Const (nm,T |> strip_type |>> drop fixedNum |> (op --->)) + | shortFunc _ _ = error "Internal error: Invalid term" +fun shortApp fixedNum (c, args) = + (shortFunc fixedNum c, drop fixedNum args) +fun shortOriginFunc (term: term list) fixedNum (f as (c as Const (_,_), _)) = + if contains' const_comp term c then shortApp fixedNum f else f + | shortOriginFunc _ _ t = t +fun fixTerms ctxt (term: term list) (fixedNum: int) (t as pT $ (eq $ l $ r)) = let - val _ = check_args "args" (walk_func (get_l t) []) - in - map_r (walk ctxt term { + val _ = check_args "args" (strip_comb (get_l t)) + val l' = shortApp fixedNum (strip_comb l) |> list_comb + val shortOriginFunc' = shortOriginFunc (term |> map (fst o strip_comb)) fixedNum + val r' = walk ctxt term { funcc = (fn wctxt => fn t => fn args => - (check_args "func" (t,args); build_func (t, map (#f wctxt) args))), + (check_args "func" (t,args); (t, map (#f wctxt) args) |> shortOriginFunc' |> list_comb)), constc = (fn _ => fn c => (case c of Abs _ => error "Lambdas not supported" | _ => c)), - ifc = (fn wctxt => fn T => fn cond => fn tt => fn tf => - ((Const (@{const_name "HOL.If"}, T)) $ (#f wctxt) cond $ ((#f wctxt) tt) $ ((#f wctxt) tf))), + ifc = Iif, casec = fixCasec, - letc = (fn wctxt => fn expT => fn exp => fn nms => fn Ts => fn t => + letc = (fn wctxt => fn expT => fn exp => fn abs => fn t => let - val f' = if length nms = 0 then + val f' = if length abs = 0 then (case (#f wctxt) (t$exp) of t$_ => t | _ => error "Internal error: case could not be fixed (let)") else (#f wctxt) t - in (Const (@{const_name "HOL.Let"},expT) $ ((#f wctxt) exp) $ build_abs f' nms Ts) end) - }) t + in (Const (@{const_name "HOL.Let"},expT) $ ((#f wctxt) exp) $ build_abs f' abs) end) + } r + in + pT $ (eq $ l' $ r') end + | fixTerms _ _ _ _ = error "Internal error: invalid term" (* 2. Check for properties about the function *) (* 2.1 Check if function is recursive *) @@ -282,43 +320,11 @@ (#f wctxt) cond orelse (#f wctxt) tt orelse (#f wctxt) tf), casec = (fn wctxt => fn t => fn cs => (#f wctxt) t orelse List.foldr (or (rem_abs (#f wctxt))) false cs), - letc = (fn wctxt => fn _ => fn exp => fn _ => fn _ => fn t => + letc = (fn wctxt => fn _ => fn exp => fn _ => fn t => (#f wctxt) exp orelse (#f wctxt) t) }) o get_r fun is_rec ctxt (term: term list) = List.foldr (or (find_rec ctxt term)) false -(* 2.2 Check for higher-order function if original function is used *) -fun find_used' ctxt term t T_t = -let - val (ident, _) = walk_func (get_l t) [] - val (T_ident, T_args) = walk_func (get_l T_t) [] - - fun filter_passed [] = [] - | filter_passed ((f as Free (_, Type ("Product_Type.prod",[Type ("fun",_), Type ("fun", _)])))::args) = - f :: filter_passed args - | filter_passed (_::args) = filter_passed args - val frees' = (walk ctxt term { - funcc = (fn wctxt => fn t => fn args => - (case t of (Const ("Product_Type.prod.snd", _)) => [] - | _ => (if t = T_ident then [] else filter_passed args) - @ List.foldr (fn (l,r) => (#f wctxt) l @ r) [] args)), - constc = (K o K) [], - ifc = (fn wctxt => fn _ => fn cond => fn tt => fn tf => (#f wctxt) cond @ (#f wctxt) tt @ (#f wctxt) tf), - casec = (fn wctxt => fn _ => fn cs => List.foldr (fn (l,r) => (#f wctxt) l @ r) [] cs), - letc = (fn wctxt => fn _ => fn exp => fn _ => fn _ => fn t => (#f wctxt) exp @ (#f wctxt) t) - }) (get_r T_t) - fun build _ [] _ = false - | build i (a::args) item = - (if item = (ident,i) then contains frees' a else build (i+1) args item) -in - build 0 T_args -end -fun find_used ctxt term terms T_terms = - ListPair.zip (terms, T_terms) - |> List.map (fn (t, T_t) => find_used' ctxt term t T_t) - |> List.foldr (fn (f,g) => fn item => f item orelse g item) (K false) - - (* 3. Convert equations *) (* Some Helper *) val plusTyp = @{typ "nat => nat => nat"} @@ -332,53 +338,48 @@ | use_origin t = t (* Conversion of function term *) -fun fun_to_time ctxt orig_used _ (origin: term list) (func as Const (nm,T)) = +fun fun_to_time' ctxt (origin: term list) second (func as Const (nm,T)) = let - val used' = used_for_const orig_used func + val origin' = map (fst o strip_comb) origin in - if contains' const_comp origin func then SOME (Free (func |> Term.term_name |> fun_name_to_time ctxt true, change_typ' used' T)) else + if contains' const_comp origin' func then SOME (Free (func |> Term.term_name |> fun_name_to_time' ctxt true second, change_typ T)) else if Zero_Funcs.is_zero (Proof_Context.theory_of ctxt) (nm,T) then NONE else time_term ctxt false func end - | fun_to_time ctxt _ used _ (f as Free (nm,T)) = SOME ( - if used f then HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T))) - else Free (fun_name_to_time ctxt false nm, change_typ T) - ) - | fun_to_time _ _ _ _ _ = error "Internal error: invalid function to convert" + | fun_to_time' _ _ _ (Free (nm,T)) = + SOME (HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T)))) + | fun_to_time' _ _ _ _ = error "Internal error: invalid function to convert" +fun fun_to_time context origin func = fun_to_time' context origin false func (* Convert arguments of left side of a term *) -fun conv_arg ctxt used _ (f as Free (nm,T as Type("fun",_))) = - if used f then Free (nm, HOLogic.mk_prodT (T, change_typ' (K false) T)) - else Free (fun_name_to_time ctxt false nm, change_typ' (K false) T) - | conv_arg _ _ _ x = x -fun conv_args ctxt used origin = map (conv_arg ctxt used origin) +fun conv_arg _ (Free (nm,T as Type("fun",_))) = + Free (nm, HOLogic.mk_prodT (T, change_typ' (K false) T)) + | conv_arg _ x = x +fun conv_args ctxt = map (conv_arg ctxt) (* Handle function calls *) fun build_zero (Type ("fun", [T, R])) = Abs ("uu", T, build_zero R) | build_zero _ = zero -fun funcc_use_origin used (f as Free (nm, T as Type ("fun",_))) = - if used f then HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T))) - else error "Internal error: Error in used detection" - | funcc_use_origin _ t = t -fun funcc_conv_arg _ used _ (t as (_ $ _)) = map_aterms (funcc_use_origin used) t - | funcc_conv_arg wctxt used u (f as Free (nm, T as Type ("fun",_))) = - if used f then - if u then Free (nm, HOLogic.mk_prodT (T, change_typ T)) - else HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T))) - else Free (fun_name_to_time (#ctxt wctxt) false nm, change_typ T) - | funcc_conv_arg wctxt _ true (f as Const (_,T as Type ("fun",_))) = +fun funcc_use_origin (Free (nm, T as Type ("fun",_))) = + HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T))) + | funcc_use_origin t = t +fun funcc_conv_arg _ _ (t as (_ $ _)) = map_aterms funcc_use_origin t + | funcc_conv_arg _ u (Free (nm, T as Type ("fun",_))) = + if u then Free (nm, HOLogic.mk_prodT (T, change_typ T)) + else HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T))) + | funcc_conv_arg wctxt true (f as Const (_,T as Type ("fun",_))) = (Const (@{const_name "Product_Type.Pair"}, Type ("fun", [T,Type ("fun", [change_typ T, HOLogic.mk_prodT (T,change_typ T)])])) - $ f $ (Option.getOpt (fun_to_time (#ctxt wctxt) (K false) (K false) (#origins wctxt) f, build_zero T))) - | funcc_conv_arg wctxt _ false (f as Const (_,T as Type ("fun",_))) = - Option.getOpt (fun_to_time (#ctxt wctxt) (K false) (K false) (#origins wctxt) f, build_zero T) - | funcc_conv_arg _ _ _ t = t + $ f $ (Option.getOpt (fun_to_time (#ctxt wctxt) (#origins wctxt) f, build_zero T))) + | funcc_conv_arg wctxt false (f as Const (_,T as Type ("fun",_))) = + Option.getOpt (fun_to_time (#ctxt wctxt) (#origins wctxt) f, build_zero T) + | funcc_conv_arg _ _ t = t -fun funcc_conv_args _ _ _ [] = [] - | funcc_conv_args wctxt used (Type ("fun", [t, ts])) (a::args) = - funcc_conv_arg wctxt used (is_Used t) a :: funcc_conv_args wctxt used ts args - | funcc_conv_args _ _ _ _ = error "Internal error: Non matching type" -fun funcc orig_used used wctxt func args = +fun funcc_conv_args _ _ [] = [] + | funcc_conv_args wctxt (Type ("fun", [t, ts])) (a::args) = + funcc_conv_arg wctxt (is_Used t) a :: funcc_conv_args wctxt ts args + | funcc_conv_args _ _ _ = error "Internal error: Non matching type" +fun funcc wctxt func args = let fun get_T (Free (_,T)) = T | get_T (Const (_,T)) = T @@ -386,8 +387,8 @@ | get_T _ = error "Internal error: Forgotten type" in List.foldr (I #-> plus) - (case fun_to_time (#ctxt wctxt) orig_used used (#origins wctxt) func - of SOME t => SOME (build_func (t,funcc_conv_args wctxt used (get_T t) args)) + (case fun_to_time (#ctxt wctxt) (#origins wctxt) func + of SOME t => SOME (list_comb (t,funcc_conv_args wctxt (get_T t) args)) | NONE => NONE) (map (#f wctxt) args) end @@ -413,7 +414,7 @@ plus ((#f wctxt) (List.last args)) (if nconst then - SOME (build_func (Const (t,casecTyp T), args')) + SOME (list_comb (Const (t,casecTyp T), args')) else NONE) end | casec _ _ _ = error "Internal error: Invalid case term" @@ -433,9 +434,9 @@ fun letc_change_typ (Type ("fun", [T1, Type ("fun", [T2, _])])) = (Type ("fun", [T1, Type ("fun", [change_typ T2, HOLogic.natT])])) | letc_change_typ _ = error "Internal error: invalid let type" -fun letc wctxt expT exp nms Ts t = +fun letc wctxt expT exp abs t = plus (#f wctxt exp) - (if length nms = 0 (* In case of "length nms = 0" the expression got reducted + (if length abs = 0 (* In case of "length nms = 0" the expression got reducted Here we need Bound 0 to gain non-partial application *) then (case #f wctxt (t $ Bound 0) of SOME (t' $ Bound 0) => (SOME (Const (@{const_name "HOL.Let"}, letc_change_typ expT) $ (map_aterms use_origin exp) $ t')) @@ -443,16 +444,16 @@ | SOME t' => SOME t' | NONE => NONE) else (case #f wctxt t of SOME t' => - SOME (if Term.is_dependent t' then Const (@{const_name "HOL.Let"}, letc_change_typ expT) $ (map_aterms use_origin exp) $ build_abs t' nms Ts + SOME (if Term.is_dependent t' then Const (@{const_name "HOL.Let"}, letc_change_typ expT) $ (map_aterms use_origin exp) $ build_abs t' abs else Term.subst_bounds([exp],t')) | NONE => NONE)) (* The converter for timing functions given to the walker *) -fun converter orig_used used : term option converter = { +val converter : term option converter = { constc = fn _ => fn t => (case t of Const ("HOL.undefined", _) => SOME (Const ("HOL.undefined", @{typ "nat"})) | _ => NONE), - funcc = (funcc orig_used used), + funcc = funcc, ifc = ifc, casec = casec, letc = letc @@ -460,24 +461,39 @@ fun top_converter is_rec _ _ = opt_term o (fn exp => plus exp (if is_rec then SOME one else NONE)) (* Use converter to convert right side of a term *) -fun to_time ctxt origin is_rec orig_used used term = - top_converter is_rec ctxt origin (walk ctxt origin (converter orig_used used) term) +fun to_time ctxt origin is_rec term = + top_converter is_rec ctxt origin (walk ctxt origin converter term) (* Converts a term to its running time version *) -fun convert_term ctxt (origin: term list) is_rec orig_used (pT $ (Const (eqN, _) $ l $ r)) = +fun convert_term ctxt (origin: term list) is_rec (pT $ (Const (eqN, _) $ l $ r)) = let - val (l' as (l_const, l_params)) = walk_func l [] - val used = - l_const - |> used_for_const orig_used - |> (fn f => fn n => f (index l_params n)) + val (l_const, l_params) = strip_comb l in - pT - $ (Const (eqN, @{typ "nat \ nat \ bool"}) - $ (build_func (l' |>> (fun_to_time ctxt orig_used used origin) |>> Option.valOf ||> conv_args ctxt used origin)) - $ (to_time ctxt origin is_rec orig_used used r)) + pT + $ (Const (eqN, @{typ "nat \ nat \ bool"}) + $ (list_comb (l_const |> fun_to_time ctxt origin |> Option.valOf, l_params |> conv_args ctxt)) + $ (to_time ctxt origin is_rec r)) end - | convert_term _ _ _ _ _ = error "Internal error: invalid term to convert" + | convert_term _ _ _ _ = error "Internal error: invalid term to convert" + +(* 3.5 Support for locales *) +fun replaceFstSndFree ctxt (origin: term list) (rfst: term -> term) (rsnd: term -> term) = + (walk ctxt origin { + funcc = fn wctxt => fn t => fn args => + case args of + (f as Free _)::args => + (case t of + Const ("Product_Type.prod.fst", _) => + list_comb (rfst (t $ f), map (#f wctxt) args) + | Const ("Product_Type.prod.snd", _) => + list_comb (rsnd (t $ f), map (#f wctxt) args) + | t => list_comb (t, map (#f wctxt) (f :: args))) + | args => list_comb (t, map (#f wctxt) args), + constc = Iconst, + ifc = Iif, + casec = Icase, + letc = Ilet + }) (* 4. Tactic to prove "f_dom n" *) fun time_dom_tac ctxt induct_rule domintros = @@ -495,12 +511,196 @@ handle Empty => error "Function or terms of function not found" in equations - |> filter (fn ts => typ_comp (ts |> hd |> get_l |> walk_func' |> fst |> dest_Const |> snd) (term |> dest_Const |> snd)) + |> filter (List.exists + (fn t => typ_comp (t |> get_l |> strip_comb |> fst |> dest_Const |> snd) (term |> strip_comb |> fst |> dest_Const |> snd))) |> hd end +(* 5. Check for higher-order function if original function is used \ find simplifications *) +fun find_used' T_t = +let + val (T_ident, T_args) = strip_comb (get_l T_t) + + fun filter_passed [] = [] + | filter_passed ((f as Free (_, Type ("Product_Type.prod",[Type ("fun",_), Type ("fun", _)])))::args) = + f :: filter_passed args + | filter_passed (_::args) = filter_passed args + val frees = (walk @{context} [] { + funcc = (fn wctxt => fn t => fn args => + (case t of (Const ("Product_Type.prod.snd", _)) => [] + | _ => (if t = T_ident then [] else filter_passed args) + @ List.foldr (fn (l,r) => (#f wctxt) l @ r) [] args)), + constc = (K o K) [], + ifc = (fn wctxt => fn _ => fn cond => fn tt => fn tf => (#f wctxt) cond @ (#f wctxt) tt @ (#f wctxt) tf), + casec = (fn wctxt => fn _ => fn cs => List.foldr (fn (l,r) => (#f wctxt) l @ r) [] cs), + letc = (fn wctxt => fn _ => fn exp => fn _ => fn t => (#f wctxt) exp @ (#f wctxt) t) + }) (get_r T_t) + fun build _ [] = [] + | build i (a::args) = + (if contains frees a then [(T_ident,i)] else []) @ build (i+1) args +in + build 0 T_args +end +fun find_simplifyble ctxt term terms = +let + val used = + terms + |> List.map find_used' + |> List.foldr (op @) [] + val change = + Option.valOf o fun_to_time ctxt term + fun detect t i (Type ("fun",_)::args) = + (if contains used (change t,i) then [] else [i]) @ detect t (i+1) args + | detect t i (_::args) = detect t (i+1) args + | detect _ _ [] = [] +in + map (fn t => t |> type_of |> strip_type |> fst |> detect t 0) term +end + +fun define_simp' term simplifyable ctxt = +let + val base_name = case Named_Target.locale_of ctxt of + NONE => ctxt |> Proof_Context.theory_of |> Context.theory_base_name + | SOME nm => nm + + val orig_name = term |> dest_Const_name |> split_name |> List.last + val red_name = fun_name_to_time ctxt false orig_name + val name = fun_name_to_time' ctxt true true orig_name + val full_name = base_name ^ "." ^ name + val def_name = red_name ^ "_def" + val def = Binding.name def_name + + val canon = Syntax.read_term (Local_Theory.exit ctxt) name |> strip_comb + val canonFrees = canon |> snd + val canonType = canon |> fst |> dest_Const_type |> strip_type |> fst |> take (length canonFrees) + + val types = term |> dest_Const_type |> strip_type |> fst + val vars = Variable.variant_fixes (map (K "") types) ctxt |> fst + fun l_typs' i ((T as (Type ("fun",_)))::types) = + (if contains simplifyable i + then change_typ T + else HOLogic.mk_prodT (T,change_typ T)) + :: l_typs' (i+1) types + | l_typs' i (T::types) = T :: l_typs' (i+1) types + | l_typs' _ [] = [] + val l_typs = l_typs' 0 types + val lhs = + List.foldl (fn ((v,T),t) => t $ Free (v,T)) (Free (red_name,l_typs ---> HOLogic.natT)) (ListPair.zip (vars,l_typs)) + fun fixType (TFree _) = HOLogic.natT + | fixType T = T + fun fixUnspecified T = T |> strip_type ||> fixType |> (op --->) + fun r_terms' i (v::vars) ((T as (Type ("fun",_)))::types) = + (if contains simplifyable i + then HOLogic.mk_prod (Const ("HOL.undefined", fixUnspecified T), Free (v,change_typ T)) + else Free (v,HOLogic.mk_prodT (T,change_typ T))) + :: r_terms' (i+1) vars types + | r_terms' i (v::vars) (T::types) = Free (v,T) :: r_terms' (i+1) vars types + | r_terms' _ _ _ = [] + val r_terms = r_terms' 0 vars types + val full_type = (r_terms |> map (type_of) ---> HOLogic.natT) + val full = list_comb (Const (full_name,canonType ---> full_type), canonFrees) + val rhs = list_comb (full, r_terms) + val eq = (lhs, rhs) |> HOLogic.mk_eq |> HOLogic.mk_Trueprop + val _ = Pretty.writeln (Pretty.block [Pretty.str "Defining simplified version:\n", + Syntax.pretty_term ctxt eq]) + + val (_, ctxt') = Specification.definition NONE [] [] ((def, []), eq) ctxt + +in + ((def_name, orig_name), ctxt') +end +fun define_simp simpables ctxt = +let + fun cond ((term,simplifyable),(defs,ctxt)) = + define_simp' term simplifyable ctxt |>> (fn def => def :: defs) +in + List.foldr cond ([], ctxt) simpables +end + + +fun replace from to = + map (map_aterms (fn t => if t = from then to else t)) +fun replaceAll [] = I + | replaceAll ((from,to)::xs) = replaceAll xs o replace from to +fun calculateSimplifications ctxt T_terms term simpables = +let + (* Show where a simplification can take place *) + fun reportReductions (t,(i::is)) = + (Pretty.writeln (Pretty.str + ((Term.term_name t |> fun_name_to_time ctxt true) + ^ " can be simplified because only the time-function component of parameter " + ^ (Int.toString (i + 1)) ^ " is used. ")); + reportReductions (t,is)) + | reportReductions (_,[]) = () + val _ = simpables + |> map reportReductions + + (* Register definitions for simplified function *) + val (reds, ctxt) = define_simp simpables ctxt + + fun genRetype (Const (nm,T),is) = + let + val T_name = fun_name_to_time ctxt true nm |> split_name |> List.last + val from = Free (T_name,change_typ T) + val to = Free (T_name,change_typ' (not o contains is) T) + in + (from,to) + end + | genRetype _ = error "Internal error: invalid term" + val retyping = map genRetype simpables + + fun replaceArgs (pT $ (eq $ l $ r)) = + let + val (t,params) = strip_comb l + fun match (Const (f_nm,_),_) = + (fun_name_to_time ctxt true f_nm |> Long_Name.base_name) = (dest_Free t |> fst) + | match _ = false + val simps = List.find match simpables |> Option.valOf |> snd + + fun dest_Prod_snd (Free (nm, Type (_, [_, T2]))) = + Free (fun_name_to_time ctxt false nm, T2) + | dest_Prod_snd _ = error "Internal error: Argument is not a pair" + fun rep _ [] = ([],[]) + | rep i (x::xs) = + let + val (rs,args) = rep (i+1) xs + in + if contains simps i + then (x::rs,dest_Prod_snd x::args) + else (rs,x::args) + end + val (rs,params) = rep 0 params + fun fFst _ = error "Internal error: Invalid term to simplify" + fun fSnd (t as (Const _ $ f)) = + (if contains rs f + then dest_Prod_snd f + else t) + | fSnd t = t + in + (pT $ (eq + $ (list_comb (t,params)) + $ (replaceFstSndFree ctxt term fFst fSnd r + |> (fn t => replaceAll (map (fn t => (t,dest_Prod_snd t)) rs) [t]) + |> hd + ) + )) + end + | replaceArgs _ = error "Internal error: Invalid term" + + (* Calculate reduced terms *) + val T_terms_red = T_terms + |> replaceAll retyping + |> map replaceArgs + + val _ = print_lemma ctxt reds T_terms_red + val _ = + Pretty.writeln (Pretty.str "If you do not want the simplified T function, use \"time_fun [no_simp]\"") +in + ctxt +end + (* Register timing function of a given function *) -fun reg_time_func (lthy: local_theory) (term: term list) (terms: term list) print = +fun reg_time_func (lthy: local_theory) (term: term list) (terms: term list) print simp = let val _ = case time_term lthy true (hd term) @@ -508,13 +708,23 @@ of SOME _ => error ("Timing function already declared: " ^ (Term.term_name (hd term))) | NONE => () + (* Number of terms fixed by locale *) + val fixedNum = term + |> hd + |> strip_comb |> snd + |> length + (* 1. Fix all terms *) (* Exchange Var in types and terms to Free and check constraints *) val terms = map - (map_aterms fixTerms - #> map_types (map_atyps fixTypes) - #> fixPartTerms lthy term) + (map_aterms freeTerms + #> map_types (map_atyps freeTypes) + #> fixTerms lthy term fixedNum) terms + val fixedFrees = (hd term) |> strip_comb |> snd |> take fixedNum + val fixedFreesNames = map (fst o dest_Free) fixedFrees + val term = map (shortFunc fixedNum o fst o strip_comb) term + (* 2. Find properties about the function *) (* 2.1 Check if function is recursive *) @@ -525,25 +735,60 @@ - On left side change name of function to timing function - Convert right side of equation with conversion schema *) - fun convert used = map (convert_term lthy term is_rec used) - fun repeat T_terms = + fun fFst (t as (Const (_,T) $ Free (nm,_))) = + (if contains fixedFreesNames nm + then Free (nm,strip_type T |>> tl |> (op --->)) + else t) + | fFst t = t + fun fSnd (t as (Const (_,T) $ Free (nm,_))) = + (if contains fixedFreesNames nm + then Free (fun_name_to_time lthy false nm,strip_type T |>> tl |> (op --->)) + else t) + | fSnd t = t + val T_terms = map (convert_term lthy term is_rec) terms + |> map (map_r (replaceFstSndFree lthy term fFst fSnd)) + + val simpables = (if simp + then find_simplifyble lthy term T_terms + else map (K []) term) + |> (fn s => ListPair.zip (term,s)) + (* Determine if something is simpable, if so rename everything *) + val simpable = simpables |> map snd |> exists (not o null) + (* Rename to secondary if simpable *) + fun genRename (t,_) = let - val orig_used = find_used lthy term terms T_terms - val T_terms' = convert orig_used terms + val old = fun_to_time lthy term t |> Option.valOf + val new = fun_to_time' lthy term true t |> Option.valOf in - if T_terms' <> T_terms then repeat T_terms' else T_terms' + (old,new) end - val T_terms = repeat (convert (K true) terms) - val orig_used = find_used lthy term terms T_terms + val can_T_terms = if simpable + then replaceAll (map genRename simpables) T_terms + else T_terms - (* 4. Register function and prove termination *) + (* 4. Register function and prove completeness *) val names = map Term.term_name term - val timing_names = map (fun_name_to_time lthy true) names + val timing_names = map (fun_name_to_time' lthy true simpable) names val bindings = map (fn nm => (Binding.name nm, NONE, NoSyn)) timing_names fun pat_completeness_auto ctxt = Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt - val specs = map (fn eq => (((Binding.empty, []), eq), [], [])) T_terms + val specs = map (fn eq => (((Binding.empty, []), eq), [], [])) can_T_terms + (* Context for printing without showing question marks *) + val print_ctxt = lthy + |> Config.put show_question_marks false + |> Config.put show_sorts false (* Change it for debugging *) + val print_ctxt = List.foldl (fn (term, ctxt) => Variable.add_fixes_implicit term ctxt) print_ctxt (term @ T_terms) + (* Print result if print *) + val _ = if not print then () else + let + val nms = map (dest_Const_name) term + val typs = map (dest_Const_type) term + in + print_timing' print_ctxt { names=nms, terms=terms, typs=typs } + { names=timing_names, terms=can_T_terms, typs=map change_typ typs } + end + (* For partial functions sequential=true is needed in order to support them We need sequential=false to support the automatic proof of termination over dom *) @@ -556,30 +801,18 @@ Function.add_function bindings specs fun_config pat_completeness_auto lthy end - (* Context for printing without showing question marks *) - val print_ctxt = lthy - |> Config.put show_question_marks false - |> Config.put show_sorts false (* Change it for debugging *) - val print_ctxt = List.foldl (fn (term, ctxt) => Variable.add_fixes_implicit term ctxt) print_ctxt (term @ T_terms) - (* Print result if print *) - val _ = if not print then () else - let - val nms = map (fst o dest_Const) term - val used = map (used_for_const orig_used) term - val typs = map (snd o dest_Const) term - in - print_timing' print_ctxt { names=nms, terms=terms, typs=typs } - { names=timing_names, terms=T_terms, typs=map (fn (used, typ) => change_typ' used typ) (ListPair.zip (used, typs)) } - end + val (info,ctxt) = + register false + handle (ERROR _) => + register true + | Match => + register true + val ctxt = if simpable then calculateSimplifications ctxt T_terms term simpables else ctxt in - register false - handle (ERROR _) => - register true - | Match => - register true + (info,ctxt) end -fun proove_termination (term: term list) terms print (T_info: Function.info, lthy: local_theory) = +fun proove_termination (term: term list) terms (T_info: Function.info, lthy: local_theory) = let (* Start proving the termination *) val infos = SOME (map (Function.get_info lthy) term) handle Empty => NONE @@ -598,7 +831,7 @@ | args (a$(Const (_,T))) = args a |> (fn ar => ("uu",T)::ar) | args _ = [] val dom_vars = - terms |> hd |> get_l |> map_types (map_atyps fixTypes) + terms |> hd |> get_l |> map_types (map_atyps freeTypes) |> args |> Variable.variant_frees lthy [] val dom_args = List.foldl (fn (t,p) => HOLogic.mk_prod ((Free t),p)) (Free (hd dom_vars)) (tl dom_vars) @@ -619,25 +852,12 @@ Function.prove_termination NONE (auto_tac simp_lthy) lthy end - - (* Context for printing without showing question marks *) - val print_ctxt = lthy' - |> Config.put show_question_marks false - |> Config.put show_sorts false (* Change it for debugging *) - (* Print result if print *) - val _ = if not print then () else - let - val nms = map (fst o dest_Const) term - val typs = map (snd o dest_Const) term - in - print_timing' print_ctxt { names=nms, terms=terms, typs=typs } (info_pfunc time_info) - end in (time_info, lthy') end -fun reg_and_proove_time_func (lthy: local_theory) (term: term list) (terms: term list) print = - reg_time_func lthy term terms false - |> proove_termination term terms print +fun reg_and_proove_time_func (lthy: local_theory) (term: term list) (terms: term list) print simp = + reg_time_func lthy term terms print simp + |> proove_termination term terms fun fix_definition (Const ("Pure.eq", _) $ l $ r) = Const ("HOL.Trueprop", @{typ "bool \ prop"}) $ (Const ("HOL.eq", @{typ "bool \ bool \ bool"}) $ l $ r) @@ -676,45 +896,54 @@ (case suffix of NONE => I | SOME s => Config.put bsuffix ("_" ^ s)) ctxt end +fun check_opts [] = false + | check_opts ["no_simp"] = true + | check_opts (a::_) = error ("Option " ^ a ^ " is not defined") + (* Convert function into its timing function (called by command) *) -fun reg_time_fun_cmd (funcs, thms) (ctxt: local_theory) = +fun reg_time_fun_cmd ((opts, funcs), thms) (ctxt: local_theory) = let + val no_simp = check_opts opts val fterms = map (Syntax.read_term ctxt) funcs val ctxt = set_suffix fterms ctxt val (_, ctxt') = reg_and_proove_time_func ctxt fterms (case thms of NONE => get_terms ctxt (hd fterms) | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) - true + true (not no_simp) in ctxt' end (* Convert function into its timing function (called by command) with termination proof provided by user*) -fun reg_time_function_cmd (funcs, thms) (ctxt: local_theory) = +fun reg_time_function_cmd ((opts, funcs), thms) (ctxt: local_theory) = let + val no_simp = check_opts opts val fterms = map (Syntax.read_term ctxt) funcs val ctxt = set_suffix fterms ctxt val ctxt' = reg_time_func ctxt fterms (case thms of NONE => get_terms ctxt (hd fterms) | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) - true + true (not no_simp) |> snd in ctxt' end (* Convert function into its timing function (called by command) *) -fun reg_time_definition_cmd (funcs, thms) (ctxt: local_theory) = +fun reg_time_definition_cmd ((opts, funcs), thms) (ctxt: local_theory) = let + val no_simp = check_opts opts val fterms = map (Syntax.read_term ctxt) funcs val ctxt = set_suffix fterms ctxt val (_, ctxt') = reg_and_proove_time_func ctxt fterms (case thms of NONE => get_terms ctxt (hd fterms) |> check_definition |> map fix_definition | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) - true + true (not no_simp) in ctxt' end -val parser = (Scan.repeat1 Parse.prop) -- (Scan.option (Parse.keyword_improper "equations" -- Parse.thms1 >> snd)) - +val parser = (Parse.opt_attribs >> map (fst o Token.name_of_src)) + -- Scan.repeat1 Parse.prop + -- Scan.option (Parse.keyword_improper "equations" -- Parse.thms1 >> snd) +val _ = Toplevel.local_theory val _ = Outer_Syntax.local_theory @{command_keyword "time_fun"} "Defines runtime function of a function" (parser >> reg_time_fun_cmd) diff -r 87f173836d56 -r 503e5280ba72 src/HOL/Data_Structures/Define_Time_Function.thy --- a/src/HOL/Data_Structures/Define_Time_Function.thy Thu Oct 10 14:13:18 2024 +0200 +++ b/src/HOL/Data_Structures/Define_Time_Function.thy Sat Oct 12 12:45:29 2024 +0900 @@ -34,16 +34,6 @@ use \time_function\ accompanied by a \termination\ command. Limitation: The commands do not work properly in locales yet. -If \f\ is defined in a type class, one needs to set up a corresponding type class, say \T_f\, -that fixes a function \T_f\ of the corresponding type (ending in type \nat\). -For every instance of \f :: \\ one needs to create a corresponding instance of the class \T_f\. -Inside that instance one can now use \time_fun "f :: \"\ to define an instance of the function \T_f\. -For example, see the definition and instantiation of class \T_size\ in theory \Time_Funs\. -Note that we can just write \time_fun length\ because \length\ is an abbreviation for \size\ on type \list\. - -If \f\ has an argument of function type, the corresponding argument of \T_f\ is a pair -of that function argument \g\ and its corresponding time function \T_g\. - The pre-defined functions below are assumed to have constant running time. In fact, we make that constant 0. This does not change the asymptotic running time of user-defined functions using the diff -r 87f173836d56 -r 503e5280ba72 src/HOL/Data_Structures/Time_Funs.thy --- a/src/HOL/Data_Structures/Time_Funs.thy Thu Oct 10 14:13:18 2024 +0200 +++ b/src/HOL/Data_Structures/Time_Funs.thy Sat Oct 12 12:45:29 2024 +0900 @@ -34,20 +34,25 @@ time_fun map +lemma T_map_simps [simp,code]: + "T_map T_f [] = 1" + "T_map T_f (x # xs) = T_f x + T_map T_f xs + 1" +by (simp_all add: T_map_def) + lemma T_map_eq: "T_map T_f xs = (\x\xs. T_f x) + length xs + 1" by (induction xs) auto -lemmas [simp del] = T_map.simps +lemmas [simp del] = T_map_simps +time_fun filter -fun T_filter :: "('a \ nat) \ 'a list \ nat" where - "T_filter T_p [] = 1" -| "T_filter T_p (x # xs) = T_p x + T_filter T_p xs + 1" +lemma T_filter_simps [code]: + "T_filter T_P [] = 1" + "T_filter T_P (x # xs) = T_P x + T_filter T_P xs + 1" +by (simp_all add: T_filter_def) -lemma T_filter_eq: "T_filter T_p xs = (\x\xs. T_p x) + length xs + 1" - by (induction xs) auto - -lemmas [simp del] = T_filter.simps +lemma T_filter_eq: "T_filter T_P xs = (\x\xs. T_P x) + length xs + 1" +by (induction xs) (auto simp: T_filter_simps) time_fun nth