merged
authornipkow
Mon, 29 Jan 2024 21:18:11 +0100
changeset 79543 bbed18f7a522
parent 79541 4f40225936d1 (current diff)
parent 79542 b941924a407d (diff)
child 79544 50ee2921da94
child 79545 b8a6b2ec85a2
merged
--- a/src/HOL/Data_Structures/Define_Time_Function.ML	Mon Jan 29 19:35:07 2024 +0000
+++ b/src/HOL/Data_Structures/Define_Time_Function.ML	Mon Jan 29 21:18:11 2024 +0100
@@ -64,7 +64,6 @@
   val header_content =
      List.concat (prepHeadCont (hd names,hd typs) :: map ((fn l => Pretty.str "\nand " :: l) o prepHeadCont) (ListPair.zip (tl names, tl typs)));
   val header_end = Pretty.str " where\n  ";
-  val _ = List.map
   val header = [header_beg] @ header_content @ [header_end];
   fun separate sep prts =
     flat (Library.separate [Pretty.str sep] (map single prts));
@@ -199,19 +198,21 @@
   | fixTerms t = t
 fun fixTypes (TVar ((t, _), T)) = TFree (t, T)
   | fixTypes t = t
-val _ = Variable.variant_fixes
+
+fun noFun (Type ("fun",_)) = error "Functions in datatypes are not allowed in case constructions"
+  | noFun _ = ()
 fun casecBuildBounds n t = if n > 0 then casecBuildBounds (n-1) (t $ (Bound (n-1))) else t
-fun casecAbs ctxt f n (Type (_,[_,Tr])) (Abs (v,Ta,t)) = Abs (v,Ta,casecAbs ctxt f n Tr t)
+fun casecAbs ctxt f n (Type (_,[T,Tr])) (Abs (v,Ta,t)) = (noFun T; Abs (v,Ta,casecAbs ctxt f n Tr t))
   | casecAbs ctxt f n (Type (Tn,[T,Tr])) t =
-    (case Variable.variant_fixes ["x"] ctxt of ([v],ctxt) =>
+    (noFun T; case Variable.variant_fixes ["x"] ctxt of ([v],ctxt) =>
     (if Tn = "fun" then Abs(v,T,casecAbs ctxt f (n + 1) Tr t) else f t)
     | _ => error "Internal error: could not fix variable")
   | casecAbs _ f n _ t = f (casecBuildBounds n t)
 fun fixCasecCases _ _ _ [t] = [t]
   | fixCasecCases ctxt f (Type (_,[T,Tr])) (t::ts) = casecAbs ctxt f 0 T t :: fixCasecCases ctxt f Tr ts
   | fixCasecCases _ _ _ _ = error "Internal error: invalid case types/terms"
-fun fixCasec ctxt _ f (t as Const (n,T)) args =
-      (check_args "cases" (Syntax.read_term ctxt n,args); build_func (t,fixCasecCases ctxt f T args))
+fun fixCasec ctxt _ f (t as Const (_,T)) args =
+      (check_args "cases" (t,args); build_func (t,fixCasecCases ctxt f T args))
   | fixCasec _ _ _ _ _ = error "Internal error: invalid case term"
 
 fun fixPartTerms ctxt (term: term list) t =
@@ -230,8 +231,7 @@
                 val f' = if length nms = 0 then
                 (case f (t$exp) of t$_ => t | _ => error "Internal error: case could not be fixed (let)")
                 else f t
-              in
-              (Const (Let_name,expT) $ (f exp) $ build_abs f' nms Ts) end)
+              in (Const (Let_name,expT) $ (f exp) $ build_abs f' nms Ts) end)
       }) t
   end
 
@@ -340,12 +340,16 @@
   | letc_change_typ _ = error "Internal error: invalid let type"
 fun letc _ _ f expT exp nms Ts t =
     plus (f exp)
-    (if length nms = 0 (* In case of "length nms = 0" a case expression is used to split up a type *)
-    (* Add (Bound 0) to receive a fully evaluated function, which can be handled by casec
-       Strip of (Bound 0) after conversion *)
-    then (case f (t $ Bound 0) of SOME (t' $ Bound 0) => SOME (Const (Let_name, letc_change_typ expT) $ exp $ t')
-                                | _ => NONE)
-    else (case f t of SOME t' => SOME (Const (Let_name, letc_change_typ expT) $ exp $ build_abs t' nms Ts)
+    (if length nms = 0 (* In case of "length nms = 0" the expression got reducted
+                          Here we need Bound 0 to gain non-partial application *)
+    then (case f (t $ Bound 0) of SOME (t' $ Bound 0) =>
+                                 (SOME (Const (Let_name, letc_change_typ expT) $ exp $ t'))
+                                  (* Expression is not used and can therefore let be dropped *)
+                                | SOME t' => SOME t'
+                                | NONE => NONE)
+    else (case f t of SOME t' =>
+      SOME (if Term.is_dependent t' then Const (Let_name, letc_change_typ expT) $ exp $ build_abs t' nms Ts
+                                    else Term.subst_bounds([exp],t'))
                                | NONE => NONE))
 
 (* The converter for timing functions given to the walker *)
@@ -399,7 +403,11 @@
 
     (* 1. Fix all terms *)
     (* Exchange Var in types and terms to Free and check constraints *)
-    val terms = map (map_aterms fixTerms #> map_types (map_atyps fixTypes) #> fixPartTerms lthy term) terms
+    val terms = map 
+      (map_aterms fixTerms
+        #> map_types (map_atyps fixTypes)
+        #> fixPartTerms lthy term)
+      terms
 
     (* 2. Check if function is recursive *)
     val is_rec = is_rec lthy term terms
@@ -435,6 +443,7 @@
     val print_ctxt = lthy
       |> Config.put show_question_marks false
       |> Config.put show_sorts false (* Change it for debugging *)
+    val print_ctxt =  List.foldl (fn (term, ctxt) => Variable.add_fixes_implicit term ctxt) print_ctxt (term@timing_terms)
     (* Print result if print *)
     val _ = if not print then () else
         let