src/HOL/Tools/SMT/smt_datatypes.ML
author blanchet
Mon Dec 15 07:20:49 2014 +0100 (2014-12-15)
changeset 59143 15c342a9a8e0
parent 59142 705f8aea8d60
child 67149 e61557884799
permissions -rw-r--r--
correctly apply type substitution before checking for function types
     1 (*  Title:      HOL/Tools/SMT/smt_datatypes.ML
     2     Author:     Sascha Boehme, TU Muenchen
     3 
     4 Collector functions for common type declarations and their representation
     5 as (co)algebraic datatypes.
     6 *)
     7 
     8 signature SMT_DATATYPES =
     9 sig
    10   val add_decls: BNF_Util.fp_kind list -> typ ->
    11     (BNF_Util.fp_kind * (typ * (term * term list) list)) list list * Proof.context ->
    12     (BNF_Util.fp_kind * (typ * (term * term list) list)) list list * Proof.context
    13 end;
    14 
    15 structure SMT_Datatypes: SMT_DATATYPES =
    16 struct
    17 
    18 fun mk_selectors T Ts sels =
    19   if null sels then
    20     Variable.variant_fixes (replicate (length Ts) "select")
    21     #>> map2 (fn U => fn n => Free (n, T --> U)) Ts
    22   else
    23     pair sels
    24 
    25 
    26 (* free constructor type declarations *)
    27 
    28 fun get_ctr_sugar_decl ({ctrs = ctrs0, selss = selss0, ...} : Ctr_Sugar.ctr_sugar) T Ts ctxt =
    29   let
    30     val selss = map (map (Ctr_Sugar.mk_disc_or_sel Ts)) selss0
    31     val ctrs = map (Ctr_Sugar.mk_ctr Ts) ctrs0
    32 
    33     fun mk_constr ctr sels =
    34       mk_selectors T (binder_types (fastype_of ctr)) sels #>> pair ctr
    35 
    36     val selss' =
    37       (if has_duplicates (op aconv) (flat selss) orelse
    38           exists (exists (can (dest_funT o range_type o fastype_of))) selss then
    39          []
    40        else
    41          selss)
    42       |> Ctr_Sugar_Util.pad_list [] (length ctrs)
    43   in
    44     @{fold_map 2} mk_constr ctrs selss' ctxt
    45     |>> (pair T #> single)
    46   end
    47 
    48 
    49 (* typedef declarations *)
    50 
    51 fun get_typedef_decl (({Abs_name, Rep_name, abs_type, rep_type, ...}, {Abs_inverse, ...})
    52     : Typedef.info) T Ts =
    53   if can (curry (op RS) @{thm UNIV_I}) Abs_inverse then
    54     let
    55       val env = snd (Term.dest_Type abs_type) ~~ Ts
    56       val instT = Term.map_atyps (perhaps (AList.lookup (op =) env))
    57 
    58       val constr = Const (Abs_name, instT (rep_type --> abs_type))
    59       val select = Const (Rep_name, instT (abs_type --> rep_type))
    60     in [(T, [(constr, [select])])] end
    61   else
    62     []
    63 
    64 
    65 (* collection of declarations *)
    66 
    67 val extN = "_ext" (* cf. "HOL/Tools/typedef.ML" *)
    68 
    69 fun get_decls fps T n Ts ctxt =
    70   let
    71     fun maybe_typedef () =
    72       (case Typedef.get_info ctxt n of
    73         [] => ([], ctxt)
    74       | info :: _ => (map (pair (hd fps)) (get_typedef_decl info T Ts), ctxt))
    75   in
    76     (case BNF_FP_Def_Sugar.fp_sugar_of ctxt n of
    77       SOME {fp, fp_res = {Ts = fp_Ts, ...}, fp_ctr_sugar = {ctr_sugar, ...}, ...} =>
    78       if member (op =) fps fp then
    79         let
    80           val ns = map (fst o dest_Type) fp_Ts
    81           val mutual_fp_sugars = map_filter (BNF_FP_Def_Sugar.fp_sugar_of ctxt) ns
    82           val Xs = map #X mutual_fp_sugars
    83           val ctrXs_Tsss = map (#ctrXs_Tss o #fp_ctr_sugar) mutual_fp_sugars
    84 
    85           (* Datatypes nested through datatypes and codatatypes nested through codatatypes are
    86              allowed. So are mutually (co)recursive (co)datatypes. *)
    87           fun is_same_fp s =
    88             (case BNF_FP_Def_Sugar.fp_sugar_of ctxt s of
    89               SOME {fp = fp', ...} => fp' = fp
    90             | NONE => false)
    91           fun is_homogenously_nested_co_recursive (Type (s, Ts)) =
    92               forall (if is_same_fp s then is_homogenously_nested_co_recursive
    93                 else not o BNF_FP_Rec_Sugar_Util.exists_subtype_in Xs) Ts
    94             | is_homogenously_nested_co_recursive _ = true
    95 
    96           val Type (_, As) :: _ = fp_Ts
    97           val substAs = Term.typ_subst_atomic (As ~~ Ts);
    98         in
    99           (* TODO/FIXME: The "bool" check is there to work around a CVC4 bug
   100              (http://church.cims.nyu.edu/bugzilla3/show_bug.cgi?id=597). It should be removed once
   101              the bug is fixed. *)
   102           if forall (forall (forall (is_homogenously_nested_co_recursive))) ctrXs_Tsss andalso
   103              forall (forall (forall (curry (op <>) @{typ bool})))
   104                (map (map (map substAs)) ctrXs_Tsss) then
   105             get_ctr_sugar_decl ctr_sugar T Ts ctxt |>> map (pair fp)
   106           else
   107             maybe_typedef ()
   108         end
   109       else
   110         ([], ctxt)
   111     | NONE =>
   112       if String.isSuffix extN n then
   113         (* for records (FIXME: hack) *)
   114         (case Ctr_Sugar.ctr_sugar_of ctxt n of
   115           SOME ctr_sugar =>
   116           get_ctr_sugar_decl ctr_sugar T Ts ctxt |>> map (pair (hd fps))
   117         | NONE => maybe_typedef ())
   118       else
   119         maybe_typedef ())
   120   end
   121 
   122 fun add_decls fps T (declss, ctxt) =
   123   let
   124     fun declared T = exists (exists (equal T o fst o snd))
   125     fun declared' T = exists (exists (equal T o fst o snd) o snd)
   126     fun depends ds = exists (member (op =) (map (fst o snd) ds))
   127 
   128     fun add (TFree _) = I
   129       | add (TVar _) = I
   130       | add (T as Type (@{type_name fun}, _)) =
   131           fold add (Term.body_type T :: Term.binder_types T)
   132       | add @{typ bool} = I
   133       | add (T as Type (n, Ts)) = (fn (dss, ctxt1) =>
   134           if declared T declss orelse declared' T dss then
   135             (dss, ctxt1)
   136           else if SMT_Builtin.is_builtin_typ_ext ctxt1 T then
   137             (dss, ctxt1)
   138           else
   139             (case get_decls fps T n Ts ctxt1 of
   140               ([], _) => (dss, ctxt1)
   141             | (ds, ctxt2) =>
   142                 let
   143                   val constrTs = maps (map (snd o Term.dest_Const o fst) o snd o snd) ds
   144                   val Us = fold (union (op =) o Term.binder_types) constrTs []
   145 
   146                   fun ins [] = [(Us, ds)]
   147                     | ins ((Uds as (Us', _)) :: Udss) =
   148                         if depends ds Us' then (Us, ds) :: Uds :: Udss else Uds :: ins Udss
   149             in fold add Us (ins dss, ctxt2) end))
   150   in add T ([], ctxt) |>> append declss o map snd end
   151 
   152 end;