(* Title: HOL/Tools/SMT/smt_datatypes.ML
Author: Sascha Boehme, TU Muenchen
Collector functions for common type declarations and their representation
as (co)algebraic datatypes.
*)
signature SMT_DATATYPES =
sig
val add_decls: BNF_Util.fp_kind -> typ ->
(typ * (term * term list) list) list list * Proof.context ->
(typ * (term * term list) list) list list * Proof.context
end;
structure SMT_Datatypes: SMT_DATATYPES =
struct
fun mk_selectors T Ts =
Variable.variant_fixes (replicate (length Ts) "select")
#>> map2 (fn U => fn n => Free (n, T --> U)) Ts
(* free constructor type declarations *)
fun get_ctr_sugar_decl ({ctrs, ...} : Ctr_Sugar.ctr_sugar) T Ts ctxt =
let
fun mk_constr ctr0 =
let val ctr = Ctr_Sugar.mk_ctr Ts ctr0 in
mk_selectors T (binder_types (fastype_of ctr)) #>> pair ctr
end
in
fold_map mk_constr ctrs ctxt
|>> (pair T #> single)
end
(* typedef declarations *)
fun get_typedef_decl (({Abs_name, Rep_name, abs_type, rep_type, ...}, {Abs_inverse, ...})
: Typedef.info) T Ts =
if can (curry (op RS) @{thm UNIV_I}) Abs_inverse then
let
val env = snd (Term.dest_Type abs_type) ~~ Ts
val instT = Term.map_atyps (perhaps (AList.lookup (op =) env))
val constr = Const (Abs_name, instT (rep_type --> abs_type))
val select = Const (Rep_name, instT (abs_type --> rep_type))
in [(T, [(constr, [select])])] end
else
[]
(* collection of declarations *)
fun declared declss T = exists (exists (equal T o fst)) declss
fun declared' dss T = exists (exists (equal T o fst) o snd) dss
(* Simplification: We assume that every type that is not a codatatype is a datatype (or a
record). *)
fun fp_kind_of ctxt n =
(case BNF_FP_Def_Sugar.fp_sugar_of ctxt n of
SOME {fp, ...} => fp
| NONE => BNF_Util.Least_FP)
fun get_decls fp T n Ts ctxt =
let
fun fallback () =
(case Typedef.get_info ctxt n of
[] => ([], ctxt)
| info :: _ => (get_typedef_decl info T Ts, ctxt))
in
(case Ctr_Sugar.ctr_sugar_of ctxt n of
SOME ctr_sugar =>
if fp_kind_of ctxt n = fp then get_ctr_sugar_decl ctr_sugar T Ts ctxt else fallback ()
| NONE => fallback ())
end
fun add_decls fp T (declss, ctxt) =
let
fun depends Ts ds = exists (member (op =) (map fst ds)) Ts
fun add (TFree _) = I
| add (TVar _) = I
| add (T as Type (@{type_name fun}, _)) =
fold add (Term.body_type T :: Term.binder_types T)
| add @{typ bool} = I
| add (T as Type (n, Ts)) = (fn (dss, ctxt1) =>
if declared declss T orelse declared' dss T then (dss, ctxt1)
else if SMT_Builtin.is_builtin_typ_ext ctxt1 T then (dss, ctxt1)
else
(case get_decls fp T n Ts ctxt1 of
([], _) => (dss, ctxt1)
| (ds, ctxt2) =>
let
val constrTs = maps (map (snd o Term.dest_Const o fst) o snd) ds
val Us = fold (union (op =) o Term.binder_types) constrTs []
fun ins [] = [(Us, ds)]
| ins ((Uds as (Us', _)) :: Udss) =
if depends Us' ds then (Us, ds) :: Uds :: Udss
else Uds :: ins Udss
in fold add Us (ins dss, ctxt2) end))
in add T ([], ctxt) |>> append declss o map snd end
end;