src/HOL/Tools/Nitpick/nitpick_kodkod.ML
changeset 38126 8031d099379a
parent 38124 6538e25cf5dd
child 38127 9f9f696fc4e8
--- a/src/HOL/Tools/Nitpick/nitpick_kodkod.ML	Sat Jul 31 22:02:54 2010 +0200
+++ b/src/HOL/Tools/Nitpick/nitpick_kodkod.ML	Sun Aug 01 15:51:25 2010 +0200
@@ -8,12 +8,9 @@
 signature NITPICK_KODKOD =
 sig
   type hol_context = Nitpick_HOL.hol_context
-  type dtype_spec = Nitpick_Scope.dtype_spec
+  type datatype_spec = Nitpick_Scope.datatype_spec
   type kodkod_constrs = Nitpick_Peephole.kodkod_constrs
   type nut = Nitpick_Nut.nut
-  type nfa_transition = Kodkod.rel_expr * typ
-  type nfa_entry = typ * nfa_transition list
-  type nfa_table = nfa_entry list
 
   structure NameTable : TABLE
 
@@ -25,17 +22,17 @@
   val tuple_set_from_atom_schema : (int * int) list -> Kodkod.tuple_set
   val sequential_int_bounds : int -> Kodkod.int_bound list
   val pow_of_two_int_bounds : int -> int -> Kodkod.int_bound list
-  val bounds_for_built_in_rels_in_formula :
-    bool -> int Typtab.table -> int -> int -> int -> int -> Kodkod.formula
-    -> Kodkod.bound list
+  val bounds_and_axioms_for_built_in_rels_in_formulas :
+    bool -> int Typtab.table -> int -> int -> int -> int -> Kodkod.formula list
+    -> Kodkod.bound list * Kodkod.formula list
   val bound_for_plain_rel : Proof.context -> bool -> nut -> Kodkod.bound
   val bound_for_sel_rel :
-    Proof.context -> bool -> dtype_spec list -> nut -> Kodkod.bound
+    Proof.context -> bool -> datatype_spec list -> nut -> Kodkod.bound
   val merge_bounds : Kodkod.bound list -> Kodkod.bound list
   val declarative_axiom_for_plain_rel : kodkod_constrs -> nut -> Kodkod.formula
   val declarative_axioms_for_datatypes :
-    hol_context -> bool -> int -> int Typtab.table -> kodkod_constrs
-    -> nut NameTable.table -> dtype_spec list -> Kodkod.formula list
+    hol_context -> bool -> int -> int -> int Typtab.table -> kodkod_constrs
+    -> nut NameTable.table -> datatype_spec list -> Kodkod.formula list
   val kodkod_formula_from_nut :
     int Typtab.table -> kodkod_constrs -> nut -> Kodkod.formula
 end;
@@ -52,11 +49,13 @@
 
 structure KK = Kodkod
 
-type nfa_transition = KK.rel_expr * typ
-type nfa_entry = typ * nfa_transition list
-type nfa_table = nfa_entry list
+structure NfaGraph = Typ_Graph
+
+fun pull x xs = x :: filter_out (curry (op =) x) xs
 
-structure NfaGraph = Typ_Graph
+fun is_datatype_good_and_old ({co = false, standard = true, deep = true, ...}
+                              : datatype_spec) = true
+  | is_datatype_good_and_old _ = false
 
 fun flip_nums n = index_seq 1 n @ [0] |> map KK.Num
 
@@ -123,19 +122,16 @@
         aux (iter - 1) (2 * pow_of_two) (j + 1)
   in aux (bits + 1) 1 j0 end
 
-fun built_in_rels_in_formula formula =
+fun built_in_rels_in_formulas formulas =
   let
     fun rel_expr_func (KK.Rel (x as (n, j))) =
-        if x = unsigned_bit_word_sel_rel orelse x = signed_bit_word_sel_rel then
-          I
-        else
-          (case AList.lookup (op =) (#rels initial_pool) n of
-             SOME k => j < k ? insert (op =) x
-           | NONE => I)
+        (j < 0 andalso x <> unsigned_bit_word_sel_rel andalso
+         x <> signed_bit_word_sel_rel)
+        ? insert (op =) x
       | rel_expr_func _ = I
     val expr_F = {formula_func = K I, rel_expr_func = rel_expr_func,
                   int_expr_func = K I}
-  in KK.fold_formula expr_F formula [] end
+  in fold (KK.fold_formula expr_F) formulas [] end
 
 val max_table_size = 65536
 
@@ -202,24 +198,14 @@
   else if m = 0 orelse n = 0 then (0, 1)
   else let val p = isa_zgcd (m, n) in (isa_div (m, p), isa_div (n, p)) end
 
-fun tabulate_suc debug ofs univ_card main_j0 =
-  let
-    val j0s = Typtab.fold (insert (op =) o snd) ofs [main_j0] |> sort int_ord
-    val ks = map (op -) (tl (j0s @ [univ_card]) ~~ j0s)
-  in
-    map2 (fn j0 => fn k =>
-             tabulate_func1 debug univ_card (k - 1, j0) (Integer.add 1))
-         j0s ks
-    |> List.concat
-  end
-
 fun tabulate_built_in_rel debug ofs univ_card nat_card int_card j0
                           (x as (n, _)) =
   (check_arity univ_card n;
    if x = not3_rel then
      ("not3", tabulate_func1 debug univ_card (2, j0) (curry (op -) 1))
    else if x = suc_rel then
-     ("suc", tabulate_suc debug ofs univ_card j0)
+     ("suc", tabulate_func1 debug univ_card (univ_card - j0 - 1, j0)
+                            (Integer.add 1))
    else if x = nat_add_rel then
      ("nat_add", tabulate_nat_op2 debug univ_card (nat_card, j0) (op +))
    else if x = int_add_rel then
@@ -253,16 +239,40 @@
    else
      raise ARG ("Nitpick_Kodkod.tabulate_built_in_rel", "unknown relation"))
 
-fun bound_for_built_in_rel debug ofs univ_card nat_card int_card j0 x =
-  let
-    val (nick, ts) = tabulate_built_in_rel debug ofs univ_card nat_card int_card
-                                           j0 x
-  in ([(x, nick)], [KK.TupleSet ts]) end
+fun bound_for_built_in_rel debug ofs univ_card nat_card int_card main_j0
+                           (x as (n, j)) =
+  if n = 2 andalso j <= suc_rels_base then
+    let val (y as (k, j0), tabulate) = atom_seq_for_suc_rel x in
+      ([(x, "suc")],
+       if tabulate then
+         [KK.TupleSet (tabulate_func1 debug univ_card (k - 1, j0)
+                       (Integer.add 1))]
+       else
+         [KK.TupleSet [], tuple_set_from_atom_schema [y, y]])
+    end
+  else
+    let
+      val (nick, ts) = tabulate_built_in_rel debug ofs univ_card nat_card
+                                             int_card main_j0 x
+    in ([(x, nick)], [KK.TupleSet ts]) end
 
-fun bounds_for_built_in_rels_in_formula debug ofs univ_card nat_card int_card
-                                        j0 =
-  map (bound_for_built_in_rel debug ofs univ_card nat_card int_card j0)
-  o built_in_rels_in_formula
+fun axiom_for_built_in_rel (x as (n, j)) =
+  if n = 2 andalso j <= suc_rels_base then
+    let val (y as (k, j0), tabulate) = atom_seq_for_suc_rel x in
+      if tabulate orelse k < 2 then
+        NONE
+      else
+        SOME (KK.TotalOrdering (x, KK.AtomSeq y, KK.Atom j0, KK.Atom (j0 + 1)))
+    end
+  else
+    NONE
+fun bounds_and_axioms_for_built_in_rels_in_formulas debug ofs univ_card nat_card
+                                                    int_card main_j0 formulas =
+  let val rels = built_in_rels_in_formulas formulas in
+    (map (bound_for_built_in_rel debug ofs univ_card nat_card int_card main_j0)
+         rels,
+     map_filter axiom_for_built_in_rel rels)
+  end 
 
 fun bound_comment ctxt debug nick T R =
   short_name nick ^
@@ -297,11 +307,10 @@
            else
              [KK.TupleSet [],
               if T1 = T2 andalso epsilon > delta andalso
-                 (datatype_spec dtypes T1 |> the |> pairf #co #standard)
-                 = (false, true) then
+                 is_datatype_good_and_old (the (datatype_spec dtypes T1)) then
                 index_seq delta (epsilon - delta)
                 |> map (fn j =>
-                           KK.TupleProduct (KK.TupleSet [Kodkod.Tuple [j + j0]],
+                           KK.TupleProduct (KK.TupleSet [KK.Tuple [j + j0]],
                                             KK.TupleAtomSeq (j, j0)))
                 |> foldl1 KK.TupleUnion
               else
@@ -347,8 +356,6 @@
   map2 (fn j => fn x => KK.DeclOne ((1, j), KK.AtomSeq x))
        (index_seq j0 (length schema)) schema
 
-(* The type constraint below is a workaround for a Poly/ML bug. *)
-
 fun d_n_ary_function ({kk_all, kk_join, kk_lone, kk_one, ...} : kodkod_constrs)
                      R r =
   let val body_R = body_rep R in
@@ -683,7 +690,7 @@
 
 fun nfa_transitions_for_sel hol_ctxt binarize
                             ({kk_project, ...} : kodkod_constrs) rel_table
-                            (dtypes : dtype_spec list) constr_x n =
+                            (dtypes : datatype_spec list) constr_x n =
   let
     val x as (_, T) =
       binarized_and_boxed_nth_sel_for_constr hol_ctxt binarize constr_x n
@@ -692,14 +699,14 @@
   in
     map_filter (fn (j, T) =>
                    if forall (not_equal T o #typ) dtypes then NONE
-                   else SOME (kk_project r (map KK.Num [0, j]), T))
+                   else SOME ((x, kk_project r (map KK.Num [0, j])), T))
                (index_seq 1 (arity - 1) ~~ tl type_schema)
   end
 fun nfa_transitions_for_constr hol_ctxt binarize kk rel_table dtypes
                                (x as (_, T)) =
   maps (nfa_transitions_for_sel hol_ctxt binarize kk rel_table dtypes x)
        (index_seq 0 (num_sels_for_constr_type T))
-fun nfa_entry_for_datatype _ _ _ _ _ ({co = true, ...} : dtype_spec) = NONE
+fun nfa_entry_for_datatype _ _ _ _ _ ({co = true, ...} : datatype_spec) = NONE
   | nfa_entry_for_datatype _ _ _ _ _ {standard = false, ...} = NONE
   | nfa_entry_for_datatype _ _ _ _ _ {deep = false, ...} = NONE
   | nfa_entry_for_datatype hol_ctxt binarize kk rel_table dtypes
@@ -711,7 +718,7 @@
 
 fun direct_path_rel_exprs nfa start_T final_T =
   case AList.lookup (op =) nfa final_T of
-    SOME trans => map fst (filter (curry (op =) start_T o snd) trans)
+    SOME trans => map (snd o fst) (filter (curry (op =) start_T o snd) trans)
   | NONE => []
 and any_path_rel_expr ({kk_union, ...} : kodkod_constrs) nfa [] start_T
                       final_T =
@@ -749,14 +756,164 @@
   nfa |> graph_for_nfa |> NfaGraph.strong_conn
       |> map (fn keys => filter (member (op =) keys o fst) nfa)
 
-fun acyclicity_axiom_for_datatype kk nfa start_T =
-  #kk_no kk (#kk_intersect kk
-                 (loop_path_rel_expr kk nfa (map fst nfa) start_T) KK.Iden)
-fun acyclicity_axioms_for_datatypes hol_ctxt binarize kk rel_table dtypes =
-  map_filter (nfa_entry_for_datatype hol_ctxt binarize kk rel_table dtypes)
-             dtypes
-  |> strongly_connected_sub_nfas
-  |> maps (fn nfa => map (acyclicity_axiom_for_datatype kk nfa o fst) nfa)
+fun acyclicity_axiom_for_datatype (kk as {kk_no, kk_intersect, ...}) nfa
+                                  start_T =
+  kk_no (kk_intersect
+             (loop_path_rel_expr kk nfa (pull start_T (map fst nfa)) start_T)
+             KK.Iden)
+fun acyclicity_axioms_for_datatypes kk nfas =
+  maps (fn nfa => map (acyclicity_axiom_for_datatype kk nfa o fst) nfa) nfas
+
+fun all_ge ({kk_join, kk_reflexive_closure, ...} : kodkod_constrs) z r =
+  kk_join r (kk_reflexive_closure (KK.Rel (suc_rel_for_atom_seq z)))
+fun gt ({kk_subset, kk_join, kk_closure, ...} : kodkod_constrs) z r1 r2 =
+  kk_subset r1 (kk_join r2 (kk_closure (KK.Rel (suc_rel_for_atom_seq z))))
+
+fun constr_ord (({const = (s1, _), delta = delta1, epsilon = epsilon1, ...},
+                 {const = (s2, _), delta = delta2, epsilon = epsilon2, ...})
+                : constr_spec * constr_spec) =
+  prod_ord int_ord (prod_ord int_ord string_ord)
+           ((delta1, (epsilon2, s1)), (delta2, (epsilon2, s2)))
+
+fun datatype_ord (({card = card1, self_rec = self_rec1, constrs = constr1, ...},
+                   {card = card2, self_rec = self_rec2, constrs = constr2, ...})
+                  : datatype_spec * datatype_spec) =
+  prod_ord int_ord (prod_ord bool_ord int_ord)
+           ((card1, (self_rec1, length constr1)),
+            (card2, (self_rec2, length constr2)))
+
+(* We must absolutely tabulate "suc" for all datatypes whose selector bounds
+   break cycles; otherwise, we may end up with two incompatible symmetry
+   breaking orders, leading to spurious models. *)
+fun should_tabulate_suc_for_type dtypes T =
+  case datatype_spec dtypes T of
+    SOME {self_rec, ...} => self_rec
+  | NONE => false
+
+fun lex_order_rel_expr (kk as {kk_implies, kk_and, kk_subset, kk_join, ...})
+                       dtypes sel_quadruples =
+  case sel_quadruples of
+    [] => KK.True
+  | ((r, Func (Atom _, Atom x), 2), (_, Type (_, [_, T]))) :: sel_quadruples' =>
+    let val z = (x, should_tabulate_suc_for_type dtypes T) in
+      if null sel_quadruples' then
+        gt kk z (kk_join (KK.Var (1, 1)) r) (kk_join (KK.Var (1, 0)) r)
+      else
+        kk_and (kk_subset (kk_join (KK.Var (1, 1)) r)
+                          (all_ge kk z (kk_join (KK.Var (1, 0)) r)))
+               (kk_implies (kk_subset (kk_join (KK.Var (1, 1)) r)
+                                      (kk_join (KK.Var (1, 0)) r))
+                           (lex_order_rel_expr kk dtypes sel_quadruples'))
+    end
+    (* Skip constructors components that aren't atoms, since we cannot compare
+       these easily. *)
+  | _ :: sel_quadruples' => lex_order_rel_expr kk dtypes sel_quadruples'
+
+fun has_nil_like_constr dtypes T =
+  case #constrs (the (datatype_spec dtypes T))
+       |> filter_out (is_self_recursive_constr_type o snd o #const) of
+    [{const = (_, T'), ...}] => T = T'
+  | _ => false
+
+fun sym_break_axioms_for_constr_pair hol_ctxt binarize
+       (kk as {kk_all, kk_or, kk_iff, kk_implies, kk_and, kk_some, kk_subset,
+               kk_intersect, kk_join, kk_project, ...}) rel_table nfas dtypes
+       (constr1 as {const = const1 as (_, T1), delta = delta1,
+                    epsilon = epsilon1, ...},
+        constr2 as {const = const2 as (_, T2), delta = delta2,
+                    epsilon = epsilon2, ...}) =
+  let
+    val dataT = body_type T1
+    val nfa = nfas |> find_first (exists (curry (op =) dataT o fst)) |> these
+    val rec_Ts = nfa |> map fst
+    val same_constr = (const1 = const2)
+    fun rec_and_nonrec_sels (x as (_, T)) =
+      index_seq 0 (num_sels_for_constr_type T)
+      |> map (binarized_and_boxed_nth_sel_for_constr hol_ctxt binarize x)
+      |> List.partition (member (op =) rec_Ts o range_type o snd)
+    val sel_xs1 = rec_and_nonrec_sels const1 |> op @
+  in
+    if same_constr andalso null sel_xs1 then
+      []
+    else
+      let
+        val z =
+          (case #2 (const_triple rel_table (discr_for_constr const1)) of
+             Func (Atom x, Formula _) => x
+           | R => raise REP ("Nitpick_Kodkod.sym_break_axioms_for_constr_pair",
+                             [R]), should_tabulate_suc_for_type dtypes dataT)
+        val (rec_sel_xs2, nonrec_sel_xs2) = rec_and_nonrec_sels const2
+        val sel_xs2 = rec_sel_xs2 @ nonrec_sel_xs2
+        fun sel_quadruples2 () = sel_xs2 |> map (`(const_triple rel_table))
+        (* If the two constructors are the same, we drop the first selector
+           because that one is always checked by the lexicographic order.
+           We sometimes also filter out direct subterms, because those are
+           already handled by the acyclicity breaking in the bound
+           declarations. *)
+        fun filter_out_sels no_direct sel_xs =
+          apsnd (filter_out
+                     (fn ((x, _), T) =>
+                         (same_constr andalso x = hd sel_xs) orelse
+                         (T = dataT andalso
+                          (no_direct orelse not (member (op =) sel_xs x)))))
+        fun subterms_r no_direct sel_xs j =
+          loop_path_rel_expr kk (map (filter_out_sels no_direct sel_xs) nfa)
+                           (filter_out (curry (op =) dataT) (map fst nfa)) dataT
+          |> kk_join (KK.Var (1, j))
+      in
+        [kk_all [KK.DeclOne ((1, 0), discr_rel_expr rel_table const1),
+                 KK.DeclOne ((1, 1), discr_rel_expr rel_table const2)]
+             ((if same_constr then kk_implies else kk_iff)
+                 (if delta2 >= epsilon1 then KK.True
+                  else gt kk z (KK.Var (1, 1)) (KK.Var (1, 0)))
+                 (kk_or
+                      (if has_nil_like_constr dtypes dataT andalso
+                          T1 = dataT then
+                         KK.True
+                       else
+                         kk_some (kk_intersect (subterms_r false sel_xs2 1)
+                                               (all_ge kk z (KK.Var (1, 0)))))
+                      (if same_constr then
+                         kk_and
+                             (lex_order_rel_expr kk dtypes (sel_quadruples2 ()))
+                             (if length rec_sel_xs2 > 1 then
+                                kk_all [KK.DeclOne ((1, 2),
+                                                    subterms_r true sel_xs1 0)]
+                                       (gt kk z (KK.Var (1, 1)) (KK.Var (1, 2)))
+                              else
+                                KK.True)
+                       else
+                         kk_all [KK.DeclOne ((1, 2),
+                                 subterms_r false sel_xs1 0)]
+                                (gt kk z (KK.Var (1, 1)) (KK.Var (1, 2))))))]
+      end
+  end
+
+fun sym_break_axioms_for_datatype hol_ctxt binarize kk rel_table nfas dtypes
+                                  ({constrs, ...} : datatype_spec) =
+    let val constrs = sort constr_ord constrs in
+      maps (sym_break_axioms_for_constr_pair hol_ctxt binarize kk rel_table nfas
+                                             dtypes)
+           ((constrs ~~ constrs) @ all_distinct_unordered_pairs_of constrs)
+    end
+
+val min_sym_break_card = 7
+
+fun sym_break_axioms_for_datatypes hol_ctxt binarize datatype_sym_break kk
+                                   rel_table nfas dtypes =
+  if datatype_sym_break = 0 then
+    []
+  else
+    maps (sym_break_axioms_for_datatype hol_ctxt binarize kk rel_table nfas
+                                        dtypes)
+         (dtypes |> filter is_datatype_good_and_old
+                 |> filter (fn {constrs = [_], ...} => false
+                             | {card, ...} => card >= min_sym_break_card)
+                 |> (fn dtypes' =>
+                        dtypes'
+                        |> length dtypes' > datatype_sym_break
+                           ? (sort (rev_order o datatype_ord)
+                              #> take datatype_sym_break)))
 
 fun sel_axiom_for_sel hol_ctxt binarize j0
         (kk as {kk_all, kk_formula_if, kk_subset, kk_no, kk_join, ...})
@@ -805,7 +962,7 @@
       end
   end
 fun sel_axioms_for_datatype hol_ctxt binarize bits j0 kk rel_table
-                            ({constrs, ...} : dtype_spec) =
+                            ({constrs, ...} : datatype_spec) =
   maps (sel_axioms_for_constr hol_ctxt binarize bits j0 kk rel_table) constrs
 
 fun uniqueness_axiom_for_constr hol_ctxt binarize
@@ -830,13 +987,13 @@
                   (kk_rel_eq (KK.Var (1, 0)) (KK.Var (1, 1))))
   end
 fun uniqueness_axioms_for_datatype hol_ctxt binarize kk rel_table
-                                   ({constrs, ...} : dtype_spec) =
+                                   ({constrs, ...} : datatype_spec) =
   map (uniqueness_axiom_for_constr hol_ctxt binarize kk rel_table) constrs
 
 fun effective_constr_max ({delta, epsilon, ...} : constr_spec) = epsilon - delta
 fun partition_axioms_for_datatype j0 (kk as {kk_rel_eq, kk_union, ...})
                                   rel_table
-                                  ({card, constrs, ...} : dtype_spec) =
+                                  ({card, constrs, ...} : datatype_spec) =
   if forall #exclusive constrs then
     [Integer.sum (map effective_constr_max constrs) = card |> formula_for_bool]
   else
@@ -854,11 +1011,20 @@
       partition_axioms_for_datatype j0 kk rel_table dtype
     end
 
-fun declarative_axioms_for_datatypes hol_ctxt binarize bits ofs kk rel_table
-                                     dtypes =
-  acyclicity_axioms_for_datatypes hol_ctxt binarize kk rel_table dtypes @
-  maps (other_axioms_for_datatype hol_ctxt binarize bits ofs kk rel_table)
-       dtypes
+fun declarative_axioms_for_datatypes hol_ctxt binarize datatype_sym_break bits
+                                     ofs kk rel_table dtypes =
+  let
+    val nfas =
+      dtypes |> map_filter (nfa_entry_for_datatype hol_ctxt binarize kk
+                                                   rel_table dtypes)
+             |> strongly_connected_sub_nfas
+  in
+    acyclicity_axioms_for_datatypes kk nfas @
+    sym_break_axioms_for_datatypes hol_ctxt binarize datatype_sym_break kk
+                                   rel_table nfas dtypes @
+    maps (other_axioms_for_datatype hol_ctxt binarize bits ofs kk rel_table)
+         dtypes
+  end
 
 fun kodkod_formula_from_nut ofs
         (kk as {kk_all, kk_exist, kk_formula_let, kk_formula_if, kk_or, kk_not,