src/HOL/Tools/SMT/smt_datatypes.ML
author boehmes
Mon Jan 03 16:22:08 2011 +0100 (2011-01-03)
changeset 41426 09615ed31f04
child 42361 23f352990944
permissions -rw-r--r--
re-implemented support for datatypes (including records and typedefs);
added test cases for datatypes, records and typedefs
boehmes@41426
     1
(*  Title:      HOL/Tools/SMT/smt_datatypes.ML
boehmes@41426
     2
    Author:     Sascha Boehme, TU Muenchen
boehmes@41426
     3
boehmes@41426
     4
Collector functions for common type declarations and their representation
boehmes@41426
     5
as algebraic datatypes.
boehmes@41426
     6
*)
boehmes@41426
     7
boehmes@41426
     8
signature SMT_DATATYPES =
boehmes@41426
     9
sig
boehmes@41426
    10
  val add_decls: typ ->
boehmes@41426
    11
    (typ * (term * term list) list) list list * Proof.context ->
boehmes@41426
    12
    (typ * (term * term list) list) list list * Proof.context
boehmes@41426
    13
end
boehmes@41426
    14
boehmes@41426
    15
structure SMT_Datatypes: SMT_DATATYPES =
boehmes@41426
    16
struct
boehmes@41426
    17
boehmes@41426
    18
val lhs_head_of = Term.head_of o fst o Logic.dest_equals o Thm.prop_of
boehmes@41426
    19
boehmes@41426
    20
fun mk_selectors T Ts ctxt =
boehmes@41426
    21
  let
boehmes@41426
    22
    val (sels, ctxt') =
boehmes@41426
    23
      Variable.variant_fixes (replicate (length Ts) "select") ctxt
boehmes@41426
    24
  in (map2 (fn n => fn U => Free (n, T --> U)) sels Ts, ctxt') end
boehmes@41426
    25
boehmes@41426
    26
boehmes@41426
    27
(* datatype declarations *)
boehmes@41426
    28
boehmes@41426
    29
fun get_datatype_decl ({descr, ...} : Datatype.info) n Ts ctxt =
boehmes@41426
    30
  let
boehmes@41426
    31
    fun get_vars (_, (m, vs, _)) = if m = n then SOME vs else NONE
boehmes@41426
    32
    val vars = the (get_first get_vars descr) ~~ Ts
boehmes@41426
    33
    val lookup_var = the o AList.lookup (op =) vars
boehmes@41426
    34
boehmes@41426
    35
    val dTs = map (apsnd (fn (m, vs, _) => Type (m, map lookup_var vs))) descr
boehmes@41426
    36
    val lookup_typ = the o AList.lookup (op =) dTs
boehmes@41426
    37
boehmes@41426
    38
    fun typ_of (dt as Datatype.DtTFree _) = lookup_var dt
boehmes@41426
    39
      | typ_of (Datatype.DtType (n, dts)) = Type (n, map typ_of dts)
boehmes@41426
    40
      | typ_of (Datatype.DtRec i) = lookup_typ i
boehmes@41426
    41
boehmes@41426
    42
    fun mk_constr T (m, dts) ctxt =
boehmes@41426
    43
      let
boehmes@41426
    44
        val Ts = map typ_of dts
boehmes@41426
    45
        val constr = Const (m, Ts ---> T)
boehmes@41426
    46
        val (selects, ctxt') = mk_selectors T Ts ctxt
boehmes@41426
    47
      in ((constr, selects), ctxt') end
boehmes@41426
    48
boehmes@41426
    49
    fun mk_decl (i, (_, _, constrs)) ctxt =
boehmes@41426
    50
      let
boehmes@41426
    51
        val T = lookup_typ i
boehmes@41426
    52
        val (css, ctxt') = fold_map (mk_constr T) constrs ctxt
boehmes@41426
    53
      in ((T, css), ctxt') end
boehmes@41426
    54
boehmes@41426
    55
  in fold_map mk_decl descr ctxt end
boehmes@41426
    56
boehmes@41426
    57
boehmes@41426
    58
(* record declarations *)
boehmes@41426
    59
boehmes@41426
    60
val record_name_of = Long_Name.implode o fst o split_last o Long_Name.explode
boehmes@41426
    61
boehmes@41426
    62
fun get_record_decl ({ext_def, ...} : Record.info) T ctxt =
boehmes@41426
    63
  let
boehmes@41426
    64
    val (con, _) = Term.dest_Const (lhs_head_of ext_def)
boehmes@41426
    65
    val (fields, more) = Record.get_extT_fields (ProofContext.theory_of ctxt) T
boehmes@41426
    66
    val fieldTs = map snd fields @ [snd more]
boehmes@41426
    67
boehmes@41426
    68
    val constr = Const (con, fieldTs ---> T)
boehmes@41426
    69
    val (selects, ctxt') = mk_selectors T fieldTs ctxt
boehmes@41426
    70
  in ((T, [(constr, selects)]), ctxt') end
boehmes@41426
    71
boehmes@41426
    72
boehmes@41426
    73
(* typedef declarations *)
boehmes@41426
    74
boehmes@41426
    75
fun get_typedef_decl (info : Typedef.info) T Ts =
boehmes@41426
    76
  let
boehmes@41426
    77
    val ({Abs_name, Rep_name, abs_type, rep_type, ...}, _) = info
boehmes@41426
    78
boehmes@41426
    79
    val env = snd (Term.dest_Type abs_type) ~~ Ts
boehmes@41426
    80
    val instT = Term.map_atyps (perhaps (AList.lookup (op =) env))
boehmes@41426
    81
boehmes@41426
    82
    val constr = Const (Abs_name, instT (rep_type --> abs_type))
boehmes@41426
    83
    val select = Const (Rep_name, instT (abs_type --> rep_type))
boehmes@41426
    84
  in (T, [(constr, [select])]) end
boehmes@41426
    85
boehmes@41426
    86
boehmes@41426
    87
(* collection of declarations *)
boehmes@41426
    88
boehmes@41426
    89
fun declared declss T = exists (exists (equal T o fst)) declss
boehmes@41426
    90
boehmes@41426
    91
fun get_decls T n Ts ctxt =
boehmes@41426
    92
  let val thy = ProofContext.theory_of ctxt
boehmes@41426
    93
  in
boehmes@41426
    94
    (case Datatype.get_info thy n of
boehmes@41426
    95
      SOME info => get_datatype_decl info n Ts ctxt
boehmes@41426
    96
    | NONE =>
boehmes@41426
    97
        (case Record.get_info thy (record_name_of n) of
boehmes@41426
    98
          SOME info => get_record_decl info T ctxt |>> single
boehmes@41426
    99
        | NONE =>
boehmes@41426
   100
            (case Typedef.get_info ctxt n of
boehmes@41426
   101
              [] => ([], ctxt)
boehmes@41426
   102
            | info :: _ => ([get_typedef_decl info T Ts], ctxt))))
boehmes@41426
   103
  end
boehmes@41426
   104
boehmes@41426
   105
fun add_decls T (declss, ctxt) =
boehmes@41426
   106
  let
boehmes@41426
   107
    fun add (TFree _) = I
boehmes@41426
   108
      | add (TVar _) = I
boehmes@41426
   109
      | add (T as Type (@{type_name fun}, _)) =
boehmes@41426
   110
          fold add (Term.body_type T :: Term.binder_types T)
boehmes@41426
   111
      | add @{typ bool} = I
boehmes@41426
   112
      | add (T as Type (n, Ts)) = (fn (dss, ctxt1) =>
boehmes@41426
   113
          if declared declss T orelse declared dss T then (dss, ctxt1)
boehmes@41426
   114
          else if SMT_Builtin.is_builtin_typ_ext ctxt1 T then (dss, ctxt1)
boehmes@41426
   115
          else
boehmes@41426
   116
            (case get_decls T n Ts ctxt1 of
boehmes@41426
   117
              ([], _) => (dss, ctxt1)
boehmes@41426
   118
            | (ds, ctxt2) =>
boehmes@41426
   119
                let
boehmes@41426
   120
                  val constrTs =
boehmes@41426
   121
                    maps (map (snd o Term.dest_Const o fst) o snd) ds
boehmes@41426
   122
                  val Us = fold (union (op =) o Term.binder_types) constrTs []
boehmes@41426
   123
            in fold add Us (ds :: dss, ctxt2) end))
boehmes@41426
   124
  in add T ([], ctxt) |>> append declss end
boehmes@41426
   125
boehmes@41426
   126
end