--- a/src/HOL/Tools/Nitpick/nitpick_hol.ML Mon Jun 21 09:38:20 2010 +0200
+++ b/src/HOL/Tools/Nitpick/nitpick_hol.ML Mon Jun 21 11:15:21 2010 +0200
@@ -64,6 +64,10 @@
val iter_var_prefix : string
val strip_first_name_sep : string -> string * string
val original_name : string -> string
+ val abs_var : indexname * typ -> term -> term
+ val s_let : string -> int -> typ -> typ -> (term -> term) -> term -> term
+ val s_betapply : typ list -> term * term -> term
+ val s_betapplys : typ list -> term * term list -> term
val s_conj : term * term -> term
val s_disj : term * term -> term
val strip_any_connective : term -> term list * term
@@ -162,7 +166,6 @@
val is_finite_type : hol_context -> typ -> bool
val is_small_finite_type : hol_context -> typ -> bool
val special_bounds : term list -> (indexname * typ) list
- val abs_var : indexname * typ -> term -> term
val is_funky_typedef : theory -> typ -> bool
val all_axioms_of :
Proof.context -> (term * term) list -> term list * term list * term list
@@ -302,10 +305,55 @@
else
s
-fun s_betapply (Const (@{const_name If}, _) $ @{const True} $ t, _) = t
- | s_betapply (Const (@{const_name If}, _) $ @{const False} $ _, t) = t
- | s_betapply p = betapply p
-val s_betapplys = Library.foldl s_betapply
+fun abs_var ((s, j), T) body = Abs (s, T, abstract_over (Var ((s, j), T), body))
+
+fun let_var s = (nitpick_prefix ^ s, 999)
+val let_inline_threshold = 20
+
+fun s_let s n abs_T body_T f t =
+ if (n - 1) * (size_of_term t - 1) <= let_inline_threshold then
+ f t
+ else
+ let val z = (let_var s, abs_T) in
+ Const (@{const_name Let}, abs_T --> (abs_T --> body_T) --> body_T)
+ $ t $ abs_var z (incr_boundvars 1 (f (Var z)))
+ end
+
+fun loose_bvar1_count (Bound i, k) = if i = k then 1 else 0
+ | loose_bvar1_count (t1 $ t2, k) =
+ loose_bvar1_count (t1, k) + loose_bvar1_count (t2, k)
+ | loose_bvar1_count (Abs (_, _, t), k) = loose_bvar1_count (t, k + 1)
+ | loose_bvar1_count _ = 0
+
+fun s_betapply _ (Const (@{const_name If}, _) $ @{const True} $ t1', _) = t1'
+ | s_betapply _ (Const (@{const_name If}, _) $ @{const False} $ _, t2) = t2
+ | s_betapply Ts (Const (@{const_name Let},
+ Type (_, [bound_T, Type (_, [_, body_T])]))
+ $ t12 $ Abs (s, T, t13'), t2) =
+ let val body_T' = range_type body_T in
+ Const (@{const_name Let}, bound_T --> (bound_T --> body_T') --> body_T')
+ $ t12 $ Abs (s, T, s_betapply (T :: Ts) (t13', incr_boundvars 1 t2))
+ end
+ | s_betapply Ts (t1 as Abs (s1, T1, t1'), t2) =
+ (s_let s1 (loose_bvar1_count (t1', 0)) T1 (fastype_of1 (T1 :: Ts, t1'))
+ (curry betapply t1) t2
+ handle TERM _ => betapply (t1, t2)) (* FIXME: fix all uses *)
+ | s_betapply _ (t1, t2) = t1 $ t2
+fun s_betapplys Ts = Library.foldl (s_betapply Ts)
+
+fun s_beta_norm Ts t =
+ let
+ fun aux _ (Var _) = raise Same.SAME
+ | aux Ts (Abs (s, T, t')) = Abs (s, T, aux (T :: Ts) t')
+ | aux Ts ((t1 as Abs _) $ t2) =
+ Same.commit (aux Ts) (s_betapply Ts (t1, t2))
+ | aux Ts (t1 $ t2) =
+ ((case aux Ts t1 of
+ t1 as Abs _ => Same.commit (aux Ts) (s_betapply Ts (t1, t2))
+ | t1 => t1 $ Same.commit (aux Ts) t2)
+ handle Same.SAME => t1 $ aux Ts t2)
+ | aux _ _ = raise Same.SAME
+ in aux Ts t handle Same.SAME => t end
fun s_conj (t1, @{const True}) = t1
| s_conj (@{const True}, t2) = t2
@@ -344,7 +392,7 @@
(@{const_name True}, 0),
(@{const_name All}, 1),
(@{const_name Ex}, 1),
- (@{const_name "op ="}, 2),
+ (@{const_name "op ="}, 1),
(@{const_name "op &"}, 2),
(@{const_name "op |"}, 2),
(@{const_name "op -->"}, 2),
@@ -355,7 +403,6 @@
(@{const_name fst}, 1),
(@{const_name snd}, 1),
(@{const_name Id}, 0),
- (@{const_name insert}, 2),
(@{const_name converse}, 1),
(@{const_name trancl}, 1),
(@{const_name rel_comp}, 2),
@@ -396,10 +443,7 @@
((@{const_name ord_class.less_eq}, nat_T --> nat_T --> bool_T), 2),
((@{const_name of_nat}, nat_T --> int_T), 0)]
val built_in_set_consts =
- [(@{const_name semilattice_inf_class.inf}, 2),
- (@{const_name semilattice_sup_class.sup}, 2),
- (@{const_name minus_class.minus}, 2),
- (@{const_name ord_class.less_eq}, 2)]
+ [(@{const_name ord_class.less_eq}, 2)]
fun unarize_type @{typ "unsigned_bit word"} = nat_T
| unarize_type @{typ "signed_bit word"} = int_T
@@ -924,8 +968,8 @@
Const x' =>
if x = x' then @{const True}
else if is_constr_like ctxt x' then @{const False}
- else betapply (discr_term_for_constr hol_ctxt x, t)
- | _ => betapply (discr_term_for_constr hol_ctxt x, t)
+ else s_betapply [] (discr_term_for_constr hol_ctxt x, t)
+ | _ => s_betapply [] (discr_term_for_constr hol_ctxt x, t)
fun nth_arg_sel_term_for_constr thy stds (x as (s, T)) n =
let val (arg_Ts, dataT) = strip_type T in
@@ -955,7 +999,8 @@
else if is_constr_like ctxt x' then Const (@{const_name unknown}, res_T)
else raise SAME ()
| _ => raise SAME())
- handle SAME () => betapply (nth_arg_sel_term_for_constr thy stds x n, t)
+ handle SAME () =>
+ s_betapply [] (nth_arg_sel_term_for_constr thy stds x n, t)
end
fun construct_value _ _ x [] = Const x
@@ -1150,8 +1195,6 @@
fun special_bounds ts =
fold Term.add_vars ts [] |> sort (Term_Ord.fast_indexname_ord o pairself fst)
-fun abs_var ((s, j), T) body = Abs (s, T, abstract_over (Var ((s, j), T), body))
-
fun is_funky_typedef_name thy s =
member (op =) [@{type_name unit}, @{type_name "*"}, @{type_name "+"},
@{type_name int}] s orelse
@@ -1309,7 +1352,8 @@
original_name s <> s then
NONE
else
- x |> def_props_for_const thy [(NONE, false)] false table |> List.last
+ x |> def_props_for_const thy [(NONE, false)] false table
+ |> List.last
|> normalized_rhs_of |> Option.map (prefix_abs_vars s)
handle List.Empty => NONE
@@ -1458,25 +1502,44 @@
(** Constant unfolding **)
-fun constr_case_body ctxt stds (j, (x as (_, T))) =
+fun constr_case_body ctxt stds (func_t, (x as (_, T))) =
let val arg_Ts = binder_types T in
- list_comb (Bound j, map2 (select_nth_constr_arg ctxt stds x (Bound 0))
- (index_seq 0 (length arg_Ts)) arg_Ts)
+ s_betapplys [] (func_t, map2 (select_nth_constr_arg ctxt stds x (Bound 0))
+ (index_seq 0 (length arg_Ts)) arg_Ts)
end
-fun add_constr_case (hol_ctxt as {ctxt, stds, ...}) res_T (j, x) res_t =
- Const (@{const_name If}, bool_T --> res_T --> res_T --> res_T)
- $ discriminate_value hol_ctxt x (Bound 0) $ constr_case_body ctxt stds (j, x)
- $ res_t
-fun optimized_case_def (hol_ctxt as {ctxt, stds, ...}) dataT res_T =
+fun add_constr_case res_T (body_t, guard_t) res_t =
+ if res_T = bool_T then
+ s_conj (HOLogic.mk_imp (guard_t, body_t), res_t)
+ else
+ Const (@{const_name If}, bool_T --> res_T --> res_T --> res_T)
+ $ guard_t $ body_t $ res_t
+fun optimized_case_def (hol_ctxt as {ctxt, stds, ...}) dataT res_T func_ts =
let
val xs = datatype_constrs hol_ctxt dataT
- val func_Ts = map ((fn T => binder_types T ---> res_T) o snd) xs
- val (xs', x) = split_last xs
+ val cases =
+ func_ts ~~ xs
+ |> map (fn (func_t, x) =>
+ (constr_case_body ctxt stds (incr_boundvars 1 func_t, x),
+ discriminate_value hol_ctxt x (Bound 0)))
+ |> AList.group (op aconv)
+ |> map (apsnd (List.foldl s_disj @{const False}))
+ |> sort (int_ord o pairself (size_of_term o fst))
+ |> rev
in
- constr_case_body ctxt stds (1, x)
- |> fold_rev (add_constr_case hol_ctxt res_T) (length xs downto 2 ~~ xs')
- |> fold_rev (curry absdummy) (func_Ts @ [dataT])
+ if res_T = bool_T then
+ if forall (member (op =) [@{const False}, @{const True}] o fst) cases then
+ case cases of
+ [(body_t, _)] => body_t
+ | [_, (@{const True}, head_t2)] => head_t2
+ | [_, (@{const False}, head_t2)] => @{const Not} $ head_t2
+ | _ => raise BAD ("Nitpick_HOL.optimized_case_def", "impossible cases")
+ else
+ @{const True} |> fold_rev (add_constr_case res_T) cases
+ else
+ fst (hd cases) |> fold_rev (add_constr_case res_T) (tl cases)
end
+ |> curry absdummy dataT
+
fun optimized_record_get (hol_ctxt as {thy, ctxt, stds, ...}) s rec_T res_T t =
let val constr_x = hd (datatype_constrs hol_ctxt rec_T) in
case no_of_record_field thy s rec_T of
@@ -1504,7 +1567,7 @@
map2 (fn j => fn T =>
let val t = select_nth_constr_arg ctxt stds constr_x rec_t j T in
if j = special_j then
- betapply (fun_t, t)
+ s_betapply [] (fun_t, t)
else if j = n - 1 andalso special_j = ~1 then
optimized_record_update hol_ctxt s
(rec_T |> dest_Type |> snd |> List.last) fun_t t
@@ -1542,12 +1605,13 @@
handle TERM _ => raise SAME ()
else
raise SAME ())
- handle SAME () => betapply (do_term depth Ts t0, do_term depth Ts t1))
+ handle SAME () =>
+ s_betapply [] (do_term depth Ts t0, do_term depth Ts t1))
| Const (@{const_name refl_on}, T) $ Const (@{const_name top}, _) $ t2 =>
do_const depth Ts t (@{const_name refl'}, range_type T) [t2]
| (t0 as Const (@{const_name Sigma}, _)) $ t1 $ (t2 as Abs (_, _, t2')) =>
- betapplys (t0 |> loose_bvar1 (t2', 0) ? do_term depth Ts,
- map (do_term depth Ts) [t1, t2])
+ s_betapplys Ts (t0 |> loose_bvar1 (t2', 0) ? do_term depth Ts,
+ map (do_term depth Ts) [t1, t2])
| Const (x as (@{const_name distinct},
Type (@{type_name fun}, [Type (@{type_name list}, [T']), _])))
$ (t1 as _ $ _) =>
@@ -1560,11 +1624,11 @@
do_term depth Ts t2
else
do_const depth Ts t x [t1, t2, t3]
- | Const x $ t1 $ t2 $ t3 => do_const depth Ts t x [t1, t2, t3]
- | Const x $ t1 $ t2 => do_const depth Ts t x [t1, t2]
- | Const x $ t1 => do_const depth Ts t x [t1]
| Const x => do_const depth Ts t x []
- | t1 $ t2 => betapply (do_term depth Ts t1, do_term depth Ts t2)
+ | t1 $ t2 =>
+ (case strip_comb t of
+ (Const x, ts) => do_const depth Ts t x ts
+ | _ => s_betapply [] (do_term depth Ts t1, do_term depth Ts t2))
| Free _ => t
| Var _ => t
| Bound _ => t
@@ -1585,13 +1649,17 @@
(Const x, ts)
else case AList.lookup (op =) case_names s of
SOME n =>
- let
- val (dataT, res_T) = nth_range_type n T
- |> pairf domain_type range_type
- in
- (optimized_case_def hol_ctxt dataT res_T
- |> do_term (depth + 1) Ts, ts)
- end
+ if length ts < n then
+ (do_term depth Ts (eta_expand Ts t (n - length ts)), [])
+ else
+ let
+ val (dataT, res_T) = nth_range_type n T
+ |> pairf domain_type range_type
+ in
+ (optimized_case_def hol_ctxt dataT res_T
+ (map (do_term depth Ts) (take n ts)),
+ drop n ts)
+ end
| _ =>
if is_constr ctxt stds x then
(Const x, ts)
@@ -1645,11 +1713,14 @@
string_of_int depth ^ ") while expanding " ^
quote s)
else if s = @{const_name wfrec'} then
- (do_term (depth + 1) Ts (betapplys (def, ts)), [])
+ (do_term (depth + 1) Ts (s_betapplys Ts (def, ts)), [])
else
(do_term (depth + 1) Ts def, ts)
| NONE => (Const x, ts)
- in s_betapplys (const, map (do_term depth Ts) ts) |> Envir.beta_norm end
+ in
+ s_betapplys Ts (const, map (do_term depth Ts) ts)
+ |> s_beta_norm Ts
+ end
in do_term 0 [] end
(** Axiom extraction/generation **)
@@ -1796,8 +1867,9 @@
in
[HOLogic.eq_const bool_T $ (bisim_const $ n_var $ x_var $ y_var)
$ (@{term "op |"} $ (HOLogic.eq_const iter_T $ n_var $ zero_const iter_T)
- $ (betapplys (optimized_case_def hol_ctxt T bool_T,
- map case_func xs @ [x_var]))),
+ $ (s_betapply []
+ (optimized_case_def hol_ctxt T bool_T (map case_func xs),
+ x_var))),
HOLogic.eq_const set_T $ (bisim_const $ bisim_max $ x_var)
$ (Const (@{const_name insert}, T --> set_T --> set_T)
$ x_var $ Const (@{const_name bot_class.bot}, set_T))]
@@ -2036,11 +2108,12 @@
val outer_bounds = map Bound (length outer - 1 downto 0)
val cur = Var ((iter_var_prefix, j + 1), iter_T)
val next = suc_const iter_T $ cur
- val rhs = case fp_app of
- Const _ $ t =>
- betapply (t, list_comb (Const x', next :: outer_bounds))
- | _ => raise TERM ("Nitpick_HOL.unrolled_inductive_pred_\
- \const", [fp_app])
+ val rhs =
+ case fp_app of
+ Const _ $ t =>
+ s_betapply [] (t, list_comb (Const x', next :: outer_bounds))
+ | _ => raise TERM ("Nitpick_HOL.unrolled_inductive_pred_const",
+ [fp_app])
val (inner, naked_rhs) = strip_abs rhs
val all = outer @ inner
val bounds = map Bound (length all - 1 downto 0)
@@ -2056,10 +2129,10 @@
val def = the (def_of_const thy def_table x)
val (outer, fp_app) = strip_abs def
val outer_bounds = map Bound (length outer - 1 downto 0)
- val rhs = case fp_app of
- Const _ $ t => betapply (t, list_comb (Const x, outer_bounds))
- | _ => raise TERM ("Nitpick_HOL.raw_inductive_pred_axiom",
- [fp_app])
+ val rhs =
+ case fp_app of
+ Const _ $ t => s_betapply [] (t, list_comb (Const x, outer_bounds))
+ | _ => raise TERM ("Nitpick_HOL.raw_inductive_pred_axiom", [fp_app])
val (inner, naked_rhs) = strip_abs rhs
val all = outer @ inner
val bounds = map Bound (length all - 1 downto 0)