more work on "fix_datatype_vals" optimization (renamed "preconstruct")
authorblanchet
Mon, 21 Feb 2011 17:36:32 +0100
changeset 41803 ef13e3b7cbaf
parent 41802 7592a165fa0b
child 41804 90dd5291afd8
more work on "fix_datatype_vals" optimization (renamed "preconstruct")
src/HOL/Tools/Nitpick/nitpick.ML
src/HOL/Tools/Nitpick/nitpick_hol.ML
src/HOL/Tools/Nitpick/nitpick_isar.ML
src/HOL/Tools/Nitpick/nitpick_kodkod.ML
src/HOL/Tools/Nitpick/nitpick_model.ML
src/HOL/Tools/Nitpick/nitpick_preproc.ML
--- a/src/HOL/Tools/Nitpick/nitpick.ML	Mon Feb 21 16:33:21 2011 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick.ML	Mon Feb 21 17:36:32 2011 +0100
@@ -34,8 +34,8 @@
      destroy_constrs: bool,
      specialize: bool,
      star_linear_preds: bool,
+     preconstrs: (term option * bool option) list,
      peephole_optim: bool,
-     fix_datatype_vals: bool,
      datatype_sym_break: int,
      kodkod_sym_break: int,
      timeout: Time.time option,
@@ -108,8 +108,8 @@
    destroy_constrs: bool,
    specialize: bool,
    star_linear_preds: bool,
+   preconstrs: (term option * bool option) list,
    peephole_optim: bool,
-   fix_datatype_vals: bool,
    datatype_sym_break: int,
    kodkod_sym_break: int,
    timeout: Time.time option,
@@ -211,10 +211,10 @@
          boxes, finitizes, monos, stds, wfs, sat_solver, falsify, debug,
          verbose, overlord, user_axioms, assms, whacks, merge_type_vars,
          binary_ints, destroy_constrs, specialize, star_linear_preds,
-         peephole_optim, fix_datatype_vals, datatype_sym_break,
-         kodkod_sym_break, tac_timeout, max_threads, show_datatypes,
-         show_consts, evals, formats, atomss, max_potential, max_genuine,
-         check_potential, check_genuine, batch_size, ...} = params
+         preconstrs, peephole_optim, datatype_sym_break, kodkod_sym_break,
+         tac_timeout, max_threads, show_datatypes, show_consts, evals, formats,
+         atomss, max_potential, max_genuine, check_potential, check_genuine,
+         batch_size, ...} = params
     val state_ref = Unsynchronized.ref state
     val pprint =
       if auto then
@@ -282,21 +282,23 @@
        stds = stds, wfs = wfs, user_axioms = user_axioms, debug = debug,
        whacks = whacks, binary_ints = binary_ints,
        destroy_constrs = destroy_constrs, specialize = specialize,
-       star_linear_preds = star_linear_preds, tac_timeout = tac_timeout,
-       evals = evals, case_names = case_names, def_tables = def_tables,
-       nondef_table = nondef_table, user_nondefs = user_nondefs,
-       simp_table = simp_table, psimp_table = psimp_table,
-       choice_spec_table = choice_spec_table, intro_table = intro_table,
-       ground_thm_table = ground_thm_table, ersatz_table = ersatz_table,
-       skolems = Unsynchronized.ref [], special_funs = Unsynchronized.ref [],
+       star_linear_preds = star_linear_preds, preconstrs = preconstrs,
+       tac_timeout = tac_timeout, evals = evals, case_names = case_names,
+       def_tables = def_tables, nondef_table = nondef_table,
+       user_nondefs = user_nondefs, simp_table = simp_table,
+       psimp_table = psimp_table, choice_spec_table = choice_spec_table,
+       intro_table = intro_table, ground_thm_table = ground_thm_table,
+       ersatz_table = ersatz_table, skolems = Unsynchronized.ref [],
+       special_funs = Unsynchronized.ref [],
        unrolled_preds = Unsynchronized.ref [], wf_cache = Unsynchronized.ref [],
        constr_cache = Unsynchronized.ref []}
     val pseudo_frees = []
     val real_frees = fold Term.add_frees (neg_t :: assm_ts) []
     val _ = null (fold Term.add_tvars (neg_t :: assm_ts) []) orelse
             raise NOT_SUPPORTED "schematic type variables"
-    val (nondef_ts, def_ts, got_all_mono_user_axioms, no_poly_user_axioms,
-         binarize) = preprocess_formulas hol_ctxt assm_ts neg_t
+    val (nondef_ts, def_ts, preconstr_ts, got_all_mono_user_axioms,
+         no_poly_user_axioms, binarize) =
+      preprocess_formulas hol_ctxt assm_ts neg_t
     val got_all_user_axioms =
       got_all_mono_user_axioms andalso no_poly_user_axioms
 
@@ -324,13 +326,7 @@
 
     val nondef_us = nondef_ts |> map (nut_from_term hol_ctxt Eq)
     val def_us = def_ts |> map (nut_from_term hol_ctxt DefEq)
-    val needed_us =
-      if fix_datatype_vals then
-        [@{term "[A, B, C, A]"}, @{term "[C, B, A]"}]
-        |> map (nut_from_term hol_ctxt Eq)
-        (* infer_needed_constructs ### *)
-      else
-        []
+    val preconstr_us = preconstr_ts |> map (nut_from_term hol_ctxt Eq)
     val (free_names, const_names) =
       fold add_free_and_const_names (nondef_us @ def_us) ([], [])
     val (sel_names, nonsel_names) =
@@ -519,8 +515,8 @@
           def_us |> map (choose_reps_in_nut scope unsound rep_table true)
         val nondef_us =
           nondef_us |> map (choose_reps_in_nut scope unsound rep_table false)
-        val needed_us =
-          needed_us |> map (choose_reps_in_nut scope unsound rep_table false)
+        val preconstr_us =
+          preconstr_us |> map (choose_reps_in_nut scope unsound rep_table false)
 (*
         val _ = List.app (print_g o string_for_nut ctxt)
                          (free_names @ sel_names @ nonsel_names @
@@ -534,7 +530,8 @@
           rename_free_vars nonsel_names pool rel_table
         val nondef_us = nondef_us |> map (rename_vars_in_nut pool rel_table)
         val def_us = def_us |> map (rename_vars_in_nut pool rel_table)
-        val needed_us = needed_us |> map (rename_vars_in_nut pool rel_table)
+        val preconstr_us =
+          preconstr_us |> map (rename_vars_in_nut pool rel_table)
         val nondef_fs = map (kodkod_formula_from_nut ofs kk) nondef_us
         val def_fs = map (kodkod_formula_from_nut ofs kk) def_us
         val formula = fold (fold s_and) [def_fs, nondef_fs] KK.True
@@ -560,7 +557,7 @@
         val plain_axioms = map (declarative_axiom_for_plain_rel kk) plain_rels
         val sel_bounds = map (bound_for_sel_rel ctxt debug datatypes) sel_rels
         val dtype_axioms =
-          declarative_axioms_for_datatypes hol_ctxt binarize needed_us
+          declarative_axioms_for_datatypes hol_ctxt binarize preconstr_us
               datatype_sym_break bits ofs kk rel_table datatypes
         val declarative_axioms = plain_axioms @ dtype_axioms
         val univ_card = Int.max (univ_card nat_card int_card main_j0
--- a/src/HOL/Tools/Nitpick/nitpick_hol.ML	Mon Feb 21 16:33:21 2011 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick_hol.ML	Mon Feb 21 17:36:32 2011 +0100
@@ -27,6 +27,7 @@
      destroy_constrs: bool,
      specialize: bool,
      star_linear_preds: bool,
+     preconstrs: (term option * bool option) list,
      tac_timeout: Time.time option,
      evals: term list,
      case_names: (string * int) list,
@@ -257,6 +258,7 @@
    destroy_constrs: bool,
    specialize: bool,
    star_linear_preds: bool,
+   preconstrs: (term option * bool option) list,
    tac_timeout: Time.time option,
    evals: term list,
    case_names: (string * int) list,
--- a/src/HOL/Tools/Nitpick/nitpick_isar.ML	Mon Feb 21 16:33:21 2011 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick_isar.ML	Mon Feb 21 17:36:32 2011 +0100
@@ -58,8 +58,8 @@
    ("destroy_constrs", "true"),
    ("specialize", "true"),
    ("star_linear_preds", "true"),
+   ("preconstr", "smart"),
    ("peephole_optim", "true"),
-   ("fix_datatype_vals", "true"),
    ("datatype_sym_break", "5"),
    ("kodkod_sym_break", "15"),
    ("timeout", "30"),
@@ -91,8 +91,8 @@
    ("dont_destroy_constrs", "destroy_constrs"),
    ("dont_specialize", "specialize"),
    ("dont_star_linear_preds", "star_linear_preds"),
+   ("dont_preconstr", "preconstr"),
    ("no_peephole_optim", "peephole_optim"),
-   ("fix_datatype_vals", "dont_fix_datatype_vals"),
    ("no_debug", "debug"),
    ("quiet", "verbose"),
    ("no_overlord", "overlord"),
@@ -104,11 +104,11 @@
 fun is_known_raw_param s =
   AList.defined (op =) default_default_params s orelse
   AList.defined (op =) negated_params s orelse
-  member (op =) ["max", "show_all", "whack", "atoms", "eval", "expect"] s orelse
+  member (op =) ["max", "show_all", "whack", "eval", "atoms", "expect"] s orelse
   exists (fn p => String.isPrefix (p ^ " ") s)
          ["card", "max", "iter", "box", "dont_box", "finitize", "dont_finitize",
-          "mono", "non_mono", "std", "non_std", "wf", "non_wf", "format",
-          "atoms"]
+          "mono", "non_mono", "std", "non_std", "wf", "non_wf", "preconstr",
+          "dont_preconstr", "format", "atoms"]
 
 fun check_raw_param (s, _) =
   if is_known_raw_param s then ()
@@ -253,8 +253,9 @@
     val destroy_constrs = lookup_bool "destroy_constrs"
     val specialize = lookup_bool "specialize"
     val star_linear_preds = lookup_bool "star_linear_preds"
+    val preconstrs =
+      lookup_bool_option_assigns read_term_polymorphic "preconstr"
     val peephole_optim = lookup_bool "peephole_optim"
-    val fix_datatype_vals = lookup_bool "fix_datatype_vals"
     val datatype_sym_break = lookup_int "datatype_sym_break"
     val kodkod_sym_break = lookup_int "kodkod_sym_break"
     val timeout = if auto then NONE else lookup_time "timeout"
@@ -262,9 +263,9 @@
     val max_threads = if auto then 1 else Int.max (0, lookup_int "max_threads")
     val show_datatypes = debug orelse lookup_bool "show_datatypes"
     val show_consts = debug orelse lookup_bool "show_consts"
+    val evals = lookup_term_list_polymorphic "eval"
     val formats = lookup_ints_assigns read_term_polymorphic "format" 0
     val atomss = lookup_strings_assigns read_type_polymorphic "atoms"
-    val evals = lookup_term_list_polymorphic "eval"
     val max_potential =
       if auto then 0 else Int.max (0, lookup_int "max_potential")
     val max_genuine = Int.max (0, lookup_int "max_genuine")
@@ -284,13 +285,12 @@
      user_axioms = user_axioms, assms = assms, whacks = whacks,
      merge_type_vars = merge_type_vars, binary_ints = binary_ints,
      destroy_constrs = destroy_constrs, specialize = specialize,
-     star_linear_preds = star_linear_preds, peephole_optim = peephole_optim,
-     fix_datatype_vals = fix_datatype_vals,
-     datatype_sym_break = datatype_sym_break,
+     star_linear_preds = star_linear_preds, preconstrs = preconstrs,
+     peephole_optim = peephole_optim, datatype_sym_break = datatype_sym_break,
      kodkod_sym_break = kodkod_sym_break, timeout = timeout,
      tac_timeout = tac_timeout, max_threads = max_threads,
      show_datatypes = show_datatypes, show_consts = show_consts,
-     formats = formats, atomss = atomss, evals = evals,
+     evals = evals, formats = formats, atomss = atomss,
      max_potential = max_potential, max_genuine = max_genuine,
      check_potential = check_potential, check_genuine = check_genuine,
      batch_size = batch_size, expect = expect}
--- a/src/HOL/Tools/Nitpick/nitpick_kodkod.ML	Mon Feb 21 16:33:21 2011 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick_kodkod.ML	Mon Feb 21 17:36:32 2011 +0100
@@ -1478,8 +1478,8 @@
                       "malformed Kodkod formula")
   end
 
-fun needed_value_axioms_for_datatype [] _ _ _ = []
-  | needed_value_axioms_for_datatype needed_us ofs kk
+fun preconstructed_value_axioms_for_datatype [] _ _ _ = []
+  | preconstructed_value_axioms_for_datatype preconstr_us ofs kk
         ({typ, card, constrs, ...} : datatype_spec) =
     let
       fun aux (u as Construct (FreeRel (_, _, _, s) :: _, T, _, us)) =
@@ -1507,9 +1507,10 @@
                  else
                    accum)
         | aux u =
-          raise NUT ("Nitpick_Kodkod.needed_value_axioms_for_datatype.aux", [u])
+          raise NUT ("Nitpick_Kodkod.preconstructed_value_axioms_for_datatype\
+                     \.aux", [u])
     in
-      case SOME (index_seq 0 card, []) |> fold aux needed_us of
+      case SOME (index_seq 0 card, []) |> fold aux preconstr_us of
         SOME (_, fixed) => fixed |> map (atom_equation_for_nut ofs kk)
       | NONE => [KK.False]
     end
@@ -1653,33 +1654,34 @@
                                               nfas dtypes)
   end
 
-fun is_datatype_in_needed_value T (Construct (_, T', _, us)) =
-    T = T' orelse exists (is_datatype_in_needed_value T) us
-  | is_datatype_in_needed_value _ _ = false
+fun is_datatype_in_preconstructed_value T (Construct (_, T', _, us)) =
+    T = T' orelse exists (is_datatype_in_preconstructed_value T) us
+  | is_datatype_in_preconstructed_value _ _ = false
 
 val min_sym_break_card = 7
 
-fun sym_break_axioms_for_datatypes hol_ctxt binarize needed_us
+fun sym_break_axioms_for_datatypes hol_ctxt binarize preconstr_us
                                    datatype_sym_break kk rel_table nfas dtypes =
   if datatype_sym_break = 0 then
     []
   else
-    maps (sym_break_axioms_for_datatype hol_ctxt binarize kk rel_table nfas
-                                        dtypes)
-         (dtypes |> filter is_datatype_acyclic
-                 |> filter (fn {constrs = [_], ...} => false
-                             | {card, constrs, ...} =>
-                               card >= min_sym_break_card andalso
-                               forall (forall (not o is_higher_order_type)
-                                       o binder_types o snd o #const) constrs)
-                 |> filter_out (fn {typ, ...} =>
-                                   exists (is_datatype_in_needed_value typ)
-                                          needed_us)
-                 |> (fn dtypes' =>
-                        dtypes'
-                        |> length dtypes' > datatype_sym_break
-                           ? (sort (datatype_ord o swap)
-                              #> take datatype_sym_break)))
+    dtypes |> filter is_datatype_acyclic
+           |> filter (fn {constrs = [_], ...} => false
+                       | {card, constrs, ...} =>
+                         card >= min_sym_break_card andalso
+                         forall (forall (not o is_higher_order_type)
+                                 o binder_types o snd o #const) constrs)
+           |> filter_out
+                  (fn {typ, ...} =>
+                      exists (is_datatype_in_preconstructed_value typ)
+                             preconstr_us)
+           |> (fn dtypes' =>
+                  dtypes' |> length dtypes' > datatype_sym_break
+                             ? (sort (datatype_ord o swap)
+                                #> take datatype_sym_break))
+           |> maps (sym_break_axioms_for_datatype hol_ctxt binarize kk rel_table
+                                                  nfas dtypes)
+
 
 fun sel_axiom_for_sel hol_ctxt binarize j0
         (kk as {kk_all, kk_formula_if, kk_subset, kk_no, kk_join, ...})
@@ -1777,7 +1779,7 @@
       partition_axioms_for_datatype j0 kk rel_table dtype
     end
 
-fun declarative_axioms_for_datatypes hol_ctxt binarize needed_us
+fun declarative_axioms_for_datatypes hol_ctxt binarize preconstr_us
         datatype_sym_break bits ofs kk rel_table dtypes =
   let
     val nfas =
@@ -1786,8 +1788,8 @@
              |> strongly_connected_sub_nfas
   in
     acyclicity_axioms_for_datatypes kk nfas @
-    maps (needed_value_axioms_for_datatype needed_us ofs kk) dtypes @
-    sym_break_axioms_for_datatypes hol_ctxt binarize needed_us
+    maps (preconstructed_value_axioms_for_datatype preconstr_us ofs kk) dtypes @
+    sym_break_axioms_for_datatypes hol_ctxt binarize preconstr_us
         datatype_sym_break kk rel_table nfas dtypes @
     maps (other_axioms_for_datatype hol_ctxt binarize bits ofs kk rel_table)
          dtypes
--- a/src/HOL/Tools/Nitpick/nitpick_model.ML	Mon Feb 21 16:33:21 2011 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick_model.ML	Mon Feb 21 17:36:32 2011 +0100
@@ -862,12 +862,12 @@
 fun reconstruct_hol_model {show_datatypes, show_consts}
         ({hol_ctxt = {thy, ctxt, max_bisim_depth, boxes, stds, wfs, user_axioms,
                       debug, whacks, binary_ints, destroy_constrs, specialize,
-                      star_linear_preds, tac_timeout, evals, case_names,
-                      def_tables, nondef_table, user_nondefs, simp_table,
-                      psimp_table, choice_spec_table, intro_table,
+                      star_linear_preds, preconstrs, tac_timeout, evals,
+                      case_names, def_tables, nondef_table, user_nondefs,
+                      simp_table, psimp_table, choice_spec_table, intro_table,
                       ground_thm_table, ersatz_table, skolems, special_funs,
-                      unrolled_preds, wf_cache, constr_cache},
-         binarize, card_assigns, bits, bisim_depth, datatypes, ofs} : scope)
+                      unrolled_preds, wf_cache, constr_cache}, binarize,
+                      card_assigns, bits, bisim_depth, datatypes, ofs} : scope)
         formats atomss real_frees pseudo_frees free_names sel_names nonsel_names
         rel_table bounds =
   let
@@ -879,15 +879,15 @@
        stds = stds, wfs = wfs, user_axioms = user_axioms, debug = debug,
        whacks = whacks, binary_ints = binary_ints,
        destroy_constrs = destroy_constrs, specialize = specialize,
-       star_linear_preds = star_linear_preds, tac_timeout = tac_timeout,
-       evals = evals, case_names = case_names, def_tables = def_tables,
-       nondef_table = nondef_table, user_nondefs = user_nondefs,
-       simp_table = simp_table, psimp_table = psimp_table,
-       choice_spec_table = choice_spec_table, intro_table = intro_table,
-       ground_thm_table = ground_thm_table, ersatz_table = ersatz_table,
-       skolems = skolems, special_funs = special_funs,
-       unrolled_preds = unrolled_preds, wf_cache = wf_cache,
-       constr_cache = constr_cache}
+       star_linear_preds = star_linear_preds, preconstrs = preconstrs,
+       tac_timeout = tac_timeout, evals = evals, case_names = case_names,
+       def_tables = def_tables, nondef_table = nondef_table,
+       user_nondefs = user_nondefs, simp_table = simp_table,
+       psimp_table = psimp_table, choice_spec_table = choice_spec_table,
+       intro_table = intro_table, ground_thm_table = ground_thm_table,
+       ersatz_table = ersatz_table, skolems = skolems,
+       special_funs = special_funs, unrolled_preds = unrolled_preds,
+       wf_cache = wf_cache, constr_cache = constr_cache}
     val scope =
       {hol_ctxt = hol_ctxt, binarize = binarize, card_assigns = card_assigns,
        bits = bits, bisim_depth = bisim_depth, datatypes = datatypes, ofs = ofs}
--- a/src/HOL/Tools/Nitpick/nitpick_preproc.ML	Mon Feb 21 16:33:21 2011 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick_preproc.ML	Mon Feb 21 17:36:32 2011 +0100
@@ -10,7 +10,7 @@
   type hol_context = Nitpick_HOL.hol_context
   val preprocess_formulas :
     hol_context -> term list -> term
-    -> term list * term list * bool * bool * bool
+    -> term list * term list * term list * bool * bool * bool
 end;
 
 structure Nitpick_Preproc : NITPICK_PREPROC =
@@ -1246,7 +1246,7 @@
 
 fun preprocess_formulas
         (hol_ctxt as {thy, ctxt, stds, binary_ints, destroy_constrs, boxes,
-                      ...}) assm_ts neg_t =
+                      preconstrs, ...}) assm_ts neg_t =
   let
     val (nondef_ts, def_ts, got_all_mono_user_axioms, no_poly_user_axioms) =
       neg_t |> unfold_defs_in_term hol_ctxt
@@ -1266,13 +1266,14 @@
     val table =
       Termtab.empty
       |> box ? fold (add_to_uncurry_table ctxt) (nondef_ts @ def_ts)
-    fun do_rest def =
+    fun do_middle def =
       binarize ? binarize_nat_and_int_in_term
       #> box ? uncurry_term table
       #> box ? box_fun_and_pair_in_term hol_ctxt def
-      #> destroy_constrs ? (pull_out_universal_constrs hol_ctxt def
-                            #> pull_out_existential_constrs hol_ctxt
-                            #> destroy_pulled_out_constrs hol_ctxt def)
+    fun do_tail def =
+      destroy_constrs ? (pull_out_universal_constrs hol_ctxt def
+                         #> pull_out_existential_constrs hol_ctxt
+                         #> destroy_pulled_out_constrs hol_ctxt def)
       #> curry_assms
       #> destroy_universal_equalities
       #> destroy_existential_equalities hol_ctxt
@@ -1281,10 +1282,17 @@
       #> push_quantifiers_inward
       #> close_form
       #> Term.map_abs_vars shortest_name
-    val nondef_ts = map (do_rest false) nondef_ts
-    val def_ts = map (do_rest true) def_ts
+    val nondef_ts = nondef_ts |> map (do_middle false)
+    val preconstr_ts =
+      (* FIXME: Implement preconstruction inference. *)
+      preconstrs
+      |> map_filter (fn (SOME t, SOME true) => SOME (t |> do_middle false)
+                      | _ => NONE)
+    val nondef_ts = nondef_ts |> map (do_tail false)
+    val def_ts = def_ts |> map (do_middle true #> do_tail true)
   in
-    (nondef_ts, def_ts, got_all_mono_user_axioms, no_poly_user_axioms, binarize)
+    (nondef_ts, def_ts, preconstr_ts, got_all_mono_user_axioms,
+     no_poly_user_axioms, binarize)
   end
 
 end;