first steps in implementing "fix_datatype_vals" optimization
authorblanchet
Mon, 21 Feb 2011 15:45:44 +0100
changeset 41801 ed77524f3429
parent 41800 7f333b59d5fb
child 41802 7592a165fa0b
first steps in implementing "fix_datatype_vals" optimization
src/HOL/Tools/Nitpick/nitpick.ML
src/HOL/Tools/Nitpick/nitpick_isar.ML
src/HOL/Tools/Nitpick/nitpick_kodkod.ML
--- a/src/HOL/Tools/Nitpick/nitpick.ML	Mon Feb 21 14:02:07 2011 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick.ML	Mon Feb 21 15:45:44 2011 +0100
@@ -35,6 +35,7 @@
      specialize: bool,
      star_linear_preds: bool,
      peephole_optim: bool,
+     fix_datatype_vals: bool,
      datatype_sym_break: int,
      kodkod_sym_break: int,
      timeout: Time.time option,
@@ -108,6 +109,7 @@
    specialize: bool,
    star_linear_preds: bool,
    peephole_optim: bool,
+   fix_datatype_vals: bool,
    datatype_sym_break: int,
    kodkod_sym_break: int,
    timeout: Time.time option,
@@ -209,10 +211,10 @@
          boxes, finitizes, monos, stds, wfs, sat_solver, falsify, debug,
          verbose, overlord, user_axioms, assms, whacks, merge_type_vars,
          binary_ints, destroy_constrs, specialize, star_linear_preds,
-         peephole_optim, datatype_sym_break, kodkod_sym_break, tac_timeout,
-         max_threads, show_datatypes, show_consts, evals, formats, atomss,
-         max_potential, max_genuine, check_potential, check_genuine, batch_size,
-         ...} = params
+         peephole_optim, fix_datatype_vals, datatype_sym_break,
+         kodkod_sym_break, tac_timeout, max_threads, show_datatypes,
+         show_consts, evals, formats, atomss, max_potential, max_genuine,
+         check_potential, check_genuine, batch_size, ...} = params
     val state_ref = Unsynchronized.ref state
     val pprint =
       if auto then
@@ -320,8 +322,15 @@
             handle TYPE (_, Ts, ts) =>
                    raise TYPE ("Nitpick.pick_them_nits_in_term", Ts, ts)
 
-    val nondef_us = map (nut_from_term hol_ctxt Eq) nondef_ts
-    val def_us = map (nut_from_term hol_ctxt DefEq) def_ts
+    val nondef_us = nondef_ts |> map (nut_from_term hol_ctxt Eq)
+    val def_us = def_ts |> map (nut_from_term hol_ctxt DefEq)
+    val needed_us =
+      if fix_datatype_vals then
+        [@{term "[A, B, C, A]"}, @{term "[C, B, A]"}]
+        |> map (nut_from_term hol_ctxt Eq)
+        (* infer_needed_constructs ### *)
+      else
+        []
     val (free_names, const_names) =
       fold add_free_and_const_names (nondef_us @ def_us) ([], [])
     val (sel_names, nonsel_names) =
@@ -548,8 +557,8 @@
         val plain_axioms = map (declarative_axiom_for_plain_rel kk) plain_rels
         val sel_bounds = map (bound_for_sel_rel ctxt debug datatypes) sel_rels
         val dtype_axioms =
-          declarative_axioms_for_datatypes hol_ctxt binarize datatype_sym_break
-                                           bits ofs kk rel_table datatypes
+          declarative_axioms_for_datatypes hol_ctxt binarize needed_us
+              datatype_sym_break bits ofs kk rel_table datatypes
         val declarative_axioms = plain_axioms @ dtype_axioms
         val univ_card = Int.max (univ_card nat_card int_card main_j0
                                      (plain_bounds @ sel_bounds) formula,
--- a/src/HOL/Tools/Nitpick/nitpick_isar.ML	Mon Feb 21 14:02:07 2011 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick_isar.ML	Mon Feb 21 15:45:44 2011 +0100
@@ -59,6 +59,7 @@
    ("specialize", "true"),
    ("star_linear_preds", "true"),
    ("peephole_optim", "true"),
+   ("fix_datatype_vals", "true"),
    ("datatype_sym_break", "5"),
    ("kodkod_sym_break", "15"),
    ("timeout", "30"),
@@ -91,6 +92,7 @@
    ("dont_specialize", "specialize"),
    ("dont_star_linear_preds", "star_linear_preds"),
    ("no_peephole_optim", "peephole_optim"),
+   ("fix_datatype_vals", "dont_fix_datatype_vals"),
    ("no_debug", "debug"),
    ("quiet", "verbose"),
    ("no_overlord", "overlord"),
@@ -252,6 +254,7 @@
     val specialize = lookup_bool "specialize"
     val star_linear_preds = lookup_bool "star_linear_preds"
     val peephole_optim = lookup_bool "peephole_optim"
+    val fix_datatype_vals = lookup_bool "fix_datatype_vals"
     val datatype_sym_break = lookup_int "datatype_sym_break"
     val kodkod_sym_break = lookup_int "kodkod_sym_break"
     val timeout = if auto then NONE else lookup_time "timeout"
@@ -282,6 +285,7 @@
      merge_type_vars = merge_type_vars, binary_ints = binary_ints,
      destroy_constrs = destroy_constrs, specialize = specialize,
      star_linear_preds = star_linear_preds, peephole_optim = peephole_optim,
+     fix_datatype_vals = fix_datatype_vals,
      datatype_sym_break = datatype_sym_break,
      kodkod_sym_break = kodkod_sym_break, timeout = timeout,
      tac_timeout = tac_timeout, max_threads = max_threads,
--- a/src/HOL/Tools/Nitpick/nitpick_kodkod.ML	Mon Feb 21 14:02:07 2011 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick_kodkod.ML	Mon Feb 21 15:45:44 2011 +0100
@@ -31,8 +31,9 @@
   val merge_bounds : Kodkod.bound list -> Kodkod.bound list
   val declarative_axiom_for_plain_rel : kodkod_constrs -> nut -> Kodkod.formula
   val declarative_axioms_for_datatypes :
-    hol_context -> bool -> int -> int -> int Typtab.table -> kodkod_constrs
-    -> nut NameTable.table -> datatype_spec list -> Kodkod.formula list
+    hol_context -> bool -> nut list -> int -> int -> int Typtab.table
+    -> kodkod_constrs -> nut NameTable.table -> datatype_spec list
+    -> Kodkod.formula list
   val kodkod_formula_from_nut :
     int Typtab.table -> kodkod_constrs -> nut -> Kodkod.formula
 end;
@@ -740,6 +741,43 @@
   | acyclicity_axioms_for_datatypes kk nfas =
     maps (fn nfa => map (acyclicity_axiom_for_datatype kk nfa o fst) nfa) nfas
 
+fun needed_value_axioms_for_datatype [] _ _ = []
+  | needed_value_axioms_for_datatype needed_us ofs
+        ({typ, card, constrs, ...} : datatype_spec) =
+    let
+      fun aux (u as Construct (ConstName (s, _, _) :: _, T, _, us)) =
+          fold aux us
+          #> (fn NONE => NONE
+               | accum as SOME (loose, fixed) =>
+                 if T = typ then
+                   case AList.lookup (op =) fixed u of
+                     SOME _ => accum
+                   | NONE =>
+                     let
+                       val constr_s = constr_name_for_sel_like s
+                       val {delta, epsilon, ...} =
+                         constrs
+                         |> List.find (fn {const, ...} => fst const = constr_s)
+                         |> the
+                       val j0 = offset_of_type ofs T
+                     in
+                       case find_first (fn j => j >= delta andalso
+                                        j < delta + epsilon) loose of
+                         SOME j =>
+                         SOME (remove (op =) j loose, (u, j0 + j) :: fixed)
+                       | NONE => NONE
+                     end
+                 else
+                   accum)
+        | aux u =
+          raise NUT ("Nitpick_Kodkod.needed_value_axioms_for_datatype.aux", [u])
+    in
+      case SOME (index_seq 0 card, []) |> fold aux needed_us of
+        SOME (_, fixed) =>
+         (* fixed |> map () *) [] (*###*)
+      | NONE => [KK.False]
+    end
+
 fun all_ge ({kk_join, kk_reflexive_closure, ...} : kodkod_constrs) z r =
   kk_join r (kk_reflexive_closure (KK.Rel (suc_rel_for_atom_seq z)))
 fun gt ({kk_subset, kk_join, kk_closure, ...} : kodkod_constrs) z r1 r2 =
@@ -879,10 +917,14 @@
                                               nfas dtypes)
   end
 
+fun is_datatype_in_needed_value T (Construct (_, T', _, us)) =
+    T = T' orelse exists (is_datatype_in_needed_value T) us
+  | is_datatype_in_needed_value _ _ = false
+
 val min_sym_break_card = 7
 
-fun sym_break_axioms_for_datatypes hol_ctxt binarize datatype_sym_break kk
-                                   rel_table nfas dtypes =
+fun sym_break_axioms_for_datatypes hol_ctxt binarize needed_us
+                                   datatype_sym_break kk rel_table nfas dtypes =
   if datatype_sym_break = 0 then
     []
   else
@@ -894,6 +936,9 @@
                                card >= min_sym_break_card andalso
                                forall (forall (not o is_higher_order_type)
                                        o binder_types o snd o #const) constrs)
+                 |> filter_out (fn {typ, ...} =>
+                                   exists (is_datatype_in_needed_value typ)
+                                          needed_us)
                  |> (fn dtypes' =>
                         dtypes'
                         |> length dtypes' > datatype_sym_break
@@ -996,8 +1041,8 @@
       partition_axioms_for_datatype j0 kk rel_table dtype
     end
 
-fun declarative_axioms_for_datatypes hol_ctxt binarize datatype_sym_break bits
-                                     ofs kk rel_table dtypes =
+fun declarative_axioms_for_datatypes hol_ctxt binarize needed_us
+        datatype_sym_break bits ofs kk rel_table dtypes =
   let
     val nfas =
       dtypes |> map_filter (nfa_entry_for_datatype hol_ctxt binarize kk
@@ -1005,8 +1050,9 @@
              |> strongly_connected_sub_nfas
   in
     acyclicity_axioms_for_datatypes kk nfas @
-    sym_break_axioms_for_datatypes hol_ctxt binarize datatype_sym_break kk
-                                   rel_table nfas dtypes @
+    maps (needed_value_axioms_for_datatype needed_us ofs) dtypes @
+    sym_break_axioms_for_datatypes hol_ctxt binarize needed_us
+        datatype_sym_break kk rel_table nfas dtypes @
     maps (other_axioms_for_datatype hol_ctxt binarize bits ofs kk rel_table)
          dtypes
   end