src/HOL/Tools/SMT/smt_datatypes.ML
author blanchet
Mon Dec 15 07:20:49 2014 +0100 (2014-12-15)
changeset 59143 15c342a9a8e0
parent 59142 705f8aea8d60
child 67149 e61557884799
permissions -rw-r--r--
correctly apply type substitution before checking for function types
blanchet@58061
     1
(*  Title:      HOL/Tools/SMT/smt_datatypes.ML
blanchet@56078
     2
    Author:     Sascha Boehme, TU Muenchen
blanchet@56078
     3
blanchet@56078
     4
Collector functions for common type declarations and their representation
blanchet@58361
     5
as (co)algebraic datatypes.
blanchet@56078
     6
*)
blanchet@56078
     7
blanchet@58061
     8
signature SMT_DATATYPES =
blanchet@56078
     9
sig
blanchet@58429
    10
  val add_decls: BNF_Util.fp_kind list -> typ ->
blanchet@58429
    11
    (BNF_Util.fp_kind * (typ * (term * term list) list)) list list * Proof.context ->
blanchet@58429
    12
    (BNF_Util.fp_kind * (typ * (term * term list) list)) list list * Proof.context
blanchet@57229
    13
end;
blanchet@56078
    14
blanchet@58061
    15
structure SMT_Datatypes: SMT_DATATYPES =
blanchet@56078
    16
struct
blanchet@56078
    17
blanchet@58362
    18
fun mk_selectors T Ts sels =
blanchet@58362
    19
  if null sels then
blanchet@58362
    20
    Variable.variant_fixes (replicate (length Ts) "select")
blanchet@58362
    21
    #>> map2 (fn U => fn n => Free (n, T --> U)) Ts
blanchet@58362
    22
  else
blanchet@58362
    23
    pair sels
blanchet@56078
    24
blanchet@56078
    25
blanchet@57226
    26
(* free constructor type declarations *)
blanchet@56078
    27
blanchet@58362
    28
fun get_ctr_sugar_decl ({ctrs = ctrs0, selss = selss0, ...} : Ctr_Sugar.ctr_sugar) T Ts ctxt =
blanchet@56078
    29
  let
blanchet@59143
    30
    val selss = map (map (Ctr_Sugar.mk_disc_or_sel Ts)) selss0
blanchet@59143
    31
    val ctrs = map (Ctr_Sugar.mk_ctr Ts) ctrs0
blanchet@59143
    32
blanchet@59143
    33
    fun mk_constr ctr sels =
blanchet@59143
    34
      mk_selectors T (binder_types (fastype_of ctr)) sels #>> pair ctr
blanchet@58362
    35
blanchet@59143
    36
    val selss' =
blanchet@59143
    37
      (if has_duplicates (op aconv) (flat selss) orelse
blanchet@59143
    38
          exists (exists (can (dest_funT o range_type o fastype_of))) selss then
blanchet@59143
    39
         []
blanchet@59143
    40
       else
blanchet@59143
    41
         selss)
blanchet@59143
    42
      |> Ctr_Sugar_Util.pad_list [] (length ctrs)
blanchet@57226
    43
  in
blanchet@59143
    44
    @{fold_map 2} mk_constr ctrs selss' ctxt
blanchet@57226
    45
    |>> (pair T #> single)
blanchet@57226
    46
  end
blanchet@56078
    47
blanchet@56078
    48
blanchet@56078
    49
(* typedef declarations *)
blanchet@56078
    50
blanchet@57213
    51
fun get_typedef_decl (({Abs_name, Rep_name, abs_type, rep_type, ...}, {Abs_inverse, ...})
blanchet@57213
    52
    : Typedef.info) T Ts =
blanchet@57213
    53
  if can (curry (op RS) @{thm UNIV_I}) Abs_inverse then
blanchet@57213
    54
    let
blanchet@57213
    55
      val env = snd (Term.dest_Type abs_type) ~~ Ts
blanchet@57213
    56
      val instT = Term.map_atyps (perhaps (AList.lookup (op =) env))
blanchet@56078
    57
blanchet@57213
    58
      val constr = Const (Abs_name, instT (rep_type --> abs_type))
blanchet@57213
    59
      val select = Const (Rep_name, instT (abs_type --> rep_type))
blanchet@57213
    60
    in [(T, [(constr, [select])])] end
blanchet@57213
    61
  else
blanchet@57213
    62
    []
blanchet@56078
    63
blanchet@56078
    64
blanchet@56078
    65
(* collection of declarations *)
blanchet@56078
    66
blanchet@58427
    67
val extN = "_ext" (* cf. "HOL/Tools/typedef.ML" *)
blanchet@58361
    68
blanchet@58429
    69
fun get_decls fps T n Ts ctxt =
blanchet@58361
    70
  let
blanchet@58427
    71
    fun maybe_typedef () =
blanchet@57226
    72
      (case Typedef.get_info ctxt n of
blanchet@57226
    73
        [] => ([], ctxt)
blanchet@58429
    74
      | info :: _ => (map (pair (hd fps)) (get_typedef_decl info T Ts), ctxt))
blanchet@58361
    75
  in
blanchet@58427
    76
    (case BNF_FP_Def_Sugar.fp_sugar_of ctxt n of
desharna@58460
    77
      SOME {fp, fp_res = {Ts = fp_Ts, ...}, fp_ctr_sugar = {ctr_sugar, ...}, ...} =>
blanchet@58429
    78
      if member (op =) fps fp then
blanchet@58428
    79
        let
blanchet@58428
    80
          val ns = map (fst o dest_Type) fp_Ts
blanchet@58428
    81
          val mutual_fp_sugars = map_filter (BNF_FP_Def_Sugar.fp_sugar_of ctxt) ns
blanchet@58428
    82
          val Xs = map #X mutual_fp_sugars
desharna@58460
    83
          val ctrXs_Tsss = map (#ctrXs_Tss o #fp_ctr_sugar) mutual_fp_sugars
blanchet@58428
    84
blanchet@58430
    85
          (* Datatypes nested through datatypes and codatatypes nested through codatatypes are
blanchet@58430
    86
             allowed. So are mutually (co)recursive (co)datatypes. *)
blanchet@58430
    87
          fun is_same_fp s =
blanchet@58430
    88
            (case BNF_FP_Def_Sugar.fp_sugar_of ctxt s of
blanchet@58430
    89
              SOME {fp = fp', ...} => fp' = fp
blanchet@58430
    90
            | NONE => false)
blanchet@58430
    91
          fun is_homogenously_nested_co_recursive (Type (s, Ts)) =
blanchet@58430
    92
              forall (if is_same_fp s then is_homogenously_nested_co_recursive
blanchet@58430
    93
                else not o BNF_FP_Rec_Sugar_Util.exists_subtype_in Xs) Ts
blanchet@58430
    94
            | is_homogenously_nested_co_recursive _ = true
blanchet@59020
    95
blanchet@59020
    96
          val Type (_, As) :: _ = fp_Ts
blanchet@59020
    97
          val substAs = Term.typ_subst_atomic (As ~~ Ts);
blanchet@58428
    98
        in
blanchet@59020
    99
          (* TODO/FIXME: The "bool" check is there to work around a CVC4 bug
blanchet@59020
   100
             (http://church.cims.nyu.edu/bugzilla3/show_bug.cgi?id=597). It should be removed once
blanchet@59020
   101
             the bug is fixed. *)
blanchet@59020
   102
          if forall (forall (forall (is_homogenously_nested_co_recursive))) ctrXs_Tsss andalso
blanchet@59020
   103
             forall (forall (forall (curry (op <>) @{typ bool})))
blanchet@59020
   104
               (map (map (map substAs)) ctrXs_Tsss) then
blanchet@58430
   105
            get_ctr_sugar_decl ctr_sugar T Ts ctxt |>> map (pair fp)
blanchet@58430
   106
          else
blanchet@58430
   107
            maybe_typedef ()
blanchet@58428
   108
        end
blanchet@58428
   109
      else
blanchet@58428
   110
        ([], ctxt)
blanchet@58427
   111
    | NONE =>
blanchet@58429
   112
      if String.isSuffix extN n then
blanchet@58429
   113
        (* for records (FIXME: hack) *)
blanchet@58429
   114
        (case Ctr_Sugar.ctr_sugar_of ctxt n of
blanchet@58429
   115
          SOME ctr_sugar =>
blanchet@58429
   116
          get_ctr_sugar_decl ctr_sugar T Ts ctxt |>> map (pair (hd fps))
blanchet@58429
   117
        | NONE => maybe_typedef ())
blanchet@58427
   118
      else
blanchet@58429
   119
        maybe_typedef ())
blanchet@58361
   120
  end
blanchet@56078
   121
blanchet@58429
   122
fun add_decls fps T (declss, ctxt) =
blanchet@56078
   123
  let
blanchet@58429
   124
    fun declared T = exists (exists (equal T o fst o snd))
blanchet@58429
   125
    fun declared' T = exists (exists (equal T o fst o snd) o snd)
blanchet@58429
   126
    fun depends ds = exists (member (op =) (map (fst o snd) ds))
blanchet@56078
   127
blanchet@56078
   128
    fun add (TFree _) = I
blanchet@56078
   129
      | add (TVar _) = I
blanchet@56078
   130
      | add (T as Type (@{type_name fun}, _)) =
blanchet@56078
   131
          fold add (Term.body_type T :: Term.binder_types T)
blanchet@56078
   132
      | add @{typ bool} = I
blanchet@56078
   133
      | add (T as Type (n, Ts)) = (fn (dss, ctxt1) =>
blanchet@58429
   134
          if declared T declss orelse declared' T dss then
blanchet@58429
   135
            (dss, ctxt1)
blanchet@58429
   136
          else if SMT_Builtin.is_builtin_typ_ext ctxt1 T then
blanchet@58429
   137
            (dss, ctxt1)
blanchet@56078
   138
          else
blanchet@58429
   139
            (case get_decls fps T n Ts ctxt1 of
blanchet@56078
   140
              ([], _) => (dss, ctxt1)
blanchet@56078
   141
            | (ds, ctxt2) =>
blanchet@56078
   142
                let
blanchet@58429
   143
                  val constrTs = maps (map (snd o Term.dest_Const o fst) o snd o snd) ds
blanchet@56078
   144
                  val Us = fold (union (op =) o Term.binder_types) constrTs []
blanchet@56078
   145
blanchet@56078
   146
                  fun ins [] = [(Us, ds)]
blanchet@56078
   147
                    | ins ((Uds as (Us', _)) :: Udss) =
blanchet@58364
   148
                        if depends ds Us' then (Us, ds) :: Uds :: Udss else Uds :: ins Udss
blanchet@56078
   149
            in fold add Us (ins dss, ctxt2) end))
blanchet@56078
   150
  in add T ([], ctxt) |>> append declss o map snd end
blanchet@56078
   151
blanchet@57229
   152
end;