--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Sat Oct 24 16:55:43 2009 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Sat Oct 24 16:55:43 2009 +0200
@@ -438,8 +438,9 @@
SOME (s, ms) => (case AList.lookup (op =) modes s of
SOME modes =>
if not (eq_set (map (map (rpair NONE)) ms, map snd modes)) then
- error ("expected modes were not inferred:"
- ^ "infered modes for " ^ s ^ ": " ^ commas (map (string_of_smode o snd) modes))
+ error ("expected modes were not inferred:\n"
+ ^ "inferred modes for " ^ s ^ ": "
+ ^ commas (map ((enclose "[" "]") o string_of_smode o snd) modes))
else ()
| NONE => ())
| NONE => ()
@@ -1165,9 +1166,79 @@
in (t' $ u', nvs'') end
| distinct_v x nvs = (x, nvs);
-fun compile_match thy compfuns eqs eqs' out_ts success_t =
+(** specific rpred functions -- move them to the correct place in this file *)
+
+fun mk_Eval_of additional_arguments ((x, T), NONE) names = (x, names)
+ | mk_Eval_of additional_arguments ((x, T), SOME mode) names =
+ let
+ val Ts = binder_types T
+ (*val argnames = Name.variant_list names
+ (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
+ val args = map Free (argnames ~~ Ts)
+ val (inargs, outargs) = split_smode mode args*)
+ fun mk_split_lambda [] t = lambda (Free (Name.variant names "x", HOLogic.unitT)) t
+ | mk_split_lambda [x] t = lambda x t
+ | mk_split_lambda xs t =
+ let
+ fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))
+ | mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t))
+ in
+ mk_split_lambda' xs t
+ end;
+ fun mk_arg (i, T) =
+ let
+ val vname = Name.variant names ("x" ^ string_of_int i)
+ val default = Free (vname, T)
+ in
+ case AList.lookup (op =) mode i of
+ NONE => (([], [default]), [default])
+ | SOME NONE => (([default], []), [default])
+ | SOME (SOME pis) =>
+ case HOLogic.strip_tupleT T of
+ [] => error "pair mode but unit tuple" (*(([default], []), [default])*)
+ | [_] => error "pair mode but not a tuple" (*(([default], []), [default])*)
+ | Ts =>
+ let
+ val vnames = Name.variant_list names
+ (map (fn j => "x" ^ string_of_int i ^ "p" ^ string_of_int j)
+ (1 upto length Ts))
+ val args = map Free (vnames ~~ Ts)
+ fun split_args (i, arg) (ins, outs) =
+ if member (op =) pis i then
+ (arg::ins, outs)
+ else
+ (ins, arg::outs)
+ val (inargs, outargs) = fold_rev split_args ((1 upto length Ts) ~~ args) ([], [])
+ fun tuple args = if null args then [] else [HOLogic.mk_tuple args]
+ in ((tuple inargs, tuple outargs), args) end
+ end
+ val (inoutargs, args) = split_list (map mk_arg (1 upto (length Ts) ~~ Ts))
+ val (inargs, outargs) = pairself flat (split_list inoutargs)
+ val r = PredicateCompFuns.mk_Eval (list_comb (x, inargs @ additional_arguments), mk_tuple outargs)
+ val t = fold_rev mk_split_lambda args r
+ in
+ (t, names)
+ end;
+
+fun compile_arg compilation_modifiers compfuns additional_arguments thy param_vs iss arg =
+ let
+ fun map_params (t as Free (f, T)) =
+ if member (op =) param_vs f then
+ case (the (AList.lookup (op =) (param_vs ~~ iss) f)) of
+ SOME is =>
+ let
+ val T' = #funT_of compilation_modifiers compfuns ([], is) T
+ in fst (mk_Eval_of additional_arguments ((Free (f, T'), T), SOME is) []) end
+ | NONE => t
+ else t
+ | map_params t = t
+ in map_aterms map_params arg end
+
+fun compile_match compilation_modifiers compfuns additional_arguments param_vs iss thy eqs eqs' out_ts success_t =
let
val eqs'' = maps mk_eq eqs @ eqs'
+ val eqs'' =
+ map (compile_arg compilation_modifiers compfuns additional_arguments thy param_vs iss) eqs''
val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) [];
val name = Name.variant names "x";
val name' = Name.variant (name :: names) "y";
@@ -1228,76 +1299,9 @@
| (Free (name, T), params) =>
list_comb (Free (name, #funT_of compilation_modifiers compfuns mode T), params @ inargs @ additional_arguments)
-(** specific rpred functions -- move them to the correct place in this file *)
-
-fun mk_Eval_of depth ((x, T), NONE) names = (x, names)
- | mk_Eval_of depth ((x, T), SOME mode) names =
- let
- val Ts = binder_types T
- (*val argnames = Name.variant_list names
- (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
- val args = map Free (argnames ~~ Ts)
- val (inargs, outargs) = split_smode mode args*)
- fun mk_split_lambda [] t = lambda (Free (Name.variant names "x", HOLogic.unitT)) t
- | mk_split_lambda [x] t = lambda x t
- | mk_split_lambda xs t =
- let
- fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))
- | mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t))
- in
- mk_split_lambda' xs t
- end;
- fun mk_arg (i, T) =
- let
- val vname = Name.variant names ("x" ^ string_of_int i)
- val default = Free (vname, T)
- in
- case AList.lookup (op =) mode i of
- NONE => (([], [default]), [default])
- | SOME NONE => (([default], []), [default])
- | SOME (SOME pis) =>
- case HOLogic.strip_tupleT T of
- [] => error "pair mode but unit tuple" (*(([default], []), [default])*)
- | [_] => error "pair mode but not a tuple" (*(([default], []), [default])*)
- | Ts =>
- let
- val vnames = Name.variant_list names
- (map (fn j => "x" ^ string_of_int i ^ "p" ^ string_of_int j)
- (1 upto length Ts))
- val args = map Free (vnames ~~ Ts)
- fun split_args (i, arg) (ins, outs) =
- if member (op =) pis i then
- (arg::ins, outs)
- else
- (ins, arg::outs)
- val (inargs, outargs) = fold_rev split_args ((1 upto length Ts) ~~ args) ([], [])
- fun tuple args = if null args then [] else [HOLogic.mk_tuple args]
- in ((tuple inargs, tuple outargs), args) end
- end
- val (inoutargs, args) = split_list (map mk_arg (1 upto (length Ts) ~~ Ts))
- val (inargs, outargs) = pairself flat (split_list inoutargs)
- val depth_t = case depth of NONE => [] | SOME (polarity, depth_t) => [polarity, depth_t]
- val r = PredicateCompFuns.mk_Eval (list_comb (x, inargs @ depth_t), mk_tuple outargs)
- val t = fold_rev mk_split_lambda args r
- in
- (t, names)
- end;
-
-fun compile_arg depth thy param_vs iss arg =
- let
- val funT_of = case depth of NONE => funT_of | SOME _ => depth_limited_funT_of
- fun map_params (t as Free (f, T)) =
- if member (op =) param_vs f then
- case (the (AList.lookup (op =) (param_vs ~~ iss) f)) of
- SOME is => let val T' = funT_of PredicateCompFuns.compfuns ([], is) T
- in fst (mk_Eval_of depth ((Free (f, T'), T), SOME is) []) end
- | NONE => t
- else t
- | map_params t = t
- in map_aterms map_params arg end
-
fun compile_clause compilation_modifiers compfuns thy all_vs param_vs additional_arguments (iss, is) inp (ts, moded_ps) =
let
+ val compile_match = compile_match compilation_modifiers compfuns additional_arguments param_vs iss thy
fun check_constrt t (names, eqs) =
if is_constrt thy t then (t, (names, eqs)) else
let
@@ -1316,7 +1320,7 @@
val (out_ts''', (names'', constr_vs)) = fold_map distinct_v
out_ts'' (names', map (rpair []) vs);
in
- compile_match thy compfuns constr_vs (eqs @ eqs') out_ts'''
+ compile_match constr_vs (eqs @ eqs') out_ts'''
(mk_single compfuns (mk_tuple out_ts))
end
| compile_prems out_ts vs names ((p, mode as Mode ((_, is), _, _)) :: ps) =
@@ -1332,13 +1336,8 @@
Prem (us, t) =>
let
val (in_ts, out_ts''') = split_smode is us;
- (* TODO: add test case for compile_arg *)
- (*val in_ts = map (compile_arg depth thy param_vs iss) in_ts*)
- (* additional_arguments
- val args = case depth of
- NONE => in_ts
- | SOME (polarity, depth_t) => in_ts @ [polarity, depth_t]
- *)
+ val in_ts = map (compile_arg compilation_modifiers compfuns additional_arguments
+ thy param_vs iss) in_ts
val u =
compile_expr compilation_modifiers compfuns thy (mode, t) in_ts additional_arguments'
val rest = compile_prems out_ts''' vs' names'' ps
@@ -1348,11 +1347,8 @@
| Negprem (us, t) =>
let
val (in_ts, out_ts''') = split_smode is us
- (* additional_arguments
- val args = case depth of
- NONE => in_ts
- | SOME (polarity, depth_t) => in_ts @ [HOLogic.mk_not polarity, depth_t]
- *)
+ val in_ts = map (compile_arg compilation_modifiers compfuns additional_arguments
+ thy param_vs iss) in_ts
val u = mk_not compfuns
(compile_expr compilation_modifiers compfuns thy (mode, t) in_ts additional_arguments')
val rest = compile_prems out_ts''' vs' names'' ps
@@ -1361,6 +1357,8 @@
end
| Sidecond t =>
let
+ val t = compile_arg compilation_modifiers compfuns additional_arguments
+ thy param_vs iss t
val rest = compile_prems [] vs' names'' ps;
in
(mk_if compfuns t, rest)
@@ -1374,7 +1372,7 @@
(u, rest)
end
in
- compile_match thy compfuns constr_vs' eqs out_ts''
+ compile_match constr_vs' eqs out_ts''
(mk_bind compfuns (compiled_clause, rest))
end
val prem_t = compile_prems in_ts' param_vs all_vs' moded_ps;
@@ -1463,7 +1461,7 @@
val param_names' = Name.variant_list (param_names @ argnames)
(map (fn i => "p" ^ string_of_int i) (1 upto (length iss)))
val param_vs = map Free (param_names' ~~ Ts1)
- val (params', names) = fold_map (mk_Eval_of NONE) ((params ~~ Ts1) ~~ iss) []
+ val (params', names) = fold_map (mk_Eval_of []) ((params ~~ Ts1) ~~ iss) []
val predpropI = HOLogic.mk_Trueprop (list_comb (pred, param_vs @ args))
val predpropE = HOLogic.mk_Trueprop (list_comb (pred, params' @ args))
val param_eqs = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (param_vs ~~ params')
@@ -1566,7 +1564,7 @@
val (xinoutargs, names) = fold_map mk_vars ((1 upto (length Ts2)) ~~ Ts2) param_names
val (xinout, xargs) = split_list xinoutargs
val (xins, xouts) = pairself flat (split_list xinout)
- val (xparams', names') = fold_map (mk_Eval_of NONE) ((xparams ~~ Ts1) ~~ iss) names
+ val (xparams', names') = fold_map (mk_Eval_of []) ((xparams ~~ Ts1) ~~ iss) names
fun mk_split_lambda [] t = lambda (Free (Name.variant names' "x", HOLogic.unitT)) t
| mk_split_lambda [x] t = lambda x t
| mk_split_lambda xs t =