# HG changeset patch # User bulwahn # Date 1285246213 -7200 # Node ID 655307cb8489cd5ed236589f1e71633f16e24290 # Parent 7bf0c7f0f24c5be6dab77a68eb7e82a6c7603c55 rewriting function mk_Eval_of in predicate compiler diff -r 7bf0c7f0f24c -r 655307cb8489 src/HOL/Predicate_Compile_Examples/Predicate_Compile_Examples.thy --- a/src/HOL/Predicate_Compile_Examples/Predicate_Compile_Examples.thy Thu Sep 23 10:39:25 2010 +0200 +++ b/src/HOL/Predicate_Compile_Examples/Predicate_Compile_Examples.thy Thu Sep 23 14:50:13 2010 +0200 @@ -522,7 +522,6 @@ thm filter2.equation thm filter2.random_dseq_equation -(* inductive filter3 for P where @@ -530,9 +529,9 @@ code_pred (expected_modes: (o => bool) => i => o => bool, (o => bool) => i => i => bool , (i => bool) => i => o => bool, (i => bool) => i => i => bool) [skip_proof] filter3 . -code_pred [dseq] filter3 . -thm filter3.dseq_equation -*) +code_pred filter3 . +thm filter3.equation + (* inductive filter4 where diff -r 7bf0c7f0f24c -r 655307cb8489 src/HOL/Tools/Predicate_Compile/predicate_compile_compilations.ML --- a/src/HOL/Tools/Predicate_Compile/predicate_compile_compilations.ML Thu Sep 23 10:39:25 2010 +0200 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_compilations.ML Thu Sep 23 14:50:13 2010 +0200 @@ -47,7 +47,7 @@ fun mk_Eval (f, x) = let - val T = fastype_of x + val T = dest_predT (fastype_of f) in Const (@{const_name Predicate.eval}, mk_predT T --> T --> HOLogic.boolT) $ f $ x end; diff -r 7bf0c7f0f24c -r 655307cb8489 src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML --- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Thu Sep 23 10:39:25 2010 +0200 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Thu Sep 23 14:50:13 2010 +0200 @@ -1402,7 +1402,7 @@ val modes = map (fn (s, ms) => (s, map (fn ((p, m), r) => m) ms)) modes' in (modes, modes) end val (in_ts, out_ts) = split_mode mode ts - val in_vs = maps (vars_of_destructable_term ctxt) in_ts + val in_vs = union (op =) param_vs (maps (vars_of_destructable_term ctxt) in_ts) val out_vs = terms_vs out_ts fun known_vs_after p vs = (case p of Prem t => union (op =) vs (term_vs t) @@ -1590,80 +1590,54 @@ in (t' $ u', nvs'') end | distinct_v x nvs = (x, nvs); -(** specific rpred functions -- move them to the correct place in this file *) -fun mk_Eval_of additional_arguments ((x, T), NONE) names = (x, names) - | mk_Eval_of additional_arguments ((x, T), SOME mode) names = - let - val Ts = binder_types T - fun mk_split_lambda [] t = lambda (Free (Name.variant names "x", HOLogic.unitT)) t - | mk_split_lambda [x] t = lambda x t - | mk_split_lambda xs t = - let - fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t)) - | mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t)) - in - mk_split_lambda' xs t - end; - fun mk_arg (i, T) = - let - val vname = Name.variant names ("x" ^ string_of_int i) - val default = Free (vname, T) - in - case AList.lookup (op =) mode i of - NONE => (([], [default]), [default]) - | SOME NONE => (([default], []), [default]) - | SOME (SOME pis) => - case HOLogic.strip_tupleT T of - [] => error "pair mode but unit tuple" (*(([default], []), [default])*) - | [_] => error "pair mode but not a tuple" (*(([default], []), [default])*) - | Ts => - let - val vnames = Name.variant_list names - (map (fn j => "x" ^ string_of_int i ^ "p" ^ string_of_int j) - (1 upto length Ts)) - val args = map2 (curry Free) vnames Ts - fun split_args (i, arg) (ins, outs) = - if member (op =) pis i then - (arg::ins, outs) - else - (ins, arg::outs) - val (inargs, outargs) = fold_rev split_args ((1 upto length Ts) ~~ args) ([], []) - fun tuple args = if null args then [] else [HOLogic.mk_tuple args] - in ((tuple inargs, tuple outargs), args) end - end - val (inoutargs, args) = split_list (map mk_arg (1 upto (length Ts) ~~ Ts)) - val (inargs, outargs) = pairself flat (split_list inoutargs) - val r = PredicateCompFuns.mk_Eval - (list_comb (x, inargs @ additional_arguments), HOLogic.mk_tuple outargs) - val t = fold_rev mk_split_lambda args r - in - (t, names) - end; +(** specific rpred functions -- move them to the correct place in this file *) +fun mk_Eval_of (P as (Free (f, _)), T) mode = +let + fun mk_bounds (Type (@{type_name Product_Type.prod}, [T1, T2])) i = + let + val (bs2, i') = mk_bounds T2 i + val (bs1, i'') = mk_bounds T1 i' + in + (HOLogic.pair_const T1 T2 $ bs1 $ bs2, i'' + 1) + end + | mk_bounds T i = (Bound i, i + 1) + fun mk_prod ((t1, T1), (t2, T2)) = (HOLogic.pair_const T1 T2 $ t1 $ t2, HOLogic.mk_prodT (T1, T2)) + fun mk_tuple [] = (HOLogic.unit, HOLogic.unitT) + | mk_tuple tTs = foldr1 mk_prod tTs; + fun mk_split_abs (T as Type (@{type_name Product_Type.prod}, [T1, T2])) t = absdummy (T, HOLogic.split_const (T1, T2, @{typ bool}) $ (mk_split_abs T1 (mk_split_abs T2 t))) + | mk_split_abs T t = absdummy (T, t) + val args = rev (fst (fold_map mk_bounds (rev (binder_types T)) 0)) + val (inargs, outargs) = split_mode mode args + val (inTs, outTs) = split_map_modeT (fn _ => fn T => (SOME T, NONE)) mode (binder_types T) + val inner_term = PredicateCompFuns.mk_Eval (list_comb (P, inargs), fst (mk_tuple (outargs ~~ outTs))) +in + fold_rev mk_split_abs (binder_types T) inner_term +end -(* TODO: uses param_vs -- change necessary for compilation with new modes *) -fun compile_arg compilation_modifiers additional_arguments ctxt param_vs iss arg = +fun compile_arg compilation_modifiers additional_arguments ctxt param_modes arg = let fun map_params (t as Free (f, T)) = - if member (op =) param_vs f then - case (AList.lookup (op =) (param_vs ~~ iss) f) of - SOME is => + (case (AList.lookup (op =) param_modes f) of + SOME mode => let - val _ = error "compile_arg: A parameter in a input position -- do we have a test case?" - val T' = Comp_Mod.funT_of compilation_modifiers is T - in t(*fst (mk_Eval_of additional_arguments ((Free (f, T'), T), is) [])*) end - | NONE => t - else t + val T' = Comp_Mod.funT_of compilation_modifiers mode T + in + mk_Eval_of (Free (f, T'), T) mode + end + | NONE => t) | map_params t = t - in map_aterms map_params arg end + in + map_aterms map_params arg + end -fun compile_match compilation_modifiers additional_arguments - param_vs iss ctxt eqs eqs' out_ts success_t = +fun compile_match compilation_modifiers additional_arguments ctxt param_modes + eqs eqs' out_ts success_t = let val compfuns = Comp_Mod.compfuns compilation_modifiers val eqs'' = maps mk_eq eqs @ eqs' val eqs'' = - map (compile_arg compilation_modifiers additional_arguments ctxt param_vs iss) eqs'' + map (compile_arg compilation_modifiers additional_arguments ctxt param_modes) eqs'' val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) []; val name = Name.variant names "x"; val name' = Name.variant (name :: names) "y"; @@ -1692,12 +1666,12 @@ | (t, Term Output) => Syntax.string_of_term ctxt t ^ "[Output]" | (t, Context m) => Syntax.string_of_term ctxt t ^ "[" ^ string_of_mode m ^ "]") -fun compile_expr compilation_modifiers ctxt (t, deriv) additional_arguments = +fun compile_expr compilation_modifiers ctxt (t, deriv) param_modes additional_arguments = let val compfuns = Comp_Mod.compfuns compilation_modifiers fun expr_of (t, deriv) = (case (t, deriv) of - (t, Term Input) => SOME t + (t, Term Input) => SOME (compile_arg compilation_modifiers additional_arguments ctxt param_modes t) | (t, Term Output) => NONE | (Const (name, T), Context mode) => (case alternative_compilation_of ctxt name mode of @@ -1728,13 +1702,12 @@ list_comb (the (expr_of (t, deriv)), additional_arguments) end -fun compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments - mode inp (in_ts, out_ts) moded_ps = +fun compile_clause compilation_modifiers ctxt all_vs param_modes additional_arguments + inp (in_ts, out_ts) moded_ps = let val compfuns = Comp_Mod.compfuns compilation_modifiers - val iss = ho_arg_modes_of mode (* FIXME! *) val compile_match = compile_match compilation_modifiers - additional_arguments param_vs iss ctxt + additional_arguments ctxt param_modes val (in_ts', (all_vs', eqs)) = fold_map (collect_non_invertible_subterms ctxt) in_ts (all_vs, []); fun compile_prems out_ts' vs names [] = @@ -1761,7 +1734,7 @@ Prem t => let val u = - compile_expr compilation_modifiers ctxt (t, deriv) additional_arguments' + compile_expr compilation_modifiers ctxt (t, deriv) param_modes additional_arguments' val (_, out_ts''') = split_mode mode (snd (strip_comb t)) val rest = compile_prems out_ts''' vs' names'' ps in @@ -1772,7 +1745,7 @@ val neg_compilation_modifiers = negative_comp_modifiers_of compilation_modifiers val u = mk_not compfuns - (compile_expr neg_compilation_modifiers ctxt (t, deriv) additional_arguments') + (compile_expr neg_compilation_modifiers ctxt (t, deriv) param_modes additional_arguments') val (_, out_ts''') = split_mode mode (snd (strip_comb t)) val rest = compile_prems out_ts''' vs' names'' ps in @@ -1781,7 +1754,7 @@ | Sidecond t => let val t = compile_arg compilation_modifiers additional_arguments - ctxt param_vs iss t + ctxt param_modes t val rest = compile_prems [] vs' names'' ps; in (mk_if compfuns t, rest) @@ -1797,7 +1770,7 @@ compile_match constr_vs' eqs out_ts'' (mk_bind compfuns (compiled_clause, rest)) end - val prem_t = compile_prems in_ts' param_vs all_vs' moded_ps; + val prem_t = compile_prems in_ts' (map fst param_modes) all_vs' moded_ps; in mk_bind compfuns (mk_single compfuns inp, prem_t) end @@ -1909,7 +1882,7 @@ | _ => raise Fail "unexpected pattern") -fun compile_switch compilation_modifiers ctxt all_vs param_vs additional_arguments mode +fun compile_switch compilation_modifiers ctxt all_vs param_modes additional_arguments mode in_ts' outTs switch_tree = let val compfuns = Comp_Mod.compfuns compilation_modifiers @@ -1929,8 +1902,8 @@ val in_ts' = map (Pattern.rewrite_term thy (map swap fsubst) []) (map snd pat') val out_ts' = map (Pattern.rewrite_term thy (map swap fsubst) []) out_ts in - compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments - mode inp (in_ts', out_ts') moded_ps' + compile_clause compilation_modifiers ctxt all_vs param_modes additional_arguments + inp (in_ts', out_ts') moded_ps' end in SOME (foldr1 (mk_sup compfuns) (map compile_clause' moded_clauses)) end | compile_switch_tree all_vs ctxt_eqs (Node ((position, switched_clauses), left_clauses)) = @@ -1991,17 +1964,18 @@ (param_vs, (all_vs @ param_vs)) val in_ts' = map_filter (map_filter_prod (fn t as Free (x, _) => if member (op =) param_vs x then NONE else SOME t | t => SOME t)) in_ts + val param_modes = param_vs ~~ ho_arg_modes_of mode val compilation = if detect_switches options then the_default (mk_bot compfuns (HOLogic.mk_tupleT outTs)) - (compile_switch compilation_modifiers ctxt all_vs param_vs additional_arguments - mode in_ts' outTs (mk_switch_tree ctxt mode moded_cls)) + (compile_switch compilation_modifiers ctxt all_vs param_modes additional_arguments mode + in_ts' outTs (mk_switch_tree ctxt mode moded_cls)) else let val cl_ts = map (fn (ts, moded_prems) => - compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments - mode (HOLogic.mk_tuple in_ts') (split_mode mode ts) moded_prems) moded_cls; + compile_clause compilation_modifiers ctxt all_vs param_modes additional_arguments + (HOLogic.mk_tuple in_ts') (split_mode mode ts) moded_prems) moded_cls; in Comp_Mod.wrap_compilation compilation_modifiers compfuns s T mode additional_arguments (if null cl_ts then @@ -3236,7 +3210,7 @@ | Pos_Random_DSeq => pos_random_dseq_comp_modifiers | New_Pos_Random_DSeq => new_pos_random_dseq_comp_modifiers val t_pred = compile_expr comp_modifiers ctxt - (body, deriv) additional_arguments; + (body, deriv) [] additional_arguments; val T_pred = dest_predT compfuns (fastype_of t_pred) val arrange = split_lambda (HOLogic.mk_tuple outargs) output_tuple in