added codatatype support for CVC4
authorblanchet
Wed, 17 Sep 2014 17:32:27 +0200
changeset 58361 7f2b3b6f6ad1
parent 58360 dee1fd1cc631
child 58362 cf32eb8001b8
added codatatype support for CVC4
src/HOL/Tools/SMT/smt_datatypes.ML
src/HOL/Tools/SMT/smt_translate.ML
src/HOL/Tools/SMT/smtlib_interface.ML
--- a/src/HOL/Tools/SMT/smt_datatypes.ML	Wed Sep 17 16:53:39 2014 +0200
+++ b/src/HOL/Tools/SMT/smt_datatypes.ML	Wed Sep 17 17:32:27 2014 +0200
@@ -2,7 +2,7 @@
     Author:     Sascha Boehme, TU Muenchen
 
 Collector functions for common type declarations and their representation
-as algebraic datatypes.
+as (co)algebraic datatypes.
 *)
 
 signature SMT_DATATYPES =
@@ -55,13 +55,25 @@
 fun declared declss T = exists (exists (equal T o fst)) declss
 fun declared' dss T = exists (exists (equal T o fst) o snd) dss
 
-fun get_decls T n Ts ctxt =
-  (case Ctr_Sugar.ctr_sugar_of ctxt n of
-    SOME ctr_sugar => get_ctr_sugar_decl ctr_sugar T Ts ctxt
-  | NONE =>
+(* Simplification: We assume that every type that is not a codatatype is a datatype (or a
+   record). *)
+fun fp_kind_of ctxt n =
+  (case BNF_FP_Def_Sugar.fp_sugar_of ctxt n of
+    SOME {fp, ...} => fp
+  | NONE => BNF_Util.Least_FP)
+
+fun get_decls fp T n Ts ctxt =
+  let
+    fun fallback () =
       (case Typedef.get_info ctxt n of
         [] => ([], ctxt)
-      | info :: _ => (get_typedef_decl info T Ts, ctxt)))
+      | info :: _ => (get_typedef_decl info T Ts, ctxt))
+  in
+    (case Ctr_Sugar.ctr_sugar_of ctxt n of
+      SOME ctr_sugar =>
+      if fp_kind_of ctxt n = fp then get_ctr_sugar_decl ctr_sugar T Ts ctxt else fallback ()
+    | NONE => fallback ())
+  end
 
 fun add_decls fp T (declss, ctxt) =
   let
@@ -76,7 +88,7 @@
           if declared declss T orelse declared' dss T then (dss, ctxt1)
           else if SMT_Builtin.is_builtin_typ_ext ctxt1 T then (dss, ctxt1)
           else
-            (case get_decls T n Ts ctxt1 of
+            (case get_decls fp T n Ts ctxt1 of
               ([], _) => (dss, ctxt1)
             | (ds, ctxt2) =>
                 let
--- a/src/HOL/Tools/SMT/smt_translate.ML	Wed Sep 17 16:53:39 2014 +0200
+++ b/src/HOL/Tools/SMT/smt_translate.ML	Wed Sep 17 17:32:27 2014 +0200
@@ -19,7 +19,8 @@
   type sign = {
     logic: string,
     sorts: string list,
-    dtyps: (string * (string * (string * string) list) list) list list,
+    lfp_dtyps: (string * (string * (string * string) list) list) list list,
+    gfp_dtyps: (string * (string * (string * string) list) list) list list,
     funcs: (string * (string list * string)) list }
   type config = {
     logic: term list -> string,
@@ -61,7 +62,8 @@
 type sign = {
   logic: string,
   sorts: string list,
-  dtyps: (string * (string * (string * string) list) list) list list,
+  lfp_dtyps: (string * (string * (string * string) list) list) list list,
+  gfp_dtyps: (string * (string * (string * string) list) list) list list,
   funcs: (string * (string list * string)) list }
 
 type config = {
@@ -114,10 +116,11 @@
         val terms' = Termtab.update (t, (name, sort)) terms
       in (name, (names', typs, terms')) end)
 
-fun sign_of logic dtyps (_, typs, terms) = {
+fun sign_of logic lfp_dtyps gfp_dtyps (_, typs, terms) = {
   logic = logic,
   sorts = Typtab.fold (fn (_, (n, true)) => cons n | _ => I) typs [],
-  dtyps = dtyps,
+  lfp_dtyps = lfp_dtyps,
+  gfp_dtyps = gfp_dtyps,
   funcs = Termtab.fold (fn (_, (n, SOME ss)) => cons (n,ss) | _ => I) terms []}
 
 fun replay_data_of ctxt ll_defs rules assms (_, typs, terms) =
@@ -135,17 +138,25 @@
 
 (* preprocessing *)
 
-(** datatype declarations **)
+(** (co)datatype declarations **)
 
-fun collect_datatypes_and_records (tr_context, ctxt) ts =
+fun collect_co_datatypes fp_kinds (tr_context, ctxt) ts =
   let
-    val (declss, ctxt') =
-      fold (Term.fold_types (SMT_Datatypes.add_decls BNF_Util.Least_FP)) ts ([], ctxt)
+    val (lfp_declss, ctxt') =
+      ([], ctxt)
+      |> member (op =) fp_kinds BNF_Util.Least_FP
+        ? fold (Term.fold_types (SMT_Datatypes.add_decls BNF_Util.Least_FP)) ts
+    val (gfp_declss, ctxt'') =
+      ([], ctxt')
+      |> member (op =) fp_kinds BNF_Util.Greatest_FP
+        ? fold (Term.fold_types (SMT_Datatypes.add_decls BNF_Util.Greatest_FP)) ts
 
-    fun is_decl_typ T = exists (exists (equal T o fst)) declss
+    val fp_declsss = [lfp_declss, gfp_declss]
+
+    fun is_decl_typ T = exists (exists (exists (equal T o fst))) fp_declsss
 
     fun add_typ' T proper =
-      (case SMT_Builtin.dest_builtin_typ ctxt' T of
+      (case SMT_Builtin.dest_builtin_typ ctxt'' T of
         SOME n => pair n
       | NONE => add_typ T proper)
 
@@ -155,14 +166,17 @@
     fun tr_constr (constr, selects) =
       add_fun constr NONE ##>> fold_map tr_select selects
     fun tr_typ (T, cases) = add_typ' T false ##>> fold_map tr_constr cases
-    val (declss', tr_context') = fold_map (fold_map tr_typ) declss tr_context
+
+    val (lfp_declss', tr_context') = fold_map (fold_map tr_typ) lfp_declss tr_context
+    val (gfp_declss', tr_context'') = fold_map (fold_map tr_typ) gfp_declss tr_context'
 
     fun add (constr, selects) =
       Termtab.update (constr, length selects) #>
       fold (Termtab.update o rpair 1) selects
-    val funcs = fold (fold (fold add o snd)) declss Termtab.empty
-  in ((funcs, declss', tr_context', ctxt'), ts) end
-    (* FIXME: also return necessary datatype and record theorems *)
+
+    val funcs = fold (fold (fold (fold add o snd))) fp_declsss Termtab.empty
+  in ((funcs, lfp_declss', gfp_declss', tr_context'', ctxt''), ts) end
+    (* FIXME: also return necessary (co)datatype theorems *)
 
 
 (** eta-expand quantifiers, let expressions and built-ins *)
@@ -416,7 +430,7 @@
 
 (** translation from Isabelle terms into SMT intermediate terms **)
 
-fun intermediate logic dtyps builtin ctxt ts trx =
+fun intermediate logic lfp_dtyps gfp_dtyps builtin ctxt ts trx =
   let
     fun transT (T as TFree _) = add_typ T true
       | transT (T as TVar _) = (fn _ => raise TYPE ("bad SMT type", [T], []))
@@ -453,7 +467,7 @@
       end
 
     val (us, trx') = fold_map trans ts trx
-  in ((sign_of (logic ts) dtyps trx', us), trx') end
+  in ((sign_of (logic ts) lfp_dtyps gfp_dtyps trx', us), trx') end
 
 
 (* translation *)
@@ -482,14 +496,13 @@
     val {logic, fp_kinds, serialize} = get_config ctxt
 
     fun no_dtyps (tr_context, ctxt) ts =
-      ((Termtab.empty, [], tr_context, ctxt), ts)
+      ((Termtab.empty, [], [], tr_context, ctxt), ts)
 
     val ts1 = map (Envir.beta_eta_contract o SMT_Util.prop_of o snd) ithms
 
-    val ((funcs, dtyps, tr_context, ctxt1), ts2) =
+    val ((funcs, lfp_dtyps, gfp_dtyps, tr_context, ctxt1), ts2) =
       ((empty_tr_context, ctxt), ts1)
-      |-> (if member (op =) fp_kinds BNF_Util.Least_FP then collect_datatypes_and_records
-        else no_dtyps)
+      |-> (if null fp_kinds then no_dtyps else collect_co_datatypes fp_kinds)
 
     fun is_binder (Const (@{const_name Let}, _) $ _) = true
       | is_binder t = Lambda_Lifting.is_quantifier t
@@ -516,7 +529,7 @@
       |>> apfst (cons fun_app_eq)
   in
     (ts4, tr_context)
-    |-> intermediate logic dtyps (builtin SMT_Builtin.dest_builtin) ctxt2
+    |-> intermediate logic lfp_dtyps gfp_dtyps (builtin SMT_Builtin.dest_builtin) ctxt2
     |>> uncurry (serialize smt_options comments)
     ||> replay_data_of ctxt2 ll_defs rewrite_rules ithms
   end
--- a/src/HOL/Tools/SMT/smtlib_interface.ML	Wed Sep 17 16:53:39 2014 +0200
+++ b/src/HOL/Tools/SMT/smtlib_interface.ML	Wed Sep 17 17:32:27 2014 +0200
@@ -133,15 +133,18 @@
 fun assert_name_of_index i = assert_prefix ^ string_of_int i
 fun assert_index_of_name s = the_default ~1 (Int.fromString (unprefix assert_prefix s))
 
-fun serialize smt_options comments {logic, sorts, dtyps, funcs} ts =
+fun serialize smt_options comments {logic, sorts, lfp_dtyps, gfp_dtyps, funcs} ts =
   Buffer.empty
   |> fold (Buffer.add o enclose "; " "\n") comments
   |> fold (fn (k, v) => Buffer.add ("(set-option " ^ k ^ " " ^ v ^ ")\n")) smt_options
   |> Buffer.add logic
   |> fold (Buffer.add o enclose "(declare-sort " " 0)\n") (sort fast_string_ord sorts)
-  |> (if null dtyps then I
+  |> (if null lfp_dtyps then I
     else Buffer.add (enclose "(declare-datatypes () (" "))\n"
-      (space_implode "\n  " (maps (map sdatatype) dtyps))))
+      (space_implode "\n  " (maps (map sdatatype) lfp_dtyps))))
+  |> (if null gfp_dtyps then I
+    else Buffer.add (enclose "(declare-codatatypes () (" "))\n"
+      (space_implode "\n  " (maps (map sdatatype) gfp_dtyps))))
   |> fold (Buffer.add o enclose "(declare-fun " ")\n" o string_of_fun)
       (sort (fast_string_ord o pairself fst) funcs)
   |> fold (fn (i, t) => Buffer.add (enclose "(assert " ")\n"