--- a/src/HOL/Tools/SMT/smt_datatypes.ML Wed Sep 17 16:53:39 2014 +0200
+++ b/src/HOL/Tools/SMT/smt_datatypes.ML Wed Sep 17 17:32:27 2014 +0200
@@ -2,7 +2,7 @@
Author: Sascha Boehme, TU Muenchen
Collector functions for common type declarations and their representation
-as algebraic datatypes.
+as (co)algebraic datatypes.
*)
signature SMT_DATATYPES =
@@ -55,13 +55,25 @@
fun declared declss T = exists (exists (equal T o fst)) declss
fun declared' dss T = exists (exists (equal T o fst) o snd) dss
-fun get_decls T n Ts ctxt =
- (case Ctr_Sugar.ctr_sugar_of ctxt n of
- SOME ctr_sugar => get_ctr_sugar_decl ctr_sugar T Ts ctxt
- | NONE =>
+(* Simplification: We assume that every type that is not a codatatype is a datatype (or a
+ record). *)
+fun fp_kind_of ctxt n =
+ (case BNF_FP_Def_Sugar.fp_sugar_of ctxt n of
+ SOME {fp, ...} => fp
+ | NONE => BNF_Util.Least_FP)
+
+fun get_decls fp T n Ts ctxt =
+ let
+ fun fallback () =
(case Typedef.get_info ctxt n of
[] => ([], ctxt)
- | info :: _ => (get_typedef_decl info T Ts, ctxt)))
+ | info :: _ => (get_typedef_decl info T Ts, ctxt))
+ in
+ (case Ctr_Sugar.ctr_sugar_of ctxt n of
+ SOME ctr_sugar =>
+ if fp_kind_of ctxt n = fp then get_ctr_sugar_decl ctr_sugar T Ts ctxt else fallback ()
+ | NONE => fallback ())
+ end
fun add_decls fp T (declss, ctxt) =
let
@@ -76,7 +88,7 @@
if declared declss T orelse declared' dss T then (dss, ctxt1)
else if SMT_Builtin.is_builtin_typ_ext ctxt1 T then (dss, ctxt1)
else
- (case get_decls T n Ts ctxt1 of
+ (case get_decls fp T n Ts ctxt1 of
([], _) => (dss, ctxt1)
| (ds, ctxt2) =>
let
--- a/src/HOL/Tools/SMT/smt_translate.ML Wed Sep 17 16:53:39 2014 +0200
+++ b/src/HOL/Tools/SMT/smt_translate.ML Wed Sep 17 17:32:27 2014 +0200
@@ -19,7 +19,8 @@
type sign = {
logic: string,
sorts: string list,
- dtyps: (string * (string * (string * string) list) list) list list,
+ lfp_dtyps: (string * (string * (string * string) list) list) list list,
+ gfp_dtyps: (string * (string * (string * string) list) list) list list,
funcs: (string * (string list * string)) list }
type config = {
logic: term list -> string,
@@ -61,7 +62,8 @@
type sign = {
logic: string,
sorts: string list,
- dtyps: (string * (string * (string * string) list) list) list list,
+ lfp_dtyps: (string * (string * (string * string) list) list) list list,
+ gfp_dtyps: (string * (string * (string * string) list) list) list list,
funcs: (string * (string list * string)) list }
type config = {
@@ -114,10 +116,11 @@
val terms' = Termtab.update (t, (name, sort)) terms
in (name, (names', typs, terms')) end)
-fun sign_of logic dtyps (_, typs, terms) = {
+fun sign_of logic lfp_dtyps gfp_dtyps (_, typs, terms) = {
logic = logic,
sorts = Typtab.fold (fn (_, (n, true)) => cons n | _ => I) typs [],
- dtyps = dtyps,
+ lfp_dtyps = lfp_dtyps,
+ gfp_dtyps = gfp_dtyps,
funcs = Termtab.fold (fn (_, (n, SOME ss)) => cons (n,ss) | _ => I) terms []}
fun replay_data_of ctxt ll_defs rules assms (_, typs, terms) =
@@ -135,17 +138,25 @@
(* preprocessing *)
-(** datatype declarations **)
+(** (co)datatype declarations **)
-fun collect_datatypes_and_records (tr_context, ctxt) ts =
+fun collect_co_datatypes fp_kinds (tr_context, ctxt) ts =
let
- val (declss, ctxt') =
- fold (Term.fold_types (SMT_Datatypes.add_decls BNF_Util.Least_FP)) ts ([], ctxt)
+ val (lfp_declss, ctxt') =
+ ([], ctxt)
+ |> member (op =) fp_kinds BNF_Util.Least_FP
+ ? fold (Term.fold_types (SMT_Datatypes.add_decls BNF_Util.Least_FP)) ts
+ val (gfp_declss, ctxt'') =
+ ([], ctxt')
+ |> member (op =) fp_kinds BNF_Util.Greatest_FP
+ ? fold (Term.fold_types (SMT_Datatypes.add_decls BNF_Util.Greatest_FP)) ts
- fun is_decl_typ T = exists (exists (equal T o fst)) declss
+ val fp_declsss = [lfp_declss, gfp_declss]
+
+ fun is_decl_typ T = exists (exists (exists (equal T o fst))) fp_declsss
fun add_typ' T proper =
- (case SMT_Builtin.dest_builtin_typ ctxt' T of
+ (case SMT_Builtin.dest_builtin_typ ctxt'' T of
SOME n => pair n
| NONE => add_typ T proper)
@@ -155,14 +166,17 @@
fun tr_constr (constr, selects) =
add_fun constr NONE ##>> fold_map tr_select selects
fun tr_typ (T, cases) = add_typ' T false ##>> fold_map tr_constr cases
- val (declss', tr_context') = fold_map (fold_map tr_typ) declss tr_context
+
+ val (lfp_declss', tr_context') = fold_map (fold_map tr_typ) lfp_declss tr_context
+ val (gfp_declss', tr_context'') = fold_map (fold_map tr_typ) gfp_declss tr_context'
fun add (constr, selects) =
Termtab.update (constr, length selects) #>
fold (Termtab.update o rpair 1) selects
- val funcs = fold (fold (fold add o snd)) declss Termtab.empty
- in ((funcs, declss', tr_context', ctxt'), ts) end
- (* FIXME: also return necessary datatype and record theorems *)
+
+ val funcs = fold (fold (fold (fold add o snd))) fp_declsss Termtab.empty
+ in ((funcs, lfp_declss', gfp_declss', tr_context'', ctxt''), ts) end
+ (* FIXME: also return necessary (co)datatype theorems *)
(** eta-expand quantifiers, let expressions and built-ins *)
@@ -416,7 +430,7 @@
(** translation from Isabelle terms into SMT intermediate terms **)
-fun intermediate logic dtyps builtin ctxt ts trx =
+fun intermediate logic lfp_dtyps gfp_dtyps builtin ctxt ts trx =
let
fun transT (T as TFree _) = add_typ T true
| transT (T as TVar _) = (fn _ => raise TYPE ("bad SMT type", [T], []))
@@ -453,7 +467,7 @@
end
val (us, trx') = fold_map trans ts trx
- in ((sign_of (logic ts) dtyps trx', us), trx') end
+ in ((sign_of (logic ts) lfp_dtyps gfp_dtyps trx', us), trx') end
(* translation *)
@@ -482,14 +496,13 @@
val {logic, fp_kinds, serialize} = get_config ctxt
fun no_dtyps (tr_context, ctxt) ts =
- ((Termtab.empty, [], tr_context, ctxt), ts)
+ ((Termtab.empty, [], [], tr_context, ctxt), ts)
val ts1 = map (Envir.beta_eta_contract o SMT_Util.prop_of o snd) ithms
- val ((funcs, dtyps, tr_context, ctxt1), ts2) =
+ val ((funcs, lfp_dtyps, gfp_dtyps, tr_context, ctxt1), ts2) =
((empty_tr_context, ctxt), ts1)
- |-> (if member (op =) fp_kinds BNF_Util.Least_FP then collect_datatypes_and_records
- else no_dtyps)
+ |-> (if null fp_kinds then no_dtyps else collect_co_datatypes fp_kinds)
fun is_binder (Const (@{const_name Let}, _) $ _) = true
| is_binder t = Lambda_Lifting.is_quantifier t
@@ -516,7 +529,7 @@
|>> apfst (cons fun_app_eq)
in
(ts4, tr_context)
- |-> intermediate logic dtyps (builtin SMT_Builtin.dest_builtin) ctxt2
+ |-> intermediate logic lfp_dtyps gfp_dtyps (builtin SMT_Builtin.dest_builtin) ctxt2
|>> uncurry (serialize smt_options comments)
||> replay_data_of ctxt2 ll_defs rewrite_rules ithms
end