src/HOL/Tools/Nitpick/nitpick_hol.ML
changeset 33580 45c33e97cb86
parent 33578 0c3ba1e010d2
child 33581 e1e77265fb1d
--- a/src/HOL/Tools/Nitpick/nitpick_hol.ML	Thu Nov 05 17:00:28 2009 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick_hol.ML	Thu Nov 05 17:03:22 2009 +0100
@@ -40,7 +40,8 @@
     skolems: (string * string list) list Unsynchronized.ref,
     special_funs: special_fun list Unsynchronized.ref,
     unrolled_preds: unrolled list Unsynchronized.ref,
-    wf_cache: wf_cache Unsynchronized.ref}
+    wf_cache: wf_cache Unsynchronized.ref,
+    constr_cache: (typ * styp list) list Unsynchronized.ref}
 
   val name_sep : string
   val numeral_prefix : string
@@ -100,16 +101,16 @@
   val unregister_frac_type : string -> theory -> theory
   val register_codatatype : typ -> string -> styp list -> theory -> theory
   val unregister_codatatype : typ -> theory -> theory
-  val datatype_constrs : theory -> typ -> styp list
+  val datatype_constrs : extended_context -> typ -> styp list
   val boxed_datatype_constrs : extended_context -> typ -> styp list
-  val num_datatype_constrs : theory -> typ -> int
+  val num_datatype_constrs : extended_context -> typ -> int
   val constr_name_for_sel_like : string -> string
   val boxed_constr_for_sel : extended_context -> styp -> styp
   val card_of_type : (typ * int) list -> typ -> int
   val bounded_card_of_type : int -> int -> (typ * int) list -> typ -> int
   val bounded_precise_card_of_type :
-    theory -> int -> int -> (typ * int) list -> typ -> int
-  val is_finite_type : theory -> typ -> bool
+    extended_context -> int -> int -> (typ * int) list -> typ -> int
+  val is_finite_type : extended_context -> typ -> bool
   val all_axioms_of : theory -> term list * term list * term list
   val arity_of_built_in_const : bool -> styp -> int option
   val is_built_in_const : bool -> styp -> bool
@@ -177,7 +178,8 @@
   skolems: (string * string list) list Unsynchronized.ref,
   special_funs: special_fun list Unsynchronized.ref,
   unrolled_preds: unrolled list Unsynchronized.ref,
-  wf_cache: wf_cache Unsynchronized.ref}
+  wf_cache: wf_cache Unsynchronized.ref,
+  constr_cache: (typ * styp list) list Unsynchronized.ref}
 
 structure TheoryData = TheoryDataFun(
   type T = {frac_types: (string * (string * string) list) list,
@@ -727,7 +729,7 @@
 fun suc_const T = Const (@{const_name Suc}, T --> T)
 
 (* theory -> typ -> styp list *)
-fun datatype_constrs thy (T as Type (s, Ts)) =
+fun uncached_datatype_constrs thy (T as Type (s, Ts)) =
     if is_datatype thy T then
       case Datatype.get_info thy s of
         SOME {index, descr, ...} =>
@@ -757,11 +759,19 @@
               []
     else
       []
-  | datatype_constrs _ _ = []
+  | uncached_datatype_constrs _ _ = []
 (* extended_context -> typ -> styp list *)
-fun boxed_datatype_constrs (ext_ctxt as {thy, ...}) =
-  map (apsnd (box_type ext_ctxt InConstr)) o datatype_constrs thy
-(* theory -> typ -> int *)
+fun datatype_constrs (ext_ctxt as {thy, constr_cache, ...} : extended_context)
+                     T =
+  case AList.lookup (op =) (!constr_cache) T of
+    SOME xs => xs
+  | NONE =>
+    let val xs = uncached_datatype_constrs thy T in
+      (Unsynchronized.change constr_cache (cons (T, xs)); xs)
+    end
+fun boxed_datatype_constrs ext_ctxt =
+  map (apsnd (box_type ext_ctxt InConstr)) o datatype_constrs ext_ctxt
+(* extended_context -> typ -> int *)
 val num_datatype_constrs = length oo datatype_constrs
 
 (* string -> string *)
@@ -774,26 +784,26 @@
     AList.lookup (op =) (boxed_datatype_constrs ext_ctxt (domain_type T')) s
     |> the |> pair s
   end
-(* theory -> styp -> term *)
-fun discr_term_for_constr thy (x as (s, T)) =
+(* extended_context -> styp -> term *)
+fun discr_term_for_constr ext_ctxt (x as (s, T)) =
   let val dataT = body_type T in
     if s = @{const_name Suc} then
       Abs (Name.uu, dataT,
            @{const Not} $ HOLogic.mk_eq (zero_const dataT, Bound 0))
-    else if num_datatype_constrs thy dataT >= 2 then
+    else if num_datatype_constrs ext_ctxt dataT >= 2 then
       Const (discr_for_constr x)
     else
       Abs (Name.uu, dataT, @{const True})
   end
 
-(* theory -> styp -> term -> term *)
-fun discriminate_value thy (x as (_, T)) t =
+(* extended_context -> styp -> term -> term *)
+fun discriminate_value (ext_ctxt as {thy, ...}) (x as (_, T)) t =
   case strip_comb t of
     (Const x', args) =>
     if x = x' then @{const True}
     else if is_constr_like thy x' then @{const False}
-    else betapply (discr_term_for_constr thy x, t)
-  | _ => betapply (discr_term_for_constr thy x, t)
+    else betapply (discr_term_for_constr ext_ctxt x, t)
+  | _ => betapply (discr_term_for_constr ext_ctxt x, t)
 
 (* styp -> term -> term *)
 fun nth_arg_sel_term_for_constr (x as (s, T)) n =
@@ -842,8 +852,8 @@
       | _ => list_comb (Const x, args)
     end
 
-(* theory -> typ -> term -> term *)
-fun constr_expand thy T t =
+(* extended_context -> typ -> term -> term *)
+fun constr_expand (ext_ctxt as {thy, ...}) T t =
   (case head_of t of
      Const x => if is_constr_like thy x then t else raise SAME ()
    | _ => raise SAME ())
@@ -855,7 +865,7 @@
                  (@{const_name Pair}, [T1, T2] ---> T)
                end
              else
-               datatype_constrs thy T |> the_single
+               datatype_constrs ext_ctxt T |> the_single
            val arg_Ts = binder_types T'
          in
            list_comb (Const x', map2 (select_nth_constr_arg thy x' t)
@@ -897,8 +907,8 @@
                     card_of_type asgns T
                     handle TYPE ("Nitpick_HOL.card_of_type", _, _) =>
                            default_card)
-(* theory -> int -> (typ * int) list -> typ -> int *)
-fun bounded_precise_card_of_type thy max default_card asgns T =
+(* extended_context -> int -> (typ * int) list -> typ -> int *)
+fun bounded_precise_card_of_type ext_ctxt max default_card asgns T =
   let
     (* typ list -> typ -> int *)
     fun aux avoid T =
@@ -928,12 +938,12 @@
        | @{typ bool} => 2
        | @{typ unit} => 1
        | Type _ =>
-         (case datatype_constrs thy T of
+         (case datatype_constrs ext_ctxt T of
             [] => if is_integer_type T then 0 else raise SAME ()
           | constrs =>
             let
               val constr_cards =
-                datatype_constrs thy T
+                datatype_constrs ext_ctxt T
                 |> map (Integer.prod o map (aux (T :: avoid)) o binder_types
                         o snd)
             in
@@ -944,8 +954,9 @@
       handle SAME () => AList.lookup (op =) asgns T |> the_default default_card
   in Int.min (max, aux [] T) end
 
-(* theory -> typ -> bool *)
-fun is_finite_type thy = not_equal 0 o bounded_precise_card_of_type thy 1 2 []
+(* extended_context -> typ -> bool *)
+fun is_finite_type ext_ctxt =
+  not_equal 0 o bounded_precise_card_of_type ext_ctxt 1 2 []
 
 (* term -> bool *)
 fun is_ground_term (t1 $ t2) = is_ground_term t1 andalso is_ground_term t2
@@ -1280,25 +1291,26 @@
     list_comb (Bound j, map2 (select_nth_constr_arg thy x (Bound 0))
                              (index_seq 0 (length arg_Ts)) arg_Ts)
   end
-(* theory -> typ -> int * styp -> term -> term *)
-fun add_constr_case thy res_T (j, x) res_t =
+(* extended_context -> typ -> int * styp -> term -> term *)
+fun add_constr_case (ext_ctxt as {thy, ...}) res_T (j, x) res_t =
   Const (@{const_name If}, [bool_T, res_T, res_T] ---> res_T)
-  $ discriminate_value thy x (Bound 0) $ constr_case_body thy (j, x) $ res_t
-(* theory -> typ -> typ -> term *)
-fun optimized_case_def thy dataT res_T =
+  $ discriminate_value ext_ctxt x (Bound 0) $ constr_case_body thy (j, x)
+  $ res_t
+(* extended_context -> typ -> typ -> term *)
+fun optimized_case_def (ext_ctxt as {thy, ...}) dataT res_T =
   let
-    val xs = datatype_constrs thy dataT
+    val xs = datatype_constrs ext_ctxt dataT
     val func_Ts = map ((fn T => binder_types T ---> res_T) o snd) xs
     val (xs', x) = split_last xs
   in
     constr_case_body thy (1, x)
-    |> fold_rev (add_constr_case thy res_T) (length xs downto 2 ~~ xs')
+    |> fold_rev (add_constr_case ext_ctxt res_T) (length xs downto 2 ~~ xs')
     |> fold_rev (curry absdummy) (func_Ts @ [dataT])
   end
 
-(* theory -> string -> typ -> typ -> term -> term *)
-fun optimized_record_get thy s rec_T res_T t =
-  let val constr_x = the_single (datatype_constrs thy rec_T) in
+(* extended_context -> string -> typ -> typ -> term -> term *)
+fun optimized_record_get (ext_ctxt as {thy, ...}) s rec_T res_T t =
+  let val constr_x = the_single (datatype_constrs ext_ctxt rec_T) in
     case no_of_record_field thy s rec_T of
       ~1 => (case rec_T of
                Type (_, Ts as _ :: _) =>
@@ -1307,16 +1319,16 @@
                  val j = num_record_fields thy rec_T - 1
                in
                  select_nth_constr_arg thy constr_x t j res_T
-                 |> optimized_record_get thy s rec_T' res_T
+                 |> optimized_record_get ext_ctxt s rec_T' res_T
                end
              | _ => raise TYPE ("Nitpick_HOL.optimized_record_get", [rec_T],
                                 []))
     | j => select_nth_constr_arg thy constr_x t j res_T
   end
-(* theory -> string -> typ -> term -> term -> term *)
-fun optimized_record_update thy s rec_T fun_t rec_t =
+(* extended_context -> string -> typ -> term -> term -> term *)
+fun optimized_record_update (ext_ctxt as {thy, ...}) s rec_T fun_t rec_t =
   let
-    val constr_x as (_, constr_T) = the_single (datatype_constrs thy rec_T)
+    val constr_x as (_, constr_T) = the_single (datatype_constrs ext_ctxt rec_T)
     val Ts = binder_types constr_T
     val n = length Ts
     val special_j = no_of_record_field thy s rec_T
@@ -1327,7 +1339,7 @@
                         if j = special_j then
                           betapply (fun_t, t)
                         else if j = n - 1 andalso special_j = ~1 then
-                          optimized_record_update thy s
+                          optimized_record_update ext_ctxt s
                               (rec_T |> dest_Type |> snd |> List.last) fun_t t
                         else
                           t
@@ -1471,7 +1483,7 @@
           val (const, ts) =
             if is_built_in_const fast_descrs x then
               if s = @{const_name finite} then
-                if is_finite_type thy (domain_type T) then
+                if is_finite_type ext_ctxt (domain_type T) then
                   (Abs ("A", domain_type T, @{const True}), ts)
                 else case ts of
                   [Const (@{const_name UNIV}, _)] => (@{const False}, [])
@@ -1484,7 +1496,7 @@
                 val (dataT, res_T) = nth_range_type n T
                                      |> domain_type pairf range_type
               in
-                (optimized_case_def thy dataT res_T
+                (optimized_case_def ext_ctxt dataT res_T
                  |> do_term (depth + 1) Ts, ts)
               end
             | _ =>
@@ -1493,15 +1505,14 @@
               else if is_record_get thy x then
                 case length ts of
                   0 => (do_term depth Ts (eta_expand Ts t 1), [])
-                | _ => (optimized_record_get thy s (domain_type T)
+                | _ => (optimized_record_get ext_ctxt s (domain_type T)
                                              (range_type T) (hd ts), tl ts)
               else if is_record_update thy x then
                 case length ts of
-                  2 => (optimized_record_update thy (unsuffix Record.updateN s)
-                                                (nth_range_type 2 T)
-                                                (do_term depth Ts (hd ts))
-                                                (do_term depth Ts (nth ts 1)),
-                        [])
+                  2 => (optimized_record_update ext_ctxt
+                            (unsuffix Record.updateN s) (nth_range_type 2 T)
+                            (do_term depth Ts (hd ts))
+                            (do_term depth Ts (nth ts 1)), [])
                 | n => (do_term depth Ts (eta_expand Ts t (2 - n)), [])
               else if is_rep_fun thy x then
                 let val x' = mate_of_rep_fun thy x in
@@ -1528,10 +1539,10 @@
         in s_betapplys (const, map (do_term depth Ts) ts) |> Envir.beta_norm end
   in do_term 0 [] end
 
-(* theory -> typ -> term list *)
-fun codatatype_bisim_axioms thy T =
+(* extended_context -> typ -> term list *)
+fun codatatype_bisim_axioms (ext_ctxt as {thy, ...}) T =
   let
-    val xs = datatype_constrs thy T
+    val xs = datatype_constrs ext_ctxt T
     val set_T = T --> bool_T
     val iter_T = @{typ bisim_iterator}
     val bisim_const = Const (@{const_name bisim}, [iter_T, T, T] ---> bool_T)
@@ -1554,14 +1565,14 @@
       let
         val arg_Ts = binder_types T
         val core_t =
-          discriminate_value thy x y_var ::
+          discriminate_value ext_ctxt x y_var ::
           map2 (nth_sub_bisim x) (index_seq 0 (length arg_Ts)) arg_Ts
           |> foldr1 s_conj
       in List.foldr absdummy core_t arg_Ts end
   in
     [HOLogic.eq_const bool_T $ (bisim_const $ n_var $ x_var $ y_var)
      $ (@{term "op |"} $ (HOLogic.eq_const iter_T $ n_var $ zero_const iter_T)
-        $ (betapplys (optimized_case_def thy T bool_T,
+        $ (betapplys (optimized_case_def ext_ctxt T bool_T,
                       map case_func xs @ [x_var]))),
      HOLogic.eq_const set_T $ (bisim_const $ bisim_max $ x_var)
      $ (Const (@{const_name insert}, [T, set_T] ---> set_T)
@@ -1621,11 +1632,11 @@
                         ScnpReconstruct.sizechange_tac]
 
 (* extended_context -> const_table -> styp -> bool *)
-fun is_is_well_founded_inductive_pred
+fun uncached_is_well_founded_inductive_pred
         ({thy, ctxt, debug, fast_descrs, tac_timeout, intro_table, ...}
          : extended_context) (x as (_, T)) =
   case def_props_for_const thy fast_descrs intro_table x of
-    [] => raise TERM ("Nitpick_HOL.is_is_well_founded_inductive_pred",
+    [] => raise TERM ("Nitpick_HOL.uncached_is_well_founded_inductive",
                       [Const x])
   | intro_ts =>
     (case map (triple_for_intro_rule thy x) intro_ts
@@ -1677,7 +1688,7 @@
                 | NONE =>
                   let
                     val gfp = (fixpoint_kind_of_const thy def_table x = Gfp)
-                    val wf = is_is_well_founded_inductive_pred ext_ctxt x
+                    val wf = uncached_is_well_founded_inductive_pred ext_ctxt x
                   in
                     Unsynchronized.change wf_cache (cons (x, (gfp, wf))); wf
                   end
@@ -1987,8 +1998,8 @@
                                                          seen, concl)
   end
 
-(* theory -> bool -> term -> term *)
-fun destroy_pulled_out_constrs thy axiom t =
+(* extended_context -> bool -> term -> term *)
+fun destroy_pulled_out_constrs (ext_ctxt as {thy, ...}) axiom t =
   let
     (* styp -> int *)
     val num_occs_of_var =
@@ -2022,7 +2033,7 @@
               andalso (not careful orelse not (is_Var t1)
                        orelse String.isPrefix val_var_prefix
                                               (fst (fst (dest_Var t1)))) then
-             discriminate_value thy x t1 ::
+             discriminate_value ext_ctxt x t1 ::
              map3 (sel_eq x t1) (index_seq 0 (length args)) arg_Ts args
              |> foldr1 s_conj
              |> body_type (type_of t0) = prop_T ? HOLogic.mk_Trueprop
@@ -2711,7 +2722,7 @@
         | (Type (new_s, new_Ts as [new_T1, new_T2]),
            Type (old_s, old_Ts as [old_T1, old_T2])) =>
           if old_s mem [@{type_name fun_box}, @{type_name pair_box}, "*"] then
-            case constr_expand thy old_T t of
+            case constr_expand ext_ctxt old_T t of
               Const (@{const_name FunBox}, _) $ t1 =>
               if new_s = "fun" then
                 coerce_term Ts new_T (Type ("fun", old_Ts)) t1
@@ -3054,7 +3065,7 @@
         #> (if is_pure_typedef thy T then
               fold (add_def_axiom depth) (optimized_typedef_axioms thy z)
             else if max_bisim_depth >= 0 andalso is_codatatype thy T then
-              fold (add_def_axiom depth) (codatatype_bisim_axioms thy T)
+              fold (add_def_axiom depth) (codatatype_bisim_axioms ext_ctxt T)
             else
               I)
     (* int -> typ -> sort -> accumulator -> accumulator *)
@@ -3298,7 +3309,7 @@
       #> maybe_box ? box_fun_and_pair_in_term ext_ctxt def
       #> destroy_constrs ? (pull_out_universal_constrs thy def
                             #> pull_out_existential_constrs thy
-                            #> destroy_pulled_out_constrs thy def)
+                            #> destroy_pulled_out_constrs ext_ctxt def)
       #> curry_assms
       #> destroy_universal_equalities
       #> destroy_existential_equalities thy