--- a/src/HOL/Data_Structures/Define_Time_Function.ML Tue Jan 16 13:40:19 2024 +0000
+++ b/src/HOL/Data_Structures/Define_Time_Function.ML Thu Jan 18 14:30:27 2024 +0100
@@ -2,25 +2,25 @@
signature TIMING_FUNCTIONS =
sig
type 'a converter = {
- constc : local_theory -> term -> (term -> 'a) -> term -> 'a,
- funcc : local_theory -> term -> (term -> 'a) -> term -> term list -> 'a,
- ifc : local_theory -> term -> (term -> 'a) -> typ -> term -> term -> term -> 'a,
- casec : local_theory -> term -> (term -> 'a) -> term -> term list -> 'a,
- letc : local_theory -> term -> (term -> 'a) -> typ -> term -> string list -> typ list -> term -> 'a
+ constc : local_theory -> term list -> (term -> 'a) -> term -> 'a,
+ funcc : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a,
+ ifc : local_theory -> term list -> (term -> 'a) -> typ -> term -> term -> term -> 'a,
+ casec : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a,
+ letc : local_theory -> term list -> (term -> 'a) -> typ -> term -> string list -> typ list -> term -> 'a
};
-val walk : local_theory -> term -> 'a converter -> term -> 'a
+val walk : local_theory -> term list -> 'a converter -> term -> 'a
-type pfunc = { name : string, terms : term list, typ : typ }
+type pfunc = { names : string list, terms : term list, typs : typ list }
val fun_pretty': Proof.context -> pfunc -> Pretty.T
val fun_pretty: Proof.context -> Function.info -> Pretty.T
val print_timing': Proof.context -> pfunc -> pfunc -> unit
val print_timing: Proof.context -> Function.info -> Function.info -> unit
-val reg_and_proove_time_func: theory -> term -> term list -> term option converter
- -> (bool -> local_theory -> term -> term option -> term)
+val reg_and_proove_time_func: theory -> term list -> term list -> term option converter
+ -> (bool -> local_theory -> term list -> term option -> term)
-> bool -> Function.info * theory
-val reg_time_func: theory -> term -> term list -> term option converter
- -> (bool -> local_theory -> term -> term option -> term) -> bool -> theory
+val reg_time_func: theory -> term list -> term list -> term option converter
+ -> (bool -> local_theory -> term list -> term option -> term) -> bool -> theory
val time_dom_tac: Proof.context -> thm -> thm list -> int -> tactic
@@ -43,25 +43,29 @@
end;
type pfunc = {
- name : string,
+ names : string list,
terms : term list,
- typ : typ
+ typs : typ list
}
fun info_pfunc (info: Function.info): pfunc =
let
val {defname, fs, ...} = info;
val T = case hd fs of (Const (_,T)) => T | _ => error "Internal error: Invalid info to print"
in
- { name=Binding.name_of defname, terms=terms_of_info info, typ=T }
+ { names=[Binding.name_of defname], terms=terms_of_info info, typs=[T] }
end
(* Auxiliary functions for printing functions *)
fun fun_pretty' ctxt (pfunc: pfunc) =
let
- val {name, terms, typ} = pfunc;
- val header_beg = Pretty.str ("fun " ^ name ^ " :: ");
- val header_end = Pretty.str (" where\n ");
- val header = [header_beg, Pretty.quote (Syntax.pretty_typ ctxt typ), header_end];
+ val {names, terms, typs} = pfunc;
+ val header_beg = Pretty.str "fun ";
+ fun prepHeadCont (nm,T) = [Pretty.str (nm ^ " :: "), (Pretty.quote (Syntax.pretty_typ ctxt T))]
+ val header_content =
+ List.concat (prepHeadCont (hd names,hd typs) :: map ((fn l => Pretty.str "\nand " :: l) o prepHeadCont) (ListPair.zip (tl names, tl typs)));
+ val header_end = Pretty.str " where\n ";
+ val _ = List.map
+ val header = [header_beg] @ header_content @ [header_end];
fun separate sep prts =
flat (Library.separate [Pretty.str sep] (map single prts));
val ptrms = (separate "\n| " (map (Syntax.pretty_term ctxt) terms));
@@ -71,11 +75,11 @@
fun fun_pretty ctxt = fun_pretty' ctxt o info_pfunc
fun print_timing' ctxt (opfunc: pfunc) (tpfunc: pfunc) =
let
- val {name, ...} = opfunc;
+ val {names, ...} = opfunc;
val poriginal = Pretty.item [Pretty.str "Original function:\n", fun_pretty' ctxt opfunc]
val ptiming = Pretty.item [Pretty.str ("Running time function:\n"), fun_pretty' ctxt tpfunc]
in
- Pretty.writeln (Pretty.text_fold [Pretty.str ("Converting " ^ name ^ "...\n"), poriginal, Pretty.str "\n", ptiming])
+ Pretty.writeln (Pretty.text_fold [Pretty.str ("Converting " ^ (hd names) ^ (String.concat (map (fn nm => ", " ^ nm) (tl names))) ^ "\n"), poriginal, Pretty.str "\n", ptiming])
end
fun print_timing ctxt (oinfo: Function.info) (tinfo: Function.info) =
print_timing' ctxt (info_pfunc oinfo) (info_pfunc tinfo)
@@ -149,11 +153,11 @@
| time_term _ _ = error "Internal error: No valid function given"
type 'a converter = {
- constc : local_theory -> term -> (term -> 'a) -> term -> 'a,
- funcc : local_theory -> term -> (term -> 'a) -> term -> term list -> 'a,
- ifc : local_theory -> term -> (term -> 'a) -> typ -> term -> term -> term -> 'a,
- casec : local_theory -> term -> (term -> 'a) -> term -> term list -> 'a,
- letc : local_theory -> term -> (term -> 'a) -> typ -> term -> string list -> typ list -> term -> 'a
+ constc : local_theory -> term list -> (term -> 'a) -> term -> 'a,
+ funcc : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a,
+ ifc : local_theory -> term list -> (term -> 'a) -> typ -> term -> term -> term -> 'a,
+ casec : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a,
+ letc : local_theory -> term list -> (term -> 'a) -> typ -> term -> string list -> typ list -> term -> 'a
};
(* Walks over term and calls given converter *)
@@ -166,7 +170,7 @@
fun build_abs t (nm::nms) (T::Ts) = build_abs (Abs (nm,T,t)) nms Ts
| build_abs t [] [] = t
| build_abs _ _ _ = error "Internal error: Invalid terms to build abs"
-fun walk ctxt origin (conv as {ifc, casec, funcc, letc, ...} : 'a converter) (t as _ $ _) =
+fun walk ctxt (origin: term list) (conv as {ifc, casec, funcc, letc, ...} : 'a converter) (t as _ $ _) =
let
val (f, args) = walk_func t []
val this = (walk ctxt origin conv)
@@ -210,7 +214,7 @@
(check_args "cases" (Syntax.read_term ctxt n,args); build_func (t,fixCasecCases ctxt f T args))
| fixCasec _ _ _ _ _ = error "Internal error: invalid case term"
-fun fixPartTerms ctxt term t =
+fun fixPartTerms ctxt (term: term list) t =
let
val _ = check_args "args" (walk_func (get_l t) [])
in
@@ -234,13 +238,13 @@
(* 2. Check if function is recursive *)
fun or f (a,b) = f a orelse b
fun find_rec ctxt term = (walk ctxt term {
- funcc = (fn _ => fn _ => fn f => fn t => fn args => (Const_name t = Const_name term) orelse List.foldr (or f) false args),
+ funcc = (fn _ => fn _ => fn f => fn t => fn args => List.exists (fn term => Const_name t = Const_name term) term orelse List.foldr (or f) false args),
constc = (K o K o K o K) false,
ifc = (fn _ => fn _ => fn f => fn _ => fn cond => fn tt => fn tf => f cond orelse f tt orelse f tf),
casec = (fn _ => fn _ => fn f => fn t => fn cs => f t orelse List.foldr (or (rem_abs f)) false cs),
letc = (fn _ => fn _ => fn f => fn _ => fn exp => fn _ => fn _ => fn t => f exp orelse f t)
}) o get_r
-fun is_rec ctxt term = List.foldr (or (find_rec ctxt term)) false
+fun is_rec ctxt (term: term list) = List.foldr (or (find_rec ctxt term)) false
(* 3. Convert equations *)
(* Some Helper *)
@@ -253,13 +257,12 @@
| opt_term (SOME t) = t
(* Converting of function term *)
-fun fun_to_time ctxt origin (func as Const (nm,T)) =
+fun fun_to_time ctxt (origin: term list) (func as Const (nm,T)) =
let
+ val full_name_origin = map (fst o dest_Const) origin
val prefix = Config.get ctxt bprefix
- val timing_name_origin = prefix ^ Term.term_name origin
- val full_name_origin = origin |> dest_Const |> fst
in
- if nm = full_name_origin then SOME (Free (timing_name_origin, change_typ T)) else
+ if List.exists (fn nm_orig => nm = nm_orig) full_name_origin then SOME (Free (prefix ^ Term.term_name func, change_typ T)) else
if Zero_Funcs.is_zero (Proof_Context.theory_of ctxt) (nm,T) then NONE else
time_term ctxt func
end
@@ -289,7 +292,7 @@
$ f $ (Option.getOpt (fun_to_time ctxt origin f, build_zero T)))
| funcc_conv_arg _ _ t = t
fun funcc_conv_args ctxt origin = map (funcc_conv_arg ctxt origin)
-fun funcc ctxt origin f func args = List.foldr (I #-> plus)
+fun funcc ctxt (origin: term list) f func args = List.foldr (I #-> plus)
(case fun_to_time ctxt origin func of SOME t => SOME (build_func (t,funcc_conv_args ctxt origin args))
| NONE => NONE)
(map f args)
@@ -333,14 +336,17 @@
(SOME ((Const (If_name, @{typ "bool \<Rightarrow> nat \<Rightarrow> nat \<Rightarrow> nat"})) $ rcond $ (opt_term tt) $ (opt_term ft)))) (f cond)
end
-fun exp_type (Type (_, [T1, _])) = T1
- | exp_type _ = error "Internal errror: no valid let type"
+fun letc_change_typ (Type ("fun", [T1, Type ("fun", [T2, _])])) = (Type ("fun", [T1, Type ("fun", [change_typ T2, HOLogic.natT])]))
+ | letc_change_typ _ = error "Internal error: invalid let type"
fun letc _ _ f expT exp nms Ts t =
- if length nms = 0
- then plus (f exp) (case f (t $ Bound 0) of SOME t' => SOME (Abs ("x", exp_type expT, t') $ exp)
- | NONE => NONE)
- else plus (f exp) (case f t of SOME t' => SOME (build_abs t' nms Ts $ exp)
- | NONE => NONE)
+ plus (f exp)
+ (if length nms = 0 (* In case of "length nms = 0" a case expression is used to split up a type *)
+ (* Add (Bound 0) to receive a fully evaluated function, which can be handled by casec
+ Strip of (Bound 0) after conversion *)
+ then (case f (t $ Bound 0) of SOME (t' $ Bound 0) => SOME (Const (Let_name, letc_change_typ expT) $ exp $ t')
+ | _ => NONE)
+ else (case f t of SOME t' => SOME (Const (Let_name, letc_change_typ expT) $ exp $ build_abs t' nms Ts)
+ | NONE => NONE))
(* The converter for timing functions given to the walker *)
val converter : term option converter = {
@@ -357,7 +363,7 @@
top_converter ctxt origin (walk ctxt origin converter term)
(* Converts a term to its running time version *)
-fun convert_term ctxt origin conv topConv (pT $ (Const (eqN, _) $ l $ r)) =
+fun convert_term ctxt (origin: term list) conv topConv (pT $ (Const (eqN, _) $ l $ r)) =
pT
$ (Const (eqN, @{typ "nat \<Rightarrow> nat \<Rightarrow> bool"})
$ (build_func ((walk_func l []) |>> (fun_to_time ctxt origin) |>> Option.valOf ||> conv_args ctxt origin))
@@ -379,21 +385,17 @@
handle Empty => error "Function or terms of function not found"
(* Register timing function of a given function *)
-fun reg_and_proove_time_func (theory: theory) (term: term) (terms: term list) conv topConv print =
+fun reg_and_proove_time_func (theory: theory) (term: term list) (terms: term list) conv topConv print =
reg_time_func theory term terms conv topConv false
|> proove_termination term terms print
-and reg_time_func (theory: theory) (term: term) (terms: term list) conv topConv print =
+and reg_time_func (theory: theory) (term: term list) (terms: term list) conv topConv print =
let
val lthy = Named_Target.theory_init theory
val _ =
- case time_term lthy term
+ case time_term lthy (hd term)
handle (ERROR _) => NONE
- of SOME _ => error ("Timing function already declared: " ^ (Term.term_name term))
+ of SOME _ => error ("Timing function already declared: " ^ (Term.term_name (hd term)))
| NONE => ()
-
-
- val info = SOME (Function.get_info lthy term) handle Empty => NONE
- val is_partial = case info of SOME {is_partial, ...} => is_partial | _ => false
(* 1. Fix all terms *)
(* Exchange Var in types and terms to Free and check constraints *)
@@ -410,9 +412,9 @@
val timing_terms = map (convert_term lthy term conv (topConv is_rec)) terms
(* 4. Register function and prove termination *)
- val name = Term.term_name term
- val timing_name = fun_name_to_time lthy name
- val bindings = [(Binding.name timing_name, NONE, NoSyn)]
+ val names = map Term.term_name term
+ val timing_names = map (fun_name_to_time lthy) names
+ val bindings = map (fn nm => (Binding.name nm, NONE, NoSyn)) timing_names
fun pat_completeness_auto ctxt =
Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt
val specs = map (fn eq => (((Binding.empty, []), eq), [], [])) timing_terms
@@ -424,7 +426,7 @@
let
val _ = (if seq then warning "Falling back on sequential function..." else ())
val fun_config = Function_Common.FunctionConfig
- {sequential=seq, default=NONE, domintros=true, partials=is_partial}
+ {sequential=seq, default=NONE, domintros=true, partials=false}
in
Function.add_function bindings specs fun_config pat_completeness_auto lthy
end
@@ -436,9 +438,10 @@
(* Print result if print *)
val _ = if not print then () else
let
- val (nm,T) = case term of Const t => t | _ => error "Internal error: invalid term to print"
+ val nms = map (fst o dest_Const) term
+ val typs = map (snd o dest_Const) term
in
- print_timing' print_ctxt { name=nm, terms=terms, typ=T } { name=timing_name, terms=timing_terms, typ=change_typ T }
+ print_timing' print_ctxt { names=nms, terms=terms, typs=typs } { names=timing_names, terms=timing_terms, typs=map change_typ typs }
end
(* Register function *)
@@ -451,15 +454,13 @@
in
Local_Theory.exit_global lthy
end
-and proove_termination term terms print (theory: theory) =
+and proove_termination (term: term list) terms print (theory: theory) =
let
val lthy = Named_Target.theory_init theory
(* Start proving the termination *)
- val info = SOME (Function.get_info lthy term) handle Empty => NONE
- val timing_name = term
- |> Term.term_name
- |> fun_name_to_time lthy
+ val infos = SOME (map (Function.get_info lthy) term) handle Empty => NONE
+ val timing_names = map (fun_name_to_time lthy o Term.term_name) term
(* Proof by lexicographic_order_tac *)
val (time_info, lthy') =
@@ -468,6 +469,8 @@
handle (ERROR _) =>
let
val _ = warning "Falling back on proof over dom..."
+ val _ = (if length term > 1 then error "Proof over dom not supported for mutual recursive functions" else ())
+
fun args (a$(Var ((nm,_),T))) = args a |> (fn ar => (nm,T)::ar)
| args (a$(Const (_,T))) = args a |> (fn ar => ("x",T)::ar)
| args _ = []
@@ -476,11 +479,11 @@
|> Variable.variant_frees lthy []
|> map fst
- val {inducts, ...} = case info of SOME i => i | _ => error "Proof over dom failed as no induct rule was found"
+ val {inducts, ...} = case infos of SOME [i] => i | _ => error "Proof over dom failed as no induct rule was found"
val induct = (Option.valOf inducts |> hd)
- val domintros = Proof_Context.get_fact lthy (Facts.named (timing_name ^ ".domintros"))
- val prop = (timing_name ^ "_dom (" ^ (String.concatWith "," dom_args) ^ ")")
+ val domintros = Proof_Context.get_fact lthy (Facts.named (hd timing_names ^ ".domintros"))
+ val prop = (hd timing_names ^ "_dom (" ^ (String.concatWith "," dom_args) ^ ")")
|> Syntax.read_prop lthy
(* Prove a helper lemma *)
@@ -501,39 +504,40 @@
(* Print result if print *)
val _ = if not print then () else
let
- val (nm,T) = case term of Const t => t | _ => error "Internal error: invalid term to print"
+ val nms = map (fst o dest_Const) term
+ val typs = map (snd o dest_Const) term
in
- print_timing' print_ctxt { name=nm, terms=terms, typ=T } (info_pfunc time_info)
+ print_timing' print_ctxt { names=nms, terms=terms, typs=typs } (info_pfunc time_info)
end
in
(time_info, Local_Theory.exit_global lthy')
end
(* Convert function into its timing function (called by command) *)
-fun reg_time_fun_cmd (func, thms) conv topConv (theory: theory) =
+fun reg_time_fun_cmd (funcs, thms) conv topConv (theory: theory) =
let
val ctxt = Proof_Context.init_global theory
- val fterm = Syntax.read_term ctxt func
- val (_, lthy') = reg_and_proove_time_func theory fterm
- (case thms of NONE => get_terms theory fterm
+ val fterms = map (Syntax.read_term ctxt) funcs
+ val (_, lthy') = reg_and_proove_time_func theory fterms
+ (case thms of NONE => get_terms theory (hd fterms)
| SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of)
conv topConv true
in lthy'
end
(* Convert function into its timing function (called by command) with termination proof provided by user*)
-fun reg_time_function_cmd (func, thms) conv topConv (theory: theory) =
+fun reg_time_function_cmd (funcs, thms) conv topConv (theory: theory) =
let
val ctxt = Proof_Context.init_global theory
- val fterm = Syntax.read_term ctxt func
- val theory = reg_time_func theory fterm
- (case thms of NONE => get_terms theory fterm
+ val fterms = map (Syntax.read_term ctxt) funcs
+ val theory = reg_time_func theory fterms
+ (case thms of NONE => get_terms theory (hd fterms)
| SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of)
conv topConv true
in theory
end
-val parser = Parse.prop -- (Scan.option (Parse.keyword_improper "equations" -- Parse.thms1 >> snd))
+val parser = (Scan.repeat1 Parse.prop) -- (Scan.option (Parse.keyword_improper "equations" -- Parse.thms1 >> snd))
val _ = Outer_Syntax.command @{command_keyword "define_time_fun"}
"Defines runtime function of a function"