src/HOL/Tools/SMT/smt_datatypes.ML
author boehmes
Tue Jun 14 13:50:54 2011 +0200 (2011-06-14)
changeset 43385 9cd4b4ecb4dd
parent 42361 23f352990944
child 57213 9daec42f6784
permissions -rw-r--r--
slightly more general treatment of mutually recursive datatypes;
treat datatype constructors and selectors similarly to built-in constants wrt. introduction of explicit application (in the same way as what is already done for eta-expansion)
     1 (*  Title:      HOL/Tools/SMT/smt_datatypes.ML
     2     Author:     Sascha Boehme, TU Muenchen
     3 
     4 Collector functions for common type declarations and their representation
     5 as algebraic datatypes.
     6 *)
     7 
     8 signature SMT_DATATYPES =
     9 sig
    10   val add_decls: typ ->
    11     (typ * (term * term list) list) list list * Proof.context ->
    12     (typ * (term * term list) list) list list * Proof.context
    13 end
    14 
    15 structure SMT_Datatypes: SMT_DATATYPES =
    16 struct
    17 
    18 val lhs_head_of = Term.head_of o fst o Logic.dest_equals o Thm.prop_of
    19 
    20 fun mk_selectors T Ts ctxt =
    21   let
    22     val (sels, ctxt') =
    23       Variable.variant_fixes (replicate (length Ts) "select") ctxt
    24   in (map2 (fn n => fn U => Free (n, T --> U)) sels Ts, ctxt') end
    25 
    26 
    27 (* datatype declarations *)
    28 
    29 fun get_datatype_decl ({descr, ...} : Datatype.info) n Ts ctxt =
    30   let
    31     fun get_vars (_, (m, vs, _)) = if m = n then SOME vs else NONE
    32     val vars = the (get_first get_vars descr) ~~ Ts
    33     val lookup_var = the o AList.lookup (op =) vars
    34 
    35     fun typ_of (dt as Datatype.DtTFree _) = lookup_var dt
    36       | typ_of (Datatype.DtType (m, dts)) = Type (m, map typ_of dts)
    37       | typ_of (Datatype.DtRec i) =
    38           the (AList.lookup (op =) descr i)
    39           |> (fn (m, dts, _) => Type (m, map typ_of dts))
    40 
    41     fun mk_constr T (m, dts) ctxt =
    42       let
    43         val Ts = map typ_of dts
    44         val constr = Const (m, Ts ---> T)
    45         val (selects, ctxt') = mk_selectors T Ts ctxt
    46       in ((constr, selects), ctxt') end
    47 
    48     fun mk_decl (i, (_, _, constrs)) ctxt =
    49       let
    50         val T = typ_of (Datatype.DtRec i)
    51         val (css, ctxt') = fold_map (mk_constr T) constrs ctxt
    52       in ((T, css), ctxt') end
    53 
    54   in fold_map mk_decl descr ctxt end
    55 
    56 
    57 (* record declarations *)
    58 
    59 val record_name_of = Long_Name.implode o fst o split_last o Long_Name.explode
    60 
    61 fun get_record_decl ({ext_def, ...} : Record.info) T ctxt =
    62   let
    63     val (con, _) = Term.dest_Const (lhs_head_of ext_def)
    64     val (fields, more) = Record.get_extT_fields (Proof_Context.theory_of ctxt) T
    65     val fieldTs = map snd fields @ [snd more]
    66 
    67     val constr = Const (con, fieldTs ---> T)
    68     val (selects, ctxt') = mk_selectors T fieldTs ctxt
    69   in ((T, [(constr, selects)]), ctxt') end
    70 
    71 
    72 (* typedef declarations *)
    73 
    74 fun get_typedef_decl (info : Typedef.info) T Ts =
    75   let
    76     val ({Abs_name, Rep_name, abs_type, rep_type, ...}, _) = info
    77 
    78     val env = snd (Term.dest_Type abs_type) ~~ Ts
    79     val instT = Term.map_atyps (perhaps (AList.lookup (op =) env))
    80 
    81     val constr = Const (Abs_name, instT (rep_type --> abs_type))
    82     val select = Const (Rep_name, instT (abs_type --> rep_type))
    83   in (T, [(constr, [select])]) end
    84 
    85 
    86 (* collection of declarations *)
    87 
    88 fun declared declss T = exists (exists (equal T o fst)) declss
    89 fun declared' dss T = exists (exists (equal T o fst) o snd) dss
    90 
    91 fun get_decls T n Ts ctxt =
    92   let val thy = Proof_Context.theory_of ctxt
    93   in
    94     (case Datatype.get_info thy n of
    95       SOME info => get_datatype_decl info n Ts ctxt
    96     | NONE =>
    97         (case Record.get_info thy (record_name_of n) of
    98           SOME info => get_record_decl info T ctxt |>> single
    99         | NONE =>
   100             (case Typedef.get_info ctxt n of
   101               [] => ([], ctxt)
   102             | info :: _ => ([get_typedef_decl info T Ts], ctxt))))
   103   end
   104 
   105 fun add_decls T (declss, ctxt) =
   106   let
   107     fun depends Ts ds = exists (member (op =) (map fst ds)) Ts
   108 
   109     fun add (TFree _) = I
   110       | add (TVar _) = I
   111       | add (T as Type (@{type_name fun}, _)) =
   112           fold add (Term.body_type T :: Term.binder_types T)
   113       | add @{typ bool} = I
   114       | add (T as Type (n, Ts)) = (fn (dss, ctxt1) =>
   115           if declared declss T orelse declared' dss T then (dss, ctxt1)
   116           else if SMT_Builtin.is_builtin_typ_ext ctxt1 T then (dss, ctxt1)
   117           else
   118             (case get_decls T n Ts ctxt1 of
   119               ([], _) => (dss, ctxt1)
   120             | (ds, ctxt2) =>
   121                 let
   122                   val constrTs =
   123                     maps (map (snd o Term.dest_Const o fst) o snd) ds
   124                   val Us = fold (union (op =) o Term.binder_types) constrTs []
   125 
   126                   fun ins [] = [(Us, ds)]
   127                     | ins ((Uds as (Us', _)) :: Udss) =
   128                         if depends Us' ds then (Us, ds) :: Uds :: Udss
   129                         else Uds :: ins Udss
   130             in fold add Us (ins dss, ctxt2) end))
   131   in add T ([], ctxt) |>> append declss o map snd end
   132 
   133 
   134 end