interleave (co)datatypes in the right order w.r.t. dependencies
authorblanchet
Wed, 24 Sep 2014 15:46:25 +0200
changeset 58429 0b94858325a5
parent 58428 e4e34dfc3e68
child 58430 73df5884edcf
interleave (co)datatypes in the right order w.r.t. dependencies
src/HOL/Tools/SMT/smt_datatypes.ML
src/HOL/Tools/SMT/smt_translate.ML
src/HOL/Tools/SMT/smtlib_interface.ML
--- a/src/HOL/Tools/SMT/smt_datatypes.ML	Wed Sep 24 15:46:24 2014 +0200
+++ b/src/HOL/Tools/SMT/smt_datatypes.ML	Wed Sep 24 15:46:25 2014 +0200
@@ -7,9 +7,9 @@
 
 signature SMT_DATATYPES =
 sig
-  val add_decls: BNF_Util.fp_kind -> typ ->
-    (typ * (term * term list) list) list list * Proof.context ->
-    (typ * (term * term list) list) list list * Proof.context
+  val add_decls: BNF_Util.fp_kind list -> typ ->
+    (BNF_Util.fp_kind * (typ * (term * term list) list)) list list * Proof.context ->
+    (BNF_Util.fp_kind * (typ * (term * term list) list)) list list * Proof.context
 end;
 
 structure SMT_Datatypes: SMT_DATATYPES =
@@ -63,49 +63,47 @@
 
 val extN = "_ext" (* cf. "HOL/Tools/typedef.ML" *)
 
-fun get_decls fp T n Ts ctxt =
+fun get_decls fps T n Ts ctxt =
   let
     fun maybe_typedef () =
       (case Typedef.get_info ctxt n of
         [] => ([], ctxt)
-      | info :: _ => (get_typedef_decl info T Ts, ctxt))
+      | info :: _ => (map (pair (hd fps)) (get_typedef_decl info T Ts), ctxt))
   in
     (case BNF_FP_Def_Sugar.fp_sugar_of ctxt n of
-      SOME {fp = fp', fp_res = {Ts = fp_Ts, ...}, ctr_sugar, ...} =>
-      if fp' = fp then
+      SOME {fp, fp_res = {Ts = fp_Ts, ...}, ctr_sugar, ...} =>
+      if member (op =) fps fp then
         let
           val ns = map (fst o dest_Type) fp_Ts
           val mutual_fp_sugars = map_filter (BNF_FP_Def_Sugar.fp_sugar_of ctxt) ns
           val Xs = map #X mutual_fp_sugars
           val ctrXs_Tsss = map #ctrXs_Tss mutual_fp_sugars
 
-          fun is_nested_co_recursive (T as Type _) =
-              BNF_FP_Rec_Sugar_Util.exists_subtype_in Xs T
+          (* FIXME: allow nested recursion to same FP kind *)
+          fun is_nested_co_recursive (T as Type _) = BNF_FP_Rec_Sugar_Util.exists_subtype_in Xs T
             | is_nested_co_recursive _ = false
         in
           if exists (exists (exists is_nested_co_recursive)) ctrXs_Tsss then maybe_typedef ()
-          else get_ctr_sugar_decl ctr_sugar T Ts ctxt
+          else get_ctr_sugar_decl ctr_sugar T Ts ctxt |>> map (pair fp)
         end
       else
         ([], ctxt)
     | NONE =>
-      if fp = BNF_Util.Least_FP then
-        if String.isSuffix extN n then
-          (* for records (FIXME: hack) *)
-          (case Ctr_Sugar.ctr_sugar_of ctxt n of
-            SOME ctr_sugar => get_ctr_sugar_decl ctr_sugar T Ts ctxt
-          | NONE => maybe_typedef ())
-        else
-          maybe_typedef ()
+      if String.isSuffix extN n then
+        (* for records (FIXME: hack) *)
+        (case Ctr_Sugar.ctr_sugar_of ctxt n of
+          SOME ctr_sugar =>
+          get_ctr_sugar_decl ctr_sugar T Ts ctxt |>> map (pair (hd fps))
+        | NONE => maybe_typedef ())
       else
-        ([], ctxt))
+        maybe_typedef ())
   end
 
-fun add_decls fp T (declss, ctxt) =
+fun add_decls fps T (declss, ctxt) =
   let
-    fun declared T = exists (exists (equal T o fst))
-    fun declared' T = exists (exists (equal T o fst) o snd)
-    fun depends ds = exists (member (op =) (map fst ds))
+    fun declared T = exists (exists (equal T o fst o snd))
+    fun declared' T = exists (exists (equal T o fst o snd) o snd)
+    fun depends ds = exists (member (op =) (map (fst o snd) ds))
 
     fun add (TFree _) = I
       | add (TVar _) = I
@@ -113,14 +111,16 @@
           fold add (Term.body_type T :: Term.binder_types T)
       | add @{typ bool} = I
       | add (T as Type (n, Ts)) = (fn (dss, ctxt1) =>
-          if declared T declss orelse declared' T dss then (dss, ctxt1)
-          else if SMT_Builtin.is_builtin_typ_ext ctxt1 T then (dss, ctxt1)
+          if declared T declss orelse declared' T dss then
+            (dss, ctxt1)
+          else if SMT_Builtin.is_builtin_typ_ext ctxt1 T then
+            (dss, ctxt1)
           else
-            (case get_decls fp T n Ts ctxt1 of
+            (case get_decls fps T n Ts ctxt1 of
               ([], _) => (dss, ctxt1)
             | (ds, ctxt2) =>
                 let
-                  val constrTs = maps (map (snd o Term.dest_Const o fst) o snd) ds
+                  val constrTs = maps (map (snd o Term.dest_Const o fst) o snd o snd) ds
                   val Us = fold (union (op =) o Term.binder_types) constrTs []
 
                   fun ins [] = [(Us, ds)]
--- a/src/HOL/Tools/SMT/smt_translate.ML	Wed Sep 24 15:46:24 2014 +0200
+++ b/src/HOL/Tools/SMT/smt_translate.ML	Wed Sep 24 15:46:25 2014 +0200
@@ -19,8 +19,7 @@
   type sign = {
     logic: string,
     sorts: string list,
-    lfp_dtyps: (string * (string * (string * string) list) list) list list,
-    gfp_dtyps: (string * (string * (string * string) list) list) list list,
+    dtyps: (BNF_Util.fp_kind * (string * (string * (string * string) list) list)) list,
     funcs: (string * (string list * string)) list }
   type config = {
     logic: term list -> string,
@@ -62,8 +61,7 @@
 type sign = {
   logic: string,
   sorts: string list,
-  lfp_dtyps: (string * (string * (string * string) list) list) list list,
-  gfp_dtyps: (string * (string * (string * string) list) list) list list,
+  dtyps: (BNF_Util.fp_kind * (string * (string * (string * string) list) list)) list,
   funcs: (string * (string list * string)) list }
 
 type config = {
@@ -116,11 +114,10 @@
         val terms' = Termtab.update (t, (name, sort)) terms
       in (name, (names', typs, terms')) end)
 
-fun sign_of logic lfp_dtyps gfp_dtyps (_, typs, terms) = {
+fun sign_of logic dtyps (_, typs, terms) = {
   logic = logic,
   sorts = Typtab.fold (fn (_, (n, true)) => cons n | _ => I) typs [],
-  lfp_dtyps = lfp_dtyps,
-  gfp_dtyps = gfp_dtyps,
+  dtyps = dtyps,
   funcs = Termtab.fold (fn (_, (n, SOME ss)) => cons (n,ss) | _ => I) terms []}
 
 fun replay_data_of ctxt ll_defs rules assms (_, typs, terms) =
@@ -142,21 +139,15 @@
 
 fun collect_co_datatypes fp_kinds (tr_context, ctxt) ts =
   let
-    val (lfp_declss, ctxt') =
+    val (fp_decls, ctxt') =
       ([], ctxt)
-      |> member (op =) fp_kinds BNF_Util.Least_FP
-        ? fold (Term.fold_types (SMT_Datatypes.add_decls BNF_Util.Least_FP)) ts
-    val (gfp_declss, ctxt'') =
-      ([], ctxt')
-      |> member (op =) fp_kinds BNF_Util.Greatest_FP
-        ? fold (Term.fold_types (SMT_Datatypes.add_decls BNF_Util.Greatest_FP)) ts
+      |> fold (Term.fold_types (SMT_Datatypes.add_decls fp_kinds)) ts
+      |>> flat
 
-    val fp_declsss = [lfp_declss, gfp_declss]
-
-    fun is_decl_typ T = exists (exists (exists (equal T o fst))) fp_declsss
+    fun is_decl_typ T = exists (equal T o fst o snd) fp_decls
 
     fun add_typ' T proper =
-      (case SMT_Builtin.dest_builtin_typ ctxt'' T of
+      (case SMT_Builtin.dest_builtin_typ ctxt' T of
         SOME n => pair n
       | NONE => add_typ T proper)
 
@@ -165,17 +156,18 @@
       in add_fun sel NONE ##>> add_typ' T (not (is_decl_typ T)) end
     fun tr_constr (constr, selects) =
       add_fun constr NONE ##>> fold_map tr_select selects
-    fun tr_typ (T, cases) = add_typ' T false ##>> fold_map tr_constr cases
+    fun tr_typ (fp, (T, cases)) =
+      add_typ' T false ##>> fold_map tr_constr cases #>> pair fp
 
-    val (lfp_declss', tr_context') = fold_map (fold_map tr_typ) lfp_declss tr_context
-    val (gfp_declss', tr_context'') = fold_map (fold_map tr_typ) gfp_declss tr_context'
+    val (fp_decls', tr_context') = fold_map tr_typ fp_decls tr_context
 
     fun add (constr, selects) =
       Termtab.update (constr, length selects) #>
       fold (Termtab.update o rpair 1) selects
 
-    val funcs = fold (fold (fold (fold add o snd))) fp_declsss Termtab.empty
-  in ((funcs, lfp_declss', gfp_declss', tr_context'', ctxt''), ts) end
+    val funcs = fold (fold add o snd o snd) fp_decls Termtab.empty
+
+  in ((funcs, fp_decls', tr_context', ctxt'), ts) end
     (* FIXME: also return necessary (co)datatype theorems *)
 
 
@@ -430,7 +422,7 @@
 
 (** translation from Isabelle terms into SMT intermediate terms **)
 
-fun intermediate logic lfp_dtyps gfp_dtyps builtin ctxt ts trx =
+fun intermediate logic dtyps builtin ctxt ts trx =
   let
     fun transT (T as TFree _) = add_typ T true
       | transT (T as TVar _) = (fn _ => raise TYPE ("bad SMT type", [T], []))
@@ -467,7 +459,7 @@
       end
 
     val (us, trx') = fold_map trans ts trx
-  in ((sign_of (logic ts) lfp_dtyps gfp_dtyps trx', us), trx') end
+  in ((sign_of (logic ts) dtyps trx', us), trx') end
 
 
 (* translation *)
@@ -496,11 +488,11 @@
     val {logic, fp_kinds, serialize} = get_config ctxt
 
     fun no_dtyps (tr_context, ctxt) ts =
-      ((Termtab.empty, [], [], tr_context, ctxt), ts)
+      ((Termtab.empty, [], tr_context, ctxt), ts)
 
     val ts1 = map (Envir.beta_eta_contract o SMT_Util.prop_of o snd) ithms
 
-    val ((funcs, lfp_dtyps, gfp_dtyps, tr_context, ctxt1), ts2) =
+    val ((funcs, dtyps, tr_context, ctxt1), ts2) =
       ((empty_tr_context, ctxt), ts1)
       |-> (if null fp_kinds then no_dtyps else collect_co_datatypes fp_kinds)
 
@@ -527,9 +519,10 @@
 
     val ((rewrite_rules, builtin), ts4) = folify ctxt2 ts3
       |>> apfst (cons fun_app_eq)
+val _ = dtyps : (BNF_Util.fp_kind * (string * (string * (string * string) list) list)) list (*###*)
   in
     (ts4, tr_context)
-    |-> intermediate logic lfp_dtyps gfp_dtyps (builtin SMT_Builtin.dest_builtin) ctxt2
+    |-> intermediate logic dtyps (builtin SMT_Builtin.dest_builtin) ctxt2
     |>> uncurry (serialize smt_options comments)
     ||> replay_data_of ctxt2 ll_defs rewrite_rules ithms
   end
--- a/src/HOL/Tools/SMT/smtlib_interface.ML	Wed Sep 24 15:46:24 2014 +0200
+++ b/src/HOL/Tools/SMT/smtlib_interface.ML	Wed Sep 24 15:46:25 2014 +0200
@@ -133,18 +133,17 @@
 fun assert_name_of_index i = assert_prefix ^ string_of_int i
 fun assert_index_of_name s = the_default ~1 (Int.fromString (unprefix assert_prefix s))
 
-fun serialize smt_options comments {logic, sorts, lfp_dtyps, gfp_dtyps, funcs} ts =
+fun sdtyp (fp, dtyps) =
+  Buffer.add (enclose ("(declare-" ^ BNF_FP_Util.co_prefix fp ^ "datatypes () (") "))\n"
+    (space_implode "\n  " (map sdatatype dtyps)))
+
+fun serialize smt_options comments {logic, sorts, dtyps, funcs} ts =
   Buffer.empty
   |> fold (Buffer.add o enclose "; " "\n") comments
   |> fold (fn (k, v) => Buffer.add ("(set-option " ^ k ^ " " ^ v ^ ")\n")) smt_options
   |> Buffer.add logic
   |> fold (Buffer.add o enclose "(declare-sort " " 0)\n") (sort fast_string_ord sorts)
-  |> (if null lfp_dtyps then I
-    else Buffer.add (enclose "(declare-datatypes () (" "))\n"
-      (space_implode "\n  " (maps (map sdatatype) lfp_dtyps))))
-  |> (if null gfp_dtyps then I
-    else Buffer.add (enclose "(declare-codatatypes () (" "))\n"
-      (space_implode "\n  " (maps (map sdatatype) gfp_dtyps))))
+  |> fold sdtyp (AList.coalesce (op =) dtyps)
   |> fold (Buffer.add o enclose "(declare-fun " ")\n" o string_of_fun)
       (sort (fast_string_ord o pairself fst) funcs)
   |> fold (fn (i, t) => Buffer.add (enclose "(assert " ")\n"