src/HOL/Tools/SMT/smt_translate.ML
changeset 41426 09615ed31f04
parent 41328 6792a5c92a58
child 41785 77dcc197df9a
     1.1 --- a/src/HOL/Tools/SMT/smt_translate.ML	Fri Dec 31 00:11:24 2010 +0100
     1.2 +++ b/src/HOL/Tools/SMT/smt_translate.ML	Mon Jan 03 16:22:08 2011 +0100
     1.3 @@ -134,20 +134,35 @@
     1.4  
     1.5  (* preprocessing *)
     1.6  
     1.7 -(** FIXME **)
     1.8 -
     1.9 -local
    1.10 -  (*
    1.11 -    force eta-expansion for constructors and selectors,
    1.12 -    add missing datatype selectors via hypothetical definitions,
    1.13 -    also return necessary datatype and record theorems
    1.14 -  *)
    1.15 -in
    1.16 +(** datatype declarations **)
    1.17  
    1.18  fun collect_datatypes_and_records (tr_context, ctxt) ts =
    1.19 -  (([], tr_context, ctxt), ts)
    1.20 +  let
    1.21 +    val (declss, ctxt') =
    1.22 +      fold (Term.fold_types SMT_Datatypes.add_decls) ts ([], ctxt)
    1.23 +
    1.24 +    fun is_decl_typ T = exists (exists (equal T o fst)) declss
    1.25 +
    1.26 +    fun add_typ' T proper =
    1.27 +      (case SMT_Builtin.dest_builtin_typ ctxt' T of
    1.28 +        SOME n => pair n
    1.29 +      | NONE => add_typ T proper)
    1.30  
    1.31 -end
    1.32 +    fun tr_select sel =
    1.33 +      let val T = Term.range_type (Term.fastype_of sel)
    1.34 +      in add_fun sel NONE ##>> add_typ' T (not (is_decl_typ T)) end
    1.35 +    fun tr_constr (constr, selects) =
    1.36 +      add_fun constr NONE ##>> fold_map tr_select selects
    1.37 +    fun tr_typ (T, cases) = add_typ' T false ##>> fold_map tr_constr cases
    1.38 +    val (declss', tr_context') = fold_map (fold_map tr_typ) declss tr_context
    1.39 +
    1.40 +    fun add (constr, selects) =
    1.41 +      Termtab.update (constr, length selects) #>
    1.42 +      fold (Termtab.update o rpair 1) selects
    1.43 +    val funcs = fold (fold (fold add o snd)) declss Termtab.empty
    1.44 +
    1.45 +  in ((funcs, declss', tr_context', ctxt'), ts) end
    1.46 +    (* FIXME: also return necessary datatype and record theorems *)
    1.47  
    1.48  
    1.49  (** eta-expand quantifiers, let expressions and built-ins *)
    1.50 @@ -174,8 +189,15 @@
    1.51      end
    1.52  in
    1.53  
    1.54 -fun eta_expand ctxt =
    1.55 +fun eta_expand ctxt funcs =
    1.56    let
    1.57 +    fun exp_func t T ts =
    1.58 +      (case Termtab.lookup funcs t of
    1.59 +        SOME k =>
    1.60 +          Term.list_comb (t, ts)
    1.61 +          |> k <> length ts ? expf k (length ts) T
    1.62 +      | NONE => Term.list_comb (t, ts))
    1.63 +
    1.64      fun expand ((q as Const (@{const_name All}, _)) $ Abs a) = q $ abs_expand a
    1.65        | expand ((q as Const (@{const_name All}, T)) $ t) = q $ exp T t
    1.66        | expand (q as Const (@{const_name All}, T)) = exp2 T q
    1.67 @@ -196,7 +218,8 @@
    1.68                  SOME (_, k, us, mk) =>
    1.69                    if k = length us then mk (map expand us)
    1.70                    else expf k (length ts) T (mk (map expand us))
    1.71 -              | NONE => Term.list_comb (u, map expand ts))
    1.72 +              | NONE => exp_func u T (map expand ts))
    1.73 +          | (u as Free (_, T), ts) => exp_func u T (map expand ts)
    1.74            | (Abs a, ts) => Term.list_comb (abs_expand a, map expand ts)
    1.75            | (u, ts) => Term.list_comb (u, map expand ts))
    1.76  
    1.77 @@ -530,17 +553,18 @@
    1.78      val with_datatypes =
    1.79        has_datatypes andalso Config.get ctxt SMT_Config.datatypes
    1.80  
    1.81 -    fun no_dtyps (tr_context, ctxt) ts = (([], tr_context, ctxt), ts)
    1.82 +    fun no_dtyps (tr_context, ctxt) ts =
    1.83 +      ((Termtab.empty, [], tr_context, ctxt), ts)
    1.84  
    1.85      val ts1 = map (Envir.beta_eta_contract o SMT_Utils.prop_of o snd) ithms
    1.86  
    1.87 -    val ((dtyps, tr_context, ctxt1), ts2) =
    1.88 +    val ((funcs, dtyps, tr_context, ctxt1), ts2) =
    1.89        ((make_tr_context prefixes, ctxt), ts1)
    1.90        |-> (if with_datatypes then collect_datatypes_and_records else no_dtyps)
    1.91  
    1.92      val (ctxt2, ts3) =
    1.93        ts2
    1.94 -      |> eta_expand ctxt1
    1.95 +      |> eta_expand ctxt1 funcs
    1.96        |> lift_lambdas ctxt1
    1.97        ||> intro_explicit_application
    1.98