--- a/src/HOL/Data_Structures/Define_Time_Function.ML Mon Jul 29 10:49:17 2024 +0100
+++ b/src/HOL/Data_Structures/Define_Time_Function.ML Mon Jul 29 15:26:56 2024 +0200
@@ -85,6 +85,14 @@
val If_name = @{const_name "HOL.If"}
val Let_name = @{const_name "HOL.Let"}
+fun contains l e = exists (fn e' => e' = e) l
+fun zip [] [] = []
+ | zip (x::xs) (y::ys) = (x, y) :: zip xs ys
+ | zip _ _ = error "Internal error: Cannot zip lists with differing size"
+fun index [] _ = 0
+ | index (x::xs) el = (if x = el then 0 else 1 + index xs el)
+fun used_for_const orig_used t i = orig_used (t,i)
+
(* returns true if it's an if term *)
fun is_if (Const (n,_)) = (n = If_name)
| is_if _ = false
@@ -96,11 +104,13 @@
| is_let _ = false
(* change type of original function to new type (_ \<Rightarrow> ... \<Rightarrow> _ to _ \<Rightarrow> ... \<Rightarrow> nat)
and replace all function arguments f with (t*T_f) *)
-fun change_typ (Type ("fun", [T1, T2])) = Type ("fun", [check_for_fun T1, change_typ T2])
- | change_typ _ = HOLogic.natT
-and check_for_fun (f as Type ("fun", [_,_])) = HOLogic.mk_prodT (f, change_typ f)
- | check_for_fun (Type ("Product_Type.prod", [t1,t2])) = HOLogic.mk_prodT (check_for_fun t1, check_for_fun t2)
- | check_for_fun f = f
+fun change_typ' used i (Type ("fun", [T1, T2])) =
+ Type ("fun", [check_for_fun' (used i) T1, change_typ' used (i+1) T2])
+ | change_typ' _ _ _ = HOLogic.natT
+and check_for_fun' true (f as Type ("fun", [_,_])) = HOLogic.mk_prodT (f, change_typ' (K false) 0 f)
+ | check_for_fun' false (f as Type ("fun", [_,_])) = change_typ' (K false) 0 f
+ | check_for_fun' _ t = t
+val change_typ = change_typ' (K false) 0
(* Convert string name of function to its timing equivalent *)
fun fun_name_to_time ctxt name =
let
@@ -136,6 +146,8 @@
(* Return name of Const *)
fun Const_name (Const (nm,_)) = SOME nm
| Const_name _ = NONE
+fun is_Used (Type ("Product_Type.prod", _)) = true
+ | is_Used _ = false
fun time_term ctxt (Const (nm,T)) =
let
@@ -145,7 +157,15 @@
(SOME (Syntax.check_term ctxt (Const (T_nm,T_T))))
handle (ERROR _) =>
case Syntax.read_term ctxt (Long_Name.base_name T_nm)
- of (Const (nm,_)) => SOME (Const (nm,T_T))
+ of (Const (nm,T_T)) =>
+ let
+ fun col_Used i (Type ("fun", [Type ("fun", _), Ts])) (Type ("fun", [T', Ts'])) =
+ (if is_Used T' then [i] else []) @ col_Used (i+1) Ts Ts'
+ | col_Used i (Type ("fun", [_, Ts])) (Type ("fun", [_, Ts'])) = col_Used (i+1) Ts Ts'
+ | col_Used _ _ _ = []
+ in
+ SOME (Const (nm,change_typ' (contains (col_Used 0 T T_T)) 0 T))
+ end
| _ => error ("Timing function of " ^ nm ^ " is not defined")
end
| time_term _ _ = error "Internal error: No valid function given"
@@ -198,7 +218,7 @@
fun fixTypes (TVar ((t, _), T)) = TFree (t, T)
| fixTypes t = t
-fun noFun (Type ("fun",_)) = error "Functions in datatypes are not allowed in case constructions"
+fun noFun (Type ("fun", _)) = error "Functions in datatypes are not allowed in case constructions"
| noFun _ = ()
fun casecBuildBounds n t = if n > 0 then casecBuildBounds (n-1) (t $ (Bound (n-1))) else t
fun casecAbs ctxt f n (Type (_,[T,Tr])) (Abs (v,Ta,t)) = (noFun T; Abs (v,Ta,casecAbs ctxt f n Tr t))
@@ -234,7 +254,8 @@
}) t
end
-(* 2. Check if function is recursive *)
+(* 2. Check for properties about the function *)
+(* 2.1 Check if function is recursive *)
fun or f (a,b) = f a orelse b
fun find_rec ctxt term = (walk ctxt term {
funcc = (fn _ => fn _ => fn f => fn t => fn args => List.exists (fn term => Const_name t = Const_name term) term orelse List.foldr (or f) false args),
@@ -245,6 +266,38 @@
}) o get_r
fun is_rec ctxt (term: term list) = List.foldr (or (find_rec ctxt term)) false
+(* 2.2 Check for higher-order function if original function is used *)
+fun find_used' ctxt term t T_t =
+let
+ val (ident, _) = walk_func (get_l t) []
+ val (T_ident, T_args) = walk_func (get_l T_t) []
+
+ fun filter_passed [] = []
+ | filter_passed ((f as Free (_, Type ("Product_Type.prod",[Type ("fun",_), Type ("fun", _)])))::args) =
+ f :: filter_passed args
+ | filter_passed (_::args) = filter_passed args
+ val frees' = (walk ctxt term {
+ funcc = (fn _ => fn _ => fn f => fn t => fn args =>
+ (case t of (Const ("Product_Type.prod.snd", _)) => []
+ | _ => (if t = T_ident then [] else filter_passed args)
+ @ List.foldr (fn (l,r) => f l @ r) [] args)),
+ constc = (K o K o K o K) [],
+ ifc = (fn _ => fn _ => fn f => fn _ => fn cond => fn tt => fn tf => f cond @ f tt @ f tf),
+ casec = (fn _ => fn _ => fn f => fn _ => fn cs => List.foldr (fn (l,r) => f l @ r) [] cs),
+ letc = (fn _ => fn _ => fn f => fn _ => fn exp => fn _ => fn _ => fn t => f exp @ f t)
+ }) (get_r T_t)
+ fun build _ [] _ = false
+ | build i (a::args) item =
+ (if item = (ident,i) then contains frees' a else build (i+1) args item)
+in
+ build 0 T_args
+end
+fun find_used ctxt term terms T_terms =
+ zip terms T_terms
+ |> List.map (fn (t, T_t) => find_used' ctxt term t T_t)
+ |> List.foldr (fn (f,g) => fn item => f item orelse g item) (K false)
+
+
(* 3. Convert equations *)
(* Some Helper *)
val plusTyp = @{typ "nat => nat => nat"}
@@ -254,65 +307,90 @@
| plus NONE NONE = NONE
fun opt_term NONE = HOLogic.zero
| opt_term (SOME t) = t
+fun use_origin (Free (nm, T as Type ("fun",_))) = HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T)))
+ | use_origin t = t
(* Converting of function term *)
-fun fun_to_time ctxt (origin: term list) (func as Const (nm,T)) =
+fun fun_to_time ctxt orig_used _ (origin: term list) (func as Const (nm,T)) =
let
- val full_name_origin = map (fst o dest_Const) origin
val prefix = Config.get ctxt bprefix
+ val used' = used_for_const orig_used func
in
- if List.exists (fn nm_orig => nm = nm_orig) full_name_origin then SOME (Free (prefix ^ Term.term_name func, change_typ T)) else
+ if contains origin func then SOME (Free (prefix ^ Term.term_name func, change_typ' used' 0 T)) else
if Zero_Funcs.is_zero (Proof_Context.theory_of ctxt) (nm,T) then NONE else
time_term ctxt func
end
- | fun_to_time _ _ (Free (nm,T)) = SOME (HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T))))
- | fun_to_time _ _ _ = error "Internal error: invalid function to convert"
+ | fun_to_time ctxt _ used _ (f as Free (nm,T)) = SOME (
+ if used f then HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T)))
+ else Free (Config.get ctxt bprefix ^ nm, change_typ T)
+ )
+ | fun_to_time _ _ _ _ _ = error "Internal error: invalid function to convert"
(* Convert arguments of left side of a term *)
-fun conv_arg _ _ (Free (nm,T as Type("fun",_))) = Free (nm, HOLogic.mk_prodT (T, change_typ T))
- | conv_arg ctxt origin (f as Const (_,T as Type("fun",_))) =
- (Const (@{const_name "Product_Type.Pair"},
- Type ("fun", [T,Type ("fun",[change_typ T, HOLogic.mk_prodT (T,change_typ T)])]))
- $ f $ (fun_to_time ctxt origin f |> Option.valOf))
- | conv_arg ctxt origin ((Const ("Product_Type.Pair", _)) $ l $ r) = HOLogic.mk_prod (conv_arg ctxt origin l, conv_arg ctxt origin r)
- | conv_arg _ _ x = x
-fun conv_args ctxt origin = map (conv_arg ctxt origin)
+fun conv_arg ctxt used _ (f as Free (nm,T as Type("fun",_))) =
+ if used f then Free (nm, HOLogic.mk_prodT (T, change_typ' (K false) 0 T))
+ else Free (Config.get ctxt bprefix ^ nm, change_typ' (K false) 0 T)
+ | conv_arg ctxt _ origin (f as Const (_, Type("fun",_))) =
+ (error "weird case i don't understand TODO"; HOLogic.mk_prod (f, fun_to_time ctxt (K false) (K false) origin f |> Option.valOf))
+ | conv_arg _ _ _ x = x
+fun conv_args ctxt used origin = map (conv_arg ctxt used origin)
(* Handle function calls *)
fun build_zero (Type ("fun", [T, R])) = Abs ("x", T, build_zero R)
| build_zero _ = zero
-fun funcc_use_origin _ _ (Free (nm, T as Type ("fun",_))) = HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T)))
- | funcc_use_origin _ _ t = t
-fun funcc_conv_arg ctxt origin (t as (_ $ _)) = map_aterms (funcc_use_origin ctxt origin) t
- | funcc_conv_arg _ _ (Free (nm, T as Type ("fun",_))) = (Free (nm, HOLogic.mk_prodT (T, change_typ T)))
- | funcc_conv_arg ctxt origin (f as Const (_,T as Type ("fun",_))) =
+fun funcc_use_origin _ _ used (f as Free (nm, T as Type ("fun",_))) =
+ if used f then HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T)))
+ else error "Internal error: Error in used detection"
+ | funcc_use_origin _ _ _ t = t
+fun funcc_conv_arg ctxt origin used _ (t as (_ $ _)) = map_aterms (funcc_use_origin ctxt origin used) t
+ | funcc_conv_arg ctxt _ used u (f as Free (nm, T as Type ("fun",_))) =
+ if used f then
+ if u then Free (nm, HOLogic.mk_prodT (T, change_typ T))
+ else HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T)))
+ else Free (Config.get ctxt bprefix ^ nm, change_typ T)
+ | funcc_conv_arg ctxt origin _ true (f as Const (_,T as Type ("fun",_))) =
(Const (@{const_name "Product_Type.Pair"},
- Type ("fun", [T,Type ("fun",[change_typ T, HOLogic.mk_prodT (T,change_typ T)])]))
- $ f $ (Option.getOpt (fun_to_time ctxt origin f, build_zero T)))
- | funcc_conv_arg _ _ t = t
-fun funcc_conv_args ctxt origin = map (funcc_conv_arg ctxt origin)
-fun funcc ctxt (origin: term list) f func args = List.foldr (I #-> plus)
- (case fun_to_time ctxt origin func of SOME t => SOME (build_func (t,funcc_conv_args ctxt origin args))
- | NONE => NONE)
+ Type ("fun", [T,Type ("fun", [change_typ T, HOLogic.mk_prodT (T,change_typ T)])]))
+ $ f $ (Option.getOpt (fun_to_time ctxt (K false) (K false) origin f, build_zero T)))
+ | funcc_conv_arg ctxt origin _ false (f as Const (_,T as Type ("fun",_))) =
+ Option.getOpt (fun_to_time ctxt (K false) (K false) origin f, build_zero T)
+ | funcc_conv_arg _ _ _ _ t = t
+
+fun funcc_conv_args _ _ _ _ [] = []
+ | funcc_conv_args ctxt origin used (Type ("fun", [t, ts])) (a::args) =
+ funcc_conv_arg ctxt origin used (is_Used t) a :: funcc_conv_args ctxt origin used ts args
+ | funcc_conv_args _ _ _ _ _ = error "Internal error: Non matching type"
+fun funcc orig_used used ctxt (origin: term list) f func args =
+let
+ fun get_T (Free (_,T)) = T
+ | get_T (Const (_,T)) = T
+ | get_T (_ $ (Free (_,Type (_, [_, T])))) = T (* Case of snd was constructed *)
+ | get_T _ = error "Internal error: Forgotten type"
+in
+ List.foldr (I #-> plus)
+ (case fun_to_time ctxt orig_used used origin func
+ of SOME t => SOME (build_func (t,funcc_conv_args ctxt origin used (get_T t) args))
+ | NONE => NONE)
(map f args)
+end
(* Handle case terms *)
fun casecIsCase (Type (n1, [_,Type (n2, _)])) = (n1 = "fun" andalso n2 = "fun")
| casecIsCase _ = false
fun casecLastTyp (Type (n, [T1,T2])) = Type (n, [T1, change_typ T2])
- | casecLastTyp _ = error "Internal error: invalid case type"
+ | casecLastTyp _ = error "Internal error: Invalid case type"
fun casecTyp (Type (n, [T1, T2])) =
Type (n, [change_typ T1, (if casecIsCase T2 then casecTyp else casecLastTyp) T2])
- | casecTyp _ = error "Internal error: invalid case type"
+ | casecTyp _ = error "Internal error: Invalid case type"
fun casecAbs f (Abs (v,Ta,t)) = (case casecAbs f t of (nconst,t) => (nconst,Abs (v,Ta,t)))
| casecAbs f t = (case f t of NONE => (false,HOLogic.zero) | SOME t => (true,t))
-fun casecArgs _ [t] = (false, [t])
+fun casecArgs _ [t] = (false, [map_aterms use_origin t])
| casecArgs f (t::ar) =
(case casecAbs f t of (nconst, tt) =>
casecArgs f ar ||> (fn ar => tt :: ar) |>> (if nconst then K true else I))
- | casecArgs _ _ = error "Internal error: invalid case term"
+ | casecArgs _ _ = error "Internal error: Invalid case term"
fun casec _ _ f (Const (t,T)) args =
- if not (casecIsCase T) then error "Internal error: invalid case type" else
+ if not (casecIsCase T) then error "Internal error: Invalid case type" else
let val (nconst, args') = casecArgs f args in
plus
(f (List.last args))
@@ -320,18 +398,17 @@
SOME (build_func (Const (t,casecTyp T), args'))
else NONE)
end
- | casec _ _ _ _ _ = error "Internal error: invalid case term"
+ | casec _ _ _ _ _ = error "Internal error: Invalid case term"
(* Handle if terms -> drop the term if true and false terms are zero *)
-fun ifc ctxt origin f _ cond tt ft =
+fun ifc _ _ f _ cond tt ft =
let
- fun use_origin _ _ (Free (nm, T as Type ("fun",_))) = HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T)))
- | use_origin _ _ t = t
- val rcond = map_aterms (use_origin ctxt origin) cond
+ val rcond = map_aterms use_origin cond
val tt = f tt
val ft = f ft
in
plus (f cond) (case (tt,ft) of (NONE, NONE) => NONE | _ =>
+ if tt = ft then tt else
(SOME ((Const (If_name, @{typ "bool \<Rightarrow> nat \<Rightarrow> nat \<Rightarrow> nat"})) $ rcond $ (opt_term tt) $ (opt_term ft))))
end
@@ -342,21 +419,21 @@
(if length nms = 0 (* In case of "length nms = 0" the expression got reducted
Here we need Bound 0 to gain non-partial application *)
then (case f (t $ Bound 0) of SOME (t' $ Bound 0) =>
- (SOME (Const (Let_name, letc_change_typ expT) $ exp $ t'))
+ (SOME (Const (Let_name, letc_change_typ expT) $ (map_aterms use_origin exp) $ t'))
(* Expression is not used and can therefore let be dropped *)
| SOME t' => SOME t'
| NONE => NONE)
else (case f t of SOME t' =>
- SOME (if Term.is_dependent t' then Const (Let_name, letc_change_typ expT) $ exp $ build_abs t' nms Ts
+ SOME (if Term.is_dependent t' then Const (Let_name, letc_change_typ expT) $ (map_aterms use_origin exp) $ build_abs t' nms Ts
else Term.subst_bounds([exp],t'))
- | NONE => NONE))
+ | NONE => NONE))
(* The converter for timing functions given to the walker *)
-val converter : term option converter = {
+fun converter orig_used used : term option converter = {
constc = fn _ => fn _ => fn _ => fn t =>
(case t of Const ("HOL.undefined", _) => SOME (Const ("HOL.undefined", @{typ "nat"}))
| _ => NONE),
- funcc = funcc,
+ funcc = (funcc orig_used used),
ifc = ifc,
casec = casec,
letc = letc
@@ -364,16 +441,24 @@
fun top_converter is_rec _ _ = opt_term o (fn exp => plus exp (if is_rec then SOME one else NONE))
(* Use converter to convert right side of a term *)
-fun to_time ctxt origin is_rec term =
- top_converter is_rec ctxt origin (walk ctxt origin converter term)
+fun to_time ctxt origin is_rec orig_used used term =
+ top_converter is_rec ctxt origin (walk ctxt origin (converter orig_used used) term)
(* Converts a term to its running time version *)
-fun convert_term ctxt (origin: term list) is_rec (pT $ (Const (eqN, _) $ l $ r)) =
+fun convert_term ctxt (origin: term list) is_rec orig_used (pT $ (Const (eqN, _) $ l $ r)) =
+let
+ val (l' as (l_const, l_params)) = walk_func l []
+ val used =
+ l_const
+ |> used_for_const orig_used
+ |> (fn f => fn n => f (index l_params n))
+in
pT
$ (Const (eqN, @{typ "nat \<Rightarrow> nat \<Rightarrow> bool"})
- $ (build_func ((walk_func l []) |>> (fun_to_time ctxt origin) |>> Option.valOf ||> conv_args ctxt origin))
- $ (to_time ctxt origin is_rec r))
- | convert_term _ _ _ _ = error "Internal error: invalid term to convert"
+ $ (build_func (l' |>> (fun_to_time ctxt orig_used used origin) |>> Option.valOf ||> conv_args ctxt used origin))
+ $ (to_time ctxt origin is_rec orig_used used r))
+end
+ | convert_term _ _ _ _ _ = error "Internal error: invalid term to convert"
(* 4. Tactic to prove "f_dom n" *)
fun time_dom_tac ctxt induct_rule domintros =
@@ -410,7 +495,8 @@
#> fixPartTerms lthy term)
terms
- (* 2. Check if function is recursive *)
+ (* 2. Find properties about the function *)
+ (* 2.1 Check if function is recursive *)
val is_rec = is_rec lthy term terms
(* 3. Convert every equation
@@ -418,7 +504,16 @@
- On left side change name of function to timing function
- Convert right side of equation with conversion schema
*)
- val timing_terms = map (convert_term lthy term is_rec) terms
+ fun convert used = map (convert_term lthy term is_rec used)
+ fun repeat T_terms =
+ let
+ val orig_used = find_used lthy term terms T_terms
+ val T_terms' = convert orig_used terms
+ in
+ if T_terms' <> T_terms then repeat T_terms' else T_terms'
+ end
+ val T_terms = repeat (convert (K true) terms)
+ val orig_used = find_used lthy term terms T_terms
(* 4. Register function and prove termination *)
val names = map Term.term_name term
@@ -426,7 +521,7 @@
val bindings = map (fn nm => (Binding.name nm, NONE, NoSyn)) timing_names
fun pat_completeness_auto ctxt =
Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt
- val specs = map (fn eq => (((Binding.empty, []), eq), [], [])) timing_terms
+ val specs = map (fn eq => (((Binding.empty, []), eq), [], [])) T_terms
(* For partial functions sequential=true is needed in order to support them
We need sequential=false to support the automatic proof of termination over dom
@@ -444,14 +539,15 @@
val print_ctxt = lthy
|> Config.put show_question_marks false
|> Config.put show_sorts false (* Change it for debugging *)
- val print_ctxt = List.foldl (fn (term, ctxt) => Variable.add_fixes_implicit term ctxt) print_ctxt (term@timing_terms)
+ val print_ctxt = List.foldl (fn (term, ctxt) => Variable.add_fixes_implicit term ctxt) print_ctxt (term @ T_terms)
(* Print result if print *)
val _ = if not print then () else
let
val nms = map (fst o dest_Const) term
+ val used = map (used_for_const orig_used) term
val typs = map (snd o dest_Const) term
in
- print_timing' print_ctxt { names=nms, terms=terms, typs=typs } { names=timing_names, terms=timing_terms, typs=map change_typ typs }
+ print_timing' print_ctxt { names=nms, terms=terms, typs=typs } { names=timing_names, terms=T_terms, typs=map (fn (used, typ) => change_typ' used 0 typ) (zip used typs) }
end
(* Register function *)