slightly more general treatment of mutually recursive datatypes;
authorboehmes
Tue Jun 14 13:50:54 2011 +0200 (2011-06-14)
changeset 433859cd4b4ecb4dd
parent 43380 809de915155f
child 43389 328dcc5cc43f
slightly more general treatment of mutually recursive datatypes;
treat datatype constructors and selectors similarly to built-in constants wrt. introduction of explicit application (in the same way as what is already done for eta-expansion)
src/HOL/Tools/SMT/smt_datatypes.ML
src/HOL/Tools/SMT/smt_translate.ML
     1.1 --- a/src/HOL/Tools/SMT/smt_datatypes.ML	Tue Jun 14 08:33:51 2011 +0200
     1.2 +++ b/src/HOL/Tools/SMT/smt_datatypes.ML	Tue Jun 14 13:50:54 2011 +0200
     1.3 @@ -32,12 +32,11 @@
     1.4      val vars = the (get_first get_vars descr) ~~ Ts
     1.5      val lookup_var = the o AList.lookup (op =) vars
     1.6  
     1.7 -    val dTs = map (apsnd (fn (m, vs, _) => Type (m, map lookup_var vs))) descr
     1.8 -    val lookup_typ = the o AList.lookup (op =) dTs
     1.9 -
    1.10      fun typ_of (dt as Datatype.DtTFree _) = lookup_var dt
    1.11 -      | typ_of (Datatype.DtType (n, dts)) = Type (n, map typ_of dts)
    1.12 -      | typ_of (Datatype.DtRec i) = lookup_typ i
    1.13 +      | typ_of (Datatype.DtType (m, dts)) = Type (m, map typ_of dts)
    1.14 +      | typ_of (Datatype.DtRec i) =
    1.15 +          the (AList.lookup (op =) descr i)
    1.16 +          |> (fn (m, dts, _) => Type (m, map typ_of dts))
    1.17  
    1.18      fun mk_constr T (m, dts) ctxt =
    1.19        let
    1.20 @@ -48,7 +47,7 @@
    1.21  
    1.22      fun mk_decl (i, (_, _, constrs)) ctxt =
    1.23        let
    1.24 -        val T = lookup_typ i
    1.25 +        val T = typ_of (Datatype.DtRec i)
    1.26          val (css, ctxt') = fold_map (mk_constr T) constrs ctxt
    1.27        in ((T, css), ctxt') end
    1.28  
    1.29 @@ -87,6 +86,7 @@
    1.30  (* collection of declarations *)
    1.31  
    1.32  fun declared declss T = exists (exists (equal T o fst)) declss
    1.33 +fun declared' dss T = exists (exists (equal T o fst) o snd) dss
    1.34  
    1.35  fun get_decls T n Ts ctxt =
    1.36    let val thy = Proof_Context.theory_of ctxt
    1.37 @@ -104,13 +104,15 @@
    1.38  
    1.39  fun add_decls T (declss, ctxt) =
    1.40    let
    1.41 +    fun depends Ts ds = exists (member (op =) (map fst ds)) Ts
    1.42 +
    1.43      fun add (TFree _) = I
    1.44        | add (TVar _) = I
    1.45        | add (T as Type (@{type_name fun}, _)) =
    1.46            fold add (Term.body_type T :: Term.binder_types T)
    1.47        | add @{typ bool} = I
    1.48        | add (T as Type (n, Ts)) = (fn (dss, ctxt1) =>
    1.49 -          if declared declss T orelse declared dss T then (dss, ctxt1)
    1.50 +          if declared declss T orelse declared' dss T then (dss, ctxt1)
    1.51            else if SMT_Builtin.is_builtin_typ_ext ctxt1 T then (dss, ctxt1)
    1.52            else
    1.53              (case get_decls T n Ts ctxt1 of
    1.54 @@ -120,7 +122,13 @@
    1.55                    val constrTs =
    1.56                      maps (map (snd o Term.dest_Const o fst) o snd) ds
    1.57                    val Us = fold (union (op =) o Term.binder_types) constrTs []
    1.58 -            in fold add Us (ds :: dss, ctxt2) end))
    1.59 -  in add T ([], ctxt) |>> append declss end
    1.60 +
    1.61 +                  fun ins [] = [(Us, ds)]
    1.62 +                    | ins ((Uds as (Us', _)) :: Udss) =
    1.63 +                        if depends Us' ds then (Us, ds) :: Uds :: Udss
    1.64 +                        else Uds :: ins Udss
    1.65 +            in fold add Us (ins dss, ctxt2) end))
    1.66 +  in add T ([], ctxt) |>> append declss o map snd end
    1.67 +
    1.68  
    1.69  end
     2.1 --- a/src/HOL/Tools/SMT/smt_translate.ML	Tue Jun 14 08:33:51 2011 +0200
     2.2 +++ b/src/HOL/Tools/SMT/smt_translate.ML	Tue Jun 14 13:50:54 2011 +0200
     2.3 @@ -160,7 +160,6 @@
     2.4        Termtab.update (constr, length selects) #>
     2.5        fold (Termtab.update o rpair 1) selects
     2.6      val funcs = fold (fold (fold add o snd)) declss Termtab.empty
     2.7 -
     2.8    in ((funcs, declss', tr_context', ctxt'), ts) end
     2.9      (* FIXME: also return necessary datatype and record theorems *)
    2.10  
    2.11 @@ -344,11 +343,14 @@
    2.12      in fst (fold app ts2 (Term.list_comb (t, ts1), U)) end
    2.13  in
    2.14  
    2.15 -fun intro_explicit_application ctxt ts =
    2.16 +fun intro_explicit_application ctxt funcs ts =
    2.17    let
    2.18      val (arities, types) = fold min_arities ts (Termtab.empty, Typtab.empty)
    2.19      val arities' = Termtab.map (minimize types) arities
    2.20 -    fun apply' t = apply (the (Termtab.lookup arities' t)) t
    2.21 +
    2.22 +    fun app_func t T ts =
    2.23 +      if is_some (Termtab.lookup funcs t) then Term.list_comb (t, ts)
    2.24 +      else apply (the (Termtab.lookup arities' t)) t T ts
    2.25  
    2.26      fun traverse Ts t =
    2.27        (case Term.strip_comb t of
    2.28 @@ -359,8 +361,8 @@
    2.29        | (u as Const (c as (_, T)), ts) =>
    2.30            (case SMT_Builtin.dest_builtin ctxt c ts of
    2.31              SOME (_, _, us, mk) => mk (map (traverse Ts) us)
    2.32 -          | NONE => apply' u T (map (traverse Ts) ts))
    2.33 -      | (u as Free (_, T), ts) => apply' u T (map (traverse Ts) ts)
    2.34 +          | NONE => app_func u T (map (traverse Ts) ts))
    2.35 +      | (u as Free (_, T), ts) => app_func u T (map (traverse Ts) ts)
    2.36        | (u as Bound i, ts) => apply 0 u (nth Ts i) (map (traverse Ts) ts)
    2.37        | (Abs (n, T, u), ts) => traverses Ts (Abs (n, T, traverse (T::Ts) u)) ts
    2.38        | (u, ts) => traverses Ts u ts)
    2.39 @@ -586,7 +588,7 @@
    2.40        ts2
    2.41        |> eta_expand ctxt1 is_fol funcs
    2.42        |> lift_lambdas ctxt1
    2.43 -      |-> (fn ctxt1' => pair ctxt1' o intro_explicit_application ctxt1)
    2.44 +      |-> (fn ctxt1' => pair ctxt1' o intro_explicit_application ctxt1 funcs)
    2.45  
    2.46      val ((rewrite_rules, extra_thms, builtin), ts4) =
    2.47        (if is_fol then folify ctxt2 else pair ([], [], I)) ts3