src/HOL/Tools/SMT/smt_translate.ML
changeset 41232 4ea9f2a8c093
parent 41198 aa627a799e8e
child 41250 41f86829e22f
     1.1 --- a/src/HOL/Tools/SMT/smt_translate.ML	Fri Dec 17 08:37:35 2010 +0100
     1.2 +++ b/src/HOL/Tools/SMT/smt_translate.ML	Fri Dec 17 14:36:33 2010 +0100
     1.3 @@ -249,40 +249,38 @@
     1.4        fun mk_all T t = HOLogic.all_const T $ Abs (Name.uu, T, t)
     1.5      in fold mk_all Ts (@{const SMT.trigger} $ trigger $ eq) end
     1.6  
     1.7 +  fun mk_abs Ts = fold (fn T => fn t => Abs (Name.uu, T, t)) Ts
     1.8 +
     1.9 +  fun dest_abs Ts (Abs (_, T, t)) = dest_abs (T :: Ts) t
    1.10 +    | dest_abs Ts t = (Ts, t)
    1.11 +
    1.12    fun replace_lambda Us Ts t (cx as (defs, ctxt)) =
    1.13      let
    1.14 +      val t1 = mk_abs Us t
    1.15 +      val bs = sort int_ord (Term.add_loose_bnos (t1, 0, []))
    1.16 +      fun rep i k = if member (op =) bs i then (Bound k, k+1) else (Bound i, k)
    1.17 +      val (rs, _) = fold_map rep (0 upto length Ts - 1) 0
    1.18 +      val t2 = Term.subst_bounds (rs, t1)
    1.19 +      val Ts' = map (nth Ts) bs 
    1.20 +      val (_, t3) = dest_abs [] t2
    1.21 +      val t4 = mk_abs Ts' t2
    1.22 +
    1.23        val T = Term.fastype_of1 (Us @ Ts, t)
    1.24 -      val lev = length Us
    1.25 -      val bs = sort int_ord (Term.add_loose_bnos (t, lev, []))
    1.26 -      val bss = map_index (fn (i, j) => (j + lev, i + lev)) bs
    1.27 -      val norm = perhaps (AList.lookup (op =) bss)
    1.28 -      val t' = Term.map_aterms (fn Bound i => Bound (norm i) | t => t) t
    1.29 -      val Ts' = map (nth Ts) bs
    1.30 -
    1.31 -      fun mk_abs U u = Abs (Name.uu, U, u)
    1.32 -      val abs_rhs = fold mk_abs Ts' (fold mk_abs Us t')
    1.33 -
    1.34 -      fun app f = Term.list_comb (f, map Bound bs)
    1.35 +      fun app f = Term.list_comb (f, map Bound (rev bs))
    1.36      in
    1.37 -      (case Termtab.lookup defs abs_rhs of
    1.38 +      (case Termtab.lookup defs t4 of
    1.39          SOME (f, _) => (app f, cx)
    1.40        | NONE =>
    1.41            let
    1.42              val (n, ctxt') =
    1.43                yield_singleton Variable.variant_fixes Name.uu ctxt
    1.44 -            val f = Free (n, rev Ts' ---> (rev Us ---> T))
    1.45 -            fun mk_bapp i t = t $ Bound i
    1.46 -            val lhs =
    1.47 -              f
    1.48 -              |> fold_rev (mk_bapp o snd) bss
    1.49 -              |> fold_rev mk_bapp (0 upto (length Us - 1))
    1.50 -            val def = mk_def (Us @ Ts') T lhs t'
    1.51 -          in (app f, (Termtab.update (abs_rhs, (f, def)) defs, ctxt')) end)
    1.52 +            val (is, UTs) = split_list (map_index I (Us @ Ts'))
    1.53 +            val f = Free (n, rev UTs ---> T)
    1.54 +            val lhs = Term.list_comb (f, map Bound (rev is))
    1.55 +            val def = mk_def UTs (Term.fastype_of1 (Us @ Ts, t)) lhs t3
    1.56 +          in (app f, (Termtab.update (t4, (f, def)) defs, ctxt')) end)
    1.57      end
    1.58  
    1.59 -  fun dest_abs Ts (Abs (_, T, t)) = dest_abs (T :: Ts) t
    1.60 -    | dest_abs Ts t = (Ts, t)
    1.61 -
    1.62    fun traverse Ts t =
    1.63      (case t of
    1.64        (q as Const (@{const_name All}, _)) $ Abs a =>
    1.65 @@ -317,57 +315,37 @@
    1.66      Make application explicit for functions with varying number of arguments.
    1.67    *)
    1.68  
    1.69 -  fun add t ts =
    1.70 -    Termtab.map_default (t, []) (Ord_List.insert int_ord (length ts))
    1.71 -
    1.72 -  fun collect t =
    1.73 -    (case Term.strip_comb t of
    1.74 -      (u as Const _, ts) => add u ts #> fold collect ts
    1.75 -    | (u as Free _, ts) => add u ts #> fold collect ts
    1.76 -    | (Abs (_, _, u), ts) => collect u #> fold collect ts
    1.77 -    | (_, ts) => fold collect ts)
    1.78 -
    1.79 -  fun app ts (t, T) =
    1.80 -    let val f = Const (@{const_name SMT.fun_app}, T --> T)
    1.81 -    in (Term.list_comb (f $ t, ts), snd (U.dest_funT (length ts) T)) end 
    1.82 +  fun add t i = Termtab.map_default (t, i) (Integer.min i)
    1.83  
    1.84 -  fun appl _ _ [] = fst
    1.85 -    | appl _ [] ts = fst o app ts
    1.86 -    | appl i (k :: ks) ts =
    1.87 -        let val (ts1, ts2) = chop (k - i) ts
    1.88 -        in appl k ks ts2 o app ts1 end
    1.89 -
    1.90 -  fun appl0 [_] ts (t, _) = Term.list_comb (t, ts)
    1.91 -    | appl0 (0 :: ks) ts tT = appl 0 ks ts tT
    1.92 -    | appl0 ks ts tT = appl 0 ks ts tT
    1.93 +  fun min_arities t =
    1.94 +    (case Term.strip_comb t of
    1.95 +      (u as Const _, ts) => add u (length ts) #> fold min_arities ts
    1.96 +    | (u as Free _, ts) => add u (length ts) #> fold min_arities ts
    1.97 +    | (Abs (_, _, u), ts) => min_arities u #> fold min_arities ts
    1.98 +    | (_, ts) => fold min_arities ts)
    1.99  
   1.100 -  fun apply terms T t ts = appl0 (Termtab.lookup_list terms t) ts (t, T)
   1.101 +  fun app u (t, T) =
   1.102 +    (Const (@{const_name SMT.fun_app}, T --> T) $ t $ u, Term.range_type T)
   1.103  
   1.104 -  fun get_arities i t =
   1.105 -    (case Term.strip_comb t of
   1.106 -      (Bound j, ts) =>
   1.107 -        (if i = j then Ord_List.insert int_ord (length ts) else I) #>
   1.108 -        fold (get_arities i) ts
   1.109 -    | (Abs (_, _, u), ts) => get_arities (i+1) u #> fold (get_arities i) ts
   1.110 -    | (_, ts) => fold (get_arities i) ts)
   1.111 +  fun apply i t T ts =
   1.112 +    let val (ts1, ts2) = chop i ts
   1.113 +    in fst (fold app ts2 (Term.list_comb (t, ts1), snd (U.dest_funT i T))) end
   1.114  in
   1.115  
   1.116  fun intro_explicit_application ts =
   1.117    let
   1.118 -    val terms = fold collect ts Termtab.empty
   1.119 +    val arities = fold min_arities ts Termtab.empty
   1.120 +    fun apply' t = apply (the (Termtab.lookup arities t)) t
   1.121  
   1.122 -    fun traverse (env as (arities, Ts)) t =
   1.123 +    fun traverse Ts t =
   1.124        (case Term.strip_comb t of
   1.125 -        (u as Const (_, T), ts) => apply terms T u (map (traverse env) ts)
   1.126 -      | (u as Free (_, T), ts) => apply terms T u (map (traverse env) ts)
   1.127 -      | (u as Bound i, ts) =>
   1.128 -          appl0 (nth arities i) (map (traverse env) ts) (u, nth Ts i)
   1.129 -      | (Abs (n, T, u), ts) =>
   1.130 -          let val env' = (get_arities 0 u [0] :: arities, T :: Ts)
   1.131 -          in traverses env (Abs (n, T, traverse env' u)) ts end
   1.132 -      | (u, ts) => traverses env u ts)
   1.133 -    and traverses env t ts = Term.list_comb (t, map (traverse env) ts)
   1.134 -  in map (traverse ([], [])) ts end
   1.135 +        (u as Const (_, T), ts) => apply' u T (map (traverse Ts) ts)
   1.136 +      | (u as Free (_, T), ts) => apply' u T (map (traverse Ts) ts)
   1.137 +      | (u as Bound i, ts) => apply 0 u (nth Ts i) (map (traverse Ts) ts)
   1.138 +      | (Abs (n, T, u), ts) => traverses Ts (Abs (n, T, traverse (T::Ts) u)) ts
   1.139 +      | (u, ts) => traverses Ts u ts)
   1.140 +    and traverses Ts t ts = Term.list_comb (t, map (traverse Ts) ts)
   1.141 +  in map (traverse []) ts end
   1.142  
   1.143  val fun_app_eq = mk_meta_eq @{thm SMT.fun_app_def}
   1.144  
   1.145 @@ -451,16 +429,17 @@
   1.146      and in_weight ((c as @{const SMT.weight}) $ w $ t) = c $ w $ in_form t
   1.147        | in_weight t = in_form t 
   1.148  
   1.149 -    and in_pat (Const (c as (@{const_name pat}, _)) $ t) =
   1.150 +    and in_pat (Const (c as (@{const_name SMT.pat}, _)) $ t) =
   1.151            Const (func 1 c) $ in_term t
   1.152 -      | in_pat (Const (c as (@{const_name nopat}, _)) $ t) =
   1.153 +      | in_pat (Const (c as (@{const_name SMT.nopat}, _)) $ t) =
   1.154            Const (func 1 c) $ in_term t
   1.155        | in_pat t = raise TERM ("bad pattern", [t])
   1.156  
   1.157      and in_pats ps =
   1.158 -      in_list @{typ "pattern list"} (in_list @{typ pattern} in_pat) ps
   1.159 +      in_list @{typ "SMT.pattern list"} (in_list @{typ SMT.pattern} in_pat) ps
   1.160  
   1.161 -    and in_trig ((c as @{const trigger}) $ p $ t) = c $ in_pats p $ in_weight t
   1.162 +    and in_trig ((c as @{const SMT.trigger}) $ p $ t) =
   1.163 +          c $ in_pats p $ in_weight t
   1.164        | in_trig t = in_weight t
   1.165  
   1.166      and in_form t =
   1.167 @@ -506,8 +485,8 @@
   1.168        (SOME (snd (HOLogic.dest_number w)), t)
   1.169    | dest_weight t = (NONE, t)
   1.170  
   1.171 -fun dest_pat (Const (@{const_name pat}, _) $ t) = (t, true)
   1.172 -  | dest_pat (Const (@{const_name nopat}, _) $ t) = (t, false)
   1.173 +fun dest_pat (Const (@{const_name SMT.pat}, _) $ t) = (t, true)
   1.174 +  | dest_pat (Const (@{const_name SMT.nopat}, _) $ t) = (t, false)
   1.175    | dest_pat t = raise TERM ("bad pattern", [t])
   1.176  
   1.177  fun dest_pats [] = I
   1.178 @@ -517,7 +496,7 @@
   1.179        | (ps, [false]) => cons (SNoPat ps)
   1.180        | _ => raise TERM ("bad multi-pattern", ts))
   1.181  
   1.182 -fun dest_trigger (@{const trigger} $ tl $ t) =
   1.183 +fun dest_trigger (@{const SMT.trigger} $ tl $ t) =
   1.184        (rev (fold (dest_pats o HOLogic.dest_list) (HOLogic.dest_list tl) []), t)
   1.185    | dest_trigger t = ([], t)
   1.186  
   1.187 @@ -593,15 +572,19 @@
   1.188  
   1.189  fun add_config (cs, cfg) = Configs.map (U.dict_update (cs, cfg))
   1.190  
   1.191 +fun get_config ctxt = 
   1.192 +  let val cs = SMT_Config.solver_class_of ctxt
   1.193 +  in
   1.194 +    (case U.dict_get (Configs.get (Context.Proof ctxt)) cs of
   1.195 +      SOME cfg => cfg ctxt
   1.196 +    | NONE => error ("SMT: no translation configuration found " ^
   1.197 +        "for solver class " ^ quote (U.string_of_class cs)))
   1.198 +  end
   1.199 +
   1.200  fun translate ctxt comments ithms =
   1.201    let
   1.202 -    val cs = SMT_Config.solver_class_of ctxt
   1.203 -    val {prefixes, is_fol, header, has_datatypes, serialize} =
   1.204 -      (case U.dict_get (Configs.get (Context.Proof ctxt)) cs of
   1.205 -        SOME cfg => cfg ctxt
   1.206 -      | NONE => error ("SMT: no translation configuration found " ^
   1.207 -          "for solver class " ^ quote (U.string_of_class cs)))
   1.208 -      
   1.209 +    val {prefixes, is_fol, header, has_datatypes, serialize} = get_config ctxt
   1.210 +
   1.211      val with_datatypes =
   1.212        has_datatypes andalso Config.get ctxt SMT_Config.datatypes
   1.213