# HG changeset patch # User nipkow # Date 1738849609 -3600 # Node ID 50dd4fc40fcbdc6f0139f73c80113beed081035f # Parent 0aa2d1c132b2291a7a6d7a32036c07cd8737961d added time_partial_function command diff -r 0aa2d1c132b2 -r 50dd4fc40fcb src/HOL/Data_Structures/Define_Time_Function.ML --- a/src/HOL/Data_Structures/Define_Time_Function.ML Wed Feb 05 16:34:56 2025 +0000 +++ b/src/HOL/Data_Structures/Define_Time_Function.ML Thu Feb 06 14:46:49 2025 +0100 @@ -26,13 +26,21 @@ val print_timing': Proof.context -> pfunc -> pfunc -> unit val print_timing: Proof.context -> Function.info -> Function.info -> unit +type time_config = { + print: bool, + simp: bool, + partial: bool +} +datatype result = Function of Function.info | PartialFunction of thm val reg_and_proove_time_func: local_theory -> term list -> term list - -> bool -> bool -> Function.info * local_theory + -> time_config -> result * local_theory val reg_time_func: local_theory -> term list -> term list - -> bool -> bool -> Function.info * local_theory + -> time_config -> result * local_theory + val time_dom_tac: Proof.context -> thm -> thm list -> int -> tactic + end structure Timing_Functions : TIMING_FUNCTIONS = @@ -43,9 +51,6 @@ (* Configure config variable to adjust the suffix *) val bsuffix = Attrib.setup_config_string @{binding "time_suffix"} (K "") -(* some default values to build terms easier *) -val zero = Const (@{const_name "Groups.zero"}, HOLogic.natT) -val one = Const (@{const_name "Groups.one"}, HOLogic.natT) (* Extracts terms from function info *) fun terms_of_info (info: Function.info) = map Thm.prop_of (case #simps info of SOME s => s @@ -132,15 +137,6 @@ (* returns true if it's a let term *) fun is_let (Const (@{const_name "HOL.Let"},_)) = true | is_let _ = false -(* change type of original function to new type (_ \ ... \ _ to _ \ ... \ nat) - and replace all function arguments f with (t*T_f) if used *) -fun change_typ' used (Type ("fun", [T1, T2])) = - Type ("fun", [check_for_fun' (used 0) T1, change_typ' (fn i => used (i+1)) T2]) - | change_typ' _ _ = HOLogic.natT -and check_for_fun' true (f as Type ("fun", [_,_])) = HOLogic.mk_prodT (f, change_typ' (K false) f) - | check_for_fun' false (f as Type ("fun", [_,_])) = change_typ' (K true) f - | check_for_fun' _ t = t -val change_typ = change_typ' (K true) (* Convert string name of function to its timing equivalent *) fun fun_name_to_time' ctxt s second name = let @@ -186,28 +182,6 @@ fun const_comp (Const (nm,T)) (Const (nm',T')) = nm = nm' andalso typ_comp T T' | const_comp _ _ = false -fun time_term ctxt s (Const (nm,T)) = -let - val T_nm = fun_name_to_time ctxt s nm - val T_T = change_typ T -in -(SOME (Syntax.check_term ctxt (Const (T_nm,T_T)))) - handle (ERROR _) => - case Syntax.read_term ctxt (Long_Name.base_name T_nm) - of (Const (T_nm,T_T)) => - let - fun col_Used i (Type ("fun", [Type ("fun", _), Ts])) (Type ("fun", [T', Ts'])) = - (if is_Used T' then [i] else []) @ col_Used (i+1) Ts Ts' - | col_Used i (Type ("fun", [_, Ts])) (Type ("fun", [_, Ts'])) = col_Used (i+1) Ts Ts' - | col_Used _ _ _ = [] - in - SOME (Const (T_nm,change_typ' (contains (col_Used 0 T T_T)) T)) - end - | _ => error ("Timing function of " ^ nm ^ " is not defined") -end - | time_term _ _ _ = error "Internal error: No valid function given" - - type 'a wctxt = { ctxt: local_theory, origins: term list, @@ -262,6 +236,23 @@ | freeTerms t = t fun freeTypes (TVar ((t, _), T)) = TFree (t, T) | freeTypes t = t +fun fix_definition (Const ("Pure.eq", _) $ l $ r) = HOLogic.mk_Trueprop (HOLogic.mk_eq (l,r)) + | fix_definition t = t +fun check_definition [t] = [t] + | check_definition _ = error "Only a single definition is allowed" +fun get_terms theory (term: term) = +let + val equations = Spec_Rules.retrieve theory term + |> map #rules + |> map (map Thm.prop_of) + handle Empty => error "Function or terms of function not found" +in + equations + |> map (map fix_definition) + |> filter (List.exists + (fn t => typ_comp (t |> get_l |> strip_comb |> fst |> dest_Const |> snd) (term |> strip_comb |> fst |> dest_Const |> snd))) + |> hd +end fun fixCasecCases _ [t] = [t] | fixCasecCases wctxt (t::ts) = @@ -348,193 +339,6 @@ }) o get_r fun is_rec ctxt (term: term list) = List.foldr (or (find_rec ctxt term)) false -(* 3. Convert equations *) -(* Some Helper *) -val plusTyp = @{typ "nat => nat => nat"} -fun plus (SOME a) (SOME b) = SOME (Const (@{const_name "Groups.plus"}, plusTyp) $ a $ b) - | plus (SOME a) NONE = SOME a - | plus NONE (SOME b) = SOME b - | plus NONE NONE = NONE -fun opt_term NONE = HOLogic.zero - | opt_term (SOME t) = t -fun use_origin (Free (nm, T as Type ("fun",_))) = HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T))) - | use_origin t = t - -(* Conversion of function term *) -fun fun_to_time' ctxt (origin: term list) second (func as Const (nm,T)) = -let - val origin' = map (fst o strip_comb) origin -in - if contains' const_comp origin' func then SOME (Free (func |> Term.term_name |> fun_name_to_time' ctxt true second, change_typ T)) else - if Zero_Funcs.is_zero (Proof_Context.theory_of ctxt) (nm,T) then NONE else - time_term ctxt false func -end - | fun_to_time' _ _ _ (Free (nm,T)) = - SOME (HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T)))) - | fun_to_time' _ _ _ _ = error "Internal error: invalid function to convert" -fun fun_to_time context origin func = fun_to_time' context origin false func - -(* Convert arguments of left side of a term *) -fun conv_arg _ (Free (nm,T as Type("fun",_))) = - Free (nm, HOLogic.mk_prodT (T, change_typ' (K false) T)) - | conv_arg _ x = x -fun conv_args ctxt = map (conv_arg ctxt) - -(* Handle function calls *) -fun build_zero (Type ("fun", [T, R])) = Abs ("uu", T, build_zero R) - | build_zero _ = zero -fun funcc_use_origin (Free (nm, T as Type ("fun",_))) = - HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T))) - | funcc_use_origin t = t -fun funcc_conv_arg _ _ (t as (_ $ _)) = map_aterms funcc_use_origin t - | funcc_conv_arg _ u (Free (nm, T as Type ("fun",_))) = - if u then Free (nm, HOLogic.mk_prodT (T, change_typ T)) - else HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T))) - | funcc_conv_arg wctxt true (f as Const (_,Type ("fun",_))) = - HOLogic.mk_prod (f, funcc_conv_arg wctxt false f) - | funcc_conv_arg wctxt false (f as Const (_,T as Type ("fun",_))) = - Option.getOpt (fun_to_time (#ctxt wctxt) (#origins wctxt) f, build_zero T) - | funcc_conv_arg wctxt false (f as Abs _) = - f - |> Term.strip_abs_eta ((length o fst o strip_type o type_of) f) - ||> #f wctxt ||> opt_term - |> list_abs - | funcc_conv_arg wctxt true (f as Abs _) = - let - val f' = f - |> Term.strip_abs_eta ((length o fst o strip_type o type_of) f) - ||> map_aterms funcc_use_origin - |> list_abs - in - HOLogic.mk_prod (f', funcc_conv_arg wctxt false f) - end - | funcc_conv_arg _ _ t = t - -fun funcc_conv_args _ _ [] = [] - | funcc_conv_args wctxt (Type ("fun", [t, ts])) (a::args) = - funcc_conv_arg wctxt (is_Used t) a :: funcc_conv_args wctxt ts args - | funcc_conv_args _ _ _ = error "Internal error: Non matching type" -fun funcc wctxt func args = -let - fun get_T (Free (_,T)) = T - | get_T (Const (_,T)) = T - | get_T (_ $ (Free (_,Type (_, [_, T])))) = T (* Case of snd was constructed *) - | get_T _ = error "Internal error: Forgotten type" -in - List.foldr (I #-> plus) - (case fun_to_time (#ctxt wctxt) (#origins wctxt) func (* add case for abs *) - of SOME t => SOME (list_comb (t, funcc_conv_args wctxt (get_T t) args)) - | NONE => NONE) - (map (#f wctxt) args) -end - -(* Handle case terms *) -fun casecIsCase (Type (n1, [_,Type (n2, _)])) = (n1 = "fun" andalso n2 = "fun") - | casecIsCase _ = false -fun casecLastTyp (Type (n, [T1,T2])) = Type (n, [T1, change_typ T2]) - | casecLastTyp _ = error "Internal error: Invalid case type" -fun casecTyp (Type (n, [T1, T2])) = - Type (n, [change_typ T1, (if casecIsCase T2 then casecTyp else casecLastTyp) T2]) - | casecTyp _ = error "Internal error: Invalid case type" -fun casecAbs f (Abs (v,Ta,t)) = (case casecAbs f (subst_bound (Free (v,Ta), t)) - of (nconst,t) => (nconst,absfree (v,Ta) t)) - | casecAbs f t = (case f t of NONE => (false,HOLogic.zero) | SOME t => (true,t)) -fun casecArgs _ [t] = (false, [map_aterms use_origin t]) - | casecArgs f (t::ar) = - (case casecAbs f t of (nconst, tt) => - casecArgs f ar ||> (fn ar => tt :: ar) |>> (if nconst then K true else I)) - | casecArgs _ _ = error "Internal error: Invalid case term" -fun casec wctxt (Const (t,T)) args = - if not (casecIsCase T) then error "Internal error: Invalid case type" else - let val (nconst, args') = casecArgs (#f wctxt) args in - plus - ((#f wctxt) (List.last args)) - (if nconst then - SOME (list_comb (Const (t,casecTyp T), args')) - else NONE) - end - | casec _ _ _ = error "Internal error: Invalid case term" - -(* Handle if terms -> drop the term if true and false terms are zero *) -fun ifc wctxt _ cond tt ft = - let - val f = #f wctxt - val rcond = map_aterms use_origin cond - val tt = f tt - val ft = f ft - in - plus (f cond) (case (tt,ft) of (NONE, NONE) => NONE | _ => - if tt = ft then tt else - (SOME ((Const (@{const_name "HOL.If"}, @{typ "bool \ nat \ nat \ nat"})) $ rcond $ (opt_term tt) $ (opt_term ft)))) - end - -fun letc_lambda wctxt T (t as Abs _) = - HOLogic.mk_prod (map_aterms use_origin t, - Term.strip_abs_eta (strip_type T |> fst |> length) t ||> #f wctxt ||> opt_term |> list_abs) - | letc_lambda _ _ t = map_aterms use_origin t -fun letc wctxt expT exp ([(nm,_)]) t = - plus (#f wctxt exp) - (case #f wctxt t of SOME t' => - (if Term.used_free nm t' - then - let - val exp' = letc_lambda wctxt expT exp - val t' = list_abs ([(nm,fastype_of exp')], t') - in - Const (@{const_name "HOL.Let"}, [fastype_of exp', fastype_of t'] ---> HOLogic.natT) $ exp' $ t' - end - else t') |> SOME - | NONE => NONE) - | letc _ _ _ _ _ = error "Unknown let state" - -fun constc _ (Const ("HOL.undefined", _)) = SOME (Const ("HOL.undefined", @{typ "nat"})) - | constc _ _ = NONE - -(* The converter for timing functions given to the walker *) -val converter : term option converter = { - constc = constc, - funcc = funcc, - ifc = ifc, - casec = casec, - letc = letc - } -fun top_converter is_rec _ _ = opt_term o (fn exp => plus exp (if is_rec then SOME one else NONE)) - -(* Use converter to convert right side of a term *) -fun to_time ctxt origin is_rec term = - top_converter is_rec ctxt origin (walk ctxt origin converter term) - -(* Converts a term to its running time version *) -fun convert_term ctxt (origin: term list) is_rec (pT $ (Const (eqN, _) $ l $ r)) = -let - val (l_const, l_params) = strip_comb l -in - pT - $ (Const (eqN, @{typ "nat \ nat \ bool"}) - $ (list_comb (l_const |> fun_to_time ctxt origin |> Option.valOf, l_params |> conv_args ctxt)) - $ (to_time ctxt origin is_rec r)) -end - | convert_term _ _ _ _ = error "Internal error: invalid term to convert" - -(* 3.5 Support for locales *) -fun replaceFstSndFree ctxt (origin: term list) (rfst: term -> term) (rsnd: term -> term) = - (walk ctxt origin { - funcc = fn wctxt => fn t => fn args => - case args of - (f as Free _)::args => - (case t of - Const ("Product_Type.prod.fst", _) => - list_comb (rfst (t $ f), map (#f wctxt) args) - | Const ("Product_Type.prod.snd", _) => - list_comb (rsnd (t $ f), map (#f wctxt) args) - | t => list_comb (t, map (#f wctxt) (f :: args))) - | args => list_comb (t, map (#f wctxt) args), - constc = Iconst, - ifc = Iif, - casec = Icase, - letc = Ilet - }) - (* 4. Tactic to prove "f_dom n" *) fun time_dom_tac ctxt induct_rule domintros = (Induction.induction_tac ctxt true [] [[]] [] (SOME [induct_rule]) [] @@ -542,211 +346,445 @@ (if i <= length domintros then [Metis_Tactic.metis_tac [] ATP_Problem_Generate.combsN ctxt [List.nth (domintros, i-1)]] else []) @ [Metis_Tactic.metis_tac [] ATP_Problem_Generate.combsN ctxt domintros]) i))) - -fun fix_definition (Const ("Pure.eq", _) $ l $ r) = HOLogic.mk_Trueprop (HOLogic.mk_eq (l,r)) - | fix_definition t = t -fun check_definition [t] = [t] - | check_definition _ = error "Only a single definition is allowed" -fun get_terms theory (term: term) = -let - val equations = Spec_Rules.retrieve theory term - |> map #rules - |> map (map Thm.prop_of) - handle Empty => error "Function or terms of function not found" -in - equations - |> map (map fix_definition) - |> filter (List.exists - (fn t => typ_comp (t |> get_l |> strip_comb |> fst |> dest_Const |> snd) (term |> strip_comb |> fst |> dest_Const |> snd))) - |> hd -end - -(* 5. Check for higher-order function if original function is used \ find simplifications *) -fun find_used' T_t = -let - val (T_ident, T_args) = strip_comb (get_l T_t) +(* Register timing function of a given function *) +type time_config = { + print: bool, + simp: bool, + partial: bool +} +datatype result = Function of Function.info | PartialFunction of thm +fun reg_time_func (lthy: local_theory) (term: term list) (terms: term list) (config: time_config) = + let + (* some default values to build terms easier *) + (* Const (@{const_name "Groups.zero"}, HOLogic.natT) *) + val zero = if #partial config then @{term "Some (0::nat)"} else HOLogic.zero + val one = Const (@{const_name "Groups.one"}, HOLogic.natT) + val natOptT = @{typ "nat option"} + val finT = if #partial config then natOptT else HOLogic.natT + val some = @{term "Some::nat \ nat option"} - fun filter_passed [] = [] - | filter_passed ((f as Free (_, Type ("Product_Type.prod",[Type ("fun",_), Type ("fun", _)])))::args) = - f :: filter_passed args - | filter_passed (_::args) = filter_passed args - val frees = (walk @{context} [] { - funcc = (fn wctxt => fn t => fn args => - (case t of (Const ("Product_Type.prod.snd", _)) => [] - | _ => (if t = T_ident then [] else filter_passed args) - @ List.foldr (fn (l,r) => (#f wctxt) l @ r) [] args)), - constc = (K o K) [], - ifc = (fn wctxt => fn _ => fn cond => fn tt => fn tf => (#f wctxt) cond @ (#f wctxt) tt @ (#f wctxt) tf), - casec = (fn wctxt => fn _ => fn cs => List.foldr (fn (l,r) => (#f wctxt) l @ r) [] cs), - letc = (fn wctxt => fn _ => fn exp => fn _ => fn t => (#f wctxt) exp @ (#f wctxt) t) - }) (get_r T_t) - fun build _ [] = [] - | build i (a::args) = - (if contains frees a then [(T_ident,i)] else []) @ build (i+1) args -in - build 0 T_args -end -fun find_simplifyble ctxt term terms = -let - val used = - terms - |> List.map find_used' - |> List.foldr (op @) [] - val change = - Option.valOf o fun_to_time ctxt term - fun detect t i (Type ("fun",_)::args) = - (if contains used (change t,i) then [] else [i]) @ detect t (i+1) args - | detect t i (_::args) = detect t (i+1) args - | detect _ _ [] = [] -in - map (fn t => t |> type_of |> strip_type |> fst |> detect t 0) term -end + (* change type of original function to new type (_ \ ... \ _ to _ \ ... \ nat) + and replace all function arguments f with (t*T_f) if used *) + fun change_typ' used (Type ("fun", [T1, T2])) = + Type ("fun", [check_for_fun' (used 0) T1, change_typ' (fn i => used (i+1)) T2]) + | change_typ' _ _ = finT + and check_for_fun' true (f as Type ("fun", [_,_])) = HOLogic.mk_prodT (f, change_typ' (K false) f) + | check_for_fun' false (f as Type ("fun", [_,_])) = change_typ' (K true) f + | check_for_fun' _ t = t + val change_typ = change_typ' (K true) + fun time_term ctxt s (Const (nm,T)) = + let + val T_nm = fun_name_to_time ctxt s nm + val T_T = change_typ T + in + (SOME (Syntax.check_term ctxt (Const (T_nm,T_T)))) + handle (ERROR _) => + case Syntax.read_term ctxt (Long_Name.base_name T_nm) + of (Const (T_nm,T_T)) => + let + fun col_Used i (Type ("fun", [Type ("fun", _), Ts])) (Type ("fun", [T', Ts'])) = + (if is_Used T' then [i] else []) @ col_Used (i+1) Ts Ts' + | col_Used i (Type ("fun", [_, Ts])) (Type ("fun", [_, Ts'])) = col_Used (i+1) Ts Ts' + | col_Used _ _ _ = [] + val binderT = change_typ' (contains (col_Used 0 T T_T)) T |> Term.binder_types + val finT = Term.body_type T_T + in + SOME (Const (T_nm, binderT ---> finT)) + end + | _ => error ("Timing function of " ^ nm ^ " is not defined") + end + | time_term _ _ _ = error "Internal error: No valid function given" -fun define_simp' term simplifyable ctxt = -let - val base_name = case Named_Target.locale_of ctxt of - NONE => ctxt |> Proof_Context.theory_of |> Context.theory_base_name - | SOME nm => nm - - val orig_name = term |> dest_Const_name |> split_name |> List.last - val red_name = fun_name_to_time ctxt false orig_name - val name = fun_name_to_time' ctxt true true orig_name - val full_name = base_name ^ "." ^ name - val def_name = red_name ^ "_def" - val def = Binding.name def_name - - val canon = Syntax.read_term (Local_Theory.exit ctxt) name |> strip_comb - val canonFrees = canon |> snd - val canonType = canon |> fst |> dest_Const_type |> strip_type |> fst |> take (length canonFrees) + fun opt_term NONE = zero + | opt_term (SOME t) = t + fun use_origin (Free (nm, T as Type ("fun",_))) = HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T))) + | use_origin t = t + + (* Conversion of function term *) + fun fun_to_time' ctxt (origin: term list) second (func as Const (nm,T)) = + let + val origin' = map (fst o strip_comb) origin + in + if contains' const_comp origin' func then SOME (Free (func |> Term.term_name |> fun_name_to_time' ctxt true second, change_typ T)) else + if Zero_Funcs.is_zero (Proof_Context.theory_of ctxt) (nm,T) then NONE else + time_term ctxt false func + end + | fun_to_time' _ _ _ (Free (nm,T)) = + SOME (HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ' (K true) T)))) + | fun_to_time' _ _ _ _ = error "Internal error: invalid function to convert" + fun fun_to_time context origin func = fun_to_time' context origin false func + + (* Convert arguments of left side of a term *) + fun conv_arg _ (Free (nm,T as Type("fun",_))) = + Free (nm, HOLogic.mk_prodT (T, change_typ' (K false) T)) + | conv_arg _ x = x + fun conv_args ctxt = map (conv_arg ctxt) + + (* 3. Convert equations *) + (* Some Helper *) + val plusTyp = @{typ "nat => nat => nat"} + fun plus (SOME a) (SOME b) = SOME ((if #partial config then @{term part_add} else Const (@{const_name "Groups.plus"}, plusTyp)) $ a $ b) + | plus (SOME a) NONE = SOME a + | plus NONE (SOME b) = SOME b + | plus NONE NONE = NONE + (* Partial helper *) + val OPTION_BIND = @{term "Option.bind::nat option \ (nat \ nat option) \ nat option"} + fun OPTION_ABS_SUC args = Term.absfree ("_uu", @{typ nat}) + (List.foldr (uncurry plus) + (SOME (some $ HOLogic.mk_Suc (Free ("_uu", @{typ nat})))) args |> Option.valOf) + fun build_option_bind term args = + OPTION_BIND $ term $ OPTION_ABS_SUC args + fun WRAP_FUNCTION t = + if (Term.head_of t |> Term.fastype_of |> Term.body_type) = finT + then t + else if #partial config + then some $ t + else @{term "the::nat option \ nat"} $ t - val types = term |> dest_Const_type |> strip_type |> fst - val vars = Variable.variant_fixes (map (K "") types) ctxt |> fst - fun l_typs' i ((T as (Type ("fun",_)))::types) = - (if contains simplifyable i - then change_typ T - else HOLogic.mk_prodT (T,change_typ T)) - :: l_typs' (i+1) types - | l_typs' i (T::types) = T :: l_typs' (i+1) types - | l_typs' _ [] = [] - val l_typs = l_typs' 0 types - val lhs = - List.foldl (fn ((v,T),t) => t $ Free (v,T)) (Free (red_name,l_typs ---> HOLogic.natT)) (ListPair.zip (vars,l_typs)) - fun fixType (TFree _) = HOLogic.natT - | fixType T = T - fun fixUnspecified T = T |> strip_type ||> fixType |> (op --->) - fun r_terms' i (v::vars) ((T as (Type ("fun",_)))::types) = - (if contains simplifyable i - then HOLogic.mk_prod (Const ("HOL.undefined", fixUnspecified T), Free (v,change_typ T)) - else Free (v,HOLogic.mk_prodT (T,change_typ T))) - :: r_terms' (i+1) vars types - | r_terms' i (v::vars) (T::types) = Free (v,T) :: r_terms' (i+1) vars types - | r_terms' _ _ _ = [] - val r_terms = r_terms' 0 vars types - val full_type = (r_terms |> map (type_of) ---> HOLogic.natT) - val full = list_comb (Const (full_name,canonType ---> full_type), canonFrees) - val rhs = list_comb (full, r_terms) - val eq = (lhs, rhs) |> HOLogic.mk_eq |> HOLogic.mk_Trueprop - val _ = Pretty.writeln (Pretty.block [Pretty.str "Defining simplified version:\n", - Syntax.pretty_term ctxt eq]) - - val (_, ctxt') = Specification.definition NONE [] [] ((def, []), eq) ctxt - -in - ((def_name, orig_name), ctxt') -end -fun define_simp simpables ctxt = -let - fun cond ((term,simplifyable),(defs,ctxt)) = - define_simp' term simplifyable ctxt |>> (fn def => def :: defs) -in - List.foldr cond ([], ctxt) simpables -end - - -fun replace from to = - map (map_aterms (fn t => if t = from then to else t)) -fun replaceAll [] = I - | replaceAll ((from,to)::xs) = replaceAll xs o replace from to -fun calculateSimplifications ctxt T_terms term simpables = -let - (* Show where a simplification can take place *) - fun reportReductions (t,(i::is)) = - (Pretty.writeln (Pretty.str - ((Term.term_name t |> fun_name_to_time ctxt true) - ^ " can be simplified because only the time-function component of parameter " - ^ (Int.toString (i + 1)) ^ " is used. ")); - reportReductions (t,is)) - | reportReductions (_,[]) = () - val _ = simpables - |> map reportReductions - - (* Register definitions for simplified function *) - val (reds, ctxt) = define_simp simpables ctxt + (* Handle function calls *) + fun build_zero (Type ("fun", [T, R])) = Abs ("uu", T, build_zero R) + | build_zero _ = zero + fun funcc_use_origin (Free (nm, T as Type ("fun",_))) = + HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T))) + | funcc_use_origin t = t + fun funcc_conv_arg _ _ (t as (_ $ _)) = map_aterms funcc_use_origin t + | funcc_conv_arg _ u (Free (nm, T as Type ("fun",_))) = + if u then Free (nm, HOLogic.mk_prodT (T, change_typ T)) + else HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T))) + | funcc_conv_arg wctxt true (f as Const (_,Type ("fun",_))) = + HOLogic.mk_prod (f, funcc_conv_arg wctxt false f) + | funcc_conv_arg wctxt false (f as Const (_,T as Type ("fun",_))) = + Option.getOpt (fun_to_time (#ctxt wctxt) (#origins wctxt) f, build_zero T) + | funcc_conv_arg wctxt false (f as Abs _) = + f + |> Term.strip_abs_eta ((length o fst o strip_type o type_of) f) + ||> #f wctxt ||> opt_term + |> list_abs + | funcc_conv_arg wctxt true (f as Abs _) = + let + val f' = f + |> Term.strip_abs_eta ((length o fst o strip_type o type_of) f) + ||> map_aterms funcc_use_origin + |> list_abs + in + HOLogic.mk_prod (f', funcc_conv_arg wctxt false f) + end + | funcc_conv_arg _ _ t = t + + fun funcc_conv_args _ _ [] = [] + | funcc_conv_args wctxt (Type ("fun", [t, ts])) (a::args) = + funcc_conv_arg wctxt (is_Used t) a :: funcc_conv_args wctxt ts args + | funcc_conv_args _ _ _ = error "Internal error: Non matching type" + fun funcc wctxt func args = + let + fun get_T (Free (_,T)) = T + | get_T (Const (_,T)) = T + | get_T (_ $ (Free (_,Type (_, [_, T])))) = T (* Case of snd was constructed *) + | get_T _ = error "Internal error: Forgotten type" + val func = (case fun_to_time (#ctxt wctxt) (#origins wctxt) func + of SOME t => SOME (WRAP_FUNCTION (list_comb (t, funcc_conv_args wctxt (get_T t) args))) + | NONE => NONE) + val args = (map (#f wctxt) args) + in + (if not (#partial config) orelse func = NONE + then List.foldr (uncurry plus) func args + else build_option_bind (Option.valOf func) args |> SOME) + end + + (* Handle case terms *) + fun casecIsCase (Type (n1, [_,Type (n2, _)])) = (n1 = "fun" andalso n2 = "fun") + | casecIsCase _ = false + fun casecLastTyp (Type (n, [T1,T2])) = Type (n, [T1, change_typ T2]) + | casecLastTyp _ = error "Internal error: Invalid case type" + fun casecTyp (Type (n, [T1, T2])) = + Type (n, [change_typ T1, (if casecIsCase T2 then casecTyp else casecLastTyp) T2]) + | casecTyp _ = error "Internal error: Invalid case type" + fun casecAbs f (Abs (v,Ta,t)) = (case casecAbs f (subst_bound (Free (v,Ta), t)) + of (nconst,t) => (nconst,absfree (v,Ta) t)) + | casecAbs f t = (case f t of NONE => (false, opt_term NONE) | SOME t => (true,t)) + fun casecArgs _ [t] = (false, [map_aterms use_origin t]) + | casecArgs f (t::ar) = + (case casecAbs f t of (nconst, tt) => + casecArgs f ar ||> (fn ar => tt :: ar) |>> (if nconst then K true else I)) + | casecArgs _ _ = error "Internal error: Invalid case term" + fun casec wctxt (Const (t,T)) args = + if not (casecIsCase T) then error "Internal error: Invalid case type" else + let val (nconst, args') = casecArgs (#f wctxt) args in + plus + ((#f wctxt) (List.last args)) + (if nconst then + SOME (list_comb (Const (t,casecTyp T), args')) + else NONE) + end + | casec _ _ _ = error "Internal error: Invalid case term" + + (* Handle if terms -> drop the term if true and false terms are zero *) + fun ifc wctxt _ cond tt ft = + let + val f = #f wctxt + val rcond = map_aterms use_origin cond + val tt = f tt + val ft = f ft + in + plus (f cond) (case (tt,ft) of (NONE, NONE) => NONE | _ => + if tt = ft then tt else + (SOME ((Const (@{const_name "HOL.If"}, @{typ "bool"} --> finT --> finT --> finT)) $ rcond + $ (opt_term tt) $ (opt_term ft)))) + end + + fun letc_lambda wctxt T (t as Abs _) = + HOLogic.mk_prod (map_aterms use_origin t, + Term.strip_abs_eta (strip_type T |> fst |> length) t ||> #f wctxt ||> opt_term ||> map_types change_typ |> list_abs) + | letc_lambda _ _ t = map_aterms use_origin t + fun letc wctxt expT exp ([(nm,_)]) t = + plus (#f wctxt exp) + (case #f wctxt t of SOME t' => + (if Term.used_free nm t' + then + let + val exp' = letc_lambda wctxt expT exp + val t' = list_abs ([(nm,fastype_of exp')], t') + in + Const (@{const_name "HOL.Let"}, [fastype_of exp', fastype_of t'] ---> finT) $ exp' $ t' + end + else t') |> SOME + | NONE => NONE) + | letc _ _ _ _ _ = error "Unknown let state" + + fun constc _ (Const ("HOL.undefined", _)) = SOME (Const ("HOL.undefined", finT)) + | constc _ _ = NONE + + (* The converter for timing functions given to the walker *) + val converter : term option converter = { + constc = constc, + funcc = funcc, + ifc = ifc, + casec = casec, + letc = letc + } + fun top_converter is_rec _ _ = + if #partial config + then (fn t => Option.getOpt (t, zero)) + else (opt_term o (fn exp => plus exp (if is_rec then SOME one else NONE))) + + (* Use converter to convert right side of a term *) + fun to_time ctxt origin is_rec term = + top_converter is_rec ctxt origin (walk ctxt origin converter term) + + (* Converts a term to its running time version *) + fun convert_term ctxt (origin: term list) is_rec (pT $ (Const (eqN, _) $ l $ r)) = + let + val (l_const, l_params) = strip_comb l + in + pT + $ (Const (eqN, finT --> finT --> @{typ "bool"}) + $ (list_comb (l_const |> fun_to_time ctxt origin |> Option.valOf, l_params |> conv_args ctxt)) + $ (to_time ctxt origin is_rec r)) + end + | convert_term _ _ _ _ = error "Internal error: invalid term to convert" + + (* 3.5 Support for locales *) + fun replaceFstSndFree ctxt (origin: term list) (rfst: term -> term) (rsnd: term -> term) = + (walk ctxt origin { + funcc = fn wctxt => fn t => fn args => + case args of + (f as Free _)::args => + (case t of + Const ("Product_Type.prod.fst", _) => + list_comb (rfst (t $ f), map (#f wctxt) args) + | Const ("Product_Type.prod.snd", _) => + list_comb (rsnd (t $ f), map (#f wctxt) args) + | t => list_comb (t, map (#f wctxt) (f :: args))) + | args => list_comb (t, map (#f wctxt) args), + constc = Iconst, + ifc = Iif, + casec = Icase, + letc = Ilet + }) + + (* 5. Check for higher-order function if original function is used \ find simplifications *) + fun find_used' T_t = + let + val (T_ident, T_args) = strip_comb (get_l T_t) + + fun filter_passed [] = [] + | filter_passed ((f as Free (_, Type ("Product_Type.prod",[Type ("fun",_), Type ("fun", _)])))::args) = + f :: filter_passed args + | filter_passed (_::args) = filter_passed args + val frees = (walk lthy [] { + funcc = (fn wctxt => fn t => fn args => + (case t of (Const ("Product_Type.prod.snd", _)) => [] + | _ => (if t = T_ident then [] else filter_passed args) + @ List.foldr (fn (l,r) => (#f wctxt) l @ r) [] args)), + constc = (K o K) [], + ifc = (fn wctxt => fn _ => fn cond => fn tt => fn tf => (#f wctxt) cond @ (#f wctxt) tt @ (#f wctxt) tf), + casec = (fn wctxt => fn _ => fn cs => List.foldr (fn (l,r) => (#f wctxt) l @ r) [] cs), + letc = (fn wctxt => fn _ => fn exp => fn _ => fn t => (#f wctxt) exp @ (#f wctxt) t) + }) (get_r T_t) + fun build _ [] = [] + | build i (a::args) = + (if contains frees a then [(T_ident,i)] else []) @ build (i+1) args + in + build 0 T_args + end + fun find_simplifyble ctxt term terms = + let + val used = + terms + |> List.map find_used' + |> List.foldr (op @) [] + val change = + Option.valOf o fun_to_time ctxt term + fun detect t i (Type ("fun",_)::args) = + (if contains used (change t,i) then [] else [i]) @ detect t (i+1) args + | detect t i (_::args) = detect t (i+1) args + | detect _ _ [] = [] + in + map (fn t => t |> type_of |> strip_type |> fst |> detect t 0) term + end + + fun define_simp' term simplifyable ctxt = + let + val base_name = case Named_Target.locale_of ctxt of + NONE => ctxt |> Proof_Context.theory_of |> Context.theory_base_name + | SOME nm => nm + + val orig_name = term |> dest_Const_name |> split_name |> List.last + val red_name = fun_name_to_time ctxt false orig_name + val name = fun_name_to_time' ctxt true true orig_name + val full_name = base_name ^ "." ^ name + val def_name = red_name ^ "_def" + val def = Binding.name def_name + + val canon = Syntax.read_term (Local_Theory.exit ctxt) name |> strip_comb + val canonFrees = canon |> snd + val canonType = canon |> fst |> dest_Const_type |> strip_type |> fst |> take (length canonFrees) + + val types = term |> dest_Const_type |> strip_type |> fst + val vars = Variable.variant_fixes (map (K "") types) ctxt |> fst + fun l_typs' i ((T as (Type ("fun",_)))::types) = + (if contains simplifyable i + then change_typ T + else HOLogic.mk_prodT (T,change_typ T)) + :: l_typs' (i+1) types + | l_typs' i (T::types) = T :: l_typs' (i+1) types + | l_typs' _ [] = [] + val l_typs = l_typs' 0 types + val lhs = + List.foldl (fn ((v,T),t) => t $ Free (v,T)) (Free (red_name,l_typs ---> HOLogic.natT)) (ListPair.zip (vars,l_typs)) + fun fixType (TFree _) = HOLogic.natT + | fixType T = T + fun fixUnspecified T = T |> strip_type ||> fixType |> (op --->) + fun r_terms' i (v::vars) ((T as (Type ("fun",_)))::types) = + (if contains simplifyable i + then HOLogic.mk_prod (Const ("HOL.undefined", fixUnspecified T), Free (v,change_typ T)) + else Free (v,HOLogic.mk_prodT (T,change_typ T))) + :: r_terms' (i+1) vars types + | r_terms' i (v::vars) (T::types) = Free (v,T) :: r_terms' (i+1) vars types + | r_terms' _ _ _ = [] + val r_terms = r_terms' 0 vars types + val full_type = (r_terms |> map (type_of) ---> HOLogic.natT) + val full = list_comb (Const (full_name,canonType ---> full_type), canonFrees) + val rhs = list_comb (full, r_terms) + val eq = (lhs, rhs) |> HOLogic.mk_eq |> HOLogic.mk_Trueprop + val _ = Pretty.writeln (Pretty.block [Pretty.str "Defining simplified version:\n", + Syntax.pretty_term ctxt eq]) + + val (_, ctxt') = Specification.definition NONE [] [] ((def, []), eq) ctxt + + in + ((def_name, orig_name), ctxt') + end + fun define_simp simpables ctxt = + let + fun cond ((term,simplifyable),(defs,ctxt)) = + define_simp' term simplifyable ctxt |>> (fn def => def :: defs) + in + List.foldr cond ([], ctxt) simpables + end + + + fun replace from to = + map (map_aterms (fn t => if t = from then to else t)) + fun replaceAll [] = I + | replaceAll ((from,to)::xs) = replaceAll xs o replace from to + fun calculateSimplifications ctxt T_terms term simpables = + let + (* Show where a simplification can take place *) + fun reportReductions (t,(i::is)) = + (Pretty.writeln (Pretty.str + ((Term.term_name t |> fun_name_to_time ctxt true) + ^ " can be simplified because only the time-function component of parameter " + ^ (Int.toString (i + 1)) ^ " is used. ")); + reportReductions (t,is)) + | reportReductions (_,[]) = () + val _ = simpables + |> map reportReductions + + (* Register definitions for simplified function *) + val (reds, ctxt) = define_simp simpables ctxt + + fun genRetype (Const (nm,T),is) = + let + val T_name = fun_name_to_time ctxt true nm |> split_name |> List.last + val from = Free (T_name,change_typ T) + val to = Free (T_name,change_typ' (not o contains is) T) + in + (from,to) + end + | genRetype _ = error "Internal error: invalid term" + val retyping = map genRetype simpables + + fun replaceArgs (pT $ (eq $ l $ r)) = + let + val (t,params) = strip_comb l + fun match (Const (f_nm,_),_) = + (fun_name_to_time ctxt true f_nm |> Long_Name.base_name) = (dest_Free t |> fst) + | match _ = false + val simps = List.find match simpables |> Option.valOf |> snd + + fun dest_Prod_snd (Free (nm, Type (_, [_, T2]))) = + Free (fun_name_to_time ctxt false nm, T2) + | dest_Prod_snd _ = error "Internal error: Argument is not a pair" + fun rep _ [] = ([],[]) + | rep i (x::xs) = + let + val (rs,args) = rep (i+1) xs + in + if contains simps i + then (x::rs,dest_Prod_snd x::args) + else (rs,x::args) + end + val (rs,params) = rep 0 params + fun fFst _ = error "Internal error: Invalid term to simplify" + fun fSnd (t as (Const _ $ f)) = + (if contains rs f + then dest_Prod_snd f + else t) + | fSnd t = t + in + (pT $ (eq + $ (list_comb (t,params)) + $ (replaceFstSndFree ctxt term fFst fSnd r + |> (fn t => replaceAll (map (fn t => (t,dest_Prod_snd t)) rs) [t]) + |> hd + ) + )) + end + | replaceArgs _ = error "Internal error: Invalid term" + + (* Calculate reduced terms *) + val T_terms_red = T_terms + |> replaceAll retyping + |> map replaceArgs + + val _ = print_lemma ctxt reds T_terms_red + val _ = + Pretty.writeln (Pretty.str "If you do not want the simplified T function, use \"time_fun [no_simp]\"") + in + ctxt + end - fun genRetype (Const (nm,T),is) = - let - val T_name = fun_name_to_time ctxt true nm |> split_name |> List.last - val from = Free (T_name,change_typ T) - val to = Free (T_name,change_typ' (not o contains is) T) - in - (from,to) - end - | genRetype _ = error "Internal error: invalid term" - val retyping = map genRetype simpables - - fun replaceArgs (pT $ (eq $ l $ r)) = - let - val (t,params) = strip_comb l - fun match (Const (f_nm,_),_) = - (fun_name_to_time ctxt true f_nm |> Long_Name.base_name) = (dest_Free t |> fst) - | match _ = false - val simps = List.find match simpables |> Option.valOf |> snd - - fun dest_Prod_snd (Free (nm, Type (_, [_, T2]))) = - Free (fun_name_to_time ctxt false nm, T2) - | dest_Prod_snd _ = error "Internal error: Argument is not a pair" - fun rep _ [] = ([],[]) - | rep i (x::xs) = - let - val (rs,args) = rep (i+1) xs - in - if contains simps i - then (x::rs,dest_Prod_snd x::args) - else (rs,x::args) - end - val (rs,params) = rep 0 params - fun fFst _ = error "Internal error: Invalid term to simplify" - fun fSnd (t as (Const _ $ f)) = - (if contains rs f - then dest_Prod_snd f - else t) - | fSnd t = t - in - (pT $ (eq - $ (list_comb (t,params)) - $ (replaceFstSndFree ctxt term fFst fSnd r - |> (fn t => replaceAll (map (fn t => (t,dest_Prod_snd t)) rs) [t]) - |> hd - ) - )) - end - | replaceArgs _ = error "Internal error: Invalid term" - - (* Calculate reduced terms *) - val T_terms_red = T_terms - |> replaceAll retyping - |> map replaceArgs - - val _ = print_lemma ctxt reds T_terms_red - val _ = - Pretty.writeln (Pretty.str "If you do not want the simplified T function, use \"time_fun [no_simp]\"") -in - ctxt -end - -(* Register timing function of a given function *) -fun reg_time_func (lthy: local_theory) (term: term list) (terms: term list) print simp = - let val _ = case time_term lthy true (hd term) handle (ERROR _) => NONE @@ -758,6 +796,7 @@ |> strip_comb |> snd |> length + (********************* BEGIN OF CONVERSION *********************) (* 1. Fix all terms *) (* Exchange Var in types and terms to Free and check constraints *) val terms = map @@ -799,7 +838,7 @@ val T_terms = map (convert_term lthy term is_rec) terms |> map (map_r (replaceFstSndFree lthy term fFst fSnd)) - val simpables = (if simp + val simpables = (if #simp config then find_simplifyble lthy term T_terms else map (K []) term) |> (fn s => ListPair.zip (term,s)) @@ -808,7 +847,7 @@ (* Rename to secondary if simpable *) fun genRename (t,_) = let - val old = fun_to_time lthy term t |> Option.valOf + val old = fun_to_time' lthy term false t |> Option.valOf val new = fun_to_time' lthy term true t |> Option.valOf in (old,new) @@ -824,6 +863,7 @@ fun pat_completeness_auto ctxt = Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt val specs = map (fn eq => (((Binding.empty, []), eq), [], [])) can_T_terms + val part_specs = (Binding.empty_atts, hd can_T_terms) (* Context for printing without showing question marks *) val print_ctxt = lthy @@ -831,13 +871,13 @@ |> Config.put show_sorts false (* Change it for debugging *) val print_ctxt = List.foldl (fn (term, ctxt) => Variable.add_fixes_implicit term ctxt) print_ctxt (term @ T_terms) (* Print result if print *) - val _ = if not print then () else + val _ = if not (#print config) then () else let val nms = map (dest_Const_name) term val typs = map (dest_Const_type) term in print_timing' print_ctxt { names=nms, terms=terms, typs=typs } - { names=timing_names, terms=can_T_terms, typs=map change_typ typs } + { names=timing_names, terms=can_T_terms, typs=map change_typ typs } end (* For partial functions sequential=true is needed in order to support them @@ -849,7 +889,9 @@ val fun_config = Function_Common.FunctionConfig {sequential=seq, default=NONE, domintros=true, partials=false} in - Function.add_function bindings specs fun_config pat_completeness_auto lthy + if #partial config + then Partial_Function.add_partial_function "option" bindings part_specs lthy |>> PartialFunction o snd + else Function.add_function bindings specs fun_config pat_completeness_auto lthy |>> Function end val (info,ctxt) = @@ -861,7 +903,7 @@ val ctxt = if simpable then calculateSimplifications ctxt T_terms term simpables else ctxt in - (info,ctxt) + (info, ctxt) end fun proove_termination (term: term list) terms (T_info: Function.info, lthy: local_theory) = let @@ -904,11 +946,12 @@ (auto_tac simp_lthy) lthy end in - (time_info, lthy') + (Function time_info, lthy') end -fun reg_and_proove_time_func (lthy: local_theory) (term: term list) (terms: term list) print simp = - reg_time_func lthy term terms print simp - |> proove_termination term terms +fun reg_and_proove_time_func (lthy: local_theory) (term: term list) (terms: term list) (config: time_config) = + case reg_time_func lthy term terms config + of (Function info, lthy') => proove_termination term terms (info, lthy') + | r => r fun isTypeClass' (Const (nm,_)) = (case split_name nm |> rev @@ -945,43 +988,60 @@ | check_opts ["no_simp"] = true | check_opts (a::_) = error ("Option " ^ a ^ " is not defined") -(* Convert function into its timing function (called by command) *) +(* Converts a function into its timing function using fun *) fun reg_time_fun_cmd ((opts, funcs), thms) (ctxt: local_theory) = let val no_simp = check_opts opts val fterms = map (Syntax.read_term ctxt) funcs val ctxt = set_suffix fterms ctxt + val config = { print = true, simp = not no_simp, partial = false } val (_, ctxt') = reg_and_proove_time_func ctxt fterms (case thms of NONE => get_terms ctxt (hd fterms) | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) - true (not no_simp) + config in ctxt' end -(* Convert function into its timing function (called by command) with termination proof provided by user*) +(* Converts a function into its timing function using function with termination proof provided by user*) fun reg_time_function_cmd ((opts, funcs), thms) (ctxt: local_theory) = let val no_simp = check_opts opts val fterms = map (Syntax.read_term ctxt) funcs val ctxt = set_suffix fterms ctxt + val config = { print = true, simp = not no_simp, partial = false } val ctxt' = reg_time_func ctxt fterms (case thms of NONE => get_terms ctxt (hd fterms) | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) - true (not no_simp) + config |> snd in ctxt' end -(* Convert function into its timing function (called by command) *) +(* Converts a function definition into its timing function using definition *) fun reg_time_definition_cmd ((opts, funcs), thms) (ctxt: local_theory) = let val no_simp = check_opts opts val fterms = map (Syntax.read_term ctxt) funcs val ctxt = set_suffix fterms ctxt + val config = { print = true, simp = not no_simp, partial = false } val (_, ctxt') = reg_and_proove_time_func ctxt fterms (case thms of NONE => get_terms ctxt (hd fterms) |> check_definition | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) - true (not no_simp) + config +in ctxt' +end + +(* Converts a a partial function into its timing function using partial_function *) +fun reg_time_partial_function_cmd ((opts, funcs), thms) (ctxt: local_theory) = +let + val no_simp = check_opts opts + val fterms = map (Syntax.read_term ctxt) funcs + val ctxt = set_suffix fterms ctxt + val config = { print = true, simp = not no_simp, partial = true } + val (_, ctxt') = reg_and_proove_time_func ctxt fterms + (case thms of NONE => get_terms ctxt (hd fterms) |> check_definition + | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of) + config in ctxt' end @@ -1001,4 +1061,8 @@ "Defines runtime function of a definition" (parser >> reg_time_definition_cmd) +val _ = Outer_Syntax.local_theory @{command_keyword "time_partial_function"} + "Defines runtime function of a definition" + (parser >> reg_time_partial_function_cmd) + end diff -r 0aa2d1c132b2 -r 50dd4fc40fcb src/HOL/Data_Structures/Define_Time_Function.thy --- a/src/HOL/Data_Structures/Define_Time_Function.thy Wed Feb 05 16:34:56 2025 +0000 +++ b/src/HOL/Data_Structures/Define_Time_Function.thy Thu Feb 06 14:46:49 2025 +0100 @@ -12,6 +12,7 @@ keywords "time_fun" :: thy_decl and "time_function" :: thy_decl and "time_definition" :: thy_decl + and "time_partial_function" :: thy_decl and "equations" and "time_fun_0" :: thy_decl begin