--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Thu Sep 23 10:39:25 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Thu Sep 23 14:50:13 2010 +0200
@@ -1402,7 +1402,7 @@
val modes = map (fn (s, ms) => (s, map (fn ((p, m), r) => m) ms)) modes'
in (modes, modes) end
val (in_ts, out_ts) = split_mode mode ts
- val in_vs = maps (vars_of_destructable_term ctxt) in_ts
+ val in_vs = union (op =) param_vs (maps (vars_of_destructable_term ctxt) in_ts)
val out_vs = terms_vs out_ts
fun known_vs_after p vs = (case p of
Prem t => union (op =) vs (term_vs t)
@@ -1590,80 +1590,54 @@
in (t' $ u', nvs'') end
| distinct_v x nvs = (x, nvs);
-(** specific rpred functions -- move them to the correct place in this file *)
-fun mk_Eval_of additional_arguments ((x, T), NONE) names = (x, names)
- | mk_Eval_of additional_arguments ((x, T), SOME mode) names =
- let
- val Ts = binder_types T
- fun mk_split_lambda [] t = lambda (Free (Name.variant names "x", HOLogic.unitT)) t
- | mk_split_lambda [x] t = lambda x t
- | mk_split_lambda xs t =
- let
- fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t))
- | mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t))
- in
- mk_split_lambda' xs t
- end;
- fun mk_arg (i, T) =
- let
- val vname = Name.variant names ("x" ^ string_of_int i)
- val default = Free (vname, T)
- in
- case AList.lookup (op =) mode i of
- NONE => (([], [default]), [default])
- | SOME NONE => (([default], []), [default])
- | SOME (SOME pis) =>
- case HOLogic.strip_tupleT T of
- [] => error "pair mode but unit tuple" (*(([default], []), [default])*)
- | [_] => error "pair mode but not a tuple" (*(([default], []), [default])*)
- | Ts =>
- let
- val vnames = Name.variant_list names
- (map (fn j => "x" ^ string_of_int i ^ "p" ^ string_of_int j)
- (1 upto length Ts))
- val args = map2 (curry Free) vnames Ts
- fun split_args (i, arg) (ins, outs) =
- if member (op =) pis i then
- (arg::ins, outs)
- else
- (ins, arg::outs)
- val (inargs, outargs) = fold_rev split_args ((1 upto length Ts) ~~ args) ([], [])
- fun tuple args = if null args then [] else [HOLogic.mk_tuple args]
- in ((tuple inargs, tuple outargs), args) end
- end
- val (inoutargs, args) = split_list (map mk_arg (1 upto (length Ts) ~~ Ts))
- val (inargs, outargs) = pairself flat (split_list inoutargs)
- val r = PredicateCompFuns.mk_Eval
- (list_comb (x, inargs @ additional_arguments), HOLogic.mk_tuple outargs)
- val t = fold_rev mk_split_lambda args r
- in
- (t, names)
- end;
+(** specific rpred functions -- move them to the correct place in this file *)
+fun mk_Eval_of (P as (Free (f, _)), T) mode =
+let
+ fun mk_bounds (Type (@{type_name Product_Type.prod}, [T1, T2])) i =
+ let
+ val (bs2, i') = mk_bounds T2 i
+ val (bs1, i'') = mk_bounds T1 i'
+ in
+ (HOLogic.pair_const T1 T2 $ bs1 $ bs2, i'' + 1)
+ end
+ | mk_bounds T i = (Bound i, i + 1)
+ fun mk_prod ((t1, T1), (t2, T2)) = (HOLogic.pair_const T1 T2 $ t1 $ t2, HOLogic.mk_prodT (T1, T2))
+ fun mk_tuple [] = (HOLogic.unit, HOLogic.unitT)
+ | mk_tuple tTs = foldr1 mk_prod tTs;
+ fun mk_split_abs (T as Type (@{type_name Product_Type.prod}, [T1, T2])) t = absdummy (T, HOLogic.split_const (T1, T2, @{typ bool}) $ (mk_split_abs T1 (mk_split_abs T2 t)))
+ | mk_split_abs T t = absdummy (T, t)
+ val args = rev (fst (fold_map mk_bounds (rev (binder_types T)) 0))
+ val (inargs, outargs) = split_mode mode args
+ val (inTs, outTs) = split_map_modeT (fn _ => fn T => (SOME T, NONE)) mode (binder_types T)
+ val inner_term = PredicateCompFuns.mk_Eval (list_comb (P, inargs), fst (mk_tuple (outargs ~~ outTs)))
+in
+ fold_rev mk_split_abs (binder_types T) inner_term
+end
-(* TODO: uses param_vs -- change necessary for compilation with new modes *)
-fun compile_arg compilation_modifiers additional_arguments ctxt param_vs iss arg =
+fun compile_arg compilation_modifiers additional_arguments ctxt param_modes arg =
let
fun map_params (t as Free (f, T)) =
- if member (op =) param_vs f then
- case (AList.lookup (op =) (param_vs ~~ iss) f) of
- SOME is =>
+ (case (AList.lookup (op =) param_modes f) of
+ SOME mode =>
let
- val _ = error "compile_arg: A parameter in a input position -- do we have a test case?"
- val T' = Comp_Mod.funT_of compilation_modifiers is T
- in t(*fst (mk_Eval_of additional_arguments ((Free (f, T'), T), is) [])*) end
- | NONE => t
- else t
+ val T' = Comp_Mod.funT_of compilation_modifiers mode T
+ in
+ mk_Eval_of (Free (f, T'), T) mode
+ end
+ | NONE => t)
| map_params t = t
- in map_aterms map_params arg end
+ in
+ map_aterms map_params arg
+ end
-fun compile_match compilation_modifiers additional_arguments
- param_vs iss ctxt eqs eqs' out_ts success_t =
+fun compile_match compilation_modifiers additional_arguments ctxt param_modes
+ eqs eqs' out_ts success_t =
let
val compfuns = Comp_Mod.compfuns compilation_modifiers
val eqs'' = maps mk_eq eqs @ eqs'
val eqs'' =
- map (compile_arg compilation_modifiers additional_arguments ctxt param_vs iss) eqs''
+ map (compile_arg compilation_modifiers additional_arguments ctxt param_modes) eqs''
val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) [];
val name = Name.variant names "x";
val name' = Name.variant (name :: names) "y";
@@ -1692,12 +1666,12 @@
| (t, Term Output) => Syntax.string_of_term ctxt t ^ "[Output]"
| (t, Context m) => Syntax.string_of_term ctxt t ^ "[" ^ string_of_mode m ^ "]")
-fun compile_expr compilation_modifiers ctxt (t, deriv) additional_arguments =
+fun compile_expr compilation_modifiers ctxt (t, deriv) param_modes additional_arguments =
let
val compfuns = Comp_Mod.compfuns compilation_modifiers
fun expr_of (t, deriv) =
(case (t, deriv) of
- (t, Term Input) => SOME t
+ (t, Term Input) => SOME (compile_arg compilation_modifiers additional_arguments ctxt param_modes t)
| (t, Term Output) => NONE
| (Const (name, T), Context mode) =>
(case alternative_compilation_of ctxt name mode of
@@ -1728,13 +1702,12 @@
list_comb (the (expr_of (t, deriv)), additional_arguments)
end
-fun compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments
- mode inp (in_ts, out_ts) moded_ps =
+fun compile_clause compilation_modifiers ctxt all_vs param_modes additional_arguments
+ inp (in_ts, out_ts) moded_ps =
let
val compfuns = Comp_Mod.compfuns compilation_modifiers
- val iss = ho_arg_modes_of mode (* FIXME! *)
val compile_match = compile_match compilation_modifiers
- additional_arguments param_vs iss ctxt
+ additional_arguments ctxt param_modes
val (in_ts', (all_vs', eqs)) =
fold_map (collect_non_invertible_subterms ctxt) in_ts (all_vs, []);
fun compile_prems out_ts' vs names [] =
@@ -1761,7 +1734,7 @@
Prem t =>
let
val u =
- compile_expr compilation_modifiers ctxt (t, deriv) additional_arguments'
+ compile_expr compilation_modifiers ctxt (t, deriv) param_modes additional_arguments'
val (_, out_ts''') = split_mode mode (snd (strip_comb t))
val rest = compile_prems out_ts''' vs' names'' ps
in
@@ -1772,7 +1745,7 @@
val neg_compilation_modifiers =
negative_comp_modifiers_of compilation_modifiers
val u = mk_not compfuns
- (compile_expr neg_compilation_modifiers ctxt (t, deriv) additional_arguments')
+ (compile_expr neg_compilation_modifiers ctxt (t, deriv) param_modes additional_arguments')
val (_, out_ts''') = split_mode mode (snd (strip_comb t))
val rest = compile_prems out_ts''' vs' names'' ps
in
@@ -1781,7 +1754,7 @@
| Sidecond t =>
let
val t = compile_arg compilation_modifiers additional_arguments
- ctxt param_vs iss t
+ ctxt param_modes t
val rest = compile_prems [] vs' names'' ps;
in
(mk_if compfuns t, rest)
@@ -1797,7 +1770,7 @@
compile_match constr_vs' eqs out_ts''
(mk_bind compfuns (compiled_clause, rest))
end
- val prem_t = compile_prems in_ts' param_vs all_vs' moded_ps;
+ val prem_t = compile_prems in_ts' (map fst param_modes) all_vs' moded_ps;
in
mk_bind compfuns (mk_single compfuns inp, prem_t)
end
@@ -1909,7 +1882,7 @@
| _ => raise Fail "unexpected pattern")
-fun compile_switch compilation_modifiers ctxt all_vs param_vs additional_arguments mode
+fun compile_switch compilation_modifiers ctxt all_vs param_modes additional_arguments mode
in_ts' outTs switch_tree =
let
val compfuns = Comp_Mod.compfuns compilation_modifiers
@@ -1929,8 +1902,8 @@
val in_ts' = map (Pattern.rewrite_term thy (map swap fsubst) []) (map snd pat')
val out_ts' = map (Pattern.rewrite_term thy (map swap fsubst) []) out_ts
in
- compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments
- mode inp (in_ts', out_ts') moded_ps'
+ compile_clause compilation_modifiers ctxt all_vs param_modes additional_arguments
+ inp (in_ts', out_ts') moded_ps'
end
in SOME (foldr1 (mk_sup compfuns) (map compile_clause' moded_clauses)) end
| compile_switch_tree all_vs ctxt_eqs (Node ((position, switched_clauses), left_clauses)) =
@@ -1991,17 +1964,18 @@
(param_vs, (all_vs @ param_vs))
val in_ts' = map_filter (map_filter_prod
(fn t as Free (x, _) => if member (op =) param_vs x then NONE else SOME t | t => SOME t)) in_ts
+ val param_modes = param_vs ~~ ho_arg_modes_of mode
val compilation =
if detect_switches options then
the_default (mk_bot compfuns (HOLogic.mk_tupleT outTs))
- (compile_switch compilation_modifiers ctxt all_vs param_vs additional_arguments
- mode in_ts' outTs (mk_switch_tree ctxt mode moded_cls))
+ (compile_switch compilation_modifiers ctxt all_vs param_modes additional_arguments mode
+ in_ts' outTs (mk_switch_tree ctxt mode moded_cls))
else
let
val cl_ts =
map (fn (ts, moded_prems) =>
- compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments
- mode (HOLogic.mk_tuple in_ts') (split_mode mode ts) moded_prems) moded_cls;
+ compile_clause compilation_modifiers ctxt all_vs param_modes additional_arguments
+ (HOLogic.mk_tuple in_ts') (split_mode mode ts) moded_prems) moded_cls;
in
Comp_Mod.wrap_compilation compilation_modifiers compfuns s T mode additional_arguments
(if null cl_ts then
@@ -3236,7 +3210,7 @@
| Pos_Random_DSeq => pos_random_dseq_comp_modifiers
| New_Pos_Random_DSeq => new_pos_random_dseq_comp_modifiers
val t_pred = compile_expr comp_modifiers ctxt
- (body, deriv) additional_arguments;
+ (body, deriv) [] additional_arguments;
val T_pred = dest_predT compfuns (fastype_of t_pred)
val arrange = split_lambda (HOLogic.mk_tuple outargs) output_tuple
in