src/HOL/Tools/SMT/smt_translate.ML
changeset 58361 7f2b3b6f6ad1
parent 58360 dee1fd1cc631
child 58429 0b94858325a5
--- a/src/HOL/Tools/SMT/smt_translate.ML	Wed Sep 17 16:53:39 2014 +0200
+++ b/src/HOL/Tools/SMT/smt_translate.ML	Wed Sep 17 17:32:27 2014 +0200
@@ -19,7 +19,8 @@
   type sign = {
     logic: string,
     sorts: string list,
-    dtyps: (string * (string * (string * string) list) list) list list,
+    lfp_dtyps: (string * (string * (string * string) list) list) list list,
+    gfp_dtyps: (string * (string * (string * string) list) list) list list,
     funcs: (string * (string list * string)) list }
   type config = {
     logic: term list -> string,
@@ -61,7 +62,8 @@
 type sign = {
   logic: string,
   sorts: string list,
-  dtyps: (string * (string * (string * string) list) list) list list,
+  lfp_dtyps: (string * (string * (string * string) list) list) list list,
+  gfp_dtyps: (string * (string * (string * string) list) list) list list,
   funcs: (string * (string list * string)) list }
 
 type config = {
@@ -114,10 +116,11 @@
         val terms' = Termtab.update (t, (name, sort)) terms
       in (name, (names', typs, terms')) end)
 
-fun sign_of logic dtyps (_, typs, terms) = {
+fun sign_of logic lfp_dtyps gfp_dtyps (_, typs, terms) = {
   logic = logic,
   sorts = Typtab.fold (fn (_, (n, true)) => cons n | _ => I) typs [],
-  dtyps = dtyps,
+  lfp_dtyps = lfp_dtyps,
+  gfp_dtyps = gfp_dtyps,
   funcs = Termtab.fold (fn (_, (n, SOME ss)) => cons (n,ss) | _ => I) terms []}
 
 fun replay_data_of ctxt ll_defs rules assms (_, typs, terms) =
@@ -135,17 +138,25 @@
 
 (* preprocessing *)
 
-(** datatype declarations **)
+(** (co)datatype declarations **)
 
-fun collect_datatypes_and_records (tr_context, ctxt) ts =
+fun collect_co_datatypes fp_kinds (tr_context, ctxt) ts =
   let
-    val (declss, ctxt') =
-      fold (Term.fold_types (SMT_Datatypes.add_decls BNF_Util.Least_FP)) ts ([], ctxt)
+    val (lfp_declss, ctxt') =
+      ([], ctxt)
+      |> member (op =) fp_kinds BNF_Util.Least_FP
+        ? fold (Term.fold_types (SMT_Datatypes.add_decls BNF_Util.Least_FP)) ts
+    val (gfp_declss, ctxt'') =
+      ([], ctxt')
+      |> member (op =) fp_kinds BNF_Util.Greatest_FP
+        ? fold (Term.fold_types (SMT_Datatypes.add_decls BNF_Util.Greatest_FP)) ts
 
-    fun is_decl_typ T = exists (exists (equal T o fst)) declss
+    val fp_declsss = [lfp_declss, gfp_declss]
+
+    fun is_decl_typ T = exists (exists (exists (equal T o fst))) fp_declsss
 
     fun add_typ' T proper =
-      (case SMT_Builtin.dest_builtin_typ ctxt' T of
+      (case SMT_Builtin.dest_builtin_typ ctxt'' T of
         SOME n => pair n
       | NONE => add_typ T proper)
 
@@ -155,14 +166,17 @@
     fun tr_constr (constr, selects) =
       add_fun constr NONE ##>> fold_map tr_select selects
     fun tr_typ (T, cases) = add_typ' T false ##>> fold_map tr_constr cases
-    val (declss', tr_context') = fold_map (fold_map tr_typ) declss tr_context
+
+    val (lfp_declss', tr_context') = fold_map (fold_map tr_typ) lfp_declss tr_context
+    val (gfp_declss', tr_context'') = fold_map (fold_map tr_typ) gfp_declss tr_context'
 
     fun add (constr, selects) =
       Termtab.update (constr, length selects) #>
       fold (Termtab.update o rpair 1) selects
-    val funcs = fold (fold (fold add o snd)) declss Termtab.empty
-  in ((funcs, declss', tr_context', ctxt'), ts) end
-    (* FIXME: also return necessary datatype and record theorems *)
+
+    val funcs = fold (fold (fold (fold add o snd))) fp_declsss Termtab.empty
+  in ((funcs, lfp_declss', gfp_declss', tr_context'', ctxt''), ts) end
+    (* FIXME: also return necessary (co)datatype theorems *)
 
 
 (** eta-expand quantifiers, let expressions and built-ins *)
@@ -416,7 +430,7 @@
 
 (** translation from Isabelle terms into SMT intermediate terms **)
 
-fun intermediate logic dtyps builtin ctxt ts trx =
+fun intermediate logic lfp_dtyps gfp_dtyps builtin ctxt ts trx =
   let
     fun transT (T as TFree _) = add_typ T true
       | transT (T as TVar _) = (fn _ => raise TYPE ("bad SMT type", [T], []))
@@ -453,7 +467,7 @@
       end
 
     val (us, trx') = fold_map trans ts trx
-  in ((sign_of (logic ts) dtyps trx', us), trx') end
+  in ((sign_of (logic ts) lfp_dtyps gfp_dtyps trx', us), trx') end
 
 
 (* translation *)
@@ -482,14 +496,13 @@
     val {logic, fp_kinds, serialize} = get_config ctxt
 
     fun no_dtyps (tr_context, ctxt) ts =
-      ((Termtab.empty, [], tr_context, ctxt), ts)
+      ((Termtab.empty, [], [], tr_context, ctxt), ts)
 
     val ts1 = map (Envir.beta_eta_contract o SMT_Util.prop_of o snd) ithms
 
-    val ((funcs, dtyps, tr_context, ctxt1), ts2) =
+    val ((funcs, lfp_dtyps, gfp_dtyps, tr_context, ctxt1), ts2) =
       ((empty_tr_context, ctxt), ts1)
-      |-> (if member (op =) fp_kinds BNF_Util.Least_FP then collect_datatypes_and_records
-        else no_dtyps)
+      |-> (if null fp_kinds then no_dtyps else collect_co_datatypes fp_kinds)
 
     fun is_binder (Const (@{const_name Let}, _) $ _) = true
       | is_binder t = Lambda_Lifting.is_quantifier t
@@ -516,7 +529,7 @@
       |>> apfst (cons fun_app_eq)
   in
     (ts4, tr_context)
-    |-> intermediate logic dtyps (builtin SMT_Builtin.dest_builtin) ctxt2
+    |-> intermediate logic lfp_dtyps gfp_dtyps (builtin SMT_Builtin.dest_builtin) ctxt2
     |>> uncurry (serialize smt_options comments)
     ||> replay_data_of ctxt2 ll_defs rewrite_rules ithms
   end