merged
authornipkow
Mon, 29 Jul 2024 15:26:56 +0200
changeset 80625 fbb38db0435d
parent 80623 424b90ba7b6f (current diff)
parent 80624 9f8034d29365 (diff)
child 80627 11382acb0fc4
merged
--- a/src/HOL/Data_Structures/Define_Time_Function.ML	Mon Jul 29 10:49:17 2024 +0100
+++ b/src/HOL/Data_Structures/Define_Time_Function.ML	Mon Jul 29 15:26:56 2024 +0200
@@ -85,6 +85,14 @@
 val If_name = @{const_name "HOL.If"}
 val Let_name = @{const_name "HOL.Let"}
 
+fun contains l e = exists (fn e' => e' = e) l
+fun zip [] [] = []
+  | zip (x::xs) (y::ys) = (x, y) :: zip xs ys
+  | zip _ _ = error "Internal error: Cannot zip lists with differing size"
+fun index [] _ = 0
+  | index (x::xs) el = (if x = el then 0 else 1 + index xs el)
+fun used_for_const orig_used t i = orig_used (t,i)
+
 (* returns true if it's an if term *)
 fun is_if (Const (n,_)) = (n = If_name)
   | is_if _ = false
@@ -96,11 +104,13 @@
   | is_let _ = false
 (* change type of original function to new type (_ \<Rightarrow> ... \<Rightarrow> _ to _ \<Rightarrow> ... \<Rightarrow> nat)
     and replace all function arguments f with (t*T_f) *)
-fun change_typ (Type ("fun", [T1, T2])) = Type ("fun", [check_for_fun T1, change_typ T2])
-  | change_typ _ = HOLogic.natT
-and check_for_fun (f as Type ("fun", [_,_])) = HOLogic.mk_prodT (f, change_typ f)
-  | check_for_fun (Type ("Product_Type.prod", [t1,t2])) = HOLogic.mk_prodT (check_for_fun t1, check_for_fun t2)
-  | check_for_fun f = f
+fun change_typ' used i (Type ("fun", [T1, T2])) = 
+      Type ("fun", [check_for_fun' (used i) T1, change_typ' used (i+1) T2])
+  | change_typ' _ _ _ = HOLogic.natT
+and check_for_fun' true (f as Type ("fun", [_,_])) = HOLogic.mk_prodT (f, change_typ' (K false) 0 f)
+  | check_for_fun' false (f as Type ("fun", [_,_])) = change_typ' (K false) 0 f
+  | check_for_fun' _ t = t
+val change_typ = change_typ' (K false) 0
 (* Convert string name of function to its timing equivalent *)
 fun fun_name_to_time ctxt name =
 let
@@ -136,6 +146,8 @@
 (* Return name of Const *)
 fun Const_name (Const (nm,_)) = SOME nm
   | Const_name _ = NONE
+fun is_Used (Type ("Product_Type.prod", _)) = true
+  | is_Used _ = false
 
 fun time_term ctxt (Const (nm,T)) =
 let
@@ -145,7 +157,15 @@
 (SOME (Syntax.check_term ctxt (Const (T_nm,T_T))))
   handle (ERROR _) =>
     case Syntax.read_term ctxt (Long_Name.base_name T_nm)
-      of (Const (nm,_)) => SOME (Const (nm,T_T))
+      of (Const (nm,T_T)) =>
+        let
+          fun col_Used i (Type ("fun", [Type ("fun", _), Ts])) (Type ("fun", [T', Ts'])) =
+            (if is_Used T' then [i] else []) @ col_Used (i+1) Ts Ts'
+            | col_Used i (Type ("fun", [_, Ts])) (Type ("fun", [_, Ts'])) = col_Used (i+1) Ts Ts'
+            | col_Used _ _ _ = []
+        in
+          SOME (Const (nm,change_typ' (contains (col_Used 0 T T_T)) 0 T))
+        end
        | _ => error ("Timing function of " ^ nm ^ " is not defined")
 end
   | time_term _ _ = error "Internal error: No valid function given"
@@ -198,7 +218,7 @@
 fun fixTypes (TVar ((t, _), T)) = TFree (t, T)
   | fixTypes t = t
 
-fun noFun (Type ("fun",_)) = error "Functions in datatypes are not allowed in case constructions"
+fun noFun (Type ("fun", _)) = error "Functions in datatypes are not allowed in case constructions"
   | noFun _ = ()
 fun casecBuildBounds n t = if n > 0 then casecBuildBounds (n-1) (t $ (Bound (n-1))) else t
 fun casecAbs ctxt f n (Type (_,[T,Tr])) (Abs (v,Ta,t)) = (noFun T; Abs (v,Ta,casecAbs ctxt f n Tr t))
@@ -234,7 +254,8 @@
       }) t
   end
 
-(* 2. Check if function is recursive *)
+(* 2. Check for properties about the function *)
+(* 2.1 Check if function is recursive *)
 fun or f (a,b) = f a orelse b
 fun find_rec ctxt term = (walk ctxt term {
           funcc = (fn _ => fn _ => fn f => fn t => fn args => List.exists (fn term => Const_name t = Const_name term) term orelse List.foldr (or f) false args),
@@ -245,6 +266,38 @@
       }) o get_r
 fun is_rec ctxt (term: term list) = List.foldr (or (find_rec ctxt term)) false
 
+(* 2.2 Check for higher-order function if original function is used *)
+fun find_used' ctxt term t T_t =
+let
+  val (ident, _) = walk_func (get_l t) []
+  val (T_ident, T_args) = walk_func (get_l T_t) []
+
+  fun filter_passed [] = []
+    | filter_passed ((f as Free (_, Type ("Product_Type.prod",[Type ("fun",_), Type ("fun", _)])))::args) = 
+        f :: filter_passed args
+    | filter_passed (_::args) = filter_passed args
+  val frees' = (walk ctxt term {
+          funcc = (fn _ => fn _ => fn f => fn t => fn args =>
+              (case t of (Const ("Product_Type.prod.snd", _)) => []
+                  | _ => (if t = T_ident then [] else filter_passed args)
+                    @ List.foldr (fn (l,r) => f l @ r) [] args)),
+          constc = (K o K o K o K) [],
+          ifc = (fn _ => fn _ => fn f => fn _ => fn cond => fn tt => fn tf => f cond @ f tt @ f tf),
+          casec = (fn _ => fn _ => fn f => fn _ => fn cs => List.foldr (fn (l,r) => f l @ r) [] cs),
+          letc = (fn _ => fn _ => fn f => fn _ => fn exp => fn _ => fn _ => fn t => f exp @ f t)
+      }) (get_r T_t)
+  fun build _ [] _ = false
+    | build i (a::args) item =
+        (if item = (ident,i) then contains frees' a else build (i+1) args item)
+in
+  build 0 T_args
+end
+fun find_used ctxt term terms T_terms =
+  zip terms T_terms
+  |> List.map (fn (t, T_t) => find_used' ctxt term t T_t)
+  |> List.foldr (fn (f,g) => fn item => f item orelse g item) (K false)
+
+
 (* 3. Convert equations *)
 (* Some Helper *)
 val plusTyp = @{typ "nat => nat => nat"}
@@ -254,65 +307,90 @@
   | plus NONE NONE = NONE
 fun opt_term NONE = HOLogic.zero
   | opt_term (SOME t) = t
+fun use_origin (Free (nm, T as Type ("fun",_))) = HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T)))
+  | use_origin t = t
 
 (* Converting of function term *)
-fun fun_to_time ctxt (origin: term list) (func as Const (nm,T)) =
+fun fun_to_time ctxt orig_used _ (origin: term list) (func as Const (nm,T)) =
 let
-  val full_name_origin = map (fst o dest_Const) origin
   val prefix = Config.get ctxt bprefix
+  val used' = used_for_const orig_used func
 in
-  if List.exists (fn nm_orig => nm = nm_orig) full_name_origin then SOME (Free (prefix ^ Term.term_name func, change_typ T)) else
+  if contains origin func then SOME (Free (prefix ^ Term.term_name func, change_typ' used' 0 T)) else
   if Zero_Funcs.is_zero (Proof_Context.theory_of ctxt) (nm,T) then NONE else
     time_term ctxt func
 end
-  | fun_to_time _ _ (Free (nm,T)) = SOME (HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T))))
-  | fun_to_time _ _ _ = error "Internal error: invalid function to convert"
+  | fun_to_time ctxt _ used _ (f as Free (nm,T)) = SOME (
+      if used f then HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T)))
+      else Free (Config.get ctxt bprefix ^ nm, change_typ T)
+      )
+  | fun_to_time _ _ _ _ _ = error "Internal error: invalid function to convert"
 
 (* Convert arguments of left side of a term *)
-fun conv_arg _ _ (Free (nm,T as Type("fun",_))) = Free (nm, HOLogic.mk_prodT (T, change_typ T))
-  | conv_arg ctxt origin (f as Const (_,T as Type("fun",_))) =
-  (Const (@{const_name "Product_Type.Pair"},
-      Type ("fun", [T,Type ("fun",[change_typ T, HOLogic.mk_prodT (T,change_typ T)])]))
-    $ f $ (fun_to_time ctxt origin f |> Option.valOf))
-  | conv_arg ctxt origin ((Const ("Product_Type.Pair", _)) $ l $ r) = HOLogic.mk_prod (conv_arg ctxt origin l, conv_arg ctxt origin r)
-  | conv_arg _ _ x = x
-fun conv_args ctxt origin = map (conv_arg ctxt origin)
+fun conv_arg ctxt used _ (f as Free (nm,T as Type("fun",_))) =
+    if used f then Free (nm, HOLogic.mk_prodT (T, change_typ' (K false) 0 T))
+    else Free (Config.get ctxt bprefix ^ nm, change_typ' (K false) 0 T)
+  | conv_arg ctxt _ origin (f as Const (_, Type("fun",_))) =
+      (error "weird case i don't understand TODO"; HOLogic.mk_prod (f, fun_to_time ctxt (K false) (K false) origin f |> Option.valOf))
+  | conv_arg _ _ _ x = x
+fun conv_args ctxt used origin = map (conv_arg ctxt used origin)
 
 (* Handle function calls *)
 fun build_zero (Type ("fun", [T, R])) = Abs ("x", T, build_zero R)
   | build_zero _ = zero
-fun funcc_use_origin _ _ (Free (nm, T as Type ("fun",_))) = HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T)))
-  | funcc_use_origin _ _ t = t
-fun funcc_conv_arg ctxt origin (t as (_ $ _)) = map_aterms (funcc_use_origin ctxt origin) t
-  | funcc_conv_arg _ _ (Free (nm, T as Type ("fun",_))) = (Free (nm, HOLogic.mk_prodT (T, change_typ T)))
-  | funcc_conv_arg ctxt origin (f as Const (_,T as Type ("fun",_))) =
+fun funcc_use_origin _ _ used (f as Free (nm, T as Type ("fun",_))) =
+    if used f then HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T)))
+    else error "Internal error: Error in used detection"
+  | funcc_use_origin _ _ _ t = t
+fun funcc_conv_arg ctxt origin used _ (t as (_ $ _)) = map_aterms (funcc_use_origin ctxt origin used) t
+  | funcc_conv_arg ctxt _ used u (f as Free (nm, T as Type ("fun",_))) =
+      if used f then
+        if u then Free (nm, HOLogic.mk_prodT (T, change_typ T))
+        else HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T)))
+      else Free (Config.get ctxt bprefix ^ nm, change_typ T)
+  | funcc_conv_arg ctxt origin _ true (f as Const (_,T as Type ("fun",_))) =
   (Const (@{const_name "Product_Type.Pair"},
-      Type ("fun", [T,Type ("fun",[change_typ T, HOLogic.mk_prodT (T,change_typ T)])]))
-    $ f $ (Option.getOpt (fun_to_time ctxt origin f, build_zero T)))
-  | funcc_conv_arg _ _ t = t
-fun funcc_conv_args ctxt origin = map (funcc_conv_arg ctxt origin)
-fun funcc ctxt (origin: term list) f func args = List.foldr (I #-> plus)
-  (case fun_to_time ctxt origin func of SOME t => SOME (build_func (t,funcc_conv_args ctxt origin args))
-                                      | NONE => NONE)
+      Type ("fun", [T,Type ("fun", [change_typ T, HOLogic.mk_prodT (T,change_typ T)])]))
+    $ f $ (Option.getOpt (fun_to_time ctxt (K false) (K false) origin f, build_zero T)))
+  | funcc_conv_arg ctxt origin _ false (f as Const (_,T as Type ("fun",_))) =
+      Option.getOpt (fun_to_time ctxt (K false) (K false) origin f, build_zero T)
+  | funcc_conv_arg _ _ _ _ t = t
+
+fun funcc_conv_args _ _ _ _ [] = []
+  | funcc_conv_args ctxt origin used (Type ("fun", [t, ts])) (a::args) =
+      funcc_conv_arg ctxt origin used (is_Used t) a :: funcc_conv_args ctxt origin used ts args
+  | funcc_conv_args _ _ _ _ _ = error "Internal error: Non matching type"
+fun funcc orig_used used ctxt (origin: term list) f func args =
+let
+  fun get_T (Free (_,T)) = T
+    | get_T (Const (_,T)) = T
+    | get_T (_ $ (Free (_,Type (_, [_, T])))) = T (* Case of snd was constructed *)
+    | get_T _ = error "Internal error: Forgotten type"
+in
+  List.foldr (I #-> plus)
+  (case fun_to_time ctxt orig_used used origin func
+    of SOME t => SOME (build_func (t,funcc_conv_args ctxt origin used (get_T t) args))
+    | NONE => NONE)
   (map f args)
+end
 
 (* Handle case terms *)
 fun casecIsCase (Type (n1, [_,Type (n2, _)])) = (n1 = "fun" andalso n2 = "fun")
   | casecIsCase _ = false
 fun casecLastTyp (Type (n, [T1,T2])) = Type (n, [T1, change_typ T2])
-  | casecLastTyp _ = error "Internal error: invalid case type"
+  | casecLastTyp _ = error "Internal error: Invalid case type"
 fun casecTyp (Type (n, [T1, T2])) =
       Type (n, [change_typ T1, (if casecIsCase T2 then casecTyp else casecLastTyp) T2])
-  | casecTyp _ = error "Internal error: invalid case type"
+  | casecTyp _ = error "Internal error: Invalid case type"
 fun casecAbs f (Abs (v,Ta,t)) = (case casecAbs f t of (nconst,t) => (nconst,Abs (v,Ta,t)))
   | casecAbs f t = (case f t of NONE => (false,HOLogic.zero) | SOME t => (true,t))
-fun casecArgs _ [t] = (false, [t])
+fun casecArgs _ [t] = (false, [map_aterms use_origin t])
   | casecArgs f (t::ar) =
     (case casecAbs f t of (nconst, tt) => 
       casecArgs f ar ||> (fn ar => tt :: ar) |>> (if nconst then K true else I))
-  | casecArgs _ _ = error "Internal error: invalid case term"
+  | casecArgs _ _ = error "Internal error: Invalid case term"
 fun casec _ _ f (Const (t,T)) args =
-  if not (casecIsCase T) then error "Internal error: invalid case type" else
+  if not (casecIsCase T) then error "Internal error: Invalid case type" else
     let val (nconst, args') = casecArgs f args in
       plus
         (f (List.last args))
@@ -320,18 +398,17 @@
           SOME (build_func (Const (t,casecTyp T), args'))
          else NONE)
     end
-  | casec _ _ _ _ _ = error "Internal error: invalid case term"
+  | casec _ _ _ _ _ = error "Internal error: Invalid case term"
 
 (* Handle if terms -> drop the term if true and false terms are zero *)
-fun ifc ctxt origin f _ cond tt ft =
+fun ifc _ _ f _ cond tt ft =
   let
-    fun use_origin _ _ (Free (nm, T as Type ("fun",_))) = HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T)))
-      | use_origin _ _ t = t
-    val rcond = map_aterms (use_origin ctxt origin) cond
+    val rcond = map_aterms use_origin cond
     val tt = f tt
     val ft = f ft
   in
     plus (f cond) (case (tt,ft) of (NONE, NONE) => NONE | _ =>
+       if tt = ft then tt else
        (SOME ((Const (If_name, @{typ "bool \<Rightarrow> nat \<Rightarrow> nat \<Rightarrow> nat"})) $ rcond $ (opt_term tt) $ (opt_term ft))))
   end
 
@@ -342,21 +419,21 @@
     (if length nms = 0 (* In case of "length nms = 0" the expression got reducted
                           Here we need Bound 0 to gain non-partial application *)
     then (case f (t $ Bound 0) of SOME (t' $ Bound 0) =>
-                                 (SOME (Const (Let_name, letc_change_typ expT) $ exp $ t'))
+                                 (SOME (Const (Let_name, letc_change_typ expT) $ (map_aterms use_origin exp) $ t'))
                                   (* Expression is not used and can therefore let be dropped *)
                                 | SOME t' => SOME t'
                                 | NONE => NONE)
     else (case f t of SOME t' =>
-      SOME (if Term.is_dependent t' then Const (Let_name, letc_change_typ expT) $ exp $ build_abs t' nms Ts
+      SOME (if Term.is_dependent t' then Const (Let_name, letc_change_typ expT) $ (map_aterms use_origin exp) $ build_abs t' nms Ts
                                     else Term.subst_bounds([exp],t'))
-                               | NONE => NONE))
+    | NONE => NONE))
 
 (* The converter for timing functions given to the walker *)
-val converter : term option converter = {
+fun converter orig_used used : term option converter = {
         constc = fn _ => fn _ => fn _ => fn t =>
           (case t of Const ("HOL.undefined", _) => SOME (Const ("HOL.undefined", @{typ "nat"}))
                    | _ => NONE),
-        funcc = funcc,
+        funcc = (funcc orig_used used),
         ifc = ifc,
         casec = casec,
         letc = letc
@@ -364,16 +441,24 @@
 fun top_converter is_rec _ _ = opt_term o (fn exp => plus exp (if is_rec then SOME one else NONE))
 
 (* Use converter to convert right side of a term *)
-fun to_time ctxt origin is_rec term =
-  top_converter is_rec ctxt origin (walk ctxt origin converter term)
+fun to_time ctxt origin is_rec orig_used used term =
+  top_converter is_rec ctxt origin (walk ctxt origin (converter orig_used used) term)
 
 (* Converts a term to its running time version *)
-fun convert_term ctxt (origin: term list) is_rec (pT $ (Const (eqN, _) $ l $ r)) =
+fun convert_term ctxt (origin: term list) is_rec orig_used (pT $ (Const (eqN, _) $ l $ r)) =
+let
+  val (l' as (l_const, l_params)) = walk_func l []
+  val used =
+    l_const
+    |> used_for_const orig_used
+    |> (fn f => fn n => f (index l_params n))
+in
       pT
       $ (Const (eqN, @{typ "nat \<Rightarrow> nat \<Rightarrow> bool"})
-        $ (build_func ((walk_func l []) |>> (fun_to_time ctxt origin) |>> Option.valOf ||> conv_args ctxt origin))
-        $ (to_time ctxt origin is_rec r))
-  | convert_term _ _ _ _ = error "Internal error: invalid term to convert"
+        $ (build_func (l' |>> (fun_to_time ctxt orig_used used origin) |>> Option.valOf ||> conv_args ctxt used origin))
+        $ (to_time ctxt origin is_rec orig_used used r))
+end
+  | convert_term _ _ _ _ _ = error "Internal error: invalid term to convert"
 
 (* 4. Tactic to prove "f_dom n" *)
 fun time_dom_tac ctxt induct_rule domintros =
@@ -410,7 +495,8 @@
         #> fixPartTerms lthy term)
       terms
 
-    (* 2. Check if function is recursive *)
+    (* 2. Find properties about the function *)
+    (* 2.1 Check if function is recursive *)
     val is_rec = is_rec lthy term terms
 
     (* 3. Convert every equation
@@ -418,7 +504,16 @@
       - On left side change name of function to timing function
       - Convert right side of equation with conversion schema
     *)
-    val timing_terms = map (convert_term lthy term is_rec) terms
+    fun convert used = map (convert_term lthy term is_rec used)
+    fun repeat T_terms =
+      let
+        val orig_used = find_used lthy term terms T_terms
+        val T_terms' = convert orig_used terms
+      in
+        if T_terms' <> T_terms then repeat T_terms' else T_terms'
+      end
+    val T_terms = repeat (convert (K true) terms)
+    val orig_used = find_used lthy term terms T_terms
 
     (* 4. Register function and prove termination *)
     val names = map Term.term_name term
@@ -426,7 +521,7 @@
     val bindings = map (fn nm => (Binding.name nm, NONE, NoSyn)) timing_names
     fun pat_completeness_auto ctxt =
       Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt
-    val specs = map (fn eq => (((Binding.empty, []), eq), [], [])) timing_terms
+    val specs = map (fn eq => (((Binding.empty, []), eq), [], [])) T_terms
 
     (* For partial functions sequential=true is needed in order to support them
        We need sequential=false to support the automatic proof of termination over dom
@@ -444,14 +539,15 @@
     val print_ctxt = lthy
       |> Config.put show_question_marks false
       |> Config.put show_sorts false (* Change it for debugging *)
-    val print_ctxt =  List.foldl (fn (term, ctxt) => Variable.add_fixes_implicit term ctxt) print_ctxt (term@timing_terms)
+    val print_ctxt =  List.foldl (fn (term, ctxt) => Variable.add_fixes_implicit term ctxt) print_ctxt (term @ T_terms)
     (* Print result if print *)
     val _ = if not print then () else
         let
           val nms = map (fst o dest_Const) term
+          val used = map (used_for_const orig_used) term
           val typs = map (snd o dest_Const) term
         in
-          print_timing' print_ctxt { names=nms, terms=terms, typs=typs } { names=timing_names, terms=timing_terms, typs=map change_typ typs }
+          print_timing' print_ctxt { names=nms, terms=terms, typs=typs } { names=timing_names, terms=T_terms, typs=map (fn (used, typ) => change_typ' used 0 typ) (zip used typs) }
         end
 
     (* Register function *)
--- a/src/HOL/Data_Structures/Time_Funs.thy	Mon Jul 29 10:49:17 2024 +0100
+++ b/src/HOL/Data_Structures/Time_Funs.thy	Mon Jul 29 15:26:56 2024 +0200
@@ -23,10 +23,7 @@
 
 lemmas [simp del] = T_length.simps
 
-
-fun T_map  :: "('a \<Rightarrow> nat) \<Rightarrow> 'a list \<Rightarrow> nat" where
-  "T_map T_f [] = 1"
-| "T_map T_f (x # xs) = T_f x + T_map T_f xs + 1"
+time_fun map
 
 lemma T_map_eq: "T_map T_f xs = (\<Sum>x\<leftarrow>xs. T_f x) + length xs + 1"
   by (induction xs) auto