src/HOL/Library/SMT/smt_datatypes.ML
changeset 58057 883f3c4c928e
parent 58056 fc6dd578d506
child 58058 1a0b18176548
equal deleted inserted replaced
58056:fc6dd578d506 58057:883f3c4c928e
     1 (*  Title:      HOL/Library/SMT/smt_datatypes.ML
       
     2     Author:     Sascha Boehme, TU Muenchen
       
     3 
       
     4 Collector functions for common type declarations and their representation
       
     5 as algebraic datatypes.
       
     6 *)
       
     7 
       
     8 signature SMT_DATATYPES =
       
     9 sig
       
    10   val add_decls: typ ->
       
    11     (typ * (term * term list) list) list list * Proof.context ->
       
    12     (typ * (term * term list) list) list list * Proof.context
       
    13 end
       
    14 
       
    15 structure SMT_Datatypes: SMT_DATATYPES =
       
    16 struct
       
    17 
       
    18 val lhs_head_of = Term.head_of o fst o Logic.dest_equals o Thm.prop_of
       
    19 
       
    20 fun mk_selectors T Ts =
       
    21   Variable.variant_fixes (replicate (length Ts) "select")
       
    22   #>> map2 (fn U => fn n => Free (n, T --> U)) Ts
       
    23 
       
    24 
       
    25 (* free constructor type declarations *)
       
    26 
       
    27 fun get_ctr_sugar_decl ({ctrs, ...} : Ctr_Sugar.ctr_sugar) T Ts ctxt =
       
    28   let
       
    29     fun mk_constr ctr0 =
       
    30       let val ctr = Ctr_Sugar.mk_ctr Ts ctr0 in
       
    31         mk_selectors T (binder_types (fastype_of ctr)) #>> pair ctr
       
    32       end
       
    33   in
       
    34     fold_map mk_constr ctrs ctxt
       
    35     |>> (pair T #> single)
       
    36   end
       
    37 
       
    38 
       
    39 (* typedef declarations *)
       
    40 
       
    41 fun get_typedef_decl (({Abs_name, Rep_name, abs_type, rep_type, ...}, {Abs_inverse, ...})
       
    42     : Typedef.info) T Ts =
       
    43   if can (curry (op RS) @{thm UNIV_I}) Abs_inverse then
       
    44     let
       
    45       val env = snd (Term.dest_Type abs_type) ~~ Ts
       
    46       val instT = Term.map_atyps (perhaps (AList.lookup (op =) env))
       
    47 
       
    48       val constr = Const (Abs_name, instT (rep_type --> abs_type))
       
    49       val select = Const (Rep_name, instT (abs_type --> rep_type))
       
    50     in [(T, [(constr, [select])])] end
       
    51   else
       
    52     []
       
    53 
       
    54 
       
    55 (* collection of declarations *)
       
    56 
       
    57 fun declared declss T = exists (exists (equal T o fst)) declss
       
    58 fun declared' dss T = exists (exists (equal T o fst) o snd) dss
       
    59 
       
    60 fun get_decls T n Ts ctxt =
       
    61   (case Ctr_Sugar.ctr_sugar_of ctxt n of
       
    62     SOME ctr_sugar => get_ctr_sugar_decl ctr_sugar T Ts ctxt
       
    63   | NONE =>
       
    64       (case Typedef.get_info ctxt n of
       
    65         [] => ([], ctxt)
       
    66       | info :: _ => (get_typedef_decl info T Ts, ctxt)))
       
    67 
       
    68 fun add_decls T (declss, ctxt) =
       
    69   let
       
    70     fun depends Ts ds = exists (member (op =) (map fst ds)) Ts
       
    71 
       
    72     fun add (TFree _) = I
       
    73       | add (TVar _) = I
       
    74       | add (T as Type (@{type_name fun}, _)) =
       
    75           fold add (Term.body_type T :: Term.binder_types T)
       
    76       | add @{typ bool} = I
       
    77       | add (T as Type (n, Ts)) = (fn (dss, ctxt1) =>
       
    78           if declared declss T orelse declared' dss T then (dss, ctxt1)
       
    79           else if SMT_Builtin.is_builtin_typ_ext ctxt1 T then (dss, ctxt1)
       
    80           else
       
    81             (case get_decls T n Ts ctxt1 of
       
    82               ([], _) => (dss, ctxt1)
       
    83             | (ds, ctxt2) =>
       
    84                 let
       
    85                   val constrTs =
       
    86                     maps (map (snd o Term.dest_Const o fst) o snd) ds
       
    87                   val Us = fold (union (op =) o Term.binder_types) constrTs []
       
    88 
       
    89                   fun ins [] = [(Us, ds)]
       
    90                     | ins ((Uds as (Us', _)) :: Udss) =
       
    91                         if depends Us' ds then (Us, ds) :: Uds :: Udss
       
    92                         else Uds :: ins Udss
       
    93             in fold add Us (ins dss, ctxt2) end))
       
    94   in add T ([], ctxt) |>> append declss o map snd end
       
    95 
       
    96 end