--- a/src/HOL/Data_Structures/Define_Time_Function.ML Sun Sep 28 17:49:34 2025 +0000
+++ b/src/HOL/Data_Structures/Define_Time_Function.ML Tue Sep 30 08:56:11 2025 +0000
@@ -258,13 +258,13 @@
| fixCasecCases wctxt (t::ts) =
let
val num = fastype_of t |> strip_type |> fst |> length
- val c' = Term.strip_abs_eta num t |> list_abs
+ val c' = Term.strip_abs_eta num t ||> #f wctxt |> list_abs
in
c' :: fixCasecCases wctxt ts
end
| fixCasecCases _ _ = error "Internal error: invalid case types/terms"
fun fixCasec wctxt t args =
- (check_args "cases" (t,args); list_comb (t,fixCasecCases wctxt args))
+ (check_args "cases" (t,args); list_comb ((#f wctxt) t,fixCasecCases wctxt args))
fun shortFunc fixedNum (Const (nm,T)) =
Const (nm,T |> strip_type |>> drop fixedNum |> (op --->))
@@ -281,34 +281,10 @@
val _ = check_args "args" (strip_comb (get_l t))
val l' = shortApp fixedNum (strip_comb l) |> list_comb
val shortOriginFunc' = shortOriginFunc (term |> map (fst o strip_comb)) fixedNum
- val consts = Proof_Context.consts_of ctxt
- val net = Consts.revert_abbrevs consts ["internal"] |> hd |> Item_Net.content
- (* filter out consts *)
- |> filter (is_Const o fst o strip_comb o fst)
- (* filter out abbreviations for locales *)
- |> filter (fn n => "local"
- = (n |> snd |> strip_comb |> fst |> dest_Const_name |> split_name |> hd))
- |> filter (fn n => (n |> fst |> strip_comb |> fst |> dest_Const_name |> split_name |> List.last) =
- (n |> snd |> strip_comb |> fst |> dest_Const_name |> split_name |> List.last))
- |> map (fst #> strip_comb #>> dest_Const_name ##> length)
- fun n_abbrev (Const (nm,_)) =
- let
- val f = filter (fn n => fst n = nm) net
- in
- if length f >= 1 then f |> hd |> snd else 0
- end
- | n_abbrev _ = 0
val r' = walk ctxt term {
funcc = (fn wctxt => fn t => fn args =>
- let
- val n_abb = n_abbrev t
- val t = case t of Const (nm,T) => Const (nm, T |> strip_type |>> drop n_abb |> (op --->))
- | t => t
- val args = drop n_abb args
- in
- (check_args "func" (t,args);
- (#f wctxt t, map (#f wctxt) args) |> shortOriginFunc' |> list_comb)
- end),
+ (check_args "func" (t,args);
+ (#f wctxt t, map (#f wctxt) args) |> shortOriginFunc' |> list_comb)),
constc = fn wctxt => map_abs (#f wctxt),
ifc = Iif,
casec = fixCasec,
@@ -319,6 +295,23 @@
pT $ (eq $ l' $ r')
end
| fixTerms _ _ _ _ = error "Internal error: invalid term"
+fun postFixTerms ctxt (term: term list) (pT $ (eq $ l $ r)) =
+ let
+ val r' = walk ctxt term {
+ funcc = (fn wctxt => fn t => fn args =>
+ case List.find (fn el => Term.is_Const t
+ andalso (Term.dest_Const_name (strip_comb el |> fst)) = (Term.dest_Const_name t)) term of
+ SOME t => list_comb (t, map (#f wctxt) args)
+ | NONE => list_comb (#f wctxt t, map (#f wctxt) args)),
+ constc = Iconst,
+ ifc = Iif,
+ casec = Icase,
+ letc = Ilet
+ } r
+ in
+ pT $ (eq $ l $ r')
+ end
+ | postFixTerms _ _ _ = error "Internal error: invalid term"
(* 2. Check for properties about the function *)
(* 2.1 Check if function is recursive *)
@@ -363,6 +356,36 @@
val finT = if #partial config then natOptT else HOLogic.natT
val some = @{term "Some::nat \<Rightarrow> nat option"}
+ (* Convert implicit capturing functions in locales to their basic version *)
+ val consts = Proof_Context.consts_of lthy
+ val full_term = term
+ val net = Consts.revert_abbrevs consts ["internal"] |> hd |> Item_Net.content
+ (* filter out consts *)
+ |> filter (is_Const o fst o strip_comb o fst)
+ (* filter out abbreviations for locales *)
+ |> filter (fn n => "local"
+ = (n |> snd |> strip_comb |> fst |> dest_Const_name |> split_name |> hd))
+ |> filter (fn n => (n |> fst |> strip_comb |> fst |> dest_Const_name |> split_name |> List.last) =
+ (n |> snd |> strip_comb |> fst |> dest_Const_name |> split_name |> List.last))
+ |> map (fst #> strip_comb #>> dest_Const_name ##> length)
+ fun n_abbrev (Const (nm,_)) =
+ let
+ val f = filter (fn n => fst n = nm) net
+ in
+ if length f >= 1 then f |> hd |> snd else 0
+ end
+ | n_abbrev _ = 0
+ fun simpLocFunc (t: term) (args: term list) =
+ let
+ val n_abb = n_abbrev t
+ val simp_t = case t of Const (nm,T) => Const (nm, T |> strip_type |>> drop n_abb |> (op --->))
+ | t => t
+ val simp_args = drop n_abb args
+ in
+ if Term.is_Const t andalso contains (term |> map (Term.dest_Const_name o fst o strip_comb)) (t |> Term.dest_Const_name)
+ then (t, args) else (simp_t, simp_args)
+ end
+
(* change type of original function to new type (_ \<Rightarrow> ... \<Rightarrow> _ to _ \<Rightarrow> ... \<Rightarrow> nat)
and replace all function arguments f with (t*T_f) if used *)
fun change_typ' used (Type ("fun", [T1, T2])) =
@@ -391,6 +414,19 @@
in
SOME (Const (T_nm, binderT ---> finT))
end
+ (* Case for inside of locale, would need type *)
+ | f as (_$_) =>
+ let
+ val ((T_nm,T_T), fixes) = Term.strip_comb f |>> Term.dest_Const
+ val (T_Ts, finT) = Term.strip_type T_T
+ fun col_Used i (Type ("fun", [Type ("fun", _), Ts])) (Type ("fun", [T', Ts'])) =
+ (if is_Used T' then [i] else []) @ col_Used (i+1) Ts Ts'
+ | col_Used i (Type ("fun", [_, Ts])) (Type ("fun", [_, Ts'])) = col_Used (i+1) Ts Ts'
+ | col_Used _ _ _ = []
+ val binderT = change_typ' (contains (col_Used 0 T (drop (length fixes) T_Ts ---> finT))) T |> Term.binder_types
+ in
+ SOME (Term.list_comb (Const (T_nm, (take (length fixes) T_Ts) ---> binderT ---> finT), fixes))
+ end
| _ => error ("Timing function of " ^ nm ^ " is not defined")
end
| time_term _ _ _ = error "Internal error: No valid function given"
@@ -403,9 +439,9 @@
(* Conversion of function term *)
fun fun_to_time' ctxt (origin: term list) second (func as Const (nm,T)) =
let
- val origin' = map (fst o strip_comb) origin
+ val origin' = map (Term.dest_Const_name o fst o strip_comb) origin
in
- if contains' const_comp origin' func then SOME (Free (func |> Term.term_name |> fun_name_to_time' ctxt true second, change_typ T)) else
+ if contains origin' nm then SOME (Free (func |> Term.term_name |> fun_name_to_time' ctxt true second, change_typ T)) else
if Zero_Funcs.is_zero (Proof_Context.theory_of ctxt) (nm,T) then NONE else
time_term ctxt false func
end
@@ -477,9 +513,11 @@
| funcc_conv_args _ _ _ = error "Internal error: Non matching type"
fun funcc wctxt func args =
let
+ val (func, args) = simpLocFunc func args
fun get_T (Free (_,T)) = T
| get_T (Const (_,T)) = T
- | get_T (_ $ (Free (_,Type (_, [_, T])))) = T (* Case of snd was constructed *)
+ | get_T (Const ("Product_Type.prod.snd",_) $ (Free (_,Type (_, [_, T])))) = T (* Case of snd was constructed *)
+ | get_T (h $ _) = (case get_T h of Type ("fun", [_,T]) => T | _ => error "Internal error: Not a locale func")
| get_T _ = error "Internal error: Forgotten type"
val func = (case fun_to_time (#ctxt wctxt) (#origins wctxt) func
of SOME t => SOME (WRAP_FUNCTION (list_comb (t, funcc_conv_args wctxt (get_T t) args)))
@@ -837,6 +875,7 @@
| fSnd t = t
val T_terms = map (convert_term lthy term is_rec) terms
|> map (map_r (replaceFstSndFree lthy term fFst fSnd))
+ |> map (postFixTerms lthy full_term)
val simpables = (if #simp config
then find_simplifyble lthy term T_terms