fixed lambda-lifting: shift indices of bound variables correctly (after locking the required bound variables) and apply bound variables to the new function symbol in the right order;
authorboehmes
Fri, 17 Dec 2010 14:36:33 +0100
changeset 41232 4ea9f2a8c093
parent 41223 cf5e008d38c4
child 41233 d4cb4d0c14a7
fixed lambda-lifting: shift indices of bound variables correctly (after locking the required bound variables) and apply bound variables to the new function symbol in the right order; fixed introduction of explicit application: use explicit application for every additional argument (grouping of arguments caused confusion when translating into the intermediate format)
src/HOL/Tools/SMT/smt_translate.ML
--- a/src/HOL/Tools/SMT/smt_translate.ML	Fri Dec 17 08:37:35 2010 +0100
+++ b/src/HOL/Tools/SMT/smt_translate.ML	Fri Dec 17 14:36:33 2010 +0100
@@ -249,40 +249,38 @@
       fun mk_all T t = HOLogic.all_const T $ Abs (Name.uu, T, t)
     in fold mk_all Ts (@{const SMT.trigger} $ trigger $ eq) end
 
+  fun mk_abs Ts = fold (fn T => fn t => Abs (Name.uu, T, t)) Ts
+
+  fun dest_abs Ts (Abs (_, T, t)) = dest_abs (T :: Ts) t
+    | dest_abs Ts t = (Ts, t)
+
   fun replace_lambda Us Ts t (cx as (defs, ctxt)) =
     let
+      val t1 = mk_abs Us t
+      val bs = sort int_ord (Term.add_loose_bnos (t1, 0, []))
+      fun rep i k = if member (op =) bs i then (Bound k, k+1) else (Bound i, k)
+      val (rs, _) = fold_map rep (0 upto length Ts - 1) 0
+      val t2 = Term.subst_bounds (rs, t1)
+      val Ts' = map (nth Ts) bs 
+      val (_, t3) = dest_abs [] t2
+      val t4 = mk_abs Ts' t2
+
       val T = Term.fastype_of1 (Us @ Ts, t)
-      val lev = length Us
-      val bs = sort int_ord (Term.add_loose_bnos (t, lev, []))
-      val bss = map_index (fn (i, j) => (j + lev, i + lev)) bs
-      val norm = perhaps (AList.lookup (op =) bss)
-      val t' = Term.map_aterms (fn Bound i => Bound (norm i) | t => t) t
-      val Ts' = map (nth Ts) bs
-
-      fun mk_abs U u = Abs (Name.uu, U, u)
-      val abs_rhs = fold mk_abs Ts' (fold mk_abs Us t')
-
-      fun app f = Term.list_comb (f, map Bound bs)
+      fun app f = Term.list_comb (f, map Bound (rev bs))
     in
-      (case Termtab.lookup defs abs_rhs of
+      (case Termtab.lookup defs t4 of
         SOME (f, _) => (app f, cx)
       | NONE =>
           let
             val (n, ctxt') =
               yield_singleton Variable.variant_fixes Name.uu ctxt
-            val f = Free (n, rev Ts' ---> (rev Us ---> T))
-            fun mk_bapp i t = t $ Bound i
-            val lhs =
-              f
-              |> fold_rev (mk_bapp o snd) bss
-              |> fold_rev mk_bapp (0 upto (length Us - 1))
-            val def = mk_def (Us @ Ts') T lhs t'
-          in (app f, (Termtab.update (abs_rhs, (f, def)) defs, ctxt')) end)
+            val (is, UTs) = split_list (map_index I (Us @ Ts'))
+            val f = Free (n, rev UTs ---> T)
+            val lhs = Term.list_comb (f, map Bound (rev is))
+            val def = mk_def UTs (Term.fastype_of1 (Us @ Ts, t)) lhs t3
+          in (app f, (Termtab.update (t4, (f, def)) defs, ctxt')) end)
     end
 
-  fun dest_abs Ts (Abs (_, T, t)) = dest_abs (T :: Ts) t
-    | dest_abs Ts t = (Ts, t)
-
   fun traverse Ts t =
     (case t of
       (q as Const (@{const_name All}, _)) $ Abs a =>
@@ -317,57 +315,37 @@
     Make application explicit for functions with varying number of arguments.
   *)
 
-  fun add t ts =
-    Termtab.map_default (t, []) (Ord_List.insert int_ord (length ts))
-
-  fun collect t =
-    (case Term.strip_comb t of
-      (u as Const _, ts) => add u ts #> fold collect ts
-    | (u as Free _, ts) => add u ts #> fold collect ts
-    | (Abs (_, _, u), ts) => collect u #> fold collect ts
-    | (_, ts) => fold collect ts)
-
-  fun app ts (t, T) =
-    let val f = Const (@{const_name SMT.fun_app}, T --> T)
-    in (Term.list_comb (f $ t, ts), snd (U.dest_funT (length ts) T)) end 
+  fun add t i = Termtab.map_default (t, i) (Integer.min i)
 
-  fun appl _ _ [] = fst
-    | appl _ [] ts = fst o app ts
-    | appl i (k :: ks) ts =
-        let val (ts1, ts2) = chop (k - i) ts
-        in appl k ks ts2 o app ts1 end
-
-  fun appl0 [_] ts (t, _) = Term.list_comb (t, ts)
-    | appl0 (0 :: ks) ts tT = appl 0 ks ts tT
-    | appl0 ks ts tT = appl 0 ks ts tT
+  fun min_arities t =
+    (case Term.strip_comb t of
+      (u as Const _, ts) => add u (length ts) #> fold min_arities ts
+    | (u as Free _, ts) => add u (length ts) #> fold min_arities ts
+    | (Abs (_, _, u), ts) => min_arities u #> fold min_arities ts
+    | (_, ts) => fold min_arities ts)
 
-  fun apply terms T t ts = appl0 (Termtab.lookup_list terms t) ts (t, T)
+  fun app u (t, T) =
+    (Const (@{const_name SMT.fun_app}, T --> T) $ t $ u, Term.range_type T)
 
-  fun get_arities i t =
-    (case Term.strip_comb t of
-      (Bound j, ts) =>
-        (if i = j then Ord_List.insert int_ord (length ts) else I) #>
-        fold (get_arities i) ts
-    | (Abs (_, _, u), ts) => get_arities (i+1) u #> fold (get_arities i) ts
-    | (_, ts) => fold (get_arities i) ts)
+  fun apply i t T ts =
+    let val (ts1, ts2) = chop i ts
+    in fst (fold app ts2 (Term.list_comb (t, ts1), snd (U.dest_funT i T))) end
 in
 
 fun intro_explicit_application ts =
   let
-    val terms = fold collect ts Termtab.empty
+    val arities = fold min_arities ts Termtab.empty
+    fun apply' t = apply (the (Termtab.lookup arities t)) t
 
-    fun traverse (env as (arities, Ts)) t =
+    fun traverse Ts t =
       (case Term.strip_comb t of
-        (u as Const (_, T), ts) => apply terms T u (map (traverse env) ts)
-      | (u as Free (_, T), ts) => apply terms T u (map (traverse env) ts)
-      | (u as Bound i, ts) =>
-          appl0 (nth arities i) (map (traverse env) ts) (u, nth Ts i)
-      | (Abs (n, T, u), ts) =>
-          let val env' = (get_arities 0 u [0] :: arities, T :: Ts)
-          in traverses env (Abs (n, T, traverse env' u)) ts end
-      | (u, ts) => traverses env u ts)
-    and traverses env t ts = Term.list_comb (t, map (traverse env) ts)
-  in map (traverse ([], [])) ts end
+        (u as Const (_, T), ts) => apply' u T (map (traverse Ts) ts)
+      | (u as Free (_, T), ts) => apply' u T (map (traverse Ts) ts)
+      | (u as Bound i, ts) => apply 0 u (nth Ts i) (map (traverse Ts) ts)
+      | (Abs (n, T, u), ts) => traverses Ts (Abs (n, T, traverse (T::Ts) u)) ts
+      | (u, ts) => traverses Ts u ts)
+    and traverses Ts t ts = Term.list_comb (t, map (traverse Ts) ts)
+  in map (traverse []) ts end
 
 val fun_app_eq = mk_meta_eq @{thm SMT.fun_app_def}
 
@@ -451,16 +429,17 @@
     and in_weight ((c as @{const SMT.weight}) $ w $ t) = c $ w $ in_form t
       | in_weight t = in_form t 
 
-    and in_pat (Const (c as (@{const_name pat}, _)) $ t) =
+    and in_pat (Const (c as (@{const_name SMT.pat}, _)) $ t) =
           Const (func 1 c) $ in_term t
-      | in_pat (Const (c as (@{const_name nopat}, _)) $ t) =
+      | in_pat (Const (c as (@{const_name SMT.nopat}, _)) $ t) =
           Const (func 1 c) $ in_term t
       | in_pat t = raise TERM ("bad pattern", [t])
 
     and in_pats ps =
-      in_list @{typ "pattern list"} (in_list @{typ pattern} in_pat) ps
+      in_list @{typ "SMT.pattern list"} (in_list @{typ SMT.pattern} in_pat) ps
 
-    and in_trig ((c as @{const trigger}) $ p $ t) = c $ in_pats p $ in_weight t
+    and in_trig ((c as @{const SMT.trigger}) $ p $ t) =
+          c $ in_pats p $ in_weight t
       | in_trig t = in_weight t
 
     and in_form t =
@@ -506,8 +485,8 @@
       (SOME (snd (HOLogic.dest_number w)), t)
   | dest_weight t = (NONE, t)
 
-fun dest_pat (Const (@{const_name pat}, _) $ t) = (t, true)
-  | dest_pat (Const (@{const_name nopat}, _) $ t) = (t, false)
+fun dest_pat (Const (@{const_name SMT.pat}, _) $ t) = (t, true)
+  | dest_pat (Const (@{const_name SMT.nopat}, _) $ t) = (t, false)
   | dest_pat t = raise TERM ("bad pattern", [t])
 
 fun dest_pats [] = I
@@ -517,7 +496,7 @@
       | (ps, [false]) => cons (SNoPat ps)
       | _ => raise TERM ("bad multi-pattern", ts))
 
-fun dest_trigger (@{const trigger} $ tl $ t) =
+fun dest_trigger (@{const SMT.trigger} $ tl $ t) =
       (rev (fold (dest_pats o HOLogic.dest_list) (HOLogic.dest_list tl) []), t)
   | dest_trigger t = ([], t)
 
@@ -593,15 +572,19 @@
 
 fun add_config (cs, cfg) = Configs.map (U.dict_update (cs, cfg))
 
+fun get_config ctxt = 
+  let val cs = SMT_Config.solver_class_of ctxt
+  in
+    (case U.dict_get (Configs.get (Context.Proof ctxt)) cs of
+      SOME cfg => cfg ctxt
+    | NONE => error ("SMT: no translation configuration found " ^
+        "for solver class " ^ quote (U.string_of_class cs)))
+  end
+
 fun translate ctxt comments ithms =
   let
-    val cs = SMT_Config.solver_class_of ctxt
-    val {prefixes, is_fol, header, has_datatypes, serialize} =
-      (case U.dict_get (Configs.get (Context.Proof ctxt)) cs of
-        SOME cfg => cfg ctxt
-      | NONE => error ("SMT: no translation configuration found " ^
-          "for solver class " ^ quote (U.string_of_class cs)))
-      
+    val {prefixes, is_fol, header, has_datatypes, serialize} = get_config ctxt
+
     val with_datatypes =
       has_datatypes andalso Config.get ctxt SMT_Config.datatypes