src/HOL/Tools/SMT/smt_datatypes.ML
changeset 41426 09615ed31f04
child 42361 23f352990944
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/Tools/SMT/smt_datatypes.ML	Mon Jan 03 16:22:08 2011 +0100
     1.3 @@ -0,0 +1,126 @@
     1.4 +(*  Title:      HOL/Tools/SMT/smt_datatypes.ML
     1.5 +    Author:     Sascha Boehme, TU Muenchen
     1.6 +
     1.7 +Collector functions for common type declarations and their representation
     1.8 +as algebraic datatypes.
     1.9 +*)
    1.10 +
    1.11 +signature SMT_DATATYPES =
    1.12 +sig
    1.13 +  val add_decls: typ ->
    1.14 +    (typ * (term * term list) list) list list * Proof.context ->
    1.15 +    (typ * (term * term list) list) list list * Proof.context
    1.16 +end
    1.17 +
    1.18 +structure SMT_Datatypes: SMT_DATATYPES =
    1.19 +struct
    1.20 +
    1.21 +val lhs_head_of = Term.head_of o fst o Logic.dest_equals o Thm.prop_of
    1.22 +
    1.23 +fun mk_selectors T Ts ctxt =
    1.24 +  let
    1.25 +    val (sels, ctxt') =
    1.26 +      Variable.variant_fixes (replicate (length Ts) "select") ctxt
    1.27 +  in (map2 (fn n => fn U => Free (n, T --> U)) sels Ts, ctxt') end
    1.28 +
    1.29 +
    1.30 +(* datatype declarations *)
    1.31 +
    1.32 +fun get_datatype_decl ({descr, ...} : Datatype.info) n Ts ctxt =
    1.33 +  let
    1.34 +    fun get_vars (_, (m, vs, _)) = if m = n then SOME vs else NONE
    1.35 +    val vars = the (get_first get_vars descr) ~~ Ts
    1.36 +    val lookup_var = the o AList.lookup (op =) vars
    1.37 +
    1.38 +    val dTs = map (apsnd (fn (m, vs, _) => Type (m, map lookup_var vs))) descr
    1.39 +    val lookup_typ = the o AList.lookup (op =) dTs
    1.40 +
    1.41 +    fun typ_of (dt as Datatype.DtTFree _) = lookup_var dt
    1.42 +      | typ_of (Datatype.DtType (n, dts)) = Type (n, map typ_of dts)
    1.43 +      | typ_of (Datatype.DtRec i) = lookup_typ i
    1.44 +
    1.45 +    fun mk_constr T (m, dts) ctxt =
    1.46 +      let
    1.47 +        val Ts = map typ_of dts
    1.48 +        val constr = Const (m, Ts ---> T)
    1.49 +        val (selects, ctxt') = mk_selectors T Ts ctxt
    1.50 +      in ((constr, selects), ctxt') end
    1.51 +
    1.52 +    fun mk_decl (i, (_, _, constrs)) ctxt =
    1.53 +      let
    1.54 +        val T = lookup_typ i
    1.55 +        val (css, ctxt') = fold_map (mk_constr T) constrs ctxt
    1.56 +      in ((T, css), ctxt') end
    1.57 +
    1.58 +  in fold_map mk_decl descr ctxt end
    1.59 +
    1.60 +
    1.61 +(* record declarations *)
    1.62 +
    1.63 +val record_name_of = Long_Name.implode o fst o split_last o Long_Name.explode
    1.64 +
    1.65 +fun get_record_decl ({ext_def, ...} : Record.info) T ctxt =
    1.66 +  let
    1.67 +    val (con, _) = Term.dest_Const (lhs_head_of ext_def)
    1.68 +    val (fields, more) = Record.get_extT_fields (ProofContext.theory_of ctxt) T
    1.69 +    val fieldTs = map snd fields @ [snd more]
    1.70 +
    1.71 +    val constr = Const (con, fieldTs ---> T)
    1.72 +    val (selects, ctxt') = mk_selectors T fieldTs ctxt
    1.73 +  in ((T, [(constr, selects)]), ctxt') end
    1.74 +
    1.75 +
    1.76 +(* typedef declarations *)
    1.77 +
    1.78 +fun get_typedef_decl (info : Typedef.info) T Ts =
    1.79 +  let
    1.80 +    val ({Abs_name, Rep_name, abs_type, rep_type, ...}, _) = info
    1.81 +
    1.82 +    val env = snd (Term.dest_Type abs_type) ~~ Ts
    1.83 +    val instT = Term.map_atyps (perhaps (AList.lookup (op =) env))
    1.84 +
    1.85 +    val constr = Const (Abs_name, instT (rep_type --> abs_type))
    1.86 +    val select = Const (Rep_name, instT (abs_type --> rep_type))
    1.87 +  in (T, [(constr, [select])]) end
    1.88 +
    1.89 +
    1.90 +(* collection of declarations *)
    1.91 +
    1.92 +fun declared declss T = exists (exists (equal T o fst)) declss
    1.93 +
    1.94 +fun get_decls T n Ts ctxt =
    1.95 +  let val thy = ProofContext.theory_of ctxt
    1.96 +  in
    1.97 +    (case Datatype.get_info thy n of
    1.98 +      SOME info => get_datatype_decl info n Ts ctxt
    1.99 +    | NONE =>
   1.100 +        (case Record.get_info thy (record_name_of n) of
   1.101 +          SOME info => get_record_decl info T ctxt |>> single
   1.102 +        | NONE =>
   1.103 +            (case Typedef.get_info ctxt n of
   1.104 +              [] => ([], ctxt)
   1.105 +            | info :: _ => ([get_typedef_decl info T Ts], ctxt))))
   1.106 +  end
   1.107 +
   1.108 +fun add_decls T (declss, ctxt) =
   1.109 +  let
   1.110 +    fun add (TFree _) = I
   1.111 +      | add (TVar _) = I
   1.112 +      | add (T as Type (@{type_name fun}, _)) =
   1.113 +          fold add (Term.body_type T :: Term.binder_types T)
   1.114 +      | add @{typ bool} = I
   1.115 +      | add (T as Type (n, Ts)) = (fn (dss, ctxt1) =>
   1.116 +          if declared declss T orelse declared dss T then (dss, ctxt1)
   1.117 +          else if SMT_Builtin.is_builtin_typ_ext ctxt1 T then (dss, ctxt1)
   1.118 +          else
   1.119 +            (case get_decls T n Ts ctxt1 of
   1.120 +              ([], _) => (dss, ctxt1)
   1.121 +            | (ds, ctxt2) =>
   1.122 +                let
   1.123 +                  val constrTs =
   1.124 +                    maps (map (snd o Term.dest_Const o fst) o snd) ds
   1.125 +                  val Us = fold (union (op =) o Term.binder_types) constrTs []
   1.126 +            in fold add Us (ds :: dss, ctxt2) end))
   1.127 +  in add T ([], ctxt) |>> append declss end
   1.128 +
   1.129 +end