merged
authornipkow
Wed, 21 Aug 2024 20:41:16 +0200
changeset 80735 0c406b9469ab
parent 80733 17d8b3f6d744 (current diff)
parent 80734 7054a1bc8347 (diff)
child 80737 6984640568b9
child 80739 60801e5fae26
merged
--- a/src/HOL/Data_Structures/Define_Time_Function.ML	Wed Aug 21 14:09:44 2024 +0100
+++ b/src/HOL/Data_Structures/Define_Time_Function.ML	Wed Aug 21 20:41:16 2024 +0200
@@ -1,13 +1,18 @@
 
 signature TIMING_FUNCTIONS =
 sig
+type 'a wctxt = {
+  ctxt: local_theory,
+  origins: term list,
+  f: term -> 'a
+}
 type 'a converter = {
-  constc : local_theory -> term list -> (term -> 'a) -> term -> 'a,
-  funcc : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a,
-  ifc : local_theory -> term list -> (term -> 'a) -> typ -> term -> term -> term -> 'a,
-  casec : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a,
-  letc : local_theory -> term list -> (term -> 'a) -> typ -> term -> string list -> typ list -> term -> 'a
-};
+  constc : 'a wctxt -> term -> 'a,
+  funcc : 'a wctxt -> term -> term list -> 'a,
+  ifc : 'a wctxt -> typ -> term -> term -> term -> 'a,
+  casec : 'a wctxt -> term -> term list -> 'a,
+  letc : 'a wctxt -> typ -> term -> string list -> typ list -> term -> 'a
+}
 val walk : local_theory -> term list -> 'a converter -> term -> 'a
 
 type pfunc = { names : string list, terms : term list, typs : typ list }
@@ -16,10 +21,10 @@
 val print_timing':  Proof.context -> pfunc -> pfunc -> unit
 val print_timing:  Proof.context -> Function.info -> Function.info -> unit
 
-val reg_and_proove_time_func: theory -> term list -> term list
-      -> bool -> Function.info * theory
-val reg_time_func: theory -> term list -> term list
-      -> bool -> theory
+val reg_and_proove_time_func: local_theory -> term list -> term list
+      -> bool -> Function.info * local_theory
+val reg_time_func: local_theory -> term list -> term list
+      -> bool -> Function.info * local_theory
 
 val time_dom_tac: Proof.context -> thm -> thm list -> int -> tactic
 
@@ -29,17 +34,16 @@
 struct
 (* Configure config variable to adjust the prefix *)
 val bprefix = Attrib.setup_config_string @{binding "time_prefix"} (K "T_")
+(* Configure config variable to adjust the suffix *)
+val bsuffix = Attrib.setup_config_string @{binding "time_suffix"} (K "")
 
 (* some default values to build terms easier *)
 val zero = Const (@{const_name "Groups.zero"}, HOLogic.natT)
 val one = Const (@{const_name "Groups.one"}, HOLogic.natT)
 (* Extracts terms from function info *)
 fun terms_of_info (info: Function.info) =
-let
-  val {simps, ...} = info
-in
-  map Thm.prop_of (case simps of SOME s => s | NONE => error "No terms of function found in info")
-end;
+  map Thm.prop_of (case #simps info of SOME s => s
+                                     | NONE => error "No terms of function found in info")
 
 type pfunc = {
   names : string list,
@@ -49,7 +53,9 @@
 fun info_pfunc (info: Function.info): pfunc =
 let
   val {defname, fs, ...} = info;
-  val T = case hd fs of (Const (_,T)) => T | _ => error "Internal error: Invalid info to print"
+  val T = case hd fs of (Const (_,T)) => T
+                      | (Free (_,T)) => T
+                      | _ => error "Internal error: Invalid info to print"
 in
   { names=[Binding.name_of defname], terms=terms_of_info info, typs=[T] }
 end
@@ -82,43 +88,41 @@
 fun print_timing ctxt (oinfo: Function.info) (tinfo: Function.info) =
   print_timing' ctxt (info_pfunc oinfo) (info_pfunc tinfo)
 
-val If_name = @{const_name "HOL.If"}
-val Let_name = @{const_name "HOL.Let"}
-
 fun contains l e = exists (fn e' => e' = e) l
-fun zip [] [] = []
-  | zip (x::xs) (y::ys) = (x, y) :: zip xs ys
-  | zip _ _ = error "Internal error: Cannot zip lists with differing size"
+fun contains' comp l e = exists (comp e) l
 fun index [] _ = 0
   | index (x::xs) el = (if x = el then 0 else 1 + index xs el)
 fun used_for_const orig_used t i = orig_used (t,i)
+(* Split name by . *)
+val split_name = String.fields (fn s => s = #".")
 
 (* returns true if it's an if term *)
-fun is_if (Const (n,_)) = (n = If_name)
+fun is_if (Const (@{const_name "HOL.If"},_)) = true
   | is_if _ = false
 (* returns true if it's a case term *)
-fun is_case (Const (n,_)) = String.isPrefix "case_" (List.last (String.fields (fn s => s = #".") n))
+fun is_case (Const (n,_)) = n |> split_name |> List.last |> String.isPrefix "case_"
   | is_case _ = false
 (* returns true if it's a let term *)
-fun is_let (Const (n,_)) = (n = Let_name)
+fun is_let (Const (@{const_name "HOL.Let"},_)) = true
   | is_let _ = false
 (* change type of original function to new type (_ \<Rightarrow> ... \<Rightarrow> _ to _ \<Rightarrow> ... \<Rightarrow> nat)
-    and replace all function arguments f with (t*T_f) *)
-fun change_typ' used i (Type ("fun", [T1, T2])) = 
-      Type ("fun", [check_for_fun' (used i) T1, change_typ' used (i+1) T2])
-  | change_typ' _ _ _ = HOLogic.natT
-and check_for_fun' true (f as Type ("fun", [_,_])) = HOLogic.mk_prodT (f, change_typ' (K false) 0 f)
-  | check_for_fun' false (f as Type ("fun", [_,_])) = change_typ' (K false) 0 f
+    and replace all function arguments f with (t*T_f) if used *)
+fun change_typ' used (Type ("fun", [T1, T2])) = 
+      Type ("fun", [check_for_fun' (used 0) T1, change_typ' (fn i => used (i+1)) T2])
+  | change_typ' _ _ = HOLogic.natT
+and check_for_fun' true (f as Type ("fun", [_,_])) = HOLogic.mk_prodT (f, change_typ' (K false) f)
+  | check_for_fun' false (f as Type ("fun", [_,_])) = change_typ' (K false) f
   | check_for_fun' _ t = t
-val change_typ = change_typ' (K false) 0
+val change_typ = change_typ' (K false)
 (* Convert string name of function to its timing equivalent *)
-fun fun_name_to_time ctxt name =
+fun fun_name_to_time ctxt s name =
 let
   val prefix = Config.get ctxt bprefix
-  fun replace_last_name [n] = [prefix ^ n]
+  val suffix = (if s then Config.get ctxt bsuffix else "")
+  fun replace_last_name [n] = [prefix ^ n ^ suffix]
     | replace_last_name (n::ns) = n :: (replace_last_name ns)
     | replace_last_name _ = error "Internal error: Invalid function name to convert"
-  val parts = String.fields (fn s => s = #".") name
+  val parts = split_name name
 in
   String.concatWith "." (replace_last_name parts)
 end
@@ -126,11 +130,10 @@
 fun count_args (Type (n, [_,res])) = (if n = "fun" then 1 + count_args res else 0)
   | count_args _ = 0
 (* Check if number of arguments matches function *)
-fun check_args s (Const (_,T), args) =
-    (if length args = count_args T then () else error ("Partial applications/Lambdas not allowed (" ^ s ^ ")"))
-  | check_args s (Free (_,T), args) =
-    (if length args = count_args T then () else error ("Partial applications/Lambdas not allowed (" ^ s ^ ")"))
-  | check_args s _ = error ("Partial applications/Lambdas not allowed (" ^ s ^ ")")
+val _ = dest_Const
+fun check_args s (t, args) =
+    (if length args = count_args (type_of t) then ()
+     else error ("Partial applications/Lambdas not allowed (" ^ s ^ ")"))
 (* Removes Abs *)
 fun rem_abs f (Abs (_,_,t)) = rem_abs f t
   | rem_abs f t = f t
@@ -148,39 +151,53 @@
   | Const_name _ = NONE
 fun is_Used (Type ("Product_Type.prod", _)) = true
   | is_Used _ = false
+(* Custom compare function for types ignoring variable names *)
+fun typ_comp (Type (A,a)) (Type (B,b)) = (A = B) andalso List.foldl (fn ((c,i),s) => typ_comp c i andalso s) true (ListPair.zip (a, b))
+  | typ_comp (Type _) _ = false
+  | typ_comp _ (Type _) = false
+  | typ_comp _ _ = true
+fun const_comp (Const (nm,T)) (Const (nm',T')) = nm = nm' andalso typ_comp T T'
+  | const_comp _ _ = false
 
-fun time_term ctxt (Const (nm,T)) =
+fun time_term ctxt s (Const (nm,T)) =
 let
-  val T_nm = fun_name_to_time ctxt nm
+  val T_nm = fun_name_to_time ctxt s nm
   val T_T = change_typ T
 in
 (SOME (Syntax.check_term ctxt (Const (T_nm,T_T))))
   handle (ERROR _) =>
     case Syntax.read_term ctxt (Long_Name.base_name T_nm)
-      of (Const (nm,T_T)) =>
+      of (Const (T_nm,T_T)) =>
         let
           fun col_Used i (Type ("fun", [Type ("fun", _), Ts])) (Type ("fun", [T', Ts'])) =
             (if is_Used T' then [i] else []) @ col_Used (i+1) Ts Ts'
             | col_Used i (Type ("fun", [_, Ts])) (Type ("fun", [_, Ts'])) = col_Used (i+1) Ts Ts'
             | col_Used _ _ _ = []
         in
-          SOME (Const (nm,change_typ' (contains (col_Used 0 T T_T)) 0 T))
+          SOME (Const (T_nm,change_typ' (contains (col_Used 0 T T_T)) T))
         end
        | _ => error ("Timing function of " ^ nm ^ " is not defined")
 end
-  | time_term _ _ = error "Internal error: No valid function given"
+  | time_term _ _ _ = error "Internal error: No valid function given"
+
 
+type 'a wctxt = {
+  ctxt: local_theory,
+  origins: term list,
+  f: term -> 'a
+}
 type 'a converter = {
-  constc : local_theory -> term list -> (term -> 'a) -> term -> 'a,
-  funcc : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a,
-  ifc : local_theory -> term list -> (term -> 'a) -> typ -> term -> term -> term -> 'a,
-  casec : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a,
-  letc : local_theory -> term list -> (term -> 'a) -> typ -> term -> string list -> typ list -> term -> 'a
-};
+  constc : 'a wctxt -> term -> 'a,
+  funcc : 'a wctxt -> term -> term list -> 'a,
+  ifc : 'a wctxt -> typ -> term -> term -> term -> 'a,
+  casec : 'a wctxt -> term -> term list -> 'a,
+  letc : 'a wctxt -> typ -> term -> string list -> typ list -> term -> 'a
+}
 
 (* Walks over term and calls given converter *)
 fun walk_func (t1 $ t2) ts = walk_func t1 (t2::ts)
   | walk_func t ts = (t, ts)
+fun walk_func' t = walk_func t []
 fun build_func (f, []) = f
   | build_func (f, (t::ts)) = build_func (f$t, ts)
 fun walk_abs (Abs (nm,T,t)) nms Ts = walk_abs t (nm::nms) (T::Ts)
@@ -193,23 +210,24 @@
     val (f, args) = walk_func t []
     val this = (walk ctxt origin conv)
     val _ = (case f of Abs _ => error "Lambdas not supported" | _ => ())
+    val wctxt = {ctxt = ctxt, origins = origin, f = this}
   in
     (if is_if f then
       (case f of (Const (_,T)) =>
-        (case args of [cond, t, f] => ifc ctxt origin this T cond t f
+        (case args of [cond, t, f] => ifc wctxt T cond t f
                    | _ => error "Partial applications not supported (if)")
                | _ => error "Internal error: invalid if term")
-      else if is_case f then casec ctxt origin this f args
+      else if is_case f then casec wctxt f args
       else if is_let f then
       (case f of (Const (_,lT)) =>
          (case args of [exp, t] => 
-            let val (t,nms,Ts) = walk_abs t [] [] in letc ctxt origin this lT exp nms Ts t end
+            let val (t,nms,Ts) = walk_abs t [] [] in letc wctxt lT exp nms Ts t end
                      | _ => error "Partial applications not allowed (let)")
                | _ => error "Internal error: invalid let term")
-      else funcc ctxt origin this f args)
+      else funcc wctxt f args)
   end
   | walk ctxt origin (conv as {constc, ...}) c = 
-      constc ctxt origin (walk ctxt origin conv) c
+      constc {ctxt = ctxt, origins = origin, f = walk ctxt origin conv} c
 
 (* 1. Fix all terms *)
 (* Exchange Var in types and terms to Free *)
@@ -221,36 +239,34 @@
 fun noFun (Type ("fun", _)) = error "Functions in datatypes are not allowed in case constructions"
   | noFun T = T
 fun casecBuildBounds n t = if n > 0 then casecBuildBounds (n-1) (t $ (Bound (n-1))) else t
-fun casecAbs ctxt f n (Type ("fun",[T,Tr])) (Abs (v,Ta,t)) = ( map_atyps noFun T; Abs (v,Ta,casecAbs ctxt f n Tr t))
-  | casecAbs ctxt f n (Type ("fun",[T,Tr])) t =
-    (map_atyps noFun T; case Variable.variant_fixes ["x"] ctxt of ([v],ctxt) =>
-    (Abs (v,T,casecAbs ctxt f (n + 1) Tr t))
-    | _ => error "Internal error: could not fix variable")
-  | casecAbs _ f n _ t = f (casecBuildBounds n (Term.incr_bv n 0 t))
-fun fixCasecCases _ _ _ [t] = [t]
-  | fixCasecCases ctxt f (Type (_,[T,Tr])) (t::ts) = casecAbs ctxt f 0 T t :: fixCasecCases ctxt f Tr ts
-  | fixCasecCases _ _ _ _ = error "Internal error: invalid case types/terms"
-fun fixCasec ctxt _ f (t as Const (_,T)) args =
-      (check_args "cases" (t,args); build_func (t,fixCasecCases ctxt f T args))
-  | fixCasec _ _ _ _ _ = error "Internal error: invalid case term"
+fun casecAbs wctxt n (Type ("fun",[T,Tr])) (Abs (v,Ta,t)) = (map_atyps noFun T; Abs (v,Ta,casecAbs wctxt n Tr t))
+  | casecAbs wctxt n (Type ("fun",[T,Tr])) t =
+    (map_atyps noFun T; Abs ("uu",T,casecAbs wctxt (n + 1) Tr t))
+  | casecAbs wctxt n _ t = (#f wctxt) (casecBuildBounds n (Term.incr_bv n 0 t))
+fun fixCasecCases _ _ [t] = [t]
+  | fixCasecCases wctxt (Type (_,[T,Tr])) (t::ts) = casecAbs wctxt 0 T t :: fixCasecCases wctxt Tr ts
+  | fixCasecCases _ _ _ = error "Internal error: invalid case types/terms"
+fun fixCasec wctxt (t as Const (_,T)) args =
+      (check_args "cases" (t,args); build_func (t,fixCasecCases wctxt T args))
+  | fixCasec _ _ _ = error "Internal error: invalid case term"
 
 fun fixPartTerms ctxt (term: term list) t =
   let
     val _ = check_args "args" (walk_func (get_l t) [])
   in
     map_r (walk ctxt term {
-          funcc = (fn _ => fn _ => fn f => fn t => fn args =>
-              (check_args "func" (t,args); build_func (t, map f args))),
-          constc = (fn _ => fn _ => fn _ => fn c => (case c of Abs _ => error "Lambdas not supported" | _ => c)),
-          ifc = (fn _ => fn _ => fn f => fn T => fn cond => fn tt => fn tf =>
-            ((Const (If_name, T)) $ f cond $ (f tt) $ (f tf))),
+          funcc = (fn wctxt => fn t => fn args =>
+              (check_args "func" (t,args); build_func (t, map (#f wctxt) args))),
+          constc = (fn _ => fn c => (case c of Abs _ => error "Lambdas not supported" | _ => c)),
+          ifc = (fn wctxt => fn T => fn cond => fn tt => fn tf =>
+            ((Const (@{const_name "HOL.If"}, T)) $ (#f wctxt) cond $ ((#f wctxt) tt) $ ((#f wctxt) tf))),
           casec = fixCasec,
-          letc = (fn _ => fn _ => fn f => fn expT => fn exp => fn nms => fn Ts => fn t =>
+          letc = (fn wctxt => fn expT => fn exp => fn nms => fn Ts => fn t =>
               let
                 val f' = if length nms = 0 then
-                (case f (t$exp) of t$_ => t | _ => error "Internal error: case could not be fixed (let)")
-                else f t
-              in (Const (Let_name,expT) $ (f exp) $ build_abs f' nms Ts) end)
+                (case (#f wctxt) (t$exp) of t$_ => t | _ => error "Internal error: case could not be fixed (let)")
+                else (#f wctxt) t
+              in (Const (@{const_name "HOL.Let"},expT) $ ((#f wctxt) exp) $ build_abs f' nms Ts) end)
       }) t
   end
 
@@ -258,11 +274,16 @@
 (* 2.1 Check if function is recursive *)
 fun or f (a,b) = f a orelse b
 fun find_rec ctxt term = (walk ctxt term {
-          funcc = (fn _ => fn _ => fn f => fn t => fn args => List.exists (fn term => Const_name t = Const_name term) term orelse List.foldr (or f) false args),
-          constc = (K o K o K o K) false,
-          ifc = (fn _ => fn _ => fn f => fn _ => fn cond => fn tt => fn tf => f cond orelse f tt orelse f tf),
-          casec = (fn _ => fn _ => fn f => fn t => fn cs => f t orelse List.foldr (or (rem_abs f)) false cs),
-          letc = (fn _ => fn _ => fn f => fn _ => fn exp => fn _ => fn _ => fn t => f exp orelse f t)
+          funcc = (fn wctxt => fn t => fn args =>
+            List.exists (fn term => (Const_name t) = (Const_name term)) term
+             orelse List.foldr (or (#f wctxt)) false args),
+          constc = (K o K) false,
+          ifc = (fn wctxt => fn _ => fn cond => fn tt => fn tf =>
+            (#f wctxt) cond orelse (#f wctxt) tt orelse (#f wctxt) tf),
+          casec = (fn wctxt => fn t => fn cs =>
+            (#f wctxt) t orelse List.foldr (or (rem_abs (#f wctxt))) false cs),
+          letc = (fn wctxt => fn _ => fn exp => fn _ => fn _ => fn t =>
+            (#f wctxt) exp orelse (#f wctxt) t)
       }) o get_r
 fun is_rec ctxt (term: term list) = List.foldr (or (find_rec ctxt term)) false
 
@@ -277,14 +298,14 @@
         f :: filter_passed args
     | filter_passed (_::args) = filter_passed args
   val frees' = (walk ctxt term {
-          funcc = (fn _ => fn _ => fn f => fn t => fn args =>
+          funcc = (fn wctxt => fn t => fn args =>
               (case t of (Const ("Product_Type.prod.snd", _)) => []
                   | _ => (if t = T_ident then [] else filter_passed args)
-                    @ List.foldr (fn (l,r) => f l @ r) [] args)),
-          constc = (K o K o K o K) [],
-          ifc = (fn _ => fn _ => fn f => fn _ => fn cond => fn tt => fn tf => f cond @ f tt @ f tf),
-          casec = (fn _ => fn _ => fn f => fn _ => fn cs => List.foldr (fn (l,r) => f l @ r) [] cs),
-          letc = (fn _ => fn _ => fn f => fn _ => fn exp => fn _ => fn _ => fn t => f exp @ f t)
+                    @ List.foldr (fn (l,r) => (#f wctxt) l @ r) [] args)),
+          constc = (K o K) [],
+          ifc = (fn wctxt => fn _ => fn cond => fn tt => fn tf => (#f wctxt) cond @ (#f wctxt) tt @ (#f wctxt) tf),
+          casec = (fn wctxt => fn _ => fn cs => List.foldr (fn (l,r) => (#f wctxt) l @ r) [] cs),
+          letc = (fn wctxt => fn _ => fn exp => fn _ => fn _ => fn t => (#f wctxt) exp @ (#f wctxt) t)
       }) (get_r T_t)
   fun build _ [] _ = false
     | build i (a::args) item =
@@ -293,7 +314,7 @@
   build 0 T_args
 end
 fun find_used ctxt term terms T_terms =
-  zip terms T_terms
+  ListPair.zip (terms, T_terms)
   |> List.map (fn (t, T_t) => find_used' ctxt term t T_t)
   |> List.foldr (fn (f,g) => fn item => f item orelse g item) (K false)
 
@@ -310,57 +331,54 @@
 fun use_origin (Free (nm, T as Type ("fun",_))) = HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T)))
   | use_origin t = t
 
-(* Converting of function term *)
+(* Conversion of function term *)
 fun fun_to_time ctxt orig_used _ (origin: term list) (func as Const (nm,T)) =
 let
-  val prefix = Config.get ctxt bprefix
   val used' = used_for_const orig_used func
 in
-  if contains origin func then SOME (Free (prefix ^ Term.term_name func, change_typ' used' 0 T)) else
+  if contains' const_comp origin func then SOME (Free (func |> Term.term_name |> fun_name_to_time ctxt true, change_typ' used' T)) else
   if Zero_Funcs.is_zero (Proof_Context.theory_of ctxt) (nm,T) then NONE else
-    time_term ctxt func
+    time_term ctxt false func
 end
   | fun_to_time ctxt _ used _ (f as Free (nm,T)) = SOME (
       if used f then HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T)))
-      else Free (Config.get ctxt bprefix ^ nm, change_typ T)
+      else Free (fun_name_to_time ctxt false nm, change_typ T)
       )
   | fun_to_time _ _ _ _ _ = error "Internal error: invalid function to convert"
 
 (* Convert arguments of left side of a term *)
 fun conv_arg ctxt used _ (f as Free (nm,T as Type("fun",_))) =
-    if used f then Free (nm, HOLogic.mk_prodT (T, change_typ' (K false) 0 T))
-    else Free (Config.get ctxt bprefix ^ nm, change_typ' (K false) 0 T)
-  | conv_arg ctxt _ origin (f as Const (_, Type("fun",_))) =
-      (error "weird case i don't understand TODO"; HOLogic.mk_prod (f, fun_to_time ctxt (K false) (K false) origin f |> Option.valOf))
+    if used f then Free (nm, HOLogic.mk_prodT (T, change_typ' (K false) T))
+    else Free (fun_name_to_time ctxt false nm, change_typ' (K false) T)
   | conv_arg _ _ _ x = x
 fun conv_args ctxt used origin = map (conv_arg ctxt used origin)
 
 (* Handle function calls *)
-fun build_zero (Type ("fun", [T, R])) = Abs ("x", T, build_zero R)
+fun build_zero (Type ("fun", [T, R])) = Abs ("uu", T, build_zero R)
   | build_zero _ = zero
-fun funcc_use_origin _ _ used (f as Free (nm, T as Type ("fun",_))) =
+fun funcc_use_origin used (f as Free (nm, T as Type ("fun",_))) =
     if used f then HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T)))
     else error "Internal error: Error in used detection"
-  | funcc_use_origin _ _ _ t = t
-fun funcc_conv_arg ctxt origin used _ (t as (_ $ _)) = map_aterms (funcc_use_origin ctxt origin used) t
-  | funcc_conv_arg ctxt _ used u (f as Free (nm, T as Type ("fun",_))) =
+  | funcc_use_origin _ t = t
+fun funcc_conv_arg _ used _ (t as (_ $ _)) = map_aterms (funcc_use_origin used) t
+  | funcc_conv_arg wctxt used u (f as Free (nm, T as Type ("fun",_))) =
       if used f then
         if u then Free (nm, HOLogic.mk_prodT (T, change_typ T))
         else HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T)))
-      else Free (Config.get ctxt bprefix ^ nm, change_typ T)
-  | funcc_conv_arg ctxt origin _ true (f as Const (_,T as Type ("fun",_))) =
+      else Free (fun_name_to_time (#ctxt wctxt) false nm, change_typ T)
+  | funcc_conv_arg wctxt _ true (f as Const (_,T as Type ("fun",_))) =
   (Const (@{const_name "Product_Type.Pair"},
       Type ("fun", [T,Type ("fun", [change_typ T, HOLogic.mk_prodT (T,change_typ T)])]))
-    $ f $ (Option.getOpt (fun_to_time ctxt (K false) (K false) origin f, build_zero T)))
-  | funcc_conv_arg ctxt origin _ false (f as Const (_,T as Type ("fun",_))) =
-      Option.getOpt (fun_to_time ctxt (K false) (K false) origin f, build_zero T)
-  | funcc_conv_arg _ _ _ _ t = t
+    $ f $ (Option.getOpt (fun_to_time (#ctxt wctxt) (K false) (K false) (#origins wctxt) f, build_zero T)))
+  | funcc_conv_arg wctxt _ false (f as Const (_,T as Type ("fun",_))) =
+      Option.getOpt (fun_to_time (#ctxt wctxt) (K false) (K false) (#origins wctxt) f, build_zero T)
+  | funcc_conv_arg _ _ _ t = t
 
-fun funcc_conv_args _ _ _ _ [] = []
-  | funcc_conv_args ctxt origin used (Type ("fun", [t, ts])) (a::args) =
-      funcc_conv_arg ctxt origin used (is_Used t) a :: funcc_conv_args ctxt origin used ts args
-  | funcc_conv_args _ _ _ _ _ = error "Internal error: Non matching type"
-fun funcc orig_used used ctxt (origin: term list) f func args =
+fun funcc_conv_args _ _ _ [] = []
+  | funcc_conv_args wctxt used (Type ("fun", [t, ts])) (a::args) =
+      funcc_conv_arg wctxt used (is_Used t) a :: funcc_conv_args wctxt used ts args
+  | funcc_conv_args _ _ _ _ = error "Internal error: Non matching type"
+fun funcc orig_used used wctxt func args =
 let
   fun get_T (Free (_,T)) = T
     | get_T (Const (_,T)) = T
@@ -368,10 +386,10 @@
     | get_T _ = error "Internal error: Forgotten type"
 in
   List.foldr (I #-> plus)
-  (case fun_to_time ctxt orig_used used origin func
-    of SOME t => SOME (build_func (t,funcc_conv_args ctxt origin used (get_T t) args))
+  (case fun_to_time (#ctxt wctxt) orig_used used (#origins wctxt) func
+    of SOME t => SOME (build_func (t,funcc_conv_args wctxt used (get_T t) args))
     | NONE => NONE)
-  (map f args)
+  (map (#f wctxt) args)
 end
 
 (* Handle case terms *)
@@ -389,48 +407,49 @@
     (case casecAbs f t of (nconst, tt) => 
       casecArgs f ar ||> (fn ar => tt :: ar) |>> (if nconst then K true else I))
   | casecArgs _ _ = error "Internal error: Invalid case term"
-fun casec _ _ f (Const (t,T)) args =
+fun casec wctxt (Const (t,T)) args =
   if not (casecIsCase T) then error "Internal error: Invalid case type" else
-    let val (nconst, args') = casecArgs f args in
+    let val (nconst, args') = casecArgs (#f wctxt) args in
       plus
-        (f (List.last args))
+        ((#f wctxt) (List.last args))
         (if nconst then
           SOME (build_func (Const (t,casecTyp T), args'))
          else NONE)
     end
-  | casec _ _ _ _ _ = error "Internal error: Invalid case term"
+  | casec _ _ _ = error "Internal error: Invalid case term"
 
 (* Handle if terms -> drop the term if true and false terms are zero *)
-fun ifc _ _ f _ cond tt ft =
+fun ifc wctxt _ cond tt ft =
   let
+    val f = #f wctxt
     val rcond = map_aterms use_origin cond
     val tt = f tt
     val ft = f ft
   in
     plus (f cond) (case (tt,ft) of (NONE, NONE) => NONE | _ =>
        if tt = ft then tt else
-       (SOME ((Const (If_name, @{typ "bool \<Rightarrow> nat \<Rightarrow> nat \<Rightarrow> nat"})) $ rcond $ (opt_term tt) $ (opt_term ft))))
+       (SOME ((Const (@{const_name "HOL.If"}, @{typ "bool \<Rightarrow> nat \<Rightarrow> nat \<Rightarrow> nat"})) $ rcond $ (opt_term tt) $ (opt_term ft))))
   end
 
 fun letc_change_typ (Type ("fun", [T1, Type ("fun", [T2, _])])) = (Type ("fun", [T1, Type ("fun", [change_typ T2, HOLogic.natT])]))
   | letc_change_typ _ = error "Internal error: invalid let type"
-fun letc _ _ f expT exp nms Ts t =
-    plus (f exp)
+fun letc wctxt expT exp nms Ts t =
+    plus (#f wctxt exp)
     (if length nms = 0 (* In case of "length nms = 0" the expression got reducted
                           Here we need Bound 0 to gain non-partial application *)
-    then (case f (t $ Bound 0) of SOME (t' $ Bound 0) =>
-                                 (SOME (Const (Let_name, letc_change_typ expT) $ (map_aterms use_origin exp) $ t'))
+    then (case #f wctxt (t $ Bound 0) of SOME (t' $ Bound 0) =>
+                                 (SOME (Const (@{const_name "HOL.Let"}, letc_change_typ expT) $ (map_aterms use_origin exp) $ t'))
                                   (* Expression is not used and can therefore let be dropped *)
                                 | SOME t' => SOME t'
                                 | NONE => NONE)
-    else (case f t of SOME t' =>
-      SOME (if Term.is_dependent t' then Const (Let_name, letc_change_typ expT) $ (map_aterms use_origin exp) $ build_abs t' nms Ts
+    else (case #f wctxt t of SOME t' =>
+      SOME (if Term.is_dependent t' then Const (@{const_name "HOL.Let"}, letc_change_typ expT) $ (map_aterms use_origin exp) $ build_abs t' nms Ts
                                     else Term.subst_bounds([exp],t'))
     | NONE => NONE))
 
 (* The converter for timing functions given to the walker *)
 fun converter orig_used used : term option converter = {
-        constc = fn _ => fn _ => fn _ => fn t =>
+        constc = fn _ => fn t =>
           (case t of Const ("HOL.undefined", _) => SOME (Const ("HOL.undefined", @{typ "nat"}))
                    | _ => NONE),
         funcc = (funcc orig_used used),
@@ -469,20 +488,22 @@
 
 
 fun get_terms theory (term: term) =
-  Spec_Rules.retrieve_global theory term
-      |> hd |> #rules
-      |> map Thm.prop_of
+let
+  val equations = Spec_Rules.retrieve theory term
+      |> map #rules
+      |> map (map Thm.prop_of)
    handle Empty => error "Function or terms of function not found"
+in
+  equations
+    |> filter (fn ts => typ_comp (ts |> hd |> get_l |> walk_func' |> fst |> dest_Const |> snd) (term |> dest_Const |> snd))
+    |> hd
+end
 
 (* Register timing function of a given function *)
-fun reg_and_proove_time_func (theory: theory) (term: term list) (terms: term list) print =
-  reg_time_func theory term terms false
-  |> proove_termination term terms print
-and reg_time_func (theory: theory) (term: term list) (terms: term list) print =
+fun reg_time_func (lthy: local_theory) (term: term list) (terms: term list) print =
   let
-    val lthy = Named_Target.theory_init theory
     val _ =
-      case time_term lthy (hd term)
+      case time_term lthy true (hd term)
             handle (ERROR _) => NONE
         of SOME _ => error ("Timing function already declared: " ^ (Term.term_name (hd term)))
          | NONE => ()
@@ -517,7 +538,7 @@
 
     (* 4. Register function and prove termination *)
     val names = map Term.term_name term
-    val timing_names = map (fun_name_to_time lthy) names
+    val timing_names = map (fun_name_to_time lthy true) names
     val bindings = map (fn nm => (Binding.name nm, NONE, NoSyn)) timing_names
     fun pat_completeness_auto ctxt =
       Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt
@@ -539,34 +560,30 @@
     val print_ctxt = lthy
       |> Config.put show_question_marks false
       |> Config.put show_sorts false (* Change it for debugging *)
-    val print_ctxt =  List.foldl (fn (term, ctxt) => Variable.add_fixes_implicit term ctxt) print_ctxt (term @ T_terms)
+    val print_ctxt = List.foldl (fn (term, ctxt) => Variable.add_fixes_implicit term ctxt) print_ctxt (term @ T_terms)
     (* Print result if print *)
     val _ = if not print then () else
         let
-          val nms = map dest_Const_name term
+          val nms = map (fst o dest_Const) term
           val used = map (used_for_const orig_used) term
-          val typs = map dest_Const_type term
+          val typs = map (snd o dest_Const) term
         in
-          print_timing' print_ctxt { names=nms, terms=terms, typs=typs } { names=timing_names, terms=T_terms, typs=map (fn (used, typ) => change_typ' used 0 typ) (zip used typs) }
+          print_timing' print_ctxt { names=nms, terms=terms, typs=typs }
+            { names=timing_names, terms=T_terms, typs=map (fn (used, typ) => change_typ' used typ) (ListPair.zip (used, typs)) }
         end
 
-    (* Register function *)
-    val (_, lthy) =
-      register false
+  in
+    register false
       handle (ERROR _) =>
         register true
            | Match =>
         register true
-  in
-    Local_Theory.exit_global lthy
   end
-and proove_termination (term: term list) terms print (theory: theory) =
+fun proove_termination (term: term list) terms print (T_info: Function.info, lthy: local_theory) =
   let
-    val lthy = Named_Target.theory_init theory
-
     (* Start proving the termination *)  
     val infos = SOME (map (Function.get_info lthy) term) handle Empty => NONE
-    val timing_names = map (fun_name_to_time lthy o Term.term_name) term
+    val timing_names = map (fun_name_to_time lthy true o Term.term_name) term
 
     (* Proof by lexicographic_order_tac *)
     val (time_info, lthy') =
@@ -578,22 +595,22 @@
           val _ = (if length term > 1 then error "Proof over dom not supported for mutual recursive functions" else ())
 
           fun args (a$(Var ((nm,_),T))) = args a |> (fn ar => (nm,T)::ar)
-            | args (a$(Const (_,T))) = args a |> (fn ar => ("x",T)::ar)
+            | args (a$(Const (_,T))) = args a |> (fn ar => ("uu",T)::ar)
             | args _ = []
-          val dom_args =
-            terms |> hd |> get_l |> args
-            |> Variable.variant_frees lthy []
-            |> map fst
+          val dom_vars =
+            terms |> hd |> get_l |> map_types (map_atyps fixTypes)
+            |> args |> Variable.variant_frees lthy []
+          val dom_args = 
+            List.foldl (fn (t,p) => HOLogic.mk_prod ((Free t),p)) (Free (hd dom_vars)) (tl dom_vars)
 
           val {inducts, ...} = case infos of SOME [i] => i | _ => error "Proof over dom failed as no induct rule was found"
           val induct = (Option.valOf inducts |> hd)
 
           val domintros = Proof_Context.get_fact lthy (Facts.named (hd timing_names ^ ".domintros"))
-          val prop = (hd timing_names ^ "_dom (" ^ (String.concatWith "," dom_args) ^ ")")
-                      |> Syntax.read_prop lthy
+          val prop = HOLogic.mk_Trueprop (#dom T_info $ dom_args)
 
           (* Prove a helper lemma *)
-          val dom_lemma = Goal.prove lthy dom_args [] prop
+          val dom_lemma = Goal.prove lthy (map fst dom_vars) [] prop
             (fn {context, ...} => HEADGOAL (time_dom_tac context induct domintros))
           (* Add dom_lemma to simplification set *)
           val simp_lthy = Simplifier.add_simp dom_lemma lthy
@@ -602,7 +619,7 @@
           Function.prove_termination NONE
             (auto_tac simp_lthy) lthy
         end
-    
+
     (* Context for printing without showing question marks *)
     val print_ctxt = lthy'
       |> Config.put show_question_marks false
@@ -616,63 +633,98 @@
           print_timing' print_ctxt { names=nms, terms=terms, typs=typs } (info_pfunc time_info)
         end
   in
-    (time_info, Local_Theory.exit_global lthy')
+    (time_info, lthy')
   end
+fun reg_and_proove_time_func (lthy: local_theory) (term: term list) (terms: term list) print =
+  reg_time_func lthy term terms false
+  |> proove_termination term terms print
 
 fun fix_definition (Const ("Pure.eq", _) $ l $ r) = Const ("HOL.Trueprop", @{typ "bool \<Rightarrow> prop"})
       $ (Const ("HOL.eq", @{typ "bool \<Rightarrow> bool \<Rightarrow> bool"}) $ l $ r)
   | fix_definition t = t
 fun check_definition [t] = [t]
-  | check_definition _ = error "Only a single defnition is allowed"
+  | check_definition _ = error "Only a single definition is allowed"
+
+fun isTypeClass' (Const (nm,_)) =
+  (case split_name nm |> rev
+    of (_::nm::_) => String.isSuffix "_class" nm
+     | _ => false)
+  | isTypeClass' _ = false
+val isTypeClass =
+  (List.foldr (fn (a,b) => a orelse b) false) o (map isTypeClass')
+
+fun detect_typ (ctxt: local_theory) (term: term) =
+let
+  val class_term =  (case term of Const (nm,_) => Syntax.read_term ctxt nm
+      | _ => error "Could not find term of class")
+  fun find_free (Type (_,class)) (Type (_,inst)) =
+        List.foldl (fn ((c,i),s) => (case s of NONE => find_free c i | t => t)) (NONE) (ListPair.zip (class, inst))
+    | find_free (TFree _) (TFree _) = NONE
+    | find_free (TFree _) (Type (nm,_)) = SOME nm
+    | find_free  _ _ = error "Unhandled case in detecting type"
+in
+  find_free (type_of class_term) (type_of term)
+    |> Option.map (hd o rev o split_name)
+end
+
+fun set_suffix (fterms: term list) ctxt =
+let
+  val isTypeClass = isTypeClass fterms
+  val _ = (if length fterms > 1 andalso isTypeClass then error "No mutual recursion inside instantiation allowed" else ())
+  val suffix = (if isTypeClass then detect_typ ctxt (hd fterms) else NONE)
+in
+  (case suffix of NONE => I | SOME s => Config.put bsuffix ("_" ^ s)) ctxt
+end
 
 (* Convert function into its timing function (called by command) *)
-fun reg_time_fun_cmd (funcs, thms) (theory: theory) =
+fun reg_time_fun_cmd (funcs, thms) (ctxt: local_theory) =
 let
-  val ctxt = Proof_Context.init_global theory
   val fterms = map (Syntax.read_term ctxt) funcs
-  val (_, lthy') = reg_and_proove_time_func theory fterms
-    (case thms of NONE => get_terms theory (hd fterms)
+  val ctxt = set_suffix fterms ctxt
+  val (_, ctxt') = reg_and_proove_time_func ctxt fterms
+    (case thms of NONE => get_terms ctxt (hd fterms)
                 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of)
     true
-in lthy'
+in ctxt'
 end
 
 (* Convert function into its timing function (called by command) with termination proof provided by user*)
-fun reg_time_function_cmd (funcs, thms) (theory: theory) =
+fun reg_time_function_cmd (funcs, thms) (ctxt: local_theory) =
 let
-  val ctxt = Proof_Context.init_global theory
   val fterms = map (Syntax.read_term ctxt) funcs
-  val theory = reg_time_func theory fterms
-    (case thms of NONE => get_terms theory (hd fterms)
+  val ctxt = set_suffix fterms ctxt
+  val ctxt' = reg_time_func ctxt fterms
+    (case thms of NONE => get_terms ctxt (hd fterms)
                 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of)
     true
-in theory
+    |> snd
+in ctxt'
 end
 
 (* Convert function into its timing function (called by command) *)
-fun reg_time_definition_cmd (funcs, thms) (theory: theory) =
+fun reg_time_definition_cmd (funcs, thms) (ctxt: local_theory) =
 let
-  val ctxt = Proof_Context.init_global theory
   val fterms = map (Syntax.read_term ctxt) funcs
-  val (_, lthy') = reg_and_proove_time_func theory fterms
-    (case thms of NONE => get_terms theory (hd fterms) |> check_definition |> map fix_definition
+  val ctxt = set_suffix fterms ctxt
+  val (_, ctxt') = reg_and_proove_time_func ctxt fterms
+    (case thms of NONE => get_terms ctxt (hd fterms) |> check_definition |> map fix_definition
                 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of)
     true
-in lthy'
+in ctxt'
 end
 
 val parser = (Scan.repeat1 Parse.prop) -- (Scan.option (Parse.keyword_improper "equations" -- Parse.thms1 >> snd))
 
-val _ = Outer_Syntax.command @{command_keyword "time_fun"}
+val _ = Outer_Syntax.local_theory @{command_keyword "time_fun"}
   "Defines runtime function of a function"
-  (parser >> (fn p => Toplevel.theory (reg_time_fun_cmd p)))
+  (parser >> reg_time_fun_cmd)
 
-val _ = Outer_Syntax.command @{command_keyword "time_function"}
+val _ = Outer_Syntax.local_theory @{command_keyword "time_function"}
   "Defines runtime function of a function"
-  (parser >> (fn p => Toplevel.theory (reg_time_function_cmd p)))
+  (parser >> reg_time_function_cmd)
 
-val _ = Outer_Syntax.command @{command_keyword "time_definition"}
+val _ = Outer_Syntax.local_theory @{command_keyword "time_definition"}
   "Defines runtime function of a definition"
-  (parser >> (fn p => Toplevel.theory (reg_time_definition_cmd p)))
+  (parser >> reg_time_definition_cmd)
 
 end
--- a/src/HOL/Data_Structures/Selection.thy	Wed Aug 21 14:09:44 2024 +0100
+++ b/src/HOL/Data_Structures/Selection.thy	Wed Aug 21 20:41:16 2024 +0200
@@ -645,9 +645,7 @@
 
 lemmas T_slow_select_def [simp del] = T_slow_select.simps
 
-
-definition T_slow_median :: "'a :: linorder list \<Rightarrow> nat" where
-  "T_slow_median xs = T_length xs + T_slow_select ((length xs - 1) div 2) xs"
+time_fun slow_median
 
 lemma T_slow_select_le:
   assumes "k < length xs"
@@ -671,7 +669,7 @@
   shows   "T_slow_median xs \<le> length xs ^ 2 + 4 * length xs + 2"
 proof -
   have "T_slow_median xs = length xs + T_slow_select ((length xs - 1) div 2) xs + 1"
-    by (simp add: T_slow_median_def T_length_eq)
+    by (simp add: T_length_eq)
   also from assms have "length xs > 0"
     by simp
   hence "(length xs - 1) div 2 < length xs"
--- a/src/HOL/Data_Structures/Time_Funs.thy	Wed Aug 21 14:09:44 2024 +0100
+++ b/src/HOL/Data_Structures/Time_Funs.thy	Wed Aug 21 20:41:16 2024 +0200
@@ -12,16 +12,25 @@
 lemma T_append: "T_append xs ys = length xs + 1"
 by(induction xs) auto
 
-text \<open>Automatic definition of \<open>T_length\<close> is cumbersome because of the type class for \<open>size\<close>.\<close>
+class T_size =
+  fixes T_size :: "'a \<Rightarrow> nat"
+
+instantiation list :: (_) T_size
+begin
 
-fun T_length :: "'a list \<Rightarrow> nat" where
-  "T_length [] = 1"
-| "T_length (x # xs) = T_length xs + 1"
+time_fun length
+
+instance ..
+
+end
+
+abbreviation T_length :: "'a list \<Rightarrow> nat" where
+"T_length \<equiv> T_size"
 
 lemma T_length_eq: "T_length xs = length xs + 1"
   by (induction xs) auto
 
-lemmas [simp del] = T_length.simps
+lemmas [simp del] = T_size_list.simps
 
 time_fun map