# HG changeset patch # User bulwahn # Date 1269243013 -3600 # Node ID 3122bdd9527576e8ec4f5268d2c8aa237a2c03e4 # Parent 14a0993fe64b89901575e155400a23cb8ccad244 contextifying the compilation of the predicate compiler diff -r 14a0993fe64b -r 3122bdd95275 src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML --- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML Mon Mar 22 08:30:13 2010 +0100 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML Mon Mar 22 08:30:13 2010 +0100 @@ -325,7 +325,7 @@ else false *) -val is_constr = Code.is_constr; +val is_constr = Code.is_constr o ProofContext.theory_of; fun strip_ex (Const ("Ex", _) $ Abs (x, T, t)) = let diff -r 14a0993fe64b -r 3122bdd95275 src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML --- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Mon Mar 22 08:30:13 2010 +0100 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML Mon Mar 22 08:30:13 2010 +0100 @@ -1027,24 +1027,27 @@ fold_rev (curry Fun) (map (K Output) Ts) Output end -fun is_invertible_function thy (Const (f, _)) = is_constr thy f - | is_invertible_function thy _ = false +fun is_invertible_function ctxt (Const (f, _)) = is_constr ctxt f + | is_invertible_function ctxt _ = false -fun non_invertible_subterms thy (t as Free _) = [] - | non_invertible_subterms thy t = - case (strip_comb t) of (f, args) => - if is_invertible_function thy f then - maps (non_invertible_subterms thy) args +fun non_invertible_subterms ctxt (t as Free _) = [] + | non_invertible_subterms ctxt t = + let + val (f, args) = strip_comb t + in + if is_invertible_function ctxt f then + maps (non_invertible_subterms ctxt) args else [t] + end -fun collect_non_invertible_subterms thy (f as Free _) (names, eqs) = (f, (names, eqs)) - | collect_non_invertible_subterms thy t (names, eqs) = +fun collect_non_invertible_subterms ctxt (f as Free _) (names, eqs) = (f, (names, eqs)) + | collect_non_invertible_subterms ctxt t (names, eqs) = case (strip_comb t) of (f, args) => - if is_invertible_function thy f then + if is_invertible_function ctxt f then let val (args', (names', eqs')) = - fold_map (collect_non_invertible_subterms thy) args (names, eqs) + fold_map (collect_non_invertible_subterms ctxt) args (names, eqs) in (list_comb (f, args'), (names', eqs')) end @@ -1066,18 +1069,21 @@ fun is_possible_output thy vs t = forall (fn t => is_eqT (fastype_of t) andalso forall (member (op =) vs) (term_vs t)) - (non_invertible_subterms thy t) + (non_invertible_subterms (ProofContext.init thy) t) andalso (forall (is_eqT o snd) (inter (fn ((f', _), f) => f = f') vs (Term.add_frees t []))) -fun vars_of_destructable_term thy (Free (x, _)) = [x] - | vars_of_destructable_term thy t = - case (strip_comb t) of (f, args) => - if is_invertible_function thy f then - maps (vars_of_destructable_term thy) args +fun vars_of_destructable_term ctxt (Free (x, _)) = [x] + | vars_of_destructable_term ctxt t = + let + val (f, args) = strip_comb t + in + if is_invertible_function ctxt f then + maps (vars_of_destructable_term ctxt) args else [] + end fun is_constructable thy vs t = forall (member (op =) vs) (term_vs t) @@ -1099,7 +1105,7 @@ SOME ms => SOME (map (fn m => (Context m , [])) ms) | NONE => NONE) -fun derivations_of thy modes vs (Const ("Pair", _) $ t1 $ t2) (Pair (m1, m2)) = +fun derivations_of (thy : theory) modes vs (Const ("Pair", _) $ t1 $ t2) (Pair (m1, m2)) = map_product (fn (m1, mvars1) => fn (m2, mvars2) => (Mode_Pair (m1, m2), union (op =) mvars1 mvars2)) (derivations_of thy modes vs t1 m1) (derivations_of thy modes vs t2 m2) @@ -1215,7 +1221,7 @@ tracing ("modes: " ^ (commas (map (fn (s, ms) => s ^ ": " ^ commas (map (fn (m, r) => string_of_mode m ^ (if r then " random " else " not ")) ms)) modes))) -fun select_mode_prem (mode_analysis_options : mode_analysis_options) thy pol (modes, (pos_modes, neg_modes)) vs ps = +fun select_mode_prem (mode_analysis_options : mode_analysis_options) (thy : theory) pol (modes, (pos_modes, neg_modes)) vs ps = let fun choose_mode_of_prem (Prem t) = partial_hd (sort (deriv_ord2 thy modes t) (all_derivations_of thy pos_modes vs t)) @@ -1246,7 +1252,7 @@ val modes = map (fn (s, ms) => (s, map (fn ((p, m), r) => m) ms)) modes' in (modes, modes) end val (in_ts, out_ts) = split_mode mode ts - val in_vs = maps (vars_of_destructable_term thy) in_ts + val in_vs = maps (vars_of_destructable_term (ProofContext.init thy)) in_ts val out_vs = terms_vs out_ts fun known_vs_after p vs = (case p of Prem t => union (op =) vs (term_vs t) @@ -1509,7 +1515,7 @@ end; (* TODO: uses param_vs -- change necessary for compilation with new modes *) -fun compile_arg compilation_modifiers compfuns additional_arguments thy param_vs iss arg = +fun compile_arg compilation_modifiers compfuns additional_arguments ctxt param_vs iss arg = let fun map_params (t as Free (f, T)) = if member (op =) param_vs f then @@ -1525,11 +1531,11 @@ in map_aterms map_params arg end fun compile_match compilation_modifiers compfuns additional_arguments - param_vs iss thy eqs eqs' out_ts success_t = + param_vs iss ctxt eqs eqs' out_ts success_t = let val eqs'' = maps mk_eq eqs @ eqs' val eqs'' = - map (compile_arg compilation_modifiers compfuns additional_arguments thy param_vs iss) eqs'' + map (compile_arg compilation_modifiers compfuns additional_arguments ctxt param_vs iss) eqs'' val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) []; val name = Name.variant names "x"; val name' = Name.variant (name :: names) "y"; @@ -1539,8 +1545,7 @@ val v = Free (name, T); val v' = Free (name', T); in - lambda v (fst (Datatype.make_case - (ProofContext.init thy) Datatype_Case.Quiet [] v + lambda v (fst (Datatype.make_case ctxt Datatype_Case.Quiet [] v [(HOLogic.mk_tuple out_ts, if null eqs'' then success_t else Const (@{const_name HOL.If}, HOLogic.boolT --> U --> U --> U) $ @@ -1549,25 +1554,25 @@ (v', mk_bot compfuns U')])) end; -fun string_of_tderiv thy (t, deriv) = +fun string_of_tderiv ctxt (t, deriv) = (case (t, deriv) of (t1 $ t2, Mode_App (deriv1, deriv2)) => - string_of_tderiv thy (t1, deriv1) ^ " $ " ^ string_of_tderiv thy (t2, deriv2) + string_of_tderiv ctxt (t1, deriv1) ^ " $ " ^ string_of_tderiv ctxt (t2, deriv2) | (Const ("Pair", _) $ t1 $ t2, Mode_Pair (deriv1, deriv2)) => - "(" ^ string_of_tderiv thy (t1, deriv1) ^ ", " ^ string_of_tderiv thy (t2, deriv2) ^ ")" - | (t, Term Input) => Syntax.string_of_term_global thy t ^ "[Input]" - | (t, Term Output) => Syntax.string_of_term_global thy t ^ "[Output]" - | (t, Context m) => Syntax.string_of_term_global thy t ^ "[" ^ string_of_mode m ^ "]") + "(" ^ string_of_tderiv ctxt (t1, deriv1) ^ ", " ^ string_of_tderiv ctxt (t2, deriv2) ^ ")" + | (t, Term Input) => Syntax.string_of_term ctxt t ^ "[Input]" + | (t, Term Output) => Syntax.string_of_term ctxt t ^ "[Output]" + | (t, Context m) => Syntax.string_of_term ctxt t ^ "[" ^ string_of_mode m ^ "]") -fun compile_expr compilation_modifiers compfuns thy pol (t, deriv) additional_arguments = +fun compile_expr compilation_modifiers compfuns ctxt pol (t, deriv) additional_arguments = let fun expr_of (t, deriv) = (case (t, deriv) of (t, Term Input) => SOME t | (t, Term Output) => NONE | (Const (name, T), Context mode) => - SOME (Const (function_name_of (Comp_Mod.compilation compilation_modifiers) thy name - (pol, mode), + SOME (Const (function_name_of (Comp_Mod.compilation compilation_modifiers) + (ProofContext.theory_of ctxt) name (pol, mode), Comp_Mod.funT_of compilation_modifiers mode T)) | (Free (s, T), Context m) => SOME (Free (s, Comp_Mod.funT_of compilation_modifiers m T)) @@ -1591,19 +1596,19 @@ list_comb (the (expr_of (t, deriv)), additional_arguments) end -fun compile_clause compilation_modifiers compfuns thy all_vs param_vs additional_arguments +fun compile_clause compilation_modifiers compfuns ctxt all_vs param_vs additional_arguments (pol, mode) inp (ts, moded_ps) = let val iss = ho_arg_modes_of mode val compile_match = compile_match compilation_modifiers compfuns - additional_arguments param_vs iss thy + additional_arguments param_vs iss ctxt val (in_ts, out_ts) = split_mode mode ts; val (in_ts', (all_vs', eqs)) = - fold_map (collect_non_invertible_subterms thy) in_ts (all_vs, []); + fold_map (collect_non_invertible_subterms ctxt) in_ts (all_vs, []); fun compile_prems out_ts' vs names [] = let val (out_ts'', (names', eqs')) = - fold_map (collect_non_invertible_subterms thy) out_ts' (names, []); + fold_map (collect_non_invertible_subterms ctxt) out_ts' (names, []); val (out_ts''', (names'', constr_vs)) = fold_map distinct_v out_ts'' (names', map (rpair []) vs); in @@ -1614,7 +1619,7 @@ let val vs' = distinct (op =) (flat (vs :: map term_vs out_ts)); val (out_ts', (names', eqs)) = - fold_map (collect_non_invertible_subterms thy) out_ts (names, []) + fold_map (collect_non_invertible_subterms ctxt) out_ts (names, []) val (out_ts'', (names'', constr_vs')) = fold_map distinct_v out_ts' ((names', map (rpair []) vs)) val mode = head_mode_of deriv @@ -1624,7 +1629,7 @@ Prem t => let val u = - compile_expr compilation_modifiers compfuns thy + compile_expr compilation_modifiers compfuns ctxt pol (t, deriv) additional_arguments' val (_, out_ts''') = split_mode mode (snd (strip_comb t)) val rest = compile_prems out_ts''' vs' names'' ps @@ -1634,7 +1639,7 @@ | Negprem t => let val u = mk_not compfuns - (compile_expr compilation_modifiers compfuns thy + (compile_expr compilation_modifiers compfuns ctxt (not pol) (t, deriv) additional_arguments') val (_, out_ts''') = split_mode mode (snd (strip_comb t)) val rest = compile_prems out_ts''' vs' names'' ps @@ -1644,7 +1649,7 @@ | Sidecond t => let val t = compile_arg compilation_modifiers compfuns additional_arguments - thy param_vs iss t + ctxt param_vs iss t val rest = compile_prems [] vs' names'' ps; in (mk_if compfuns t, rest) @@ -1667,6 +1672,7 @@ fun compile_pred compilation_modifiers thy all_vs param_vs s T (pol, mode) moded_cls = let + val ctxt = ProofContext.init thy val additional_arguments = Comp_Mod.additional_arguments compilation_modifiers (all_vs @ param_vs) val compfuns = Comp_Mod.compfuns compilation_modifiers @@ -1691,14 +1697,15 @@ (fn t as Free (x, _) => if member (op =) param_vs x then NONE else SOME t | t => SOME t)) in_ts val cl_ts = map (compile_clause compilation_modifiers compfuns - thy all_vs param_vs additional_arguments (pol, mode) (HOLogic.mk_tuple in_ts')) moded_cls; + ctxt all_vs param_vs additional_arguments (pol, mode) (HOLogic.mk_tuple in_ts')) moded_cls; val compilation = Comp_Mod.wrap_compilation compilation_modifiers compfuns s T mode additional_arguments (if null cl_ts then mk_bot compfuns (HOLogic.mk_tupleT outTs) else foldr1 (mk_sup compfuns) cl_ts) val fun_const = - Const (function_name_of (Comp_Mod.compilation compilation_modifiers) thy s (pol, mode), funT) + Const (function_name_of (Comp_Mod.compilation compilation_modifiers) + (ProofContext.theory_of ctxt) s (pol, mode), funT) in HOLogic.mk_Trueprop (HOLogic.mk_eq (list_comb (fun_const, in_ts @ additional_arguments), compilation)) @@ -3032,7 +3039,8 @@ (*| Annotated => annotated_comp_modifiers*) | DSeq => dseq_comp_modifiers | Pos_Random_DSeq => pos_random_dseq_comp_modifiers - val t_pred = compile_expr comp_modifiers compfuns thy true (body, deriv) additional_arguments; + val t_pred = compile_expr comp_modifiers compfuns (ProofContext.init thy) + true (body, deriv) additional_arguments; val T_pred = dest_predT compfuns (fastype_of t_pred) val arrange = split_lambda (HOLogic.mk_tuple outargs) output_tuple in