added term postprocessor to Nitpick, to provide custom syntax for typedefs
authorblanchet
Thu, 11 Mar 2010 12:22:11 +0100
changeset 35711 548d3f16404b
parent 35710 58acd48904bc
child 35712 77aa29bf14ee
child 35735 f139a9bb6501
added term postprocessor to Nitpick, to provide custom syntax for typedefs
src/HOL/Nitpick_Examples/Manual_Nits.thy
src/HOL/Tools/Nitpick/nitpick.ML
src/HOL/Tools/Nitpick/nitpick_hol.ML
src/HOL/Tools/Nitpick/nitpick_model.ML
--- a/src/HOL/Nitpick_Examples/Manual_Nits.thy	Thu Mar 11 10:13:24 2010 +0100
+++ b/src/HOL/Nitpick_Examples/Manual_Nits.thy	Thu Mar 11 12:22:11 2010 +0100
@@ -117,6 +117,19 @@
 nitpick [show_datatypes, expect = genuine]
 oops
 
+ML {*
+(* Proof.context -> typ -> term -> term *)
+fun my_int_postproc _ T (Const _ $ (Const _ $ t1 $ t2)) =
+    HOLogic.mk_number T (snd (HOLogic.dest_number t1) - snd (HOLogic.dest_number t2))
+  | my_int_postproc _ _ t = t
+*}
+
+setup {* Nitpick.register_term_postprocessor @{typ my_int} my_int_postproc *}
+
+lemma "add x y = add x x"
+nitpick [show_datatypes]
+oops
+
 record point =
   Xcoord :: int
   Ycoord :: int
--- a/src/HOL/Tools/Nitpick/nitpick.ML	Thu Mar 11 10:13:24 2010 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick.ML	Thu Mar 11 12:22:11 2010 +0100
@@ -8,6 +8,7 @@
 signature NITPICK =
 sig
   type styp = Nitpick_Util.styp
+  type term_postprocessor = Nitpick_Model.term_postprocessor
   type params = {
     cards_assigns: (typ option * int list) list,
     maxes_assigns: (styp option * int list) list,
@@ -58,6 +59,9 @@
   val unregister_frac_type : string -> theory -> theory
   val register_codatatype : typ -> string -> styp list -> theory -> theory
   val unregister_codatatype : typ -> theory -> theory
+  val register_term_postprocessor :
+    typ -> term_postprocessor -> theory -> theory
+  val unregister_term_postprocessor : typ -> theory -> theory
   val pick_nits_in_term :
     Proof.state -> params -> bool -> int -> int -> int -> (term * term) list
     -> term list -> term -> string * Proof.state
--- a/src/HOL/Tools/Nitpick/nitpick_hol.ML	Thu Mar 11 10:13:24 2010 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick_hol.ML	Thu Mar 11 12:22:11 2010 +0100
@@ -72,9 +72,11 @@
   val shortest_name : string -> string
   val short_name : string -> string
   val shorten_names_in_term : term -> term
+  val strict_type_match : theory -> typ * typ -> bool
   val type_match : theory -> typ * typ -> bool
   val const_match : theory -> styp * styp -> bool
   val term_match : theory -> term * term -> bool
+  val frac_from_term_pair : typ -> term -> term -> term
   val is_TFree : typ -> bool
   val is_higher_order_type : typ -> bool
   val is_fun_type : typ -> bool
@@ -457,6 +459,15 @@
     const_match thy ((shortest_name s1, T1), (shortest_name s2, T2))
   | term_match _ (t1, t2) = t1 aconv t2
 
+(* typ -> term -> term -> term *)
+fun frac_from_term_pair T t1 t2 =
+  case snd (HOLogic.dest_number t1) of
+    0 => HOLogic.mk_number T 0
+  | n1 => case snd (HOLogic.dest_number t2) of
+            1 => HOLogic.mk_number T n1
+          | n2 => Const (@{const_name divide}, T --> T --> T)
+                  $ HOLogic.mk_number T n1 $ HOLogic.mk_number T n2
+
 (* typ -> bool *)
 fun is_TFree (TFree _) = true
   | is_TFree _ = false
--- a/src/HOL/Tools/Nitpick/nitpick_model.ML	Thu Mar 11 10:13:24 2010 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick_model.ML	Thu Mar 11 12:22:11 2010 +0100
@@ -16,9 +16,13 @@
     show_skolems: bool,
     show_datatypes: bool,
     show_consts: bool}
+  type term_postprocessor = Proof.context -> typ -> term -> term
 
   structure NameTable : TABLE
 
+  val register_term_postprocessor :
+    typ -> term_postprocessor -> theory -> theory
+  val unregister_term_postprocessor : typ -> theory -> theory
   val tuple_list_for_name :
     nut NameTable.table -> Kodkod.raw_bound list -> nut -> int list list
   val reconstruct_hol_model :
@@ -47,6 +51,14 @@
   show_datatypes: bool,
   show_consts: bool}
 
+type term_postprocessor = Proof.context -> typ -> term -> term
+
+structure Data = Theory_Data(
+  type T = (typ * term_postprocessor) list
+  val empty = []
+  val extend = I
+  fun merge (ps1, ps2) = AList.merge (op =) (K true) (ps1, ps2))
+
 val unknown = "?"
 val unrep = "\<dots>"
 val maybe_mixfix = "_\<^sup>?"
@@ -105,6 +117,11 @@
               | ord => ord)
            | _ => Term_Ord.fast_term_ord tp
 
+(* typ -> term_postprocessor -> theory -> theory *)
+fun register_term_postprocessor T p = Data.map (cons (T, p))
+(* typ -> theory -> theory *)
+fun unregister_term_postprocessor T = Data.map (AList.delete (op =) T)
+
 (* nut NameTable.table -> KK.raw_bound list -> nut -> int list list *)
 fun tuple_list_for_name rel_table bounds name =
   the (AList.lookup (op =) bounds (the_rel rel_table name)) handle NUT _ => [[]]
@@ -271,12 +288,25 @@
   | mk_tuple _ (t :: _) = t
   | mk_tuple T [] = raise TYPE ("Nitpick_Model.mk_tuple", [T], [])
 
+(* theory -> typ * typ -> bool *)
+fun varified_type_match thy (candid_T, pat_T) =
+  strict_type_match thy (candid_T, Logic.varifyT pat_T)
+(* Proof.context -> typ -> term -> term *)
+fun postprocess_term ctxt T =
+  let val thy = ProofContext.theory_of ctxt in
+    if null (Data.get thy) then
+      I
+    else
+      (T |> AList.lookup (varified_type_match thy) (Data.get thy)
+         |> the_default (K (K I))) ctxt T
+  end
+
 (* bool -> atom_pool -> (string * string) * (string * string) -> scope
    -> nut list -> nut list -> nut list -> nut NameTable.table
    -> KK.raw_bound list -> typ -> typ -> rep -> int list list -> term *)
 fun reconstruct_term unfold pool ((maybe_name, abs_name), _)
-        ({hol_ctxt as {thy, stds, ...}, binarize, card_assigns, bits, datatypes,
-          ofs, ...} : scope) sel_names rel_table bounds =
+        ({hol_ctxt as {ctxt, thy, stds, ...}, binarize, card_assigns, bits,
+          datatypes, ofs, ...} : scope) sel_names rel_table bounds =
   let
     val for_auto = (maybe_name = "")
     (* int list list -> int *)
@@ -476,23 +506,11 @@
                     |> dest_n_tuple (length uncur_arg_Ts)
                 val t =
                   if constr_s = @{const_name Abs_Frac} then
-                    let
-                      val num_T = body_type T
-                      (* int -> term *)
-                      val mk_num = HOLogic.mk_number num_T
-                    in
-                      case ts of
-                        [Const (@{const_name Pair}, _) $ t1 $ t2] =>
-                        (case snd (HOLogic.dest_number t1) of
-                           0 => mk_num 0
-                         | n1 => case HOLogic.dest_number t2 |> snd of
-                                   1 => mk_num n1
-                                 | n2 => Const (@{const_name divide},
-                                                num_T --> num_T --> num_T)
-                                         $ mk_num n1 $ mk_num n2)
-                      | _ => raise TERM ("Nitpick_Model.reconstruct_term.\
-                                         \term_for_atom (Abs_Frac)", ts)
-                    end
+                    case ts of
+                      [Const (@{const_name Pair}, _) $ t1 $ t2] =>
+                      frac_from_term_pair (body_type T) t1 t2
+                    | _ => raise TERM ("Nitpick_Model.reconstruct_term.\
+                                       \term_for_atom (Abs_Frac)", ts)
                   else if not for_auto andalso
                           (is_abs_fun thy constr_x orelse
                            constr_s = @{const_name Quot}) then
@@ -523,6 +541,7 @@
                   t
               end
           end
+          |> postprocess_term ctxt T
     (* (typ * int) list -> int -> rep -> typ -> typ -> typ -> int list
        -> term *)
     and term_for_vect seen k R T1 T2 T' js =