src/HOL/Data_Structures/Define_Time_Function.ML
changeset 81147 503e5280ba72
parent 80734 7054a1bc8347
child 81255 47530e9a7c33
equal deleted inserted replaced
81146:87f173836d56 81147:503e5280ba72
     9 type 'a converter = {
     9 type 'a converter = {
    10   constc : 'a wctxt -> term -> 'a,
    10   constc : 'a wctxt -> term -> 'a,
    11   funcc : 'a wctxt -> term -> term list -> 'a,
    11   funcc : 'a wctxt -> term -> term list -> 'a,
    12   ifc : 'a wctxt -> typ -> term -> term -> term -> 'a,
    12   ifc : 'a wctxt -> typ -> term -> term -> term -> 'a,
    13   casec : 'a wctxt -> term -> term list -> 'a,
    13   casec : 'a wctxt -> term -> term list -> 'a,
    14   letc : 'a wctxt -> typ -> term -> string list -> typ list -> term -> 'a
    14   letc : 'a wctxt -> typ -> term -> (string * typ) list -> term -> 'a
    15 }
    15 }
    16 val walk : local_theory -> term list -> 'a converter -> term -> 'a
    16 val walk : local_theory -> term list -> 'a converter -> term -> 'a
       
    17 val Iconst : term wctxt -> term -> term
       
    18 val Ifunc : term wctxt -> term -> term list -> term
       
    19 val Iif : term wctxt -> typ -> term -> term -> term -> term
       
    20 val Icase : term wctxt -> term -> term list -> term
       
    21 val Ilet : term wctxt -> typ -> term -> (string * typ) list -> term -> term
    17 
    22 
    18 type pfunc = { names : string list, terms : term list, typs : typ list }
    23 type pfunc = { names : string list, terms : term list, typs : typ list }
    19 val fun_pretty':  Proof.context -> pfunc -> Pretty.T
    24 val fun_pretty':  Proof.context -> pfunc -> Pretty.T
    20 val fun_pretty:  Proof.context -> Function.info -> Pretty.T
    25 val fun_pretty:  Proof.context -> Function.info -> Pretty.T
    21 val print_timing':  Proof.context -> pfunc -> pfunc -> unit
    26 val print_timing':  Proof.context -> pfunc -> pfunc -> unit
    22 val print_timing:  Proof.context -> Function.info -> Function.info -> unit
    27 val print_timing:  Proof.context -> Function.info -> Function.info -> unit
    23 
    28 
    24 val reg_and_proove_time_func: local_theory -> term list -> term list
    29 val reg_and_proove_time_func: local_theory -> term list -> term list
    25       -> bool -> Function.info * local_theory
    30       -> bool -> bool -> Function.info * local_theory
    26 val reg_time_func: local_theory -> term list -> term list
    31 val reg_time_func: local_theory -> term list -> term list
    27       -> bool -> Function.info * local_theory
    32       -> bool -> bool -> Function.info * local_theory
    28 
    33 
    29 val time_dom_tac: Proof.context -> thm -> thm list -> int -> tactic
    34 val time_dom_tac: Proof.context -> thm -> thm list -> int -> tactic
    30 
    35 
    31 end
    36 end
    32 
    37 
    33 structure Timing_Functions : TIMING_FUNCTIONS =
    38 structure Timing_Functions : TIMING_FUNCTIONS =
    34 struct
    39 struct
    35 (* Configure config variable to adjust the prefix *)
    40 (* Configure config variable to adjust the prefix *)
    36 val bprefix = Attrib.setup_config_string @{binding "time_prefix"} (K "T_")
    41 val bprefix = Attrib.setup_config_string @{binding "time_prefix"} (K "T_")
       
    42 val bprefix_snd = Attrib.setup_config_string @{binding "time_prefix_snd"} (K "T2_")
    37 (* Configure config variable to adjust the suffix *)
    43 (* Configure config variable to adjust the suffix *)
    38 val bsuffix = Attrib.setup_config_string @{binding "time_suffix"} (K "")
    44 val bsuffix = Attrib.setup_config_string @{binding "time_suffix"} (K "")
    39 
    45 
    40 (* some default values to build terms easier *)
    46 (* some default values to build terms easier *)
    41 val zero = Const (@{const_name "Groups.zero"}, HOLogic.natT)
    47 val zero = Const (@{const_name "Groups.zero"}, HOLogic.natT)
    81 let
    87 let
    82   val {names, ...} = opfunc;
    88   val {names, ...} = opfunc;
    83   val poriginal = Pretty.item [Pretty.str "Original function:\n", fun_pretty' ctxt opfunc]
    89   val poriginal = Pretty.item [Pretty.str "Original function:\n", fun_pretty' ctxt opfunc]
    84   val ptiming = Pretty.item [Pretty.str ("Running time function:\n"), fun_pretty' ctxt tpfunc]
    90   val ptiming = Pretty.item [Pretty.str ("Running time function:\n"), fun_pretty' ctxt tpfunc]
    85 in
    91 in
    86   Pretty.writeln (Pretty.text_fold [Pretty.str ("Converting " ^ (hd names) ^ (String.concat (map (fn nm => ", " ^ nm) (tl names))) ^ "\n"), poriginal, Pretty.str "\n", ptiming])
    92   Pretty.writeln (Pretty.text_fold [
       
    93       Pretty.str ("Converting " ^ (hd names) ^ (String.concat (map (fn nm => ", " ^ nm) (tl names))) ^ "\n"),
       
    94       poriginal, Pretty.str "\n", ptiming])
    87 end
    95 end
    88 fun print_timing ctxt (oinfo: Function.info) (tinfo: Function.info) =
    96 fun print_timing ctxt (oinfo: Function.info) (tinfo: Function.info) =
    89   print_timing' ctxt (info_pfunc oinfo) (info_pfunc tinfo)
    97   print_timing' ctxt (info_pfunc oinfo) (info_pfunc tinfo)
    90 
    98 
       
    99 fun print_lemma ctxt defs (T_terms: term list) =
       
   100 let
       
   101   val names =
       
   102     defs
       
   103     |> map snd
       
   104     |> map (fn s => "_" ^ s)
       
   105     |> List.foldr (op ^) ""
       
   106   val begin = "lemma T" ^ names ^ "_simps [simp,code]:\n"
       
   107   fun convLine T_term =
       
   108     "  \"" ^ Syntax.string_of_term ctxt T_term ^ "\"\n"
       
   109   val lines = map convLine T_terms
       
   110   fun convDefs def = " " ^ (fst def)
       
   111   val proof = "  by (simp_all add:" :: (map convDefs defs) @ [")"]
       
   112   val _ = Pretty.writeln (Pretty.str "Characteristic recursion equations can be derived:")
       
   113 in
       
   114   (begin :: lines @ proof)
       
   115   |> String.concat
       
   116   (* |> Active.sendback_markup_properties [Markup.padding_fun] *)
       
   117   |> Pretty.str
       
   118   |> Pretty.writeln
       
   119 end
       
   120 
    91 fun contains l e = exists (fn e' => e' = e) l
   121 fun contains l e = exists (fn e' => e' = e) l
    92 fun contains' comp l e = exists (comp e) l
   122 fun contains' comp l e = exists (comp e) l
    93 fun index [] _ = 0
       
    94   | index (x::xs) el = (if x = el then 0 else 1 + index xs el)
       
    95 fun used_for_const orig_used t i = orig_used (t,i)
       
    96 (* Split name by . *)
   123 (* Split name by . *)
    97 val split_name = String.fields (fn s => s = #".")
   124 val split_name = String.fields (fn s => s = #".")
    98 
   125 
    99 (* returns true if it's an if term *)
   126 (* returns true if it's an if term *)
   100 fun is_if (Const (@{const_name "HOL.If"},_)) = true
   127 fun is_if (Const (@{const_name "HOL.If"},_)) = true
   109     and replace all function arguments f with (t*T_f) if used *)
   136     and replace all function arguments f with (t*T_f) if used *)
   110 fun change_typ' used (Type ("fun", [T1, T2])) = 
   137 fun change_typ' used (Type ("fun", [T1, T2])) = 
   111       Type ("fun", [check_for_fun' (used 0) T1, change_typ' (fn i => used (i+1)) T2])
   138       Type ("fun", [check_for_fun' (used 0) T1, change_typ' (fn i => used (i+1)) T2])
   112   | change_typ' _ _ = HOLogic.natT
   139   | change_typ' _ _ = HOLogic.natT
   113 and check_for_fun' true (f as Type ("fun", [_,_])) = HOLogic.mk_prodT (f, change_typ' (K false) f)
   140 and check_for_fun' true (f as Type ("fun", [_,_])) = HOLogic.mk_prodT (f, change_typ' (K false) f)
   114   | check_for_fun' false (f as Type ("fun", [_,_])) = change_typ' (K false) f
   141   | check_for_fun' false (f as Type ("fun", [_,_])) = change_typ' (K true) f
   115   | check_for_fun' _ t = t
   142   | check_for_fun' _ t = t
   116 val change_typ = change_typ' (K false)
   143 val change_typ = change_typ' (K true)
   117 (* Convert string name of function to its timing equivalent *)
   144 (* Convert string name of function to its timing equivalent *)
   118 fun fun_name_to_time ctxt s name =
   145 fun fun_name_to_time' ctxt s second name =
   119 let
   146 let
   120   val prefix = Config.get ctxt bprefix
   147   val prefix = Config.get ctxt (if second then bprefix_snd else bprefix)
   121   val suffix = (if s then Config.get ctxt bsuffix else "")
   148   val suffix = (if s then Config.get ctxt bsuffix else "")
   122   fun replace_last_name [n] = [prefix ^ n ^ suffix]
   149   fun replace_last_name [n] = [prefix ^ n ^ suffix]
   123     | replace_last_name (n::ns) = n :: (replace_last_name ns)
   150     | replace_last_name (n::ns) = n :: (replace_last_name ns)
   124     | replace_last_name _ = error "Internal error: Invalid function name to convert"
   151     | replace_last_name _ = error "Internal error: Invalid function name to convert"
   125   val parts = split_name name
   152   val parts = split_name name
   126 in
   153 in
   127   String.concatWith "." (replace_last_name parts)
   154   String.concatWith "." (replace_last_name parts)
   128 end
   155 end
       
   156 fun fun_name_to_time ctxt s name = fun_name_to_time' ctxt s false name
   129 (* Count number of arguments of a function *)
   157 (* Count number of arguments of a function *)
   130 fun count_args (Type (n, [_,res])) = (if n = "fun" then 1 + count_args res else 0)
   158 fun count_args (Type (n, [_,res])) = (if n = "fun" then 1 + count_args res else 0)
   131   | count_args _ = 0
   159   | count_args _ = 0
   132 (* Check if number of arguments matches function *)
   160 (* Check if number of arguments matches function *)
   133 val _ = dest_Const
       
   134 fun check_args s (t, args) =
   161 fun check_args s (t, args) =
   135     (if length args = count_args (type_of t) then ()
   162     (if length args = count_args (type_of t) then ()
   136      else error ("Partial applications/Lambdas not allowed (" ^ s ^ ")"))
   163      else error ("Partial applications/Lambdas not allowed (" ^ s ^ ")"))
   137 (* Removes Abs *)
   164 (* Removes Abs *)
   138 fun rem_abs f (Abs (_,_,t)) = rem_abs f t
   165 fun rem_abs f (Abs (_,_,t)) = rem_abs f t
   189 type 'a converter = {
   216 type 'a converter = {
   190   constc : 'a wctxt -> term -> 'a,
   217   constc : 'a wctxt -> term -> 'a,
   191   funcc : 'a wctxt -> term -> term list -> 'a,
   218   funcc : 'a wctxt -> term -> term list -> 'a,
   192   ifc : 'a wctxt -> typ -> term -> term -> term -> 'a,
   219   ifc : 'a wctxt -> typ -> term -> term -> term -> 'a,
   193   casec : 'a wctxt -> term -> term list -> 'a,
   220   casec : 'a wctxt -> term -> term list -> 'a,
   194   letc : 'a wctxt -> typ -> term -> string list -> typ list -> term -> 'a
   221   letc : 'a wctxt -> typ -> term -> (string * typ) list -> term -> 'a
   195 }
   222 }
   196 
   223 
   197 (* Walks over term and calls given converter *)
   224 (* Walks over term and calls given converter *)
   198 fun walk_func (t1 $ t2) ts = walk_func t1 (t2::ts)
   225 (* get rid and use Term.strip_abs.eta especially for lambdas *)
   199   | walk_func t ts = (t, ts)
   226 fun build_abs t ((nm,T)::abs) = build_abs (Abs (nm,T,t)) abs
   200 fun walk_func' t = walk_func t []
   227   | build_abs t [] = t
   201 fun build_func (f, []) = f
       
   202   | build_func (f, (t::ts)) = build_func (f$t, ts)
       
   203 fun walk_abs (Abs (nm,T,t)) nms Ts = walk_abs t (nm::nms) (T::Ts)
       
   204   | walk_abs t nms Ts = (t, nms, Ts)
       
   205 fun build_abs t (nm::nms) (T::Ts) = build_abs (Abs (nm,T,t)) nms Ts
       
   206   | build_abs t [] [] = t
       
   207   | build_abs _ _ _ = error "Internal error: Invalid terms to build abs"
       
   208 fun walk ctxt (origin: term list) (conv as {ifc, casec, funcc, letc, ...} : 'a converter) (t as _ $ _) =
   228 fun walk ctxt (origin: term list) (conv as {ifc, casec, funcc, letc, ...} : 'a converter) (t as _ $ _) =
   209   let
   229   let
   210     val (f, args) = walk_func t []
   230     val (f, args) = strip_comb t
   211     val this = (walk ctxt origin conv)
   231     val this = (walk ctxt origin conv)
   212     val _ = (case f of Abs _ => error "Lambdas not supported" | _ => ())
   232     val _ = (case f of Abs _ => error "Lambdas not supported" | _ => ())
   213     val wctxt = {ctxt = ctxt, origins = origin, f = this}
   233     val wctxt = {ctxt = ctxt, origins = origin, f = this}
   214   in
   234   in
   215     (if is_if f then
   235     (if is_if f then
   218                    | _ => error "Partial applications not supported (if)")
   238                    | _ => error "Partial applications not supported (if)")
   219                | _ => error "Internal error: invalid if term")
   239                | _ => error "Internal error: invalid if term")
   220       else if is_case f then casec wctxt f args
   240       else if is_case f then casec wctxt f args
   221       else if is_let f then
   241       else if is_let f then
   222       (case f of (Const (_,lT)) =>
   242       (case f of (Const (_,lT)) =>
   223          (case args of [exp, t] => 
   243          (case args of [exp, t] =>
   224             let val (t,nms,Ts) = walk_abs t [] [] in letc wctxt lT exp nms Ts t end
   244             let val (abs,t) = strip_abs t in letc wctxt lT exp abs t end
   225                      | _ => error "Partial applications not allowed (let)")
   245                      | _ => error "Partial applications not allowed (let)")
   226                | _ => error "Internal error: invalid let term")
   246                | _ => error "Internal error: invalid let term")
   227       else funcc wctxt f args)
   247       else funcc wctxt f args)
   228   end
   248   end
   229   | walk ctxt origin (conv as {constc, ...}) c = 
   249   | walk ctxt origin (conv as {constc, ...}) c = 
   230       constc {ctxt = ctxt, origins = origin, f = walk ctxt origin conv} c
   250       constc {ctxt = ctxt, origins = origin, f = walk ctxt origin conv} c
       
   251 fun Ifunc (wctxt: term wctxt) t args = list_comb (#f wctxt t,map (#f wctxt) args)
       
   252 val Iconst = K I
       
   253 fun Iif (wctxt: term wctxt) T cond tt tf =
       
   254   Const (@{const_name "HOL.If"}, T) $ (#f wctxt cond) $ (#f wctxt tt) $ (#f wctxt tf)
       
   255 fun Icase (wctxt: term wctxt) t cs = list_comb (#f wctxt t,map (#f wctxt) cs)
       
   256 fun Ilet (wctxt: term wctxt) lT exp abs t =
       
   257   Const (@{const_name "HOL.Let"},lT) $ (#f wctxt exp) $ build_abs (#f wctxt t) abs
   231 
   258 
   232 (* 1. Fix all terms *)
   259 (* 1. Fix all terms *)
   233 (* Exchange Var in types and terms to Free *)
   260 (* Exchange Var in types and terms to Free *)
   234 fun fixTerms (Var(ixn,T)) = Free (fst ixn, T)
   261 fun freeTerms (Var(ixn,T)) = Free (fst ixn, T)
   235   | fixTerms t = t
   262   | freeTerms t = t
   236 fun fixTypes (TVar ((t, _), T)) = TFree (t, T)
   263 fun freeTypes (TVar ((t, _), T)) = TFree (t, T)
   237   | fixTypes t = t
   264   | freeTypes t = t
   238 
   265 
   239 fun noFun (Type ("fun", _)) = error "Functions in datatypes are not allowed in case constructions"
   266 fun noFun (Type ("fun", _)) = error "Functions in datatypes are not allowed in case constructions"
   240   | noFun T = T
   267   | noFun T = T
   241 fun casecBuildBounds n t = if n > 0 then casecBuildBounds (n-1) (t $ (Bound (n-1))) else t
   268 fun casecBuildBounds n t = if n > 0 then casecBuildBounds (n-1) (t $ (Bound (n-1))) else t
   242 fun casecAbs wctxt n (Type ("fun",[T,Tr])) (Abs (v,Ta,t)) = (map_atyps noFun T; Abs (v,Ta,casecAbs wctxt n Tr t))
   269 fun casecAbs wctxt n (Type ("fun",[T,Tr])) (Abs (v,Ta,t)) = (map_atyps noFun T; Abs (v,Ta,casecAbs wctxt n Tr t))
   245   | casecAbs wctxt n _ t = (#f wctxt) (casecBuildBounds n (Term.incr_bv n 0 t))
   272   | casecAbs wctxt n _ t = (#f wctxt) (casecBuildBounds n (Term.incr_bv n 0 t))
   246 fun fixCasecCases _ _ [t] = [t]
   273 fun fixCasecCases _ _ [t] = [t]
   247   | fixCasecCases wctxt (Type (_,[T,Tr])) (t::ts) = casecAbs wctxt 0 T t :: fixCasecCases wctxt Tr ts
   274   | fixCasecCases wctxt (Type (_,[T,Tr])) (t::ts) = casecAbs wctxt 0 T t :: fixCasecCases wctxt Tr ts
   248   | fixCasecCases _ _ _ = error "Internal error: invalid case types/terms"
   275   | fixCasecCases _ _ _ = error "Internal error: invalid case types/terms"
   249 fun fixCasec wctxt (t as Const (_,T)) args =
   276 fun fixCasec wctxt (t as Const (_,T)) args =
   250       (check_args "cases" (t,args); build_func (t,fixCasecCases wctxt T args))
   277       (check_args "cases" (t,args); list_comb (t,fixCasecCases wctxt T args))
   251   | fixCasec _ _ _ = error "Internal error: invalid case term"
   278   | fixCasec _ _ _ = error "Internal error: invalid case term"
   252 
   279 
   253 fun fixPartTerms ctxt (term: term list) t =
   280 fun shortFunc fixedNum (Const (nm,T)) = 
       
   281     Const (nm,T |> strip_type |>> drop fixedNum |> (op --->))
       
   282   | shortFunc _ _ = error "Internal error: Invalid term"
       
   283 fun shortApp fixedNum (c, args) =
       
   284   (shortFunc fixedNum c, drop fixedNum args)
       
   285 fun shortOriginFunc (term: term list) fixedNum (f as (c as Const (_,_), _))  =
       
   286   if contains' const_comp term c then shortApp fixedNum f else f
       
   287   | shortOriginFunc _ _ t = t
       
   288 fun fixTerms ctxt (term: term list) (fixedNum: int) (t as pT $ (eq $ l $ r)) =
   254   let
   289   let
   255     val _ = check_args "args" (walk_func (get_l t) [])
   290     val _ = check_args "args" (strip_comb (get_l t))
   256   in
   291     val l' = shortApp fixedNum (strip_comb l) |> list_comb
   257     map_r (walk ctxt term {
   292     val shortOriginFunc' = shortOriginFunc (term |> map (fst o strip_comb)) fixedNum
       
   293     val r' = walk ctxt term {
   258           funcc = (fn wctxt => fn t => fn args =>
   294           funcc = (fn wctxt => fn t => fn args =>
   259               (check_args "func" (t,args); build_func (t, map (#f wctxt) args))),
   295               (check_args "func" (t,args); (t, map (#f wctxt) args) |> shortOriginFunc' |> list_comb)),
   260           constc = (fn _ => fn c => (case c of Abs _ => error "Lambdas not supported" | _ => c)),
   296           constc = (fn _ => fn c => (case c of Abs _ => error "Lambdas not supported" | _ => c)),
   261           ifc = (fn wctxt => fn T => fn cond => fn tt => fn tf =>
   297           ifc = Iif,
   262             ((Const (@{const_name "HOL.If"}, T)) $ (#f wctxt) cond $ ((#f wctxt) tt) $ ((#f wctxt) tf))),
       
   263           casec = fixCasec,
   298           casec = fixCasec,
   264           letc = (fn wctxt => fn expT => fn exp => fn nms => fn Ts => fn t =>
   299           letc = (fn wctxt => fn expT => fn exp => fn abs => fn t =>
   265               let
   300               let
   266                 val f' = if length nms = 0 then
   301                 val f' = if length abs = 0 then
   267                 (case (#f wctxt) (t$exp) of t$_ => t | _ => error "Internal error: case could not be fixed (let)")
   302                 (case (#f wctxt) (t$exp) of t$_ => t | _ => error "Internal error: case could not be fixed (let)")
   268                 else (#f wctxt) t
   303                 else (#f wctxt) t
   269               in (Const (@{const_name "HOL.Let"},expT) $ ((#f wctxt) exp) $ build_abs f' nms Ts) end)
   304               in (Const (@{const_name "HOL.Let"},expT) $ ((#f wctxt) exp) $ build_abs f' abs) end)
   270       }) t
   305       } r
       
   306   in
       
   307     pT $ (eq $ l' $ r')
   271   end
   308   end
       
   309   | fixTerms _ _ _ _ = error "Internal error: invalid term"
   272 
   310 
   273 (* 2. Check for properties about the function *)
   311 (* 2. Check for properties about the function *)
   274 (* 2.1 Check if function is recursive *)
   312 (* 2.1 Check if function is recursive *)
   275 fun or f (a,b) = f a orelse b
   313 fun or f (a,b) = f a orelse b
   276 fun find_rec ctxt term = (walk ctxt term {
   314 fun find_rec ctxt term = (walk ctxt term {
   280           constc = (K o K) false,
   318           constc = (K o K) false,
   281           ifc = (fn wctxt => fn _ => fn cond => fn tt => fn tf =>
   319           ifc = (fn wctxt => fn _ => fn cond => fn tt => fn tf =>
   282             (#f wctxt) cond orelse (#f wctxt) tt orelse (#f wctxt) tf),
   320             (#f wctxt) cond orelse (#f wctxt) tt orelse (#f wctxt) tf),
   283           casec = (fn wctxt => fn t => fn cs =>
   321           casec = (fn wctxt => fn t => fn cs =>
   284             (#f wctxt) t orelse List.foldr (or (rem_abs (#f wctxt))) false cs),
   322             (#f wctxt) t orelse List.foldr (or (rem_abs (#f wctxt))) false cs),
   285           letc = (fn wctxt => fn _ => fn exp => fn _ => fn _ => fn t =>
   323           letc = (fn wctxt => fn _ => fn exp => fn _ => fn t =>
   286             (#f wctxt) exp orelse (#f wctxt) t)
   324             (#f wctxt) exp orelse (#f wctxt) t)
   287       }) o get_r
   325       }) o get_r
   288 fun is_rec ctxt (term: term list) = List.foldr (or (find_rec ctxt term)) false
   326 fun is_rec ctxt (term: term list) = List.foldr (or (find_rec ctxt term)) false
   289 
       
   290 (* 2.2 Check for higher-order function if original function is used *)
       
   291 fun find_used' ctxt term t T_t =
       
   292 let
       
   293   val (ident, _) = walk_func (get_l t) []
       
   294   val (T_ident, T_args) = walk_func (get_l T_t) []
       
   295 
       
   296   fun filter_passed [] = []
       
   297     | filter_passed ((f as Free (_, Type ("Product_Type.prod",[Type ("fun",_), Type ("fun", _)])))::args) = 
       
   298         f :: filter_passed args
       
   299     | filter_passed (_::args) = filter_passed args
       
   300   val frees' = (walk ctxt term {
       
   301           funcc = (fn wctxt => fn t => fn args =>
       
   302               (case t of (Const ("Product_Type.prod.snd", _)) => []
       
   303                   | _ => (if t = T_ident then [] else filter_passed args)
       
   304                     @ List.foldr (fn (l,r) => (#f wctxt) l @ r) [] args)),
       
   305           constc = (K o K) [],
       
   306           ifc = (fn wctxt => fn _ => fn cond => fn tt => fn tf => (#f wctxt) cond @ (#f wctxt) tt @ (#f wctxt) tf),
       
   307           casec = (fn wctxt => fn _ => fn cs => List.foldr (fn (l,r) => (#f wctxt) l @ r) [] cs),
       
   308           letc = (fn wctxt => fn _ => fn exp => fn _ => fn _ => fn t => (#f wctxt) exp @ (#f wctxt) t)
       
   309       }) (get_r T_t)
       
   310   fun build _ [] _ = false
       
   311     | build i (a::args) item =
       
   312         (if item = (ident,i) then contains frees' a else build (i+1) args item)
       
   313 in
       
   314   build 0 T_args
       
   315 end
       
   316 fun find_used ctxt term terms T_terms =
       
   317   ListPair.zip (terms, T_terms)
       
   318   |> List.map (fn (t, T_t) => find_used' ctxt term t T_t)
       
   319   |> List.foldr (fn (f,g) => fn item => f item orelse g item) (K false)
       
   320 
       
   321 
   327 
   322 (* 3. Convert equations *)
   328 (* 3. Convert equations *)
   323 (* Some Helper *)
   329 (* Some Helper *)
   324 val plusTyp = @{typ "nat => nat => nat"}
   330 val plusTyp = @{typ "nat => nat => nat"}
   325 fun plus (SOME a) (SOME b) = SOME (Const (@{const_name "Groups.plus"}, plusTyp) $ a $ b)
   331 fun plus (SOME a) (SOME b) = SOME (Const (@{const_name "Groups.plus"}, plusTyp) $ a $ b)
   330   | opt_term (SOME t) = t
   336   | opt_term (SOME t) = t
   331 fun use_origin (Free (nm, T as Type ("fun",_))) = HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T)))
   337 fun use_origin (Free (nm, T as Type ("fun",_))) = HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T)))
   332   | use_origin t = t
   338   | use_origin t = t
   333 
   339 
   334 (* Conversion of function term *)
   340 (* Conversion of function term *)
   335 fun fun_to_time ctxt orig_used _ (origin: term list) (func as Const (nm,T)) =
   341 fun fun_to_time' ctxt (origin: term list) second (func as Const (nm,T)) =
   336 let
   342 let
   337   val used' = used_for_const orig_used func
   343   val origin' = map (fst o strip_comb) origin
   338 in
   344 in
   339   if contains' const_comp origin func then SOME (Free (func |> Term.term_name |> fun_name_to_time ctxt true, change_typ' used' T)) else
   345   if contains' const_comp origin' func then SOME (Free (func |> Term.term_name |> fun_name_to_time' ctxt true second, change_typ T)) else
   340   if Zero_Funcs.is_zero (Proof_Context.theory_of ctxt) (nm,T) then NONE else
   346   if Zero_Funcs.is_zero (Proof_Context.theory_of ctxt) (nm,T) then NONE else
   341     time_term ctxt false func
   347     time_term ctxt false func
   342 end
   348 end
   343   | fun_to_time ctxt _ used _ (f as Free (nm,T)) = SOME (
   349   | fun_to_time' _ _ _ (Free (nm,T)) =
   344       if used f then HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T)))
   350       SOME (HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T))))
   345       else Free (fun_name_to_time ctxt false nm, change_typ T)
   351   | fun_to_time' _ _ _ _ = error "Internal error: invalid function to convert"
   346       )
   352 fun fun_to_time context origin func = fun_to_time' context origin false func
   347   | fun_to_time _ _ _ _ _ = error "Internal error: invalid function to convert"
       
   348 
   353 
   349 (* Convert arguments of left side of a term *)
   354 (* Convert arguments of left side of a term *)
   350 fun conv_arg ctxt used _ (f as Free (nm,T as Type("fun",_))) =
   355 fun conv_arg _ (Free (nm,T as Type("fun",_))) =
   351     if used f then Free (nm, HOLogic.mk_prodT (T, change_typ' (K false) T))
   356     Free (nm, HOLogic.mk_prodT (T, change_typ' (K false) T))
   352     else Free (fun_name_to_time ctxt false nm, change_typ' (K false) T)
   357   | conv_arg _ x = x
   353   | conv_arg _ _ _ x = x
   358 fun conv_args ctxt = map (conv_arg ctxt)
   354 fun conv_args ctxt used origin = map (conv_arg ctxt used origin)
       
   355 
   359 
   356 (* Handle function calls *)
   360 (* Handle function calls *)
   357 fun build_zero (Type ("fun", [T, R])) = Abs ("uu", T, build_zero R)
   361 fun build_zero (Type ("fun", [T, R])) = Abs ("uu", T, build_zero R)
   358   | build_zero _ = zero
   362   | build_zero _ = zero
   359 fun funcc_use_origin used (f as Free (nm, T as Type ("fun",_))) =
   363 fun funcc_use_origin (Free (nm, T as Type ("fun",_))) =
   360     if used f then HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T)))
   364     HOLogic.mk_fst (Free (nm,HOLogic.mk_prodT (T, change_typ T)))
   361     else error "Internal error: Error in used detection"
   365   | funcc_use_origin t = t
   362   | funcc_use_origin _ t = t
   366 fun funcc_conv_arg _ _ (t as (_ $ _)) = map_aterms funcc_use_origin t
   363 fun funcc_conv_arg _ used _ (t as (_ $ _)) = map_aterms (funcc_use_origin used) t
   367   | funcc_conv_arg _ u (Free (nm, T as Type ("fun",_))) =
   364   | funcc_conv_arg wctxt used u (f as Free (nm, T as Type ("fun",_))) =
   368       if u then Free (nm, HOLogic.mk_prodT (T, change_typ T))
   365       if used f then
   369       else HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T)))
   366         if u then Free (nm, HOLogic.mk_prodT (T, change_typ T))
   370   | funcc_conv_arg wctxt true (f as Const (_,T as Type ("fun",_))) =
   367         else HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T)))
       
   368       else Free (fun_name_to_time (#ctxt wctxt) false nm, change_typ T)
       
   369   | funcc_conv_arg wctxt _ true (f as Const (_,T as Type ("fun",_))) =
       
   370   (Const (@{const_name "Product_Type.Pair"},
   371   (Const (@{const_name "Product_Type.Pair"},
   371       Type ("fun", [T,Type ("fun", [change_typ T, HOLogic.mk_prodT (T,change_typ T)])]))
   372       Type ("fun", [T,Type ("fun", [change_typ T, HOLogic.mk_prodT (T,change_typ T)])]))
   372     $ f $ (Option.getOpt (fun_to_time (#ctxt wctxt) (K false) (K false) (#origins wctxt) f, build_zero T)))
   373     $ f $ (Option.getOpt (fun_to_time (#ctxt wctxt) (#origins wctxt) f, build_zero T)))
   373   | funcc_conv_arg wctxt _ false (f as Const (_,T as Type ("fun",_))) =
   374   | funcc_conv_arg wctxt false (f as Const (_,T as Type ("fun",_))) =
   374       Option.getOpt (fun_to_time (#ctxt wctxt) (K false) (K false) (#origins wctxt) f, build_zero T)
   375       Option.getOpt (fun_to_time (#ctxt wctxt) (#origins wctxt) f, build_zero T)
   375   | funcc_conv_arg _ _ _ t = t
   376   | funcc_conv_arg _ _ t = t
   376 
   377 
   377 fun funcc_conv_args _ _ _ [] = []
   378 fun funcc_conv_args _ _ [] = []
   378   | funcc_conv_args wctxt used (Type ("fun", [t, ts])) (a::args) =
   379   | funcc_conv_args wctxt (Type ("fun", [t, ts])) (a::args) =
   379       funcc_conv_arg wctxt used (is_Used t) a :: funcc_conv_args wctxt used ts args
   380       funcc_conv_arg wctxt (is_Used t) a :: funcc_conv_args wctxt ts args
   380   | funcc_conv_args _ _ _ _ = error "Internal error: Non matching type"
   381   | funcc_conv_args _ _ _ = error "Internal error: Non matching type"
   381 fun funcc orig_used used wctxt func args =
   382 fun funcc wctxt func args =
   382 let
   383 let
   383   fun get_T (Free (_,T)) = T
   384   fun get_T (Free (_,T)) = T
   384     | get_T (Const (_,T)) = T
   385     | get_T (Const (_,T)) = T
   385     | get_T (_ $ (Free (_,Type (_, [_, T])))) = T (* Case of snd was constructed *)
   386     | get_T (_ $ (Free (_,Type (_, [_, T])))) = T (* Case of snd was constructed *)
   386     | get_T _ = error "Internal error: Forgotten type"
   387     | get_T _ = error "Internal error: Forgotten type"
   387 in
   388 in
   388   List.foldr (I #-> plus)
   389   List.foldr (I #-> plus)
   389   (case fun_to_time (#ctxt wctxt) orig_used used (#origins wctxt) func
   390   (case fun_to_time (#ctxt wctxt) (#origins wctxt) func
   390     of SOME t => SOME (build_func (t,funcc_conv_args wctxt used (get_T t) args))
   391     of SOME t => SOME (list_comb (t,funcc_conv_args wctxt (get_T t) args))
   391     | NONE => NONE)
   392     | NONE => NONE)
   392   (map (#f wctxt) args)
   393   (map (#f wctxt) args)
   393 end
   394 end
   394 
   395 
   395 (* Handle case terms *)
   396 (* Handle case terms *)
   411   if not (casecIsCase T) then error "Internal error: Invalid case type" else
   412   if not (casecIsCase T) then error "Internal error: Invalid case type" else
   412     let val (nconst, args') = casecArgs (#f wctxt) args in
   413     let val (nconst, args') = casecArgs (#f wctxt) args in
   413       plus
   414       plus
   414         ((#f wctxt) (List.last args))
   415         ((#f wctxt) (List.last args))
   415         (if nconst then
   416         (if nconst then
   416           SOME (build_func (Const (t,casecTyp T), args'))
   417           SOME (list_comb (Const (t,casecTyp T), args'))
   417          else NONE)
   418          else NONE)
   418     end
   419     end
   419   | casec _ _ _ = error "Internal error: Invalid case term"
   420   | casec _ _ _ = error "Internal error: Invalid case term"
   420 
   421 
   421 (* Handle if terms -> drop the term if true and false terms are zero *)
   422 (* Handle if terms -> drop the term if true and false terms are zero *)
   431        (SOME ((Const (@{const_name "HOL.If"}, @{typ "bool \<Rightarrow> nat \<Rightarrow> nat \<Rightarrow> nat"})) $ rcond $ (opt_term tt) $ (opt_term ft))))
   432        (SOME ((Const (@{const_name "HOL.If"}, @{typ "bool \<Rightarrow> nat \<Rightarrow> nat \<Rightarrow> nat"})) $ rcond $ (opt_term tt) $ (opt_term ft))))
   432   end
   433   end
   433 
   434 
   434 fun letc_change_typ (Type ("fun", [T1, Type ("fun", [T2, _])])) = (Type ("fun", [T1, Type ("fun", [change_typ T2, HOLogic.natT])]))
   435 fun letc_change_typ (Type ("fun", [T1, Type ("fun", [T2, _])])) = (Type ("fun", [T1, Type ("fun", [change_typ T2, HOLogic.natT])]))
   435   | letc_change_typ _ = error "Internal error: invalid let type"
   436   | letc_change_typ _ = error "Internal error: invalid let type"
   436 fun letc wctxt expT exp nms Ts t =
   437 fun letc wctxt expT exp abs t =
   437     plus (#f wctxt exp)
   438     plus (#f wctxt exp)
   438     (if length nms = 0 (* In case of "length nms = 0" the expression got reducted
   439     (if length abs = 0 (* In case of "length nms = 0" the expression got reducted
   439                           Here we need Bound 0 to gain non-partial application *)
   440                           Here we need Bound 0 to gain non-partial application *)
   440     then (case #f wctxt (t $ Bound 0) of SOME (t' $ Bound 0) =>
   441     then (case #f wctxt (t $ Bound 0) of SOME (t' $ Bound 0) =>
   441                                  (SOME (Const (@{const_name "HOL.Let"}, letc_change_typ expT) $ (map_aterms use_origin exp) $ t'))
   442                                  (SOME (Const (@{const_name "HOL.Let"}, letc_change_typ expT) $ (map_aterms use_origin exp) $ t'))
   442                                   (* Expression is not used and can therefore let be dropped *)
   443                                   (* Expression is not used and can therefore let be dropped *)
   443                                 | SOME t' => SOME t'
   444                                 | SOME t' => SOME t'
   444                                 | NONE => NONE)
   445                                 | NONE => NONE)
   445     else (case #f wctxt t of SOME t' =>
   446     else (case #f wctxt t of SOME t' =>
   446       SOME (if Term.is_dependent t' then Const (@{const_name "HOL.Let"}, letc_change_typ expT) $ (map_aterms use_origin exp) $ build_abs t' nms Ts
   447       SOME (if Term.is_dependent t' then Const (@{const_name "HOL.Let"}, letc_change_typ expT) $ (map_aterms use_origin exp) $ build_abs t' abs
   447                                     else Term.subst_bounds([exp],t'))
   448                                     else Term.subst_bounds([exp],t'))
   448     | NONE => NONE))
   449     | NONE => NONE))
   449 
   450 
   450 (* The converter for timing functions given to the walker *)
   451 (* The converter for timing functions given to the walker *)
   451 fun converter orig_used used : term option converter = {
   452 val converter : term option converter = {
   452         constc = fn _ => fn t =>
   453         constc = fn _ => fn t =>
   453           (case t of Const ("HOL.undefined", _) => SOME (Const ("HOL.undefined", @{typ "nat"}))
   454           (case t of Const ("HOL.undefined", _) => SOME (Const ("HOL.undefined", @{typ "nat"}))
   454                    | _ => NONE),
   455                    | _ => NONE),
   455         funcc = (funcc orig_used used),
   456         funcc = funcc,
   456         ifc = ifc,
   457         ifc = ifc,
   457         casec = casec,
   458         casec = casec,
   458         letc = letc
   459         letc = letc
   459     }
   460     }
   460 fun top_converter is_rec _ _ = opt_term o (fn exp => plus exp (if is_rec then SOME one else NONE))
   461 fun top_converter is_rec _ _ = opt_term o (fn exp => plus exp (if is_rec then SOME one else NONE))
   461 
   462 
   462 (* Use converter to convert right side of a term *)
   463 (* Use converter to convert right side of a term *)
   463 fun to_time ctxt origin is_rec orig_used used term =
   464 fun to_time ctxt origin is_rec term =
   464   top_converter is_rec ctxt origin (walk ctxt origin (converter orig_used used) term)
   465   top_converter is_rec ctxt origin (walk ctxt origin converter term)
   465 
   466 
   466 (* Converts a term to its running time version *)
   467 (* Converts a term to its running time version *)
   467 fun convert_term ctxt (origin: term list) is_rec orig_used (pT $ (Const (eqN, _) $ l $ r)) =
   468 fun convert_term ctxt (origin: term list) is_rec (pT $ (Const (eqN, _) $ l $ r)) =
   468 let
   469 let
   469   val (l' as (l_const, l_params)) = walk_func l []
   470   val (l_const, l_params) = strip_comb l
   470   val used =
   471 in
   471     l_const
   472     pT
   472     |> used_for_const orig_used
   473     $ (Const (eqN, @{typ "nat \<Rightarrow> nat \<Rightarrow> bool"})
   473     |> (fn f => fn n => f (index l_params n))
   474       $ (list_comb (l_const |> fun_to_time ctxt origin |> Option.valOf, l_params |> conv_args ctxt))
   474 in
   475       $ (to_time ctxt origin is_rec r))
   475       pT
   476 end
   476       $ (Const (eqN, @{typ "nat \<Rightarrow> nat \<Rightarrow> bool"})
   477   | convert_term _ _ _ _ = error "Internal error: invalid term to convert"
   477         $ (build_func (l' |>> (fun_to_time ctxt orig_used used origin) |>> Option.valOf ||> conv_args ctxt used origin))
   478 
   478         $ (to_time ctxt origin is_rec orig_used used r))
   479 (* 3.5 Support for locales *)
   479 end
   480 fun replaceFstSndFree ctxt (origin: term list) (rfst: term -> term) (rsnd: term -> term) =
   480   | convert_term _ _ _ _ _ = error "Internal error: invalid term to convert"
   481   (walk ctxt origin {
       
   482           funcc = fn wctxt => fn t => fn args =>
       
   483             case args of
       
   484                  (f as Free _)::args =>
       
   485                    (case t of
       
   486                        Const ("Product_Type.prod.fst", _) =>
       
   487                         list_comb (rfst (t $ f), map (#f wctxt) args)
       
   488                      | Const ("Product_Type.prod.snd", _) =>
       
   489                         list_comb (rsnd (t $ f), map (#f wctxt) args)
       
   490                      | t => list_comb (t, map (#f wctxt) (f :: args)))
       
   491                | args => list_comb (t, map (#f wctxt) args),
       
   492           constc = Iconst,
       
   493           ifc = Iif,
       
   494           casec = Icase,
       
   495           letc = Ilet
       
   496       })
   481 
   497 
   482 (* 4. Tactic to prove "f_dom n" *)
   498 (* 4. Tactic to prove "f_dom n" *)
   483 fun time_dom_tac ctxt induct_rule domintros =
   499 fun time_dom_tac ctxt induct_rule domintros =
   484   (Induction.induction_tac ctxt true [] [[]] [] (SOME [induct_rule]) []
   500   (Induction.induction_tac ctxt true [] [[]] [] (SOME [induct_rule]) []
   485     THEN_ALL_NEW ((K (auto_tac ctxt)) THEN' (fn i => FIRST' (
   501     THEN_ALL_NEW ((K (auto_tac ctxt)) THEN' (fn i => FIRST' (
   493       |> map #rules
   509       |> map #rules
   494       |> map (map Thm.prop_of)
   510       |> map (map Thm.prop_of)
   495    handle Empty => error "Function or terms of function not found"
   511    handle Empty => error "Function or terms of function not found"
   496 in
   512 in
   497   equations
   513   equations
   498     |> filter (fn ts => typ_comp (ts |> hd |> get_l |> walk_func' |> fst |> dest_Const |> snd) (term |> dest_Const |> snd))
   514     |> filter (List.exists
       
   515         (fn t => typ_comp (t |> get_l |> strip_comb |> fst |> dest_Const |> snd) (term |> strip_comb |> fst |> dest_Const |> snd)))
   499     |> hd
   516     |> hd
   500 end
   517 end
   501 
   518 
       
   519 (* 5. Check for higher-order function if original function is used \<rightarrow> find simplifications *)
       
   520 fun find_used' T_t =
       
   521 let
       
   522   val (T_ident, T_args) = strip_comb (get_l T_t)
       
   523 
       
   524   fun filter_passed [] = []
       
   525     | filter_passed ((f as Free (_, Type ("Product_Type.prod",[Type ("fun",_), Type ("fun", _)])))::args) = 
       
   526         f :: filter_passed args
       
   527     | filter_passed (_::args) = filter_passed args
       
   528   val frees = (walk @{context} [] {
       
   529           funcc = (fn wctxt => fn t => fn args =>
       
   530               (case t of (Const ("Product_Type.prod.snd", _)) => []
       
   531                   | _ => (if t = T_ident then [] else filter_passed args)
       
   532                     @ List.foldr (fn (l,r) => (#f wctxt) l @ r) [] args)),
       
   533           constc = (K o K) [],
       
   534           ifc = (fn wctxt => fn _ => fn cond => fn tt => fn tf => (#f wctxt) cond @ (#f wctxt) tt @ (#f wctxt) tf),
       
   535           casec = (fn wctxt => fn _ => fn cs => List.foldr (fn (l,r) => (#f wctxt) l @ r) [] cs),
       
   536           letc = (fn wctxt => fn _ => fn exp => fn _ => fn t => (#f wctxt) exp @ (#f wctxt) t)
       
   537       }) (get_r T_t)
       
   538   fun build _ [] = []
       
   539     | build i (a::args) =
       
   540         (if contains frees a then [(T_ident,i)] else []) @ build (i+1) args
       
   541 in
       
   542   build 0 T_args
       
   543 end
       
   544 fun find_simplifyble ctxt term terms =
       
   545 let
       
   546   val used =
       
   547     terms
       
   548     |> List.map find_used'
       
   549     |> List.foldr (op @) []
       
   550   val change =
       
   551     Option.valOf o fun_to_time ctxt term
       
   552   fun detect t i (Type ("fun",_)::args) = 
       
   553     (if contains used (change t,i) then [] else [i]) @ detect t (i+1) args
       
   554     | detect t i (_::args) = detect t (i+1) args
       
   555     | detect _ _ [] = []
       
   556 in
       
   557   map (fn t => t |> type_of |> strip_type |> fst |> detect t 0) term
       
   558 end
       
   559 
       
   560 fun define_simp' term simplifyable ctxt =
       
   561 let
       
   562   val base_name = case Named_Target.locale_of ctxt of
       
   563           NONE => ctxt |> Proof_Context.theory_of |> Context.theory_base_name
       
   564         | SOME nm => nm
       
   565   
       
   566   val orig_name = term |> dest_Const_name |> split_name |> List.last
       
   567   val red_name = fun_name_to_time ctxt false orig_name
       
   568   val name = fun_name_to_time' ctxt true true orig_name
       
   569   val full_name = base_name ^ "." ^ name
       
   570   val def_name = red_name ^ "_def"
       
   571   val def = Binding.name def_name
       
   572 
       
   573   val canon = Syntax.read_term (Local_Theory.exit ctxt) name |> strip_comb
       
   574   val canonFrees = canon |> snd
       
   575   val canonType = canon |> fst |> dest_Const_type |> strip_type |> fst |> take (length canonFrees)
       
   576 
       
   577   val types = term |> dest_Const_type |> strip_type |> fst
       
   578   val vars = Variable.variant_fixes (map (K "") types) ctxt |> fst
       
   579   fun l_typs' i ((T as (Type ("fun",_)))::types) =
       
   580     (if contains simplifyable i
       
   581      then change_typ T
       
   582      else HOLogic.mk_prodT (T,change_typ T))
       
   583     :: l_typs' (i+1) types
       
   584     | l_typs' i (T::types) = T :: l_typs' (i+1) types
       
   585     | l_typs' _ [] = []
       
   586   val l_typs = l_typs' 0 types
       
   587   val lhs =
       
   588     List.foldl (fn ((v,T),t) => t $ Free (v,T)) (Free (red_name,l_typs ---> HOLogic.natT)) (ListPair.zip (vars,l_typs))
       
   589   fun fixType (TFree _) = HOLogic.natT
       
   590     | fixType T = T
       
   591   fun fixUnspecified T = T |> strip_type ||> fixType |> (op --->)
       
   592   fun r_terms' i (v::vars) ((T as (Type ("fun",_)))::types) =
       
   593     (if contains simplifyable i
       
   594     then HOLogic.mk_prod (Const ("HOL.undefined", fixUnspecified T), Free (v,change_typ T))
       
   595     else Free (v,HOLogic.mk_prodT (T,change_typ T)))
       
   596     :: r_terms' (i+1) vars types
       
   597     | r_terms' i (v::vars) (T::types) = Free (v,T) :: r_terms' (i+1) vars types
       
   598     | r_terms' _ _ _ = []
       
   599   val r_terms = r_terms' 0 vars types
       
   600   val full_type = (r_terms |> map (type_of) ---> HOLogic.natT)
       
   601   val full = list_comb (Const (full_name,canonType ---> full_type), canonFrees)
       
   602   val rhs = list_comb (full, r_terms)
       
   603   val eq = (lhs, rhs) |> HOLogic.mk_eq |> HOLogic.mk_Trueprop
       
   604   val _ = Pretty.writeln (Pretty.block [Pretty.str "Defining simplified version:\n",
       
   605                                         Syntax.pretty_term ctxt eq])
       
   606 
       
   607   val (_, ctxt') = Specification.definition NONE [] [] ((def, []), eq) ctxt
       
   608 
       
   609 in
       
   610   ((def_name, orig_name), ctxt')
       
   611 end
       
   612 fun define_simp simpables ctxt =
       
   613 let
       
   614   fun cond ((term,simplifyable),(defs,ctxt)) =
       
   615     define_simp' term simplifyable ctxt |>> (fn def => def :: defs)
       
   616 in
       
   617   List.foldr cond ([], ctxt) simpables
       
   618 end
       
   619 
       
   620 
       
   621 fun replace from to =
       
   622   map (map_aterms (fn t => if t = from then to else t))
       
   623 fun replaceAll [] = I
       
   624   | replaceAll ((from,to)::xs) = replaceAll xs o replace from to
       
   625 fun calculateSimplifications ctxt T_terms term simpables =
       
   626 let
       
   627   (* Show where a simplification can take place *)
       
   628     fun reportReductions (t,(i::is)) =
       
   629     (Pretty.writeln (Pretty.str
       
   630       ((Term.term_name t |> fun_name_to_time ctxt true)
       
   631         ^ " can be simplified because only the time-function component of parameter "
       
   632         ^ (Int.toString (i + 1)) ^ " is used. "));
       
   633         reportReductions (t,is))
       
   634       | reportReductions (_,[]) = ()
       
   635     val _ = simpables
       
   636       |> map reportReductions
       
   637 
       
   638     (* Register definitions for simplified function *)
       
   639     val (reds, ctxt) = define_simp simpables ctxt
       
   640 
       
   641     fun genRetype (Const (nm,T),is) =
       
   642     let
       
   643       val T_name = fun_name_to_time ctxt true nm |> split_name |> List.last
       
   644       val from = Free (T_name,change_typ T)
       
   645       val to = Free (T_name,change_typ' (not o contains is) T)
       
   646     in
       
   647       (from,to)
       
   648     end
       
   649       | genRetype _ = error "Internal error: invalid term"
       
   650     val retyping = map genRetype simpables
       
   651 
       
   652     fun replaceArgs (pT $ (eq $ l $ r)) =
       
   653     let
       
   654       val (t,params) = strip_comb l
       
   655       fun match (Const (f_nm,_),_) = 
       
   656             (fun_name_to_time ctxt true f_nm |> Long_Name.base_name) = (dest_Free t |> fst)
       
   657         | match _ = false
       
   658       val simps = List.find match simpables |> Option.valOf |> snd
       
   659 
       
   660       fun dest_Prod_snd (Free (nm, Type (_, [_, T2]))) =
       
   661             Free (fun_name_to_time ctxt false nm, T2)
       
   662         | dest_Prod_snd _ = error "Internal error: Argument is not a pair"
       
   663       fun rep _ [] = ([],[])
       
   664         | rep i (x::xs) =
       
   665       let 
       
   666         val (rs,args) = rep (i+1) xs
       
   667       in
       
   668         if contains simps i
       
   669           then (x::rs,dest_Prod_snd x::args)
       
   670           else (rs,x::args)
       
   671       end
       
   672       val (rs,params) = rep 0 params
       
   673       fun fFst _ = error "Internal error: Invalid term to simplify"
       
   674       fun fSnd (t as (Const _ $ f)) =
       
   675         (if contains rs f
       
   676           then dest_Prod_snd f
       
   677           else t)
       
   678         | fSnd t = t
       
   679     in
       
   680       (pT $ (eq
       
   681           $ (list_comb (t,params))
       
   682           $ (replaceFstSndFree ctxt term fFst fSnd r
       
   683               |> (fn t => replaceAll (map (fn t => (t,dest_Prod_snd t)) rs) [t])
       
   684               |> hd
       
   685             )
       
   686       ))
       
   687     end
       
   688     | replaceArgs _ = error "Internal error: Invalid term"
       
   689 
       
   690     (* Calculate reduced terms *)
       
   691     val T_terms_red = T_terms
       
   692       |> replaceAll retyping
       
   693       |> map replaceArgs
       
   694 
       
   695     val _ = print_lemma ctxt reds T_terms_red
       
   696     val _ = 
       
   697         Pretty.writeln (Pretty.str "If you do not want the simplified T function, use \"time_fun [no_simp]\"")
       
   698 in
       
   699   ctxt
       
   700 end
       
   701 
   502 (* Register timing function of a given function *)
   702 (* Register timing function of a given function *)
   503 fun reg_time_func (lthy: local_theory) (term: term list) (terms: term list) print =
   703 fun reg_time_func (lthy: local_theory) (term: term list) (terms: term list) print simp =
   504   let
   704   let
   505     val _ =
   705     val _ =
   506       case time_term lthy true (hd term)
   706       case time_term lthy true (hd term)
   507             handle (ERROR _) => NONE
   707             handle (ERROR _) => NONE
   508         of SOME _ => error ("Timing function already declared: " ^ (Term.term_name (hd term)))
   708         of SOME _ => error ("Timing function already declared: " ^ (Term.term_name (hd term)))
   509          | NONE => ()
   709          | NONE => ()
   510 
   710 
       
   711     (* Number of terms fixed by locale *)
       
   712     val fixedNum = term
       
   713       |> hd
       
   714       |> strip_comb |> snd
       
   715       |> length
       
   716 
   511     (* 1. Fix all terms *)
   717     (* 1. Fix all terms *)
   512     (* Exchange Var in types and terms to Free and check constraints *)
   718     (* Exchange Var in types and terms to Free and check constraints *)
   513     val terms = map
   719     val terms = map
   514       (map_aterms fixTerms
   720       (map_aterms freeTerms
   515         #> map_types (map_atyps fixTypes)
   721         #> map_types (map_atyps freeTypes)
   516         #> fixPartTerms lthy term)
   722         #> fixTerms lthy term fixedNum)
   517       terms
   723       terms
       
   724     val fixedFrees = (hd term) |> strip_comb |> snd |> take fixedNum 
       
   725     val fixedFreesNames = map (fst o dest_Free) fixedFrees
       
   726     val term = map (shortFunc fixedNum o fst o strip_comb) term
       
   727 
   518 
   728 
   519     (* 2. Find properties about the function *)
   729     (* 2. Find properties about the function *)
   520     (* 2.1 Check if function is recursive *)
   730     (* 2.1 Check if function is recursive *)
   521     val is_rec = is_rec lthy term terms
   731     val is_rec = is_rec lthy term terms
   522 
   732 
   523     (* 3. Convert every equation
   733     (* 3. Convert every equation
   524       - Change type of toplevel equation from _ \<Rightarrow> _ \<Rightarrow> bool to nat \<Rightarrow> nat \<Rightarrow> bool
   734       - Change type of toplevel equation from _ \<Rightarrow> _ \<Rightarrow> bool to nat \<Rightarrow> nat \<Rightarrow> bool
   525       - On left side change name of function to timing function
   735       - On left side change name of function to timing function
   526       - Convert right side of equation with conversion schema
   736       - Convert right side of equation with conversion schema
   527     *)
   737     *)
   528     fun convert used = map (convert_term lthy term is_rec used)
   738     fun fFst (t as (Const (_,T) $ Free (nm,_))) =
   529     fun repeat T_terms =
   739       (if contains fixedFreesNames nm
       
   740         then Free (nm,strip_type T |>> tl |> (op --->))
       
   741         else t)
       
   742       | fFst t = t
       
   743     fun fSnd (t as (Const (_,T) $ Free (nm,_))) =
       
   744       (if contains fixedFreesNames nm
       
   745         then Free (fun_name_to_time lthy false nm,strip_type T |>> tl |> (op --->))
       
   746         else t)
       
   747       | fSnd t = t
       
   748     val T_terms = map (convert_term lthy term is_rec) terms
       
   749       |> map (map_r (replaceFstSndFree lthy term fFst fSnd))
       
   750 
       
   751     val simpables = (if simp
       
   752       then find_simplifyble lthy term T_terms
       
   753       else map (K []) term)
       
   754       |> (fn s => ListPair.zip (term,s))
       
   755     (* Determine if something is simpable, if so rename everything *)
       
   756     val simpable = simpables |> map snd |> exists (not o null)
       
   757     (* Rename to secondary if simpable *)
       
   758     fun genRename (t,_) =
   530       let
   759       let
   531         val orig_used = find_used lthy term terms T_terms
   760         val old = fun_to_time lthy term t |> Option.valOf
   532         val T_terms' = convert orig_used terms
   761         val new = fun_to_time' lthy term true t |> Option.valOf
   533       in
   762       in
   534         if T_terms' <> T_terms then repeat T_terms' else T_terms'
   763         (old,new)
   535       end
   764       end
   536     val T_terms = repeat (convert (K true) terms)
   765     val can_T_terms = if simpable 
   537     val orig_used = find_used lthy term terms T_terms
   766       then replaceAll (map genRename simpables) T_terms
   538 
   767       else T_terms
   539     (* 4. Register function and prove termination *)
   768 
       
   769     (* 4. Register function and prove completeness *)
   540     val names = map Term.term_name term
   770     val names = map Term.term_name term
   541     val timing_names = map (fun_name_to_time lthy true) names
   771     val timing_names = map (fun_name_to_time' lthy true simpable) names
   542     val bindings = map (fn nm => (Binding.name nm, NONE, NoSyn)) timing_names
   772     val bindings = map (fn nm => (Binding.name nm, NONE, NoSyn)) timing_names
   543     fun pat_completeness_auto ctxt =
   773     fun pat_completeness_auto ctxt =
   544       Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt
   774       Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt
   545     val specs = map (fn eq => (((Binding.empty, []), eq), [], [])) T_terms
   775     val specs = map (fn eq => (((Binding.empty, []), eq), [], [])) can_T_terms
   546 
   776 
       
   777     (* Context for printing without showing question marks *)
       
   778     val print_ctxt = lthy
       
   779       |> Config.put show_question_marks false
       
   780       |> Config.put show_sorts false (* Change it for debugging *)
       
   781     val print_ctxt = List.foldl (fn (term, ctxt) => Variable.add_fixes_implicit term ctxt) print_ctxt (term @ T_terms)
       
   782     (* Print result if print *)
       
   783     val _ = if not print then () else
       
   784         let
       
   785           val nms = map (dest_Const_name) term
       
   786           val typs = map (dest_Const_type) term
       
   787         in
       
   788           print_timing' print_ctxt { names=nms, terms=terms, typs=typs }
       
   789             { names=timing_names, terms=can_T_terms, typs=map change_typ typs }
       
   790         end
       
   791     
   547     (* For partial functions sequential=true is needed in order to support them
   792     (* For partial functions sequential=true is needed in order to support them
   548        We need sequential=false to support the automatic proof of termination over dom
   793        We need sequential=false to support the automatic proof of termination over dom
   549     *)
   794     *)
   550     fun register seq =
   795     fun register seq =
   551       let
   796       let
   554           {sequential=seq, default=NONE, domintros=true, partials=false}
   799           {sequential=seq, default=NONE, domintros=true, partials=false}
   555       in
   800       in
   556         Function.add_function bindings specs fun_config pat_completeness_auto lthy
   801         Function.add_function bindings specs fun_config pat_completeness_auto lthy
   557       end
   802       end
   558 
   803 
   559     (* Context for printing without showing question marks *)
   804     val (info,ctxt) = 
   560     val print_ctxt = lthy
   805       register false
   561       |> Config.put show_question_marks false
   806         handle (ERROR _) =>
   562       |> Config.put show_sorts false (* Change it for debugging *)
   807           register true
   563     val print_ctxt = List.foldl (fn (term, ctxt) => Variable.add_fixes_implicit term ctxt) print_ctxt (term @ T_terms)
   808              | Match =>
   564     (* Print result if print *)
   809           register true
   565     val _ = if not print then () else
   810 
   566         let
   811     val ctxt = if simpable then calculateSimplifications ctxt T_terms term simpables else ctxt
   567           val nms = map (fst o dest_Const) term
       
   568           val used = map (used_for_const orig_used) term
       
   569           val typs = map (snd o dest_Const) term
       
   570         in
       
   571           print_timing' print_ctxt { names=nms, terms=terms, typs=typs }
       
   572             { names=timing_names, terms=T_terms, typs=map (fn (used, typ) => change_typ' used typ) (ListPair.zip (used, typs)) }
       
   573         end
       
   574 
       
   575   in
   812   in
   576     register false
   813     (info,ctxt)
   577       handle (ERROR _) =>
       
   578         register true
       
   579            | Match =>
       
   580         register true
       
   581   end
   814   end
   582 fun proove_termination (term: term list) terms print (T_info: Function.info, lthy: local_theory) =
   815 fun proove_termination (term: term list) terms (T_info: Function.info, lthy: local_theory) =
   583   let
   816   let
   584     (* Start proving the termination *)  
   817     (* Start proving the termination *)  
   585     val infos = SOME (map (Function.get_info lthy) term) handle Empty => NONE
   818     val infos = SOME (map (Function.get_info lthy) term) handle Empty => NONE
   586     val timing_names = map (fun_name_to_time lthy true o Term.term_name) term
   819     val timing_names = map (fun_name_to_time lthy true o Term.term_name) term
   587 
   820 
   596 
   829 
   597           fun args (a$(Var ((nm,_),T))) = args a |> (fn ar => (nm,T)::ar)
   830           fun args (a$(Var ((nm,_),T))) = args a |> (fn ar => (nm,T)::ar)
   598             | args (a$(Const (_,T))) = args a |> (fn ar => ("uu",T)::ar)
   831             | args (a$(Const (_,T))) = args a |> (fn ar => ("uu",T)::ar)
   599             | args _ = []
   832             | args _ = []
   600           val dom_vars =
   833           val dom_vars =
   601             terms |> hd |> get_l |> map_types (map_atyps fixTypes)
   834             terms |> hd |> get_l |> map_types (map_atyps freeTypes)
   602             |> args |> Variable.variant_frees lthy []
   835             |> args |> Variable.variant_frees lthy []
   603           val dom_args = 
   836           val dom_args = 
   604             List.foldl (fn (t,p) => HOLogic.mk_prod ((Free t),p)) (Free (hd dom_vars)) (tl dom_vars)
   837             List.foldl (fn (t,p) => HOLogic.mk_prod ((Free t),p)) (Free (hd dom_vars)) (tl dom_vars)
   605 
   838 
   606           val {inducts, ...} = case infos of SOME [i] => i | _ => error "Proof over dom failed as no induct rule was found"
   839           val {inducts, ...} = case infos of SOME [i] => i | _ => error "Proof over dom failed as no induct rule was found"
   617         in
   850         in
   618           (* Use lemma to prove termination *)
   851           (* Use lemma to prove termination *)
   619           Function.prove_termination NONE
   852           Function.prove_termination NONE
   620             (auto_tac simp_lthy) lthy
   853             (auto_tac simp_lthy) lthy
   621         end
   854         end
   622 
       
   623     (* Context for printing without showing question marks *)
       
   624     val print_ctxt = lthy'
       
   625       |> Config.put show_question_marks false
       
   626       |> Config.put show_sorts false (* Change it for debugging *)
       
   627     (* Print result if print *)
       
   628     val _ = if not print then () else
       
   629         let
       
   630           val nms = map (fst o dest_Const) term
       
   631           val typs = map (snd o dest_Const) term
       
   632         in
       
   633           print_timing' print_ctxt { names=nms, terms=terms, typs=typs } (info_pfunc time_info)
       
   634         end
       
   635   in
   855   in
   636     (time_info, lthy')
   856     (time_info, lthy')
   637   end
   857   end
   638 fun reg_and_proove_time_func (lthy: local_theory) (term: term list) (terms: term list) print =
   858 fun reg_and_proove_time_func (lthy: local_theory) (term: term list) (terms: term list) print simp =
   639   reg_time_func lthy term terms false
   859   reg_time_func lthy term terms print simp
   640   |> proove_termination term terms print
   860   |> proove_termination term terms
   641 
   861 
   642 fun fix_definition (Const ("Pure.eq", _) $ l $ r) = Const ("HOL.Trueprop", @{typ "bool \<Rightarrow> prop"})
   862 fun fix_definition (Const ("Pure.eq", _) $ l $ r) = Const ("HOL.Trueprop", @{typ "bool \<Rightarrow> prop"})
   643       $ (Const ("HOL.eq", @{typ "bool \<Rightarrow> bool \<Rightarrow> bool"}) $ l $ r)
   863       $ (Const ("HOL.eq", @{typ "bool \<Rightarrow> bool \<Rightarrow> bool"}) $ l $ r)
   644   | fix_definition t = t
   864   | fix_definition t = t
   645 fun check_definition [t] = [t]
   865 fun check_definition [t] = [t]
   674   val suffix = (if isTypeClass then detect_typ ctxt (hd fterms) else NONE)
   894   val suffix = (if isTypeClass then detect_typ ctxt (hd fterms) else NONE)
   675 in
   895 in
   676   (case suffix of NONE => I | SOME s => Config.put bsuffix ("_" ^ s)) ctxt
   896   (case suffix of NONE => I | SOME s => Config.put bsuffix ("_" ^ s)) ctxt
   677 end
   897 end
   678 
   898 
       
   899 fun check_opts [] = false
       
   900   | check_opts ["no_simp"] = true
       
   901   | check_opts (a::_) = error ("Option " ^ a ^ " is not defined")
       
   902 
   679 (* Convert function into its timing function (called by command) *)
   903 (* Convert function into its timing function (called by command) *)
   680 fun reg_time_fun_cmd (funcs, thms) (ctxt: local_theory) =
   904 fun reg_time_fun_cmd ((opts, funcs), thms) (ctxt: local_theory) =
   681 let
   905 let
       
   906   val no_simp = check_opts opts
   682   val fterms = map (Syntax.read_term ctxt) funcs
   907   val fterms = map (Syntax.read_term ctxt) funcs
   683   val ctxt = set_suffix fterms ctxt
   908   val ctxt = set_suffix fterms ctxt
   684   val (_, ctxt') = reg_and_proove_time_func ctxt fterms
   909   val (_, ctxt') = reg_and_proove_time_func ctxt fterms
   685     (case thms of NONE => get_terms ctxt (hd fterms)
   910     (case thms of NONE => get_terms ctxt (hd fterms)
   686                 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of)
   911                 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of)
   687     true
   912     true (not no_simp)
   688 in ctxt'
   913 in ctxt'
   689 end
   914 end
   690 
   915 
   691 (* Convert function into its timing function (called by command) with termination proof provided by user*)
   916 (* Convert function into its timing function (called by command) with termination proof provided by user*)
   692 fun reg_time_function_cmd (funcs, thms) (ctxt: local_theory) =
   917 fun reg_time_function_cmd ((opts, funcs), thms) (ctxt: local_theory) =
   693 let
   918 let
       
   919   val no_simp = check_opts opts
   694   val fterms = map (Syntax.read_term ctxt) funcs
   920   val fterms = map (Syntax.read_term ctxt) funcs
   695   val ctxt = set_suffix fterms ctxt
   921   val ctxt = set_suffix fterms ctxt
   696   val ctxt' = reg_time_func ctxt fterms
   922   val ctxt' = reg_time_func ctxt fterms
   697     (case thms of NONE => get_terms ctxt (hd fterms)
   923     (case thms of NONE => get_terms ctxt (hd fterms)
   698                 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of)
   924                 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of)
   699     true
   925     true (not no_simp)
   700     |> snd
   926     |> snd
   701 in ctxt'
   927 in ctxt'
   702 end
   928 end
   703 
   929 
   704 (* Convert function into its timing function (called by command) *)
   930 (* Convert function into its timing function (called by command) *)
   705 fun reg_time_definition_cmd (funcs, thms) (ctxt: local_theory) =
   931 fun reg_time_definition_cmd ((opts, funcs), thms) (ctxt: local_theory) =
   706 let
   932 let
       
   933   val no_simp = check_opts opts
   707   val fterms = map (Syntax.read_term ctxt) funcs
   934   val fterms = map (Syntax.read_term ctxt) funcs
   708   val ctxt = set_suffix fterms ctxt
   935   val ctxt = set_suffix fterms ctxt
   709   val (_, ctxt') = reg_and_proove_time_func ctxt fterms
   936   val (_, ctxt') = reg_and_proove_time_func ctxt fterms
   710     (case thms of NONE => get_terms ctxt (hd fterms) |> check_definition |> map fix_definition
   937     (case thms of NONE => get_terms ctxt (hd fterms) |> check_definition |> map fix_definition
   711                 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of)
   938                 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of)
   712     true
   939     true (not no_simp)
   713 in ctxt'
   940 in ctxt'
   714 end
   941 end
   715 
   942 
   716 val parser = (Scan.repeat1 Parse.prop) -- (Scan.option (Parse.keyword_improper "equations" -- Parse.thms1 >> snd))
   943 val parser = (Parse.opt_attribs >> map (fst o Token.name_of_src))
   717 
   944              -- Scan.repeat1 Parse.prop
       
   945              -- Scan.option (Parse.keyword_improper "equations" -- Parse.thms1 >> snd)
       
   946 val _ = Toplevel.local_theory
   718 val _ = Outer_Syntax.local_theory @{command_keyword "time_fun"}
   947 val _ = Outer_Syntax.local_theory @{command_keyword "time_fun"}
   719   "Defines runtime function of a function"
   948   "Defines runtime function of a function"
   720   (parser >> reg_time_fun_cmd)
   949   (parser >> reg_time_fun_cmd)
   721 
   950 
   722 val _ = Outer_Syntax.local_theory @{command_keyword "time_function"}
   951 val _ = Outer_Syntax.local_theory @{command_keyword "time_function"}