--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Thu Nov 19 08:19:57 2009 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Thu Nov 19 08:25:47 2009 +0100
@@ -82,10 +82,6 @@
^ "\n" ^ Pretty.string_of (Pretty.chunks
(Goal_Display.pretty_goals_without_context (! Goal_Display.goals_limit) st)));
-(* reference to preprocessing of InductiveSet package *)
-
-val ind_set_codegen_preproc = (fn thy => I) (*Inductive_Set.codegen_preproc;*)
-
(** fundamentals **)
(* syntactic operations *)
@@ -417,22 +413,45 @@
end
(* validity checks *)
+(* EXPECTED MODE and PROPOSED_MODE are largely the same; define a clear semantics for those! *)
-fun check_expected_modes preds (options : Predicate_Compile_Aux.options) modes =
- case expected_modes options of
- SOME (s, ms) => (case AList.lookup (op =) modes s of
- SOME modes =>
- let
- val modes' = map (translate_mode (the (AList.lookup (op =) preds s))) modes
- in
- if not (eq_set eq_mode' (ms, modes')) then
- error ("expected modes were not inferred:\n"
- ^ " inferred modes for " ^ s ^ ": " ^ commas (map string_of_mode' modes') ^ "\n"
- ^ " expected modes for " ^ s ^ ": " ^ commas (map string_of_mode' ms))
- else ()
- end
- | NONE => ())
- | NONE => ()
+fun check_expected_modes preds options modes =
+ case expected_modes options of
+ SOME (s, ms) => (case AList.lookup (op =) modes s of
+ SOME modes =>
+ let
+ val modes' = map (translate_mode (the (AList.lookup (op =) preds s))) modes
+ in
+ if not (eq_set eq_mode' (ms, modes')) then
+ error ("expected modes were not inferred:\n"
+ ^ " inferred modes for " ^ s ^ ": " ^ commas (map string_of_mode' modes') ^ "\n"
+ ^ " expected modes for " ^ s ^ ": " ^ commas (map string_of_mode' ms))
+ else ()
+ end
+ | NONE => ())
+ | NONE => ()
+
+fun check_proposed_modes preds options modes extra_modes errors =
+ case proposed_modes options of
+ SOME (s, ms) => (case AList.lookup (op =) modes s of
+ SOME inferred_ms =>
+ let
+ val preds_without_modes = map fst (filter (null o snd) (modes @ extra_modes))
+ val modes' = map (translate_mode (the (AList.lookup (op =) preds s))) inferred_ms
+ in
+ if not (eq_set eq_mode' (ms, modes')) then
+ error ("expected modes were not inferred:\n"
+ ^ " inferred modes for " ^ s ^ ": " ^ commas (map string_of_mode' modes') ^ "\n"
+ ^ " expected modes for " ^ s ^ ": " ^ commas (map string_of_mode' ms) ^ "\n"
+ ^ "For the following clauses, the following modes could not be inferred: " ^ "\n"
+ ^ cat_lines errors ^
+ (if not (null preds_without_modes) then
+ "\n" ^ "No mode inferred for the predicates " ^ commas preds_without_modes
+ else ""))
+ else ()
+ end
+ | NONE => ())
+ | NONE => ()
(* importing introduction rules *)
@@ -584,13 +603,13 @@
let
val (const, _) = strip_comb (HOLogic.dest_Trueprop (concl_of intro))
in (fst (dest_Const const) = name) end;
- val intros = ind_set_codegen_preproc thy
+ val intros =
(map (expand_tuples thy #> preprocess_intro thy) (filter is_intro_of (#intrs result)))
val index = find_index (fn s => s = name) (#names (fst info))
val pre_elim = nth (#elims result) index
val pred = nth (#preds result) index
val nparams = length (Inductive.params_of (#raw_induct result))
- (*val elim = singleton (ind_set_codegen_preproc thy) (preprocess_elim thy nparams
+ (*val elim = singleton (Inductive_Set.codegen_preproc thy) (preprocess_elim thy nparams
(expand_tuples_elim pre_elim))*)
val elim =
(Drule.standard o Skip_Proof.make_thm thy)
@@ -659,8 +678,8 @@
fun register_predicate (constname, pre_intros, pre_elim, nparams) thy =
let
(* preprocessing *)
- val intros = ind_set_codegen_preproc thy (map (preprocess_intro thy) pre_intros)
- val elim = singleton (ind_set_codegen_preproc thy) (preprocess_elim thy nparams pre_elim)
+ val intros = map (preprocess_intro thy) pre_intros
+ val elim = preprocess_elim thy nparams pre_elim
in
if not (member (op =) (Graph.keys (PredData.get thy)) constname) then
PredData.map
@@ -1042,21 +1061,34 @@
else NONE
end;
-fun print_failed_mode options thy modes p m rs i =
+fun print_failed_mode options thy modes p m rs is =
if show_mode_inference options then
let
- val _ = tracing ("Clause " ^ string_of_int (i + 1) ^ " of " ^
- p ^ " violates mode " ^ string_of_mode thy p m)
+ val _ = tracing ("Clauses " ^ commas (map (fn i => string_of_int (i + 1)) is) ^ " of " ^
+ p ^ " violates mode " ^ string_of_mode thy p m)
in () end
else ()
+fun error_of thy p m is =
+ (" Clauses " ^ commas (map (fn i => string_of_int (i + 1)) is) ^ " of " ^
+ p ^ " violates mode " ^ string_of_mode thy p m)
+
+fun find_indices f xs =
+ map_filter (fn (i, true) => SOME i | (i, false) => NONE) (map_index (apsnd f) xs)
+
fun check_modes_pred options with_generator thy param_vs clauses modes gen_modes (p, ms) =
let
val rs = case AList.lookup (op =) clauses p of SOME rs => rs | NONE => []
- in (p, filter (fn m => case find_index
- (is_none o check_mode_clause with_generator thy param_vs modes gen_modes m) rs of
- ~1 => true
- | i => (print_failed_mode options thy modes p m rs i; false)) ms)
+ fun invalid_mode m =
+ case find_indices
+ (is_none o check_mode_clause with_generator thy param_vs modes gen_modes m) rs of
+ [] => NONE
+ | is => SOME (error_of thy p m is)
+ val res = map (fn m => (m, invalid_mode m)) ms
+ val ms' = map_filter (fn (m, NONE) => SOME m | _ => NONE) res
+ val errors = map_filter snd res
+ in
+ ((p, ms'), errors)
end;
fun get_modes_pred with_generator thy param_vs clauses modes gen_modes (p, ms) =
@@ -1071,14 +1103,24 @@
let val y = f x
in if x = y then x else fixp f y end;
+fun fixp_with_state f ((x : (string * mode list) list), state) =
+ let
+ val (y, state') = f (x, state)
+ in
+ if x = y then (y, state') else fixp_with_state f (y, state')
+ end
+
fun infer_modes options thy extra_modes all_modes param_vs clauses =
let
- val modes =
- fixp (fn modes =>
- map (check_modes_pred options false thy param_vs clauses (modes @ extra_modes) []) modes)
- all_modes
+ val (modes, errors) =
+ fixp_with_state (fn (modes, errors) =>
+ let
+ val res = map
+ (check_modes_pred options false thy param_vs clauses (modes @ extra_modes) []) modes
+ in (map fst res, errors @ maps snd res) end)
+ (all_modes, [])
in
- map (get_modes_pred false thy param_vs clauses (modes @ extra_modes) []) modes
+ (map (get_modes_pred false thy param_vs clauses (modes @ extra_modes) []) modes, errors)
end;
fun remove_from rem [] = []
@@ -1087,7 +1129,7 @@
NONE => (k, vs)
| SOME vs' => (k, subtract (op =) vs' vs))
:: remove_from rem xs
-
+
fun infer_modes_with_generator options thy extra_modes all_modes param_vs clauses =
let
val prednames = map fst clauses
@@ -1096,16 +1138,21 @@
|> filter_out (fn (name, _) => member (op =) prednames name)
val starting_modes = remove_from extra_modes' all_modes
fun eq_mode (m1, m2) = (m1 = m2)
- val modes =
- fixp (fn modes =>
- map (check_modes_pred options true thy param_vs clauses extra_modes'
- (gen_modes @ modes)) modes) starting_modes
+ val (modes, errors) =
+ fixp_with_state (fn (modes, errors) =>
+ let
+ val res = map
+ (check_modes_pred options true thy param_vs clauses extra_modes'
+ (gen_modes @ modes)) modes
+ in (map fst res, errors @ maps snd res) end) (starting_modes, [])
+ val moded_clauses =
+ map (get_modes_pred true thy param_vs clauses extra_modes (gen_modes @ modes)) modes
+ val (moded_clauses', _) = infer_modes options thy extra_modes all_modes param_vs clauses
+ val join_moded_clauses_table = AList.join (op =)
+ (fn _ => fn ((mps1, mps2)) =>
+ merge (fn ((m1, _), (m2, _)) => eq_mode (m1, m2)) (mps1, mps2))
in
- AList.join (op =)
- (fn _ => fn ((mps1, mps2)) =>
- merge (fn ((m1, _), (m2, _)) => eq_mode (m1, m2)) (mps1, mps2))
- (infer_modes options thy extra_modes all_modes param_vs clauses,
- map (get_modes_pred true thy param_vs clauses extra_modes (gen_modes @ modes)) modes)
+ (join_moded_clauses_table (moded_clauses', moded_clauses), errors)
end;
(* term construction *)
@@ -1524,66 +1571,67 @@
let
val compfuns = PredicateCompFuns.compfuns
val T = AList.lookup (op =) preds name |> the
- fun create_definition (mode as (iss, is)) thy = let
- val mode_cname = create_constname_of_mode options thy "" name T mode
- val mode_cbasename = Long_Name.base_name mode_cname
- val Ts = binder_types T
- val (Ts1, Ts2) = chop (length iss) Ts
- val (Us1, Us2) = split_smodeT is Ts2
- val Ts1' = map2 (fn NONE => I | SOME is => funT_of compfuns ([], is)) iss Ts1
- val funT = (Ts1' @ Us1) ---> (mk_predT compfuns (HOLogic.mk_tupleT Us2))
- val names = Name.variant_list []
- (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
- val param_names = Name.variant_list []
- (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts1')))
- val xparams = map2 (curry Free) param_names Ts1'
- fun mk_vars (i, T) names =
- let
- val vname = Name.variant names ("x" ^ string_of_int (length Ts1' + i))
- in
- case AList.lookup (op =) is i of
- NONE => ((([], [Free (vname, T)]), Free (vname, T)), vname :: names)
- | SOME NONE => ((([Free (vname, T)], []), Free (vname, T)), vname :: names)
- | SOME (SOME pis) =>
- let
- val (Tins, Touts) = split_tupleT pis T
- val name_in = Name.variant names ("x" ^ string_of_int (length Ts1' + i) ^ "in")
- val name_out = Name.variant names ("x" ^ string_of_int (length Ts1' + i) ^ "out")
- val xin = Free (name_in, HOLogic.mk_tupleT Tins)
- val xout = Free (name_out, HOLogic.mk_tupleT Touts)
- val xarg = mk_arg xin xout pis T
- in
- (((if null Tins then [] else [xin],
- if null Touts then [] else [xout]), xarg), name_in :: name_out :: names) end
- end
- val (xinoutargs, names) = fold_map mk_vars ((1 upto (length Ts2)) ~~ Ts2) param_names
- val (xinout, xargs) = split_list xinoutargs
- val (xins, xouts) = pairself flat (split_list xinout)
- val (xparams', names') = fold_map (mk_Eval_of []) ((xparams ~~ Ts1) ~~ iss) names
- fun mk_split_lambda [] t = lambda (Free (Name.variant names' "x", HOLogic.unitT)) t
- | mk_split_lambda [x] t = lambda x t
- | mk_split_lambda xs t =
- let
- fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))
- | mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t))
- in
- mk_split_lambda' xs t
+ fun create_definition (mode as (iss, is)) thy =
+ let
+ val mode_cname = create_constname_of_mode options thy "" name T mode
+ val mode_cbasename = Long_Name.base_name mode_cname
+ val Ts = binder_types T
+ val (Ts1, Ts2) = chop (length iss) Ts
+ val (Us1, Us2) = split_smodeT is Ts2
+ val Ts1' = map2 (fn NONE => I | SOME is => funT_of compfuns ([], is)) iss Ts1
+ val funT = (Ts1' @ Us1) ---> (mk_predT compfuns (HOLogic.mk_tupleT Us2))
+ val names = Name.variant_list []
+ (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
+ val param_names = Name.variant_list []
+ (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts1')))
+ val xparams = map2 (curry Free) param_names Ts1'
+ fun mk_vars (i, T) names =
+ let
+ val vname = Name.variant names ("x" ^ string_of_int (length Ts1' + i))
+ in
+ case AList.lookup (op =) is i of
+ NONE => ((([], [Free (vname, T)]), Free (vname, T)), vname :: names)
+ | SOME NONE => ((([Free (vname, T)], []), Free (vname, T)), vname :: names)
+ | SOME (SOME pis) =>
+ let
+ val (Tins, Touts) = split_tupleT pis T
+ val name_in = Name.variant names ("x" ^ string_of_int (length Ts1' + i) ^ "in")
+ val name_out = Name.variant names ("x" ^ string_of_int (length Ts1' + i) ^ "out")
+ val xin = Free (name_in, HOLogic.mk_tupleT Tins)
+ val xout = Free (name_out, HOLogic.mk_tupleT Touts)
+ val xarg = mk_arg xin xout pis T
+ in
+ (((if null Tins then [] else [xin],
+ if null Touts then [] else [xout]), xarg), name_in :: name_out :: names) end
+ end
+ val (xinoutargs, names) = fold_map mk_vars ((1 upto (length Ts2)) ~~ Ts2) param_names
+ val (xinout, xargs) = split_list xinoutargs
+ val (xins, xouts) = pairself flat (split_list xinout)
+ val (xparams', names') = fold_map (mk_Eval_of []) ((xparams ~~ Ts1) ~~ iss) names
+ fun mk_split_lambda [] t = lambda (Free (Name.variant names' "x", HOLogic.unitT)) t
+ | mk_split_lambda [x] t = lambda x t
+ | mk_split_lambda xs t =
+ let
+ fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))
+ | mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t))
+ in
+ mk_split_lambda' xs t
+ end;
+ val predterm = PredicateCompFuns.mk_Enum (mk_split_lambda xouts
+ (list_comb (Const (name, T), xparams' @ xargs)))
+ val lhs = list_comb (Const (mode_cname, funT), xparams @ xins)
+ val def = Logic.mk_equals (lhs, predterm)
+ val ([definition], thy') = thy |>
+ Sign.add_consts_i [(Binding.name mode_cbasename, funT, NoSyn)] |>
+ PureThy.add_defs false [((Binding.name (mode_cbasename ^ "_def"), def), [])]
+ val (intro, elim) =
+ create_intro_elim_rule mode definition mode_cname funT (Const (name, T)) thy'
+ in thy'
+ |> add_predfun name mode (mode_cname, definition, intro, elim)
+ |> PureThy.store_thm (Binding.name (mode_cbasename ^ "I"), intro) |> snd
+ |> PureThy.store_thm (Binding.name (mode_cbasename ^ "E"), elim) |> snd
+ |> Theory.checkpoint
end;
- val predterm = PredicateCompFuns.mk_Enum (mk_split_lambda xouts
- (list_comb (Const (name, T), xparams' @ xargs)))
- val lhs = list_comb (Const (mode_cname, funT), xparams @ xins)
- val def = Logic.mk_equals (lhs, predterm)
- val ([definition], thy') = thy |>
- Sign.add_consts_i [(Binding.name mode_cbasename, funT, NoSyn)] |>
- PureThy.add_defs false [((Binding.name (mode_cbasename ^ "_def"), def), [])]
- val (intro, elim) =
- create_intro_elim_rule mode definition mode_cname funT (Const (name, T)) thy'
- in thy'
- |> add_predfun name mode (mode_cname, definition, intro, elim)
- |> PureThy.store_thm (Binding.name (mode_cbasename ^ "I"), intro) |> snd
- |> PureThy.store_thm (Binding.name (mode_cbasename ^ "E"), elim) |> snd
- |> Theory.checkpoint
- end;
in
fold create_definition modes thy
end;
@@ -2106,9 +2154,10 @@
map SOME (all_smodes_of_typs Rs) | _ => [NONE]) Ts), all_smodes_of_typs Us)
end
val all_modes = map (fn (s, T) =>
- case proposed_modes options s of
+ case proposed_modes options of
NONE => (s, modes_of_typ T)
- | SOME modes' => (s, map (translate_mode' nparams) modes'))
+ | SOME (s', modes') =>
+ if s = s' then (s, map (translate_mode' nparams) modes') else (s, modes_of_typ T))
preds
in (preds, nparams, all_vs, param_vs, extra_modes, clauses, all_modes) end;
@@ -2199,7 +2248,7 @@
define_functions : options -> (string * typ) list -> string * mode list -> theory -> theory,
infer_modes : options -> theory -> (string * mode list) list -> (string * mode list) list
-> string list -> (string * (term list * indprem list) list) list
- -> moded_clause list pred_mode_table,
+ -> moded_clause list pred_mode_table * string list,
prove : options -> theory -> (string * (term list * indprem list) list) list
-> (string * typ) list -> (string * mode list) list
-> moded_clause list pred_mode_table -> term pred_mode_table -> thm pred_mode_table,
@@ -2220,10 +2269,11 @@
val (preds, nparams, all_vs, param_vs, extra_modes, clauses, all_modes) =
prepare_intrs options thy prednames (maps (intros_of thy) prednames)
val _ = print_step options "Infering modes..."
- val moded_clauses =
+ val (moded_clauses, errors) =
#infer_modes (dest_steps steps) options thy extra_modes all_modes param_vs clauses
val modes = map (fn (p, mps) => (p, map fst mps)) moded_clauses
val _ = check_expected_modes preds options modes
+ val _ = check_proposed_modes preds options modes extra_modes errors
val _ = print_modes options thy modes
(*val _ = print_moded_clauses thy moded_clauses*)
val _ = print_step options "Defining executable functions..."