translation to time functions now with canonical let.
authornipkow
Thu, 18 Jan 2024 14:30:27 +0100
changeset 79494 c7536609bb9b
parent 79493 d1188818634d
child 79495 8a2511062609
translation to time functions now with canonical let.
src/HOL/Data_Structures/Define_Time_Function.ML
--- a/src/HOL/Data_Structures/Define_Time_Function.ML	Tue Jan 16 13:40:19 2024 +0000
+++ b/src/HOL/Data_Structures/Define_Time_Function.ML	Thu Jan 18 14:30:27 2024 +0100
@@ -2,25 +2,25 @@
 signature TIMING_FUNCTIONS =
 sig
 type 'a converter = {
-  constc : local_theory -> term -> (term -> 'a) -> term -> 'a,
-  funcc : local_theory -> term -> (term -> 'a) -> term -> term list -> 'a,
-  ifc : local_theory -> term -> (term -> 'a) -> typ -> term -> term -> term -> 'a,
-  casec : local_theory -> term -> (term -> 'a) -> term -> term list -> 'a,
-  letc : local_theory -> term -> (term -> 'a) -> typ -> term -> string list -> typ list -> term -> 'a
+  constc : local_theory -> term list -> (term -> 'a) -> term -> 'a,
+  funcc : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a,
+  ifc : local_theory -> term list -> (term -> 'a) -> typ -> term -> term -> term -> 'a,
+  casec : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a,
+  letc : local_theory -> term list -> (term -> 'a) -> typ -> term -> string list -> typ list -> term -> 'a
 };
-val walk : local_theory -> term -> 'a converter -> term -> 'a
+val walk : local_theory -> term list -> 'a converter -> term -> 'a
 
-type pfunc = { name : string, terms : term list, typ : typ }
+type pfunc = { names : string list, terms : term list, typs : typ list }
 val fun_pretty':  Proof.context -> pfunc -> Pretty.T
 val fun_pretty:  Proof.context -> Function.info -> Pretty.T
 val print_timing':  Proof.context -> pfunc -> pfunc -> unit
 val print_timing:  Proof.context -> Function.info -> Function.info -> unit
 
-val reg_and_proove_time_func: theory -> term -> term list -> term option converter
-      -> (bool -> local_theory -> term -> term option -> term)
+val reg_and_proove_time_func: theory -> term list -> term list -> term option converter
+      -> (bool -> local_theory -> term list -> term option -> term)
       -> bool -> Function.info * theory
-val reg_time_func: theory -> term -> term list -> term option converter
-      -> (bool -> local_theory -> term -> term option -> term) -> bool -> theory
+val reg_time_func: theory -> term list -> term list -> term option converter
+      -> (bool -> local_theory -> term list -> term option -> term) -> bool -> theory
 
 val time_dom_tac: Proof.context -> thm -> thm list -> int -> tactic
 
@@ -43,25 +43,29 @@
 end;
 
 type pfunc = {
-  name : string,
+  names : string list,
   terms : term list,
-  typ : typ
+  typs : typ list
 }
 fun info_pfunc (info: Function.info): pfunc =
 let
   val {defname, fs, ...} = info;
   val T = case hd fs of (Const (_,T)) => T | _ => error "Internal error: Invalid info to print"
 in
-  { name=Binding.name_of defname, terms=terms_of_info info, typ=T }
+  { names=[Binding.name_of defname], terms=terms_of_info info, typs=[T] }
 end
 
 (* Auxiliary functions for printing functions *)
 fun fun_pretty' ctxt (pfunc: pfunc) =
 let
-  val {name, terms, typ} = pfunc;
-  val header_beg = Pretty.str ("fun " ^ name ^ " :: ");
-  val header_end = Pretty.str (" where\n  ");
-  val header = [header_beg, Pretty.quote (Syntax.pretty_typ ctxt typ), header_end];
+  val {names, terms, typs} = pfunc;
+  val header_beg = Pretty.str "fun ";
+  fun prepHeadCont (nm,T) = [Pretty.str (nm ^ " :: "), (Pretty.quote (Syntax.pretty_typ ctxt T))]
+  val header_content =
+     List.concat (prepHeadCont (hd names,hd typs) :: map ((fn l => Pretty.str "\nand " :: l) o prepHeadCont) (ListPair.zip (tl names, tl typs)));
+  val header_end = Pretty.str " where\n  ";
+  val _ = List.map
+  val header = [header_beg] @ header_content @ [header_end];
   fun separate sep prts =
     flat (Library.separate [Pretty.str sep] (map single prts));
   val ptrms = (separate "\n| " (map (Syntax.pretty_term ctxt) terms));
@@ -71,11 +75,11 @@
 fun fun_pretty ctxt = fun_pretty' ctxt o info_pfunc
 fun print_timing' ctxt (opfunc: pfunc) (tpfunc: pfunc) =
 let
-  val {name, ...} = opfunc;
+  val {names, ...} = opfunc;
   val poriginal = Pretty.item [Pretty.str "Original function:\n", fun_pretty' ctxt opfunc]
   val ptiming = Pretty.item [Pretty.str ("Running time function:\n"), fun_pretty' ctxt tpfunc]
 in
-  Pretty.writeln (Pretty.text_fold [Pretty.str ("Converting " ^ name ^ "...\n"), poriginal, Pretty.str "\n", ptiming])
+  Pretty.writeln (Pretty.text_fold [Pretty.str ("Converting " ^ (hd names) ^ (String.concat (map (fn nm => ", " ^ nm) (tl names))) ^ "\n"), poriginal, Pretty.str "\n", ptiming])
 end
 fun print_timing ctxt (oinfo: Function.info) (tinfo: Function.info) =
   print_timing' ctxt (info_pfunc oinfo) (info_pfunc tinfo)
@@ -149,11 +153,11 @@
   | time_term _ _ = error "Internal error: No valid function given"
 
 type 'a converter = {
-  constc : local_theory -> term -> (term -> 'a) -> term -> 'a,
-  funcc : local_theory -> term -> (term -> 'a) -> term -> term list -> 'a,
-  ifc : local_theory -> term -> (term -> 'a) -> typ -> term -> term -> term -> 'a,
-  casec : local_theory -> term -> (term -> 'a) -> term -> term list -> 'a,
-  letc : local_theory -> term -> (term -> 'a) -> typ -> term -> string list -> typ list -> term -> 'a
+  constc : local_theory -> term list -> (term -> 'a) -> term -> 'a,
+  funcc : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a,
+  ifc : local_theory -> term list -> (term -> 'a) -> typ -> term -> term -> term -> 'a,
+  casec : local_theory -> term list -> (term -> 'a) -> term -> term list -> 'a,
+  letc : local_theory -> term list -> (term -> 'a) -> typ -> term -> string list -> typ list -> term -> 'a
 };
 
 (* Walks over term and calls given converter *)
@@ -166,7 +170,7 @@
 fun build_abs t (nm::nms) (T::Ts) = build_abs (Abs (nm,T,t)) nms Ts
   | build_abs t [] [] = t
   | build_abs _ _ _ = error "Internal error: Invalid terms to build abs"
-fun walk ctxt origin (conv as {ifc, casec, funcc, letc, ...} : 'a converter) (t as _ $ _) =
+fun walk ctxt (origin: term list) (conv as {ifc, casec, funcc, letc, ...} : 'a converter) (t as _ $ _) =
   let
     val (f, args) = walk_func t []
     val this = (walk ctxt origin conv)
@@ -210,7 +214,7 @@
       (check_args "cases" (Syntax.read_term ctxt n,args); build_func (t,fixCasecCases ctxt f T args))
   | fixCasec _ _ _ _ _ = error "Internal error: invalid case term"
 
-fun fixPartTerms ctxt term t =
+fun fixPartTerms ctxt (term: term list) t =
   let
     val _ = check_args "args" (walk_func (get_l t) [])
   in
@@ -234,13 +238,13 @@
 (* 2. Check if function is recursive *)
 fun or f (a,b) = f a orelse b
 fun find_rec ctxt term = (walk ctxt term {
-          funcc = (fn _ => fn _ => fn f => fn t => fn args => (Const_name t = Const_name term) orelse List.foldr (or f) false args),
+          funcc = (fn _ => fn _ => fn f => fn t => fn args => List.exists (fn term => Const_name t = Const_name term) term orelse List.foldr (or f) false args),
           constc = (K o K o K o K) false,
           ifc = (fn _ => fn _ => fn f => fn _ => fn cond => fn tt => fn tf => f cond orelse f tt orelse f tf),
           casec = (fn _ => fn _ => fn f => fn t => fn cs => f t orelse List.foldr (or (rem_abs f)) false cs),
           letc = (fn _ => fn _ => fn f => fn _ => fn exp => fn _ => fn _ => fn t => f exp orelse f t)
       }) o get_r
-fun is_rec ctxt term = List.foldr (or (find_rec ctxt term)) false
+fun is_rec ctxt (term: term list) = List.foldr (or (find_rec ctxt term)) false
 
 (* 3. Convert equations *)
 (* Some Helper *)
@@ -253,13 +257,12 @@
   | opt_term (SOME t) = t
 
 (* Converting of function term *)
-fun fun_to_time ctxt origin (func as Const (nm,T)) =
+fun fun_to_time ctxt (origin: term list) (func as Const (nm,T)) =
 let
+  val full_name_origin = map (fst o dest_Const) origin
   val prefix = Config.get ctxt bprefix
-  val timing_name_origin = prefix ^ Term.term_name origin
-  val full_name_origin = origin |> dest_Const |> fst
 in
-  if nm = full_name_origin then SOME (Free (timing_name_origin, change_typ T)) else
+  if List.exists (fn nm_orig => nm = nm_orig) full_name_origin then SOME (Free (prefix ^ Term.term_name func, change_typ T)) else
   if Zero_Funcs.is_zero (Proof_Context.theory_of ctxt) (nm,T) then NONE else
     time_term ctxt func
 end
@@ -289,7 +292,7 @@
     $ f $ (Option.getOpt (fun_to_time ctxt origin f, build_zero T)))
   | funcc_conv_arg _ _ t = t
 fun funcc_conv_args ctxt origin = map (funcc_conv_arg ctxt origin)
-fun funcc ctxt origin f func args = List.foldr (I #-> plus)
+fun funcc ctxt (origin: term list) f func args = List.foldr (I #-> plus)
   (case fun_to_time ctxt origin func of SOME t => SOME (build_func (t,funcc_conv_args ctxt origin args))
                                       | NONE => NONE)
   (map f args)
@@ -333,14 +336,17 @@
        (SOME ((Const (If_name, @{typ "bool \<Rightarrow> nat \<Rightarrow> nat \<Rightarrow> nat"})) $ rcond $ (opt_term tt) $ (opt_term ft)))) (f cond)
   end
 
-fun exp_type (Type (_, [T1, _])) = T1
-  | exp_type _ = error "Internal errror: no valid let type"
+fun letc_change_typ (Type ("fun", [T1, Type ("fun", [T2, _])])) = (Type ("fun", [T1, Type ("fun", [change_typ T2, HOLogic.natT])]))
+  | letc_change_typ _ = error "Internal error: invalid let type"
 fun letc _ _ f expT exp nms Ts t =
-    if length nms = 0
-    then plus (f exp) (case f (t $ Bound 0) of SOME t' => SOME (Abs ("x", exp_type expT, t') $ exp)
-                                             | NONE => NONE)
-    else plus (f exp) (case f t of SOME t' => SOME (build_abs t' nms Ts $ exp)
-                                 | NONE => NONE)
+    plus (f exp)
+    (if length nms = 0 (* In case of "length nms = 0" a case expression is used to split up a type *)
+    (* Add (Bound 0) to receive a fully evaluated function, which can be handled by casec
+       Strip of (Bound 0) after conversion *)
+    then (case f (t $ Bound 0) of SOME (t' $ Bound 0) => SOME (Const (Let_name, letc_change_typ expT) $ exp $ t')
+                                | _ => NONE)
+    else (case f t of SOME t' => SOME (Const (Let_name, letc_change_typ expT) $ exp $ build_abs t' nms Ts)
+                               | NONE => NONE))
 
 (* The converter for timing functions given to the walker *)
 val converter : term option converter = {
@@ -357,7 +363,7 @@
   top_converter ctxt origin (walk ctxt origin converter term)
 
 (* Converts a term to its running time version *)
-fun convert_term ctxt origin conv topConv (pT $ (Const (eqN, _) $ l $ r)) =
+fun convert_term ctxt (origin: term list) conv topConv (pT $ (Const (eqN, _) $ l $ r)) =
       pT
       $ (Const (eqN, @{typ "nat \<Rightarrow> nat \<Rightarrow> bool"})
         $ (build_func ((walk_func l []) |>> (fun_to_time ctxt origin) |>> Option.valOf ||> conv_args ctxt origin))
@@ -379,21 +385,17 @@
    handle Empty => error "Function or terms of function not found"
 
 (* Register timing function of a given function *)
-fun reg_and_proove_time_func (theory: theory) (term: term) (terms: term list) conv topConv print =
+fun reg_and_proove_time_func (theory: theory) (term: term list) (terms: term list) conv topConv print =
   reg_time_func theory term terms conv topConv false
   |> proove_termination term terms print
-and reg_time_func (theory: theory) (term: term) (terms: term list) conv topConv print =
+and reg_time_func (theory: theory) (term: term list) (terms: term list) conv topConv print =
   let
     val lthy = Named_Target.theory_init theory
     val _ =
-      case time_term lthy term
+      case time_term lthy (hd term)
             handle (ERROR _) => NONE
-        of SOME _ => error ("Timing function already declared: " ^ (Term.term_name term))
+        of SOME _ => error ("Timing function already declared: " ^ (Term.term_name (hd term)))
          | NONE => ()
-      
-
-    val info = SOME (Function.get_info lthy term) handle Empty => NONE
-    val is_partial = case info of SOME {is_partial, ...} => is_partial | _ => false
 
     (* 1. Fix all terms *)
     (* Exchange Var in types and terms to Free and check constraints *)
@@ -410,9 +412,9 @@
     val timing_terms = map (convert_term lthy term conv (topConv is_rec)) terms
 
     (* 4. Register function and prove termination *)
-    val name = Term.term_name term
-    val timing_name = fun_name_to_time lthy name
-    val bindings = [(Binding.name timing_name, NONE, NoSyn)]
+    val names = map Term.term_name term
+    val timing_names = map (fun_name_to_time lthy) names
+    val bindings = map (fn nm => (Binding.name nm, NONE, NoSyn)) timing_names
     fun pat_completeness_auto ctxt =
       Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt
     val specs = map (fn eq => (((Binding.empty, []), eq), [], [])) timing_terms
@@ -424,7 +426,7 @@
       let
         val _ = (if seq then warning "Falling back on sequential function..." else ())
         val fun_config = Function_Common.FunctionConfig
-          {sequential=seq, default=NONE, domintros=true, partials=is_partial}
+          {sequential=seq, default=NONE, domintros=true, partials=false}
       in
         Function.add_function bindings specs fun_config pat_completeness_auto lthy
       end
@@ -436,9 +438,10 @@
     (* Print result if print *)
     val _ = if not print then () else
         let
-          val (nm,T) = case term of Const t => t | _ => error "Internal error: invalid term to print"
+          val nms = map (fst o dest_Const) term
+          val typs = map (snd o dest_Const) term
         in
-          print_timing' print_ctxt { name=nm, terms=terms, typ=T } { name=timing_name, terms=timing_terms, typ=change_typ T }
+          print_timing' print_ctxt { names=nms, terms=terms, typs=typs } { names=timing_names, terms=timing_terms, typs=map change_typ typs }
         end
 
     (* Register function *)
@@ -451,15 +454,13 @@
   in
     Local_Theory.exit_global lthy
   end
-and proove_termination term terms print (theory: theory) =
+and proove_termination (term: term list) terms print (theory: theory) =
   let
     val lthy = Named_Target.theory_init theory
 
     (* Start proving the termination *)  
-    val info = SOME (Function.get_info lthy term) handle Empty => NONE
-    val timing_name = term
-      |> Term.term_name
-      |> fun_name_to_time lthy
+    val infos = SOME (map (Function.get_info lthy) term) handle Empty => NONE
+    val timing_names = map (fun_name_to_time lthy o Term.term_name) term
 
     (* Proof by lexicographic_order_tac *)
     val (time_info, lthy') =
@@ -468,6 +469,8 @@
         handle (ERROR _) =>
         let
           val _ = warning "Falling back on proof over dom..."
+          val _ = (if length term > 1 then error "Proof over dom not supported for mutual recursive functions" else ())
+
           fun args (a$(Var ((nm,_),T))) = args a |> (fn ar => (nm,T)::ar)
             | args (a$(Const (_,T))) = args a |> (fn ar => ("x",T)::ar)
             | args _ = []
@@ -476,11 +479,11 @@
             |> Variable.variant_frees lthy []
             |> map fst
 
-          val {inducts, ...} = case info of SOME i => i | _ => error "Proof over dom failed as no induct rule was found"
+          val {inducts, ...} = case infos of SOME [i] => i | _ => error "Proof over dom failed as no induct rule was found"
           val induct = (Option.valOf inducts |> hd)
 
-          val domintros = Proof_Context.get_fact lthy (Facts.named (timing_name ^ ".domintros"))
-          val prop = (timing_name ^ "_dom (" ^ (String.concatWith "," dom_args) ^ ")")
+          val domintros = Proof_Context.get_fact lthy (Facts.named (hd timing_names ^ ".domintros"))
+          val prop = (hd timing_names ^ "_dom (" ^ (String.concatWith "," dom_args) ^ ")")
                       |> Syntax.read_prop lthy
 
           (* Prove a helper lemma *)
@@ -501,39 +504,40 @@
     (* Print result if print *)
     val _ = if not print then () else
         let
-          val (nm,T) = case term of Const t => t | _ => error "Internal error: invalid term to print"
+          val nms = map (fst o dest_Const) term
+          val typs = map (snd o dest_Const) term
         in
-          print_timing' print_ctxt { name=nm, terms=terms, typ=T } (info_pfunc time_info)
+          print_timing' print_ctxt { names=nms, terms=terms, typs=typs } (info_pfunc time_info)
         end
   in
     (time_info, Local_Theory.exit_global lthy')
   end
 
 (* Convert function into its timing function (called by command) *)
-fun reg_time_fun_cmd (func, thms) conv topConv (theory: theory) =
+fun reg_time_fun_cmd (funcs, thms) conv topConv (theory: theory) =
 let
   val ctxt = Proof_Context.init_global theory
-  val fterm = Syntax.read_term ctxt func
-  val (_, lthy') = reg_and_proove_time_func theory fterm
-    (case thms of NONE => get_terms theory fterm
+  val fterms = map (Syntax.read_term ctxt) funcs
+  val (_, lthy') = reg_and_proove_time_func theory fterms
+    (case thms of NONE => get_terms theory (hd fterms)
                 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of)
     conv topConv true
 in lthy'
 end
 
 (* Convert function into its timing function (called by command) with termination proof provided by user*)
-fun reg_time_function_cmd (func, thms) conv topConv (theory: theory) =
+fun reg_time_function_cmd (funcs, thms) conv topConv (theory: theory) =
 let
   val ctxt = Proof_Context.init_global theory
-  val fterm = Syntax.read_term ctxt func
-  val theory = reg_time_func theory fterm
-    (case thms of NONE => get_terms theory fterm
+  val fterms = map (Syntax.read_term ctxt) funcs
+  val theory = reg_time_func theory fterms
+    (case thms of NONE => get_terms theory (hd fterms)
                 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of)
     conv topConv true
 in theory
 end
 
-val parser = Parse.prop -- (Scan.option (Parse.keyword_improper "equations" -- Parse.thms1 >> snd))
+val parser = (Scan.repeat1 Parse.prop) -- (Scan.option (Parse.keyword_improper "equations" -- Parse.thms1 >> snd))
 
 val _ = Outer_Syntax.command @{command_keyword "define_time_fun"}
   "Defines runtime function of a function"