merged
authordesharna
Tue, 30 Sep 2025 08:56:11 +0000
changeset 83234 afcabf75f807
parent 83233 4f15c5c3781f (current diff)
parent 83232 9121e3b0b0a0 (diff)
child 83235 a90bb622445b
merged
--- a/src/HOL/Data_Structures/Define_Time_Function.ML	Sun Sep 28 17:49:34 2025 +0000
+++ b/src/HOL/Data_Structures/Define_Time_Function.ML	Tue Sep 30 08:56:11 2025 +0000
@@ -258,13 +258,13 @@
   | fixCasecCases wctxt (t::ts) =
     let
       val num = fastype_of t |> strip_type |> fst |> length
-      val c' = Term.strip_abs_eta num t |> list_abs
+      val c' = Term.strip_abs_eta num t ||> #f wctxt |> list_abs
     in
       c' :: fixCasecCases wctxt ts
     end
   | fixCasecCases _ _ = error "Internal error: invalid case types/terms"
 fun fixCasec wctxt t args =
-      (check_args "cases" (t,args); list_comb (t,fixCasecCases wctxt args))
+      (check_args "cases" (t,args); list_comb ((#f wctxt) t,fixCasecCases wctxt args))
 
 fun shortFunc fixedNum (Const (nm,T)) = 
     Const (nm,T |> strip_type |>> drop fixedNum |> (op --->))
@@ -281,34 +281,10 @@
     val _ = check_args "args" (strip_comb (get_l t))
     val l' = shortApp fixedNum (strip_comb l) |> list_comb
     val shortOriginFunc' = shortOriginFunc (term |> map (fst o strip_comb)) fixedNum
-    val consts = Proof_Context.consts_of ctxt
-    val net = Consts.revert_abbrevs consts ["internal"] |> hd |> Item_Net.content
-                (* filter out consts *)
-              |> filter (is_Const o fst o strip_comb o fst)
-                (* filter out abbreviations for locales *)
-              |> filter (fn n => "local"
-                  = (n |> snd |> strip_comb |> fst |> dest_Const_name |> split_name |> hd))
-              |> filter (fn n => (n |> fst |> strip_comb |> fst |> dest_Const_name |> split_name |> List.last) =
-                (n |> snd |> strip_comb |> fst |> dest_Const_name |> split_name |> List.last))
-              |> map (fst #> strip_comb #>> dest_Const_name ##> length)
-    fun n_abbrev (Const (nm,_)) =
-    let
-      val f = filter (fn n => fst n = nm) net
-    in
-      if length f >= 1 then f |> hd |> snd else 0
-    end
-      | n_abbrev _ = 0
     val r' = walk ctxt term {
           funcc = (fn wctxt => fn t => fn args =>
-            let
-              val n_abb = n_abbrev t
-              val t = case t of Const (nm,T) => Const (nm, T |> strip_type |>> drop n_abb |> (op --->))
-                              | t => t
-              val args = drop n_abb args           
-            in
-              (check_args "func" (t,args);
-               (#f wctxt t, map (#f wctxt) args) |> shortOriginFunc' |> list_comb)
-            end),
+            (check_args "func" (t,args);
+               (#f wctxt t, map (#f wctxt) args) |> shortOriginFunc' |> list_comb)),
           constc = fn wctxt => map_abs (#f wctxt),
           ifc = Iif,
           casec = fixCasec,
@@ -319,6 +295,23 @@
     pT $ (eq $ l' $ r')
   end
   | fixTerms _ _ _ _ = error "Internal error: invalid term"
+fun postFixTerms ctxt (term: term list) (pT $ (eq $ l $ r)) =
+  let
+    val r' = walk ctxt term {
+          funcc = (fn wctxt => fn t => fn args =>
+            case List.find (fn el => Term.is_Const t
+              andalso (Term.dest_Const_name (strip_comb el |> fst)) = (Term.dest_Const_name t)) term of
+              SOME t => list_comb (t, map (#f wctxt) args)
+            | NONE => list_comb (#f wctxt t, map (#f wctxt) args)),
+          constc = Iconst,
+          ifc = Iif,
+          casec = Icase,
+          letc = Ilet
+      } r
+  in
+    pT $ (eq $ l $ r')
+  end
+  | postFixTerms _ _ _ = error "Internal error: invalid term"
 
 (* 2. Check for properties about the function *)
 (* 2.1 Check if function is recursive *)
@@ -363,6 +356,36 @@
     val finT = if #partial config then natOptT else HOLogic.natT
     val some = @{term "Some::nat \<Rightarrow> nat option"}
 
+    (* Convert implicit capturing functions in locales to their basic version *)
+    val consts = Proof_Context.consts_of lthy
+    val full_term = term
+    val net = Consts.revert_abbrevs consts ["internal"] |> hd |> Item_Net.content
+                (* filter out consts *)
+              |> filter (is_Const o fst o strip_comb o fst)
+                (* filter out abbreviations for locales *)
+              |> filter (fn n => "local"
+                  = (n |> snd |> strip_comb |> fst |> dest_Const_name |> split_name |> hd))
+              |> filter (fn n => (n |> fst |> strip_comb |> fst |> dest_Const_name |> split_name |> List.last) =
+                (n |> snd |> strip_comb |> fst |> dest_Const_name |> split_name |> List.last))
+              |> map (fst #> strip_comb #>> dest_Const_name ##> length)
+    fun n_abbrev (Const (nm,_)) =
+    let
+      val f = filter (fn n => fst n = nm) net
+    in
+      if length f >= 1 then f |> hd |> snd else 0
+    end
+      | n_abbrev _ = 0
+    fun simpLocFunc (t: term) (args: term list) =
+      let
+        val n_abb = n_abbrev t
+        val simp_t = case t of Const (nm,T) => Const (nm, T |> strip_type |>> drop n_abb |> (op --->))
+                        | t => t
+        val simp_args = drop n_abb args 
+      in
+        if Term.is_Const t andalso contains (term |> map (Term.dest_Const_name o fst o strip_comb)) (t |> Term.dest_Const_name)
+        then (t, args) else (simp_t, simp_args)
+      end
+
     (* change type of original function to new type (_ \<Rightarrow> ... \<Rightarrow> _ to _ \<Rightarrow> ... \<Rightarrow> nat)
         and replace all function arguments f with (t*T_f) if used *)
     fun change_typ' used (Type ("fun", [T1, T2])) = 
@@ -391,6 +414,19 @@
             in
               SOME (Const (T_nm, binderT ---> finT))
             end
+           (* Case for inside of locale, would need type *)
+           | f as (_$_) =>
+            let
+              val ((T_nm,T_T), fixes) = Term.strip_comb f |>> Term.dest_Const
+              val (T_Ts, finT) = Term.strip_type T_T
+              fun col_Used i (Type ("fun", [Type ("fun", _), Ts])) (Type ("fun", [T', Ts'])) =
+                (if is_Used T' then [i] else []) @ col_Used (i+1) Ts Ts'
+                | col_Used i (Type ("fun", [_, Ts])) (Type ("fun", [_, Ts'])) = col_Used (i+1) Ts Ts'
+                | col_Used _ _ _ = []
+              val binderT = change_typ' (contains (col_Used 0 T (drop (length fixes) T_Ts ---> finT))) T |> Term.binder_types
+            in
+              SOME (Term.list_comb (Const (T_nm, (take (length fixes) T_Ts) ---> binderT ---> finT), fixes))
+            end
            | _ => error ("Timing function of " ^ nm ^ " is not defined")
     end
       | time_term _ _ _ = error "Internal error: No valid function given"
@@ -403,9 +439,9 @@
     (* Conversion of function term *)
     fun fun_to_time' ctxt (origin: term list) second (func as Const (nm,T)) =
     let
-      val origin' = map (fst o strip_comb) origin
+      val origin' = map (Term.dest_Const_name o fst o strip_comb) origin
     in
-      if contains' const_comp origin' func then SOME (Free (func |> Term.term_name |> fun_name_to_time' ctxt true second, change_typ T)) else
+      if contains origin' nm then SOME (Free (func |> Term.term_name |> fun_name_to_time' ctxt true second, change_typ T)) else
       if Zero_Funcs.is_zero (Proof_Context.theory_of ctxt) (nm,T) then NONE else
         time_term ctxt false func
     end
@@ -477,9 +513,11 @@
       | funcc_conv_args _ _ _ = error "Internal error: Non matching type"
     fun funcc wctxt func args =
     let
+      val (func, args) = simpLocFunc func args
       fun get_T (Free (_,T)) = T
         | get_T (Const (_,T)) = T
-        | get_T (_ $ (Free (_,Type (_, [_, T])))) = T (* Case of snd was constructed *)
+        | get_T (Const ("Product_Type.prod.snd",_) $ (Free (_,Type (_, [_, T])))) = T (* Case of snd was constructed *)
+        | get_T (h $ _) = (case get_T h of Type ("fun", [_,T]) => T | _ => error "Internal error: Not a locale func")
         | get_T _ = error "Internal error: Forgotten type"
       val func = (case fun_to_time (#ctxt wctxt) (#origins wctxt) func
         of SOME t => SOME (WRAP_FUNCTION (list_comb (t, funcc_conv_args wctxt (get_T t) args)))
@@ -837,6 +875,7 @@
       | fSnd t = t
     val T_terms = map (convert_term lthy term is_rec) terms
       |> map (map_r (replaceFstSndFree lthy term fFst fSnd))
+      |> map (postFixTerms lthy full_term)
 
     val simpables = (if #simp config
       then find_simplifyble lthy term T_terms