time_fun: lambdas and lets work now
authornipkow
Fri, 25 Oct 2024 16:57:17 +0200
changeset 81255 47530e9a7c33
parent 81253 bbed9f218158
child 81256 7e86118f4791
time_fun: lambdas and lets work now
src/HOL/Data_Structures/Define_Time_Function.ML
--- a/src/HOL/Data_Structures/Define_Time_Function.ML	Thu Oct 24 22:05:57 2024 +0200
+++ b/src/HOL/Data_Structures/Define_Time_Function.ML	Fri Oct 25 16:57:17 2024 +0200
@@ -223,8 +223,8 @@
 
 (* Walks over term and calls given converter *)
 (* get rid and use Term.strip_abs.eta especially for lambdas *)
-fun build_abs t ((nm,T)::abs) = build_abs (Abs (nm,T,t)) abs
-  | build_abs t [] = t
+fun list_abs ([], t) = t
+  | list_abs (a::abs,t) = list_abs (abs,t) |> absfree a
 fun walk ctxt (origin: term list) (conv as {ifc, casec, funcc, letc, ...} : 'a converter) (t as _ $ _) =
   let
     val (f, args) = strip_comb t
@@ -241,7 +241,7 @@
       else if is_let f then
       (case f of (Const (_,lT)) =>
          (case args of [exp, t] =>
-            let val (abs,t) = strip_abs t in letc wctxt lT exp abs t end
+            let val (abs,t) = Term.strip_abs_eta 1 t in letc wctxt lT exp abs t end
                      | _ => error "Partial applications not allowed (let)")
                | _ => error "Internal error: invalid let term")
       else funcc wctxt f args)
@@ -254,7 +254,7 @@
   Const (@{const_name "HOL.If"}, T) $ (#f wctxt cond) $ (#f wctxt tt) $ (#f wctxt tf)
 fun Icase (wctxt: term wctxt) t cs = list_comb (#f wctxt t,map (#f wctxt) cs)
 fun Ilet (wctxt: term wctxt) lT exp abs t =
-  Const (@{const_name "HOL.Let"},lT) $ (#f wctxt exp) $ build_abs (#f wctxt t) abs
+  Const (@{const_name "HOL.Let"},lT) $ (#f wctxt exp) $ list_abs (abs, #f wctxt t)
 
 (* 1. Fix all terms *)
 (* Exchange Var in types and terms to Free *)
@@ -285,6 +285,9 @@
 fun shortOriginFunc (term: term list) fixedNum (f as (c as Const (_,_), _))  =
   if contains' const_comp term c then shortApp fixedNum f else f
   | shortOriginFunc _ _ t = t
+val _ = strip_abs
+fun map_abs f (t as Abs _) = t |> strip_abs ||> f |> list_abs
+  | map_abs _ t = t
 fun fixTerms ctxt (term: term list) (fixedNum: int) (t as pT $ (eq $ l $ r)) =
   let
     val _ = check_args "args" (strip_comb (get_l t))
@@ -292,16 +295,12 @@
     val shortOriginFunc' = shortOriginFunc (term |> map (fst o strip_comb)) fixedNum
     val r' = walk ctxt term {
           funcc = (fn wctxt => fn t => fn args =>
-              (check_args "func" (t,args); (t, map (#f wctxt) args) |> shortOriginFunc' |> list_comb)),
-          constc = (fn _ => fn c => (case c of Abs _ => error "Lambdas not supported" | _ => c)),
+              (check_args "func" (t,args); (#f wctxt t, map (#f wctxt) args) |> shortOriginFunc' |> list_comb)),
+          constc = fn wctxt => map_abs (#f wctxt),
           ifc = Iif,
           casec = fixCasec,
           letc = (fn wctxt => fn expT => fn exp => fn abs => fn t =>
-              let
-                val f' = if length abs = 0 then
-                (case (#f wctxt) (t$exp) of t$_ => t | _ => error "Internal error: case could not be fixed (let)")
-                else (#f wctxt) t
-              in (Const (@{const_name "HOL.Let"},expT) $ ((#f wctxt) exp) $ build_abs f' abs) end)
+              (Const (@{const_name "HOL.Let"},expT) $ (#f wctxt exp) $ list_abs (abs, #f wctxt t)))
       } r
   in
     pT $ (eq $ l' $ r')
@@ -315,7 +314,9 @@
           funcc = (fn wctxt => fn t => fn args =>
             List.exists (fn term => (Const_name t) = (Const_name term)) term
              orelse List.foldr (or (#f wctxt)) false args),
-          constc = (K o K) false,
+          constc = fn wctxt => fn t => case t of
+                Abs _ => t |> strip_abs |> snd |> (#f wctxt)
+              | _ => false,
           ifc = (fn wctxt => fn _ => fn cond => fn tt => fn tf =>
             (#f wctxt) cond orelse (#f wctxt) tt orelse (#f wctxt) tf),
           casec = (fn wctxt => fn t => fn cs =>
@@ -367,12 +368,24 @@
   | funcc_conv_arg _ u (Free (nm, T as Type ("fun",_))) =
       if u then Free (nm, HOLogic.mk_prodT (T, change_typ T))
       else HOLogic.mk_snd (Free (nm,HOLogic.mk_prodT (T,change_typ T)))
-  | funcc_conv_arg wctxt true (f as Const (_,T as Type ("fun",_))) =
-  (Const (@{const_name "Product_Type.Pair"},
-      Type ("fun", [T,Type ("fun", [change_typ T, HOLogic.mk_prodT (T,change_typ T)])]))
-    $ f $ (Option.getOpt (fun_to_time (#ctxt wctxt) (#origins wctxt) f, build_zero T)))
+  | funcc_conv_arg wctxt true (f as Const (_,Type ("fun",_))) =
+      HOLogic.mk_prod (f, funcc_conv_arg wctxt false f)
   | funcc_conv_arg wctxt false (f as Const (_,T as Type ("fun",_))) =
       Option.getOpt (fun_to_time (#ctxt wctxt) (#origins wctxt) f, build_zero T)
+  | funcc_conv_arg wctxt false (f as Abs _) =
+       f
+       |> Term.strip_abs_eta ((length o fst o strip_type o type_of) f)
+       ||> #f wctxt ||> opt_term
+       |> list_abs
+  | funcc_conv_arg wctxt true (f as Abs _) =
+    let
+      val f' = f
+       |> Term.strip_abs_eta ((length o fst o strip_type o type_of) f)
+       ||> map_aterms funcc_use_origin
+       |> list_abs
+    in
+      HOLogic.mk_prod (f', funcc_conv_arg wctxt false f)
+    end
   | funcc_conv_arg _ _ t = t
 
 fun funcc_conv_args _ _ [] = []
@@ -387,9 +400,9 @@
     | get_T _ = error "Internal error: Forgotten type"
 in
   List.foldr (I #-> plus)
-  (case fun_to_time (#ctxt wctxt) (#origins wctxt) func
-    of SOME t => SOME (list_comb (t,funcc_conv_args wctxt (get_T t) args))
-    | NONE => NONE)
+  (case fun_to_time (#ctxt wctxt) (#origins wctxt) func (* add case for abs *)
+    of SOME t => SOME (list_comb (t, funcc_conv_args wctxt (get_T t) args))
+     | NONE => NONE)
   (map (#f wctxt) args)
 end
 
@@ -432,27 +445,31 @@
        (SOME ((Const (@{const_name "HOL.If"}, @{typ "bool \<Rightarrow> nat \<Rightarrow> nat \<Rightarrow> nat"})) $ rcond $ (opt_term tt) $ (opt_term ft))))
   end
 
-fun letc_change_typ (Type ("fun", [T1, Type ("fun", [T2, _])])) = (Type ("fun", [T1, Type ("fun", [change_typ T2, HOLogic.natT])]))
-  | letc_change_typ _ = error "Internal error: invalid let type"
-fun letc wctxt expT exp abs t =
-    plus (#f wctxt exp)
-    (if length abs = 0 (* In case of "length nms = 0" the expression got reducted
-                          Here we need Bound 0 to gain non-partial application *)
-    then (case #f wctxt (t $ Bound 0) of SOME (t' $ Bound 0) =>
-                                 (SOME (Const (@{const_name "HOL.Let"}, letc_change_typ expT) $ (map_aterms use_origin exp) $ t'))
-                                  (* Expression is not used and can therefore let be dropped *)
-                                | SOME t' => SOME t'
-                                | NONE => NONE)
-    else (case #f wctxt t of SOME t' =>
-      SOME (if Term.is_dependent t' then Const (@{const_name "HOL.Let"}, letc_change_typ expT) $ (map_aterms use_origin exp) $ build_abs t' abs
-                                    else Term.subst_bounds([exp],t'))
-    | NONE => NONE))
+fun letc_lambda wctxt T (t as Abs _) =
+      HOLogic.mk_prod (map_aterms use_origin t, 
+       Term.strip_abs_eta (strip_type T |> fst |> length) t ||> #f wctxt ||> opt_term |> list_abs)
+  | letc_lambda _ _ t = map_aterms use_origin t
+fun letc wctxt expT exp ([(nm,_)]) t =
+      plus (#f wctxt exp)
+      (case #f wctxt t of SOME t' =>
+        (if Term.used_free nm t'
+         then
+          let
+            val exp' = letc_lambda wctxt expT exp
+            val t' = list_abs ([(nm,type_of exp')], t')
+          in
+            Const (@{const_name "HOL.Let"}, [type_of exp', type_of t'] ---> HOLogic.natT) $ exp' $ t'
+          end
+         else t') |> SOME
+      | NONE => NONE)
+  | letc _ _ _ _ _ = error "Unknown let state"
+
+fun constc _ (Const ("HOL.undefined", _)) = SOME (Const ("HOL.undefined", @{typ "nat"}))
+  | constc _ _ = NONE
 
 (* The converter for timing functions given to the walker *)
 val converter : term option converter = {
-        constc = fn _ => fn t =>
-          (case t of Const ("HOL.undefined", _) => SOME (Const ("HOL.undefined", @{typ "nat"}))
-                   | _ => NONE),
+        constc = constc,
         funcc = funcc,
         ifc = ifc,
         casec = casec,
@@ -503,6 +520,10 @@
     [Metis_Tactic.metis_tac [] ATP_Problem_Generate.combsN ctxt domintros]) i)))
 
 
+fun fix_definition (Const ("Pure.eq", _) $ l $ r) = HOLogic.mk_Trueprop (HOLogic.mk_eq (l,r))
+  | fix_definition t = t
+fun check_definition [t] = [t]
+  | check_definition _ = error "Only a single definition is allowed"
 fun get_terms theory (term: term) =
 let
   val equations = Spec_Rules.retrieve theory term
@@ -511,6 +532,7 @@
    handle Empty => error "Function or terms of function not found"
 in
   equations
+    |> map (map fix_definition)
     |> filter (List.exists
         (fn t => typ_comp (t |> get_l |> strip_comb |> fst |> dest_Const |> snd) (term |> strip_comb |> fst |> dest_Const |> snd)))
     |> hd
@@ -725,6 +747,14 @@
     val fixedFreesNames = map (fst o dest_Free) fixedFrees
     val term = map (shortFunc fixedNum o fst o strip_comb) term
 
+    fun correctTerm term =
+    let
+      val get_f = fst o strip_comb o get_l
+    in
+      List.find (fn t => (dest_Const_name o get_f) t = dest_Const_name term) terms
+        |> Option.valOf |> get_f
+    end
+    val term = map correctTerm term
 
     (* 2. Find properties about the function *)
     (* 2.1 Check if function is recursive *)
@@ -859,12 +889,6 @@
   reg_time_func lthy term terms print simp
   |> proove_termination term terms
 
-fun fix_definition (Const ("Pure.eq", _) $ l $ r) = Const ("HOL.Trueprop", @{typ "bool \<Rightarrow> prop"})
-      $ (Const ("HOL.eq", @{typ "bool \<Rightarrow> bool \<Rightarrow> bool"}) $ l $ r)
-  | fix_definition t = t
-fun check_definition [t] = [t]
-  | check_definition _ = error "Only a single definition is allowed"
-
 fun isTypeClass' (Const (nm,_)) =
   (case split_name nm |> rev
     of (_::nm::_) => String.isSuffix "_class" nm
@@ -934,7 +958,7 @@
   val fterms = map (Syntax.read_term ctxt) funcs
   val ctxt = set_suffix fterms ctxt
   val (_, ctxt') = reg_and_proove_time_func ctxt fterms
-    (case thms of NONE => get_terms ctxt (hd fterms) |> check_definition |> map fix_definition
+    (case thms of NONE => get_terms ctxt (hd fterms) |> check_definition
                 | SOME thms => thms |> Attrib.eval_thms ctxt |> List.map Thm.prop_of)
     true (not no_simp)
 in ctxt'