# HG changeset patch # User bulwahn # Date 1249367696 -7200 # Node ID c2b74affab85d1d7d499e72c0611217d85fede2d # Parent 55166cd57a6d99b88620a8ff5e777bb9be2a58ab imported patch generic compilation of predicate compiler with different monads diff -r 55166cd57a6d -r c2b74affab85 src/HOL/ex/Predicate_Compile.thy --- a/src/HOL/ex/Predicate_Compile.thy Tue Aug 04 08:34:56 2009 +0200 +++ b/src/HOL/ex/Predicate_Compile.thy Tue Aug 04 08:34:56 2009 +0200 @@ -1,5 +1,5 @@ theory Predicate_Compile -imports Complex_Main Lattice_Syntax Code_Eval +imports Complex_Main Lattice_Syntax Code_Eval RPred uses "predicate_compile.ML" begin diff -r 55166cd57a6d -r c2b74affab85 src/HOL/ex/RPred.thy --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/src/HOL/ex/RPred.thy Tue Aug 04 08:34:56 2009 +0200 @@ -0,0 +1,46 @@ +theory RPred +imports Quickcheck Random Predicate +begin + +types 'a rpred = "Random.seed \ ('a Predicate.pred \ Random.seed)" + +section {* The RandomPred Monad *} + +text {* A monad to combine the random state monad and the predicate monad *} + +definition bot :: "'a rpred" + where "bot = Pair (bot_class.bot)" + +definition return :: "'a => 'a rpred" + where "return x = Pair (Predicate.single x)" + +definition bind :: "'a rpred \ ('a \ 'b rpred) \ 'b rpred" (infixl "\=" 60) + where "bind RP f = + (\s. let (P, s') = RP s; + (s1, s2) = Random.split_seed s' + in (Predicate.bind P (%a. fst (f a s1)), s2))" + +definition supp :: "'a rpred \ 'a rpred \ 'a rpred" (infixl "\" 80) +where + "supp RP1 RP2 = (\s. let (P1, s') = RP1 s; (P2, s'') = RP2 s' + in (upper_semilattice_class.sup P1 P2, s''))" + +definition if_rpred :: "bool \ unit rpred" +where + "if_rpred b = (if b then return () else bot)" + +(* Missing a good definition for negation: not_rpred *) + +definition not_rpred :: "unit Predicate.pred \ unit rpred" +where + "not_rpred = Pair o Predicate.not_pred" + +definition lift_pred :: "'a Predicate.pred \ 'a rpred" + where + "lift_pred = Pair" + +definition lift_random :: "(Random.seed \ ('a \ (unit \ term)) \ Random.seed) \ ('a \ (unit \ term)) rpred" + where "lift_random g = (\s. let (v, s') = g s in (Predicate.single v, s'))" + + +end \ No newline at end of file diff -r 55166cd57a6d -r c2b74affab85 src/HOL/ex/predicate_compile.ML --- a/src/HOL/ex/predicate_compile.ML Tue Aug 04 08:34:56 2009 +0200 +++ b/src/HOL/ex/predicate_compile.ML Tue Aug 04 08:34:56 2009 +0200 @@ -7,7 +7,7 @@ signature PREDICATE_COMPILE = sig type mode = int list option list * int list - val add_equations_of: string list -> theory -> theory + val add_equations_of: bool -> string list -> theory -> theory val register_predicate : (thm list * thm * int) -> theory -> theory val is_registered : theory -> string -> bool val fetch_pred_data : theory -> string -> (thm list * thm * int) @@ -34,11 +34,20 @@ val add_equations : string -> theory -> theory val code_pred_intros_attrib : attribute (* used by Quickcheck_Generator *) - val funT_of : mode -> typ -> typ + (*val funT_of : mode -> typ -> typ val mk_if_pred : term -> term - val mk_Eval : term * term -> term + val mk_Eval : term * term -> term*) val mk_tupleT : typ list -> typ - val mk_predT : typ -> typ +(* val mk_predT : typ -> typ *) + (* temporary for compilation *) + datatype indprem = Prem of term list * term | Negprem of term list * term | Sidecond of term; + val prepare_intrs: theory -> string list -> + (string * typ) list * int * string list * string list * (string * mode list) list * + (string * (term list * indprem list) list) list * (string * (int option list * int)) list + val infer_modes : theory -> (string * (int list option list * int list) list) list + -> (string * (int option list * int)) list -> string list + -> (string * (term list * indprem list) list) list -> (string * mode list) list + val split_mode : int list -> term list -> (term list * term list) end; structure Predicate_Compile : PREDICATE_COMPILE = @@ -89,23 +98,39 @@ | dest_tuple (Const (@{const_name Pair}, _) $ t1 $ t2) = t1 :: (dest_tuple t2) | dest_tuple t = [t] +(** data structures for generic compilation for different monads **) +(* maybe rename functions more generic: + mk_predT -> mk_monadT; dest_predT -> dest_monadT + mk_single -> mk_return (?) +*) +datatype compilation_funs = CompilationFuns of { + mk_predT : typ -> typ, + dest_predT : typ -> typ, + mk_bot : typ -> term, + mk_single : term -> term, + mk_bind : term * term -> term, + mk_sup : term * term -> term, + mk_if : term -> term, + mk_not : term -> term +}; + +fun mk_predT (CompilationFuns funs) = #mk_predT funs +fun dest_predT (CompilationFuns funs) = #dest_predT funs +fun mk_bot (CompilationFuns funs) = #mk_bot funs +fun mk_single (CompilationFuns funs) = #mk_single funs +fun mk_bind (CompilationFuns funs) = #mk_bind funs +fun mk_sup (CompilationFuns funs) = #mk_sup funs +fun mk_if (CompilationFuns funs) = #mk_if funs +fun mk_not (CompilationFuns funs) = #mk_not funs + +structure PredicateCompFuns = +struct + fun mk_predT T = Type (@{type_name "Predicate.pred"}, [T]) fun dest_predT (Type (@{type_name "Predicate.pred"}, [T])) = T | dest_predT T = raise TYPE ("dest_predT", [T], []); -fun mk_Enum f = - let val T as Type ("fun", [T', _]) = fastype_of f - in - Const (@{const_name Predicate.Pred}, T --> mk_predT T') $ f - end; - -fun mk_Eval (f, x) = - let val T = fastype_of x - in - Const (@{const_name Predicate.eval}, mk_predT T --> T --> HOLogic.boolT) $ f $ x - end; - fun mk_bot T = Const (@{const_name Orderings.bot}, mk_predT T); fun mk_single t = @@ -120,12 +145,70 @@ val mk_sup = HOLogic.mk_binop @{const_name sup}; -fun mk_if_pred cond = Const (@{const_name Predicate.if_pred}, +fun mk_if cond = Const (@{const_name Predicate.if_pred}, HOLogic.boolT --> mk_predT HOLogic.unitT) $ cond; -fun mk_not_pred t = let val T = mk_predT HOLogic.unitT +fun mk_not t = let val T = mk_predT HOLogic.unitT in Const (@{const_name Predicate.not_pred}, T --> T) $ t end +fun mk_Enum f = + let val T as Type ("fun", [T', _]) = fastype_of f + in + Const (@{const_name Predicate.Pred}, T --> mk_predT T') $ f + end; + +fun mk_Eval (f, x) = + let val T = fastype_of x + in + Const (@{const_name Predicate.eval}, mk_predT T --> T --> HOLogic.boolT) $ f $ x + end; + +val compfuns = CompilationFuns {mk_predT = mk_predT, dest_predT = dest_predT, mk_bot = mk_bot, + mk_single = mk_single, mk_bind = mk_bind, mk_sup = mk_sup, mk_if = mk_if, mk_not = mk_not} + +end; + +structure RPredCompFuns = +struct + +fun mk_rpredT T = + @{typ "Random.seed"} --> HOLogic.mk_prodT ((PredicateCompFuns.mk_predT T), @{typ "Random.seed"}) + +fun dest_rpredT (Type ("fun", [_, + Type (@{type_name "*"}, [Type (@{type_name "Predicate.pred"}, [T]), _])])) = T + | dest_rpredT T = raise TYPE ("dest_rpredT", [T], []); + +fun mk_rpredT T = + @{typ "Random.seed"} --> HOLogic.mk_prodT ((PredicateCompFuns.mk_predT T), @{typ "Random.seed"}) + +fun mk_bot T = Const(@{const_name RPred.bot}, mk_rpredT T) + +fun mk_single t = + let + val T = fastype_of t + in + Const (@{const_name RPred.return}, T --> mk_rpredT T) $ t + end; + +fun mk_bind (x, f) = + let + val T as (Type ("fun", [_, U])) = fastype_of f + in + Const (@{const_name RPred.bind}, fastype_of x --> T --> U) $ x $ f + end + +val mk_sup = HOLogic.mk_binop @{const_name RPred.supp} + +fun mk_if cond = Const (@{const_name RPred.if_rpred}, + HOLogic.boolT --> mk_rpredT HOLogic.unitT) $ cond; + +fun mk_not t = error "Negation is not defined for RPred" + +val compfuns = CompilationFuns {mk_predT = mk_rpredT, dest_predT = dest_rpredT, mk_bot = mk_bot, + mk_single = mk_single, mk_bind = mk_bind, mk_sup = mk_sup, mk_if = mk_if, mk_not = mk_not} + +end; + (* destruction of intro rules *) (* FIXME: look for other place where this functionality was used before *) @@ -148,7 +231,6 @@ cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map string_of_mode ms)) modes)); - datatype predfun_data = PredfunData of { name : string, definition : thm, @@ -197,7 +279,7 @@ Option.map rep_pred_data (try (Graph.get_node (PredData.get thy)) name) fun the_pred_data thy name = case lookup_pred_data thy name - of NONE => error ("No such predicate " ^ quote name) + of NONE => error ("No such predicate " ^ quote name) | SOME data => data; val is_registered = is_some oo lookup_pred_data @@ -225,7 +307,7 @@ (#functions (the_pred_data thy name)) mode) fun the_predfun_data thy name mode = case lookup_predfun_data thy name mode - of NONE => error ("No such mode" ^ string_of_mode mode) + of NONE => error ("No function defined for mode " ^ string_of_mode mode ^ " of predicate " ^ name) | SOME data => data; val predfun_name_of = #name ooo the_predfun_data @@ -236,6 +318,36 @@ val predfun_elim_of = #elim ooo the_predfun_data +(* TODO: maybe join chop nparams and split_mode is +to some function split_mode mode and rename split_mode to split_smode *) +fun split_mode is ts = let + fun split_mode' _ _ [] = ([], []) + | split_mode' is i (t::ts) = (if i mem is then apfst else apsnd) (cons t) + (split_mode' is (i+1) ts) +in split_mode' is 1 ts end + +(* Remark: types of param_funT_of and funT_of are swapped - which is the more +canonical order? *) +(* maybe remove param_funT_of completely - by using funT_of *) +fun param_funT_of compfuns T NONE = T + | param_funT_of compfuns T (SOME mode) = let + val Ts = binder_types T; + val (Us1, Us2) = split_mode mode Ts + in Us1 ---> (mk_predT compfuns (mk_tupleT Us2)) end; + +fun funT_of compfuns (iss, is) T = let + val Ts = binder_types T + val (paramTs, argTs) = chop (length iss) Ts + val paramTs' = map2 (fn SOME is => funT_of compfuns ([], is) | NONE => I) iss paramTs + val (inargTs, outargTs) = split_mode is argTs + in + (paramTs' @ inargTs) ---> (mk_predT compfuns (mk_tupleT outargTs)) + end; + +(* TODO: duplicate code in funT_of and this function *) +fun mk_predfun_of thy compfuns (name, T) mode = + Const (predfun_name_of thy name mode, funT_of compfuns mode T) + fun print_stored_rules thy = let val preds = (Graph.keys o PredData.get) thy @@ -304,6 +416,25 @@ elimrule end; +(* special case: predicate with no introduction rule *) +fun noclause thy predname = let + val T = (Logic.unvarifyT o Sign.the_const_type thy) predname + val Ts = binder_types T + val names = Name.variant_list [] + (map (fn i => "x" ^ (string_of_int i)) (1 upto (length Ts))) + val vs = map2 (curry Free) names Ts + val clausehd = HOLogic.mk_Trueprop (list_comb(Const (predname, T), vs)) + val intro_t = Logic.mk_implies (@{prop False}, clausehd) + val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT)) + val elim_t = Logic.list_implies ([clausehd, Logic.mk_implies (@{prop False}, P)], P) + val intro = Goal.prove (ProofContext.init thy) names [] intro_t + (fn {...} => etac @{thm FalseE} 1) + val elim = Goal.prove (ProofContext.init thy) ("P" :: names) [] elim_t + (fn {...} => etac (the_elim_of thy predname) 1) +in + ([intro], elim, 0) +end + fun fetch_pred_data thy name = case try (Inductive.the_inductive (ProofContext.init thy)) name of SOME (info as (_, result)) => @@ -315,7 +446,7 @@ val intros = map (preprocess_intro thy) (filter is_intro_of (#intrs result)) val elim = nth (#elims result) (find_index (fn s => s = name) (#names (fst info))) val nparams = length (Inductive.params_of (#raw_induct result)) - in (intros, elim, nparams) end + in if null intros then noclause thy name else (intros, elim, nparams) end | NONE => error ("No such predicate: " ^ quote name) (* updaters *) @@ -362,10 +493,9 @@ fun set_nparams name nparams = let fun set (intros, elim, _ ) = (intros, elim, nparams) in PredData.map (Graph.map_node name (map_pred_data (apfst set))) end - + fun register_predicate (intros, elim, nparams) thy = let val (name, _) = dest_Const (fst (strip_intro_concl nparams (prop_of (hd intros)))) - fun set _ = (intros, SOME elim, nparams) in PredData.map (Graph.new_node (name, mk_pred_data ((intros, SOME elim, nparams), [])) #> fold Graph.add_edge (map (pair name) (depending_preds_of thy intros))) thy @@ -400,12 +530,6 @@ fun term_vTs tm = fold_aterms (fn Free xT => cons xT | _ => I) tm []; -fun get_args is ts = let - fun get_args' _ _ [] = ([], []) - | get_args' is i (t::ts) = (if i mem is then apfst else apsnd) (cons t) - (get_args' is (i+1) ts) -in get_args' is 1 ts end - (*FIXME this function should not be named merge... make it local instead*) fun merge xs [] = xs | merge [] ys = ys @@ -438,8 +562,8 @@ error ("Too few arguments for inductive predicate " ^ name) else chop (length iss) args; val k = length args2; - val perm = map (fn i => (find_index (fn t => t = Bound (b - i)) args2) + 1) - (1 upto b) + val perm = map (fn i => (find_index_eq (Bound (b - i)) args2) + 1) + (1 upto b) val partial_mode = (1 upto k) \\ perm in if not (partial_mode subset is) then [] else @@ -495,11 +619,24 @@ datatype indprem = Prem of term list * term | Negprem of term list * term | Sidecond of term; +fun print_clausess thy clausess = + let + val _ = Output.tracing "function print_clauses" + fun print_prem (Prem (ts, p)) = Syntax.string_of_term_global thy (list_comb (p, ts)) + | print_prem _ = error "print_clausess: unimplemented" + fun print_clause pred (ts, prems) = + (space_implode " --> " (map print_prem prems)) ^ " --> " ^ pred ^ " " + ^ (space_implode " " (map (Syntax.string_of_term_global thy) ts)) + fun print_clauses (pred, clauses) = + "clauses of " ^ pred ^ ": " ^ cat_lines (map (print_clause pred) clauses) + val _ = Output.tracing (cat_lines (map print_clauses clausess)) + in () end; + fun select_mode_prem thy modes vs ps = find_first (is_some o snd) (ps ~~ map (fn Prem (us, t) => find_first (fn Mode (_, is, _) => let - val (in_ts, out_ts) = get_args is us; + val (in_ts, out_ts) = split_mode is us; val (out_ts', in_ts') = List.partition (is_constrt thy) out_ts; val vTs = maps term_vTs out_ts'; val dupTs = map snd (duplicates (op =) vTs) @ @@ -521,7 +658,7 @@ | Sidecond t => if term_vs t subset vs then SOME (Mode (([], []), [], [])) else NONE ) ps); - + fun check_mode_clause thy param_vs modes (iss, is) (ts, ps) = let val modes' = modes @ List.mapPartial @@ -533,7 +670,7 @@ | SOME (x, _) => check_mode_prems (case x of Prem (us, _) => vs union terms_vs us | _ => vs) (filter_out (equal x) ps)) - val (in_ts, in_ts') = List.partition (is_constrt thy) (fst (get_args is ts)); + val (in_ts, in_ts') = List.partition (is_constrt thy) (fst (split_mode is ts)); val in_vs = terms_vs in_ts; val concl_vs = terms_vs ts val _ = Output.tracing ("ts :" ^ (commas (map (Syntax.string_of_term_global thy) ts))) @@ -543,7 +680,6 @@ (case check_mode_prems (param_vs union in_vs) ps of NONE => false | SOME vs => concl_vs subset vs) - in ret end; @@ -568,27 +704,8 @@ | SOME k' => map SOME (subsets 1 k')) ks), subsets 1 k))) arities); - (* term construction *) -(* Remark: types of param_funT_of and funT_of are swapped - which is the more -canonical order? *) -fun param_funT_of T NONE = T - | param_funT_of T (SOME mode) = let - val Ts = binder_types T; - val (Us1, Us2) = get_args mode Ts - in Us1 ---> (mk_predT (mk_tupleT Us2)) end; - -fun funT_of (iss, is) T = let - val Ts = binder_types T - val (paramTs, argTs) = chop (length iss) Ts - val paramTs' = map2 (fn SOME is => funT_of ([], is) | NONE => I) iss paramTs - val (inargTs, outargTs) = get_args is argTs - in - (paramTs' @ inargTs) ---> (mk_predT (mk_tupleT outargTs)) - end; - - fun mk_v (names, vs) s T = (case AList.lookup (op =) vs s of NONE => (Free (s, T), (names, (s, [])::vs)) | SOME xs => @@ -607,7 +724,7 @@ in (t' $ u', nvs'') end | distinct_v x nvs = (x, nvs); -fun compile_match thy eqs eqs' out_ts success_t = +fun compile_match thy compfuns eqs eqs' out_ts success_t = let val eqs'' = maps mk_eq eqs @ eqs' val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) []; @@ -615,7 +732,7 @@ val name' = Name.variant (name :: names) "y"; val T = mk_tupleT (map fastype_of out_ts); val U = fastype_of success_t; - val U' = dest_predT U; + val U' = dest_predT compfuns U; val v = Free (name, T); val v' = Free (name', T); in @@ -625,8 +742,8 @@ if null eqs'' then success_t else Const (@{const_name HOL.If}, HOLogic.boolT --> U --> U --> U) $ foldr1 HOLogic.mk_conj eqs'' $ success_t $ - mk_bot U'), - (v', mk_bot U')])) + mk_bot compfuns U'), + (v', mk_bot compfuns U')])) end; (*FIXME function can be removed*) @@ -639,81 +756,78 @@ in fold_rev lambda vs (f (list_comb (t, vs))) end; - - - -fun compile_param_ext thy modes (NONE, t) = t - | compile_param_ext thy modes (m as SOME (Mode ((iss, is'), is, ms)), t) = +(* +fun compile_param_ext thy compfuns modes (NONE, t) = t + | compile_param_ext thy compfuns modes (m as SOME (Mode ((iss, is'), is, ms)), t) = let val (vs, u) = strip_abs t - val (ivs, ovs) = get_args is vs + val (ivs, ovs) = split_mode is vs val (f, args) = strip_comb u val (params, args') = chop (length ms) args - val (inargs, outargs) = get_args is' args' + val (inargs, outargs) = split_mode is' args' val b = length vs - val perm = map (fn i => find_index (fn t => t = Bound (b - i)) args' + 1) (1 upto b) + val perm = map (fn i => (find_index_eq (Bound (b - i)) args') + 1) (1 upto b) val outp_perm = - snd (get_args is perm) + snd (split_mode is perm) |> map (fn i => i - length (filter (fn x => x < i) is')) - val names = [] (* TODO *) + val names = [] -- TODO val out_names = Name.variant_list names (replicate (length outargs) "x") val f' = case f of Const (name, T) => if AList.defined op = modes name then - Const (predfun_name_of thy name (iss, is'), funT_of (iss, is') T) + mk_predfun_of thy compfuns (name, T) (iss, is') else error "compile param: Not an inductive predicate with correct mode" - | Free (name, T) => Free (name, param_funT_of T (SOME is')) - val outTs = dest_tupleT (dest_predT (body_type (fastype_of f'))) + | Free (name, T) => Free (name, param_funT_of compfuns T (SOME is')) + val outTs = dest_tupleT (dest_predT compfuns (body_type (fastype_of f'))) val out_vs = map Free (out_names ~~ outTs) val params' = map (compile_param thy modes) (ms ~~ params) val f_app = list_comb (f', params' @ inargs) - val single_t = (mk_single (mk_tuple (map (fn i => nth out_vs (i - 1)) outp_perm))) - val match_t = compile_match thy [] [] out_vs single_t + val single_t = (mk_single compfuns (mk_tuple (map (fn i => nth out_vs (i - 1)) outp_perm))) + val match_t = compile_match thy compfuns [] [] out_vs single_t in list_abs (ivs, - mk_bind (f_app, match_t)) + mk_bind compfuns (f_app, match_t)) end - | compile_param_ext _ _ _ = error "compile params" + | compile_param_ext _ _ _ _ = error "compile params" +*) -and compile_param thy modes (NONE, t) = t - | compile_param thy modes (m as SOME (Mode ((iss, is'), is, ms)), t) = +fun compile_param thy compfuns modes (NONE, t) = t + | compile_param thy compfuns modes (m as SOME (Mode ((iss, is'), is, ms)), t) = (* (case t of Abs _ => error "compile_param: Invalid term" *) (* compile_param_ext thy modes (m, t) *) (* | _ => let *) let val (f, args) = strip_comb (Envir.eta_contract t) val (params, args') = chop (length ms) args - val params' = map (compile_param thy modes) (ms ~~ params) + val params' = map (compile_param thy compfuns modes) (ms ~~ params) val f' = case f of Const (name, T) => if AList.defined op = modes name then - Const (predfun_name_of thy name (iss, is'), funT_of (iss, is') T) + mk_predfun_of thy compfuns (name, T) (iss, is') else error "compile param: Not an inductive predicate with correct mode" - | Free (name, T) => Free (name, param_funT_of T (SOME is')) + | Free (name, T) => Free (name, param_funT_of compfuns T (SOME is')) in list_comb (f', params' @ args') end - | compile_param _ _ _ = error "compile params" + | compile_param _ _ _ _ = error "compile params" -fun compile_expr thy modes (SOME (Mode (mode, is, ms)), t) = +fun compile_expr thy compfuns modes (SOME (Mode (mode, is, ms)), t) = (case strip_comb t of (Const (name, T), params) => if AList.defined op = modes name then let - val (Ts, Us) = get_args is - (curry Library.drop (length ms) (fst (strip_type T))) - val params' = map (compile_param thy modes) (ms ~~ params) - in list_comb (Const (predfun_name_of thy name mode, ((map fastype_of params') @ Ts) ---> - mk_predT (mk_tupleT Us)), params') + val params' = map (compile_param thy compfuns modes) (ms ~~ params) + in + list_comb (mk_predfun_of thy compfuns (name, T) mode, params') end else error "not a valid inductive expression" | (Free (name, T), args) => (*if name mem param_vs then *) (* Higher order mode call *) - let val r = Free (name, param_funT_of T (SOME is)) + let val r = Free (name, param_funT_of compfuns T (SOME is)) in list_comb (r, args) end) - | compile_expr _ _ _ = error "not a valid inductive expression" + | compile_expr _ _ _ _ = error "not a valid inductive expression" -fun compile_clause thy all_vs param_vs modes (iss, is) (ts, ps) inp = +fun compile_clause thy compfuns all_vs param_vs modes (iss, is) (ts, ps) inp = let val modes' = modes @ List.mapPartial (fn (_, NONE) => NONE | (v, SOME js) => SOME (v, [([], js)])) @@ -725,7 +839,7 @@ val v = Free (s, fastype_of t) in (v, (s::names, HOLogic.mk_eq (v, t)::eqs)) end; - val (in_ts, out_ts) = get_args is ts; + val (in_ts, out_ts) = split_mode is ts; val (in_ts', (all_vs', eqs)) = fold_map check_constrt in_ts (all_vs, []); @@ -736,8 +850,8 @@ val (out_ts''', (names'', constr_vs)) = fold_map distinct_v out_ts'' (names', map (rpair []) vs); in - compile_match thy constr_vs (eqs @ eqs') out_ts''' - (mk_single (mk_tuple out_ts)) + compile_match thy compfuns constr_vs (eqs @ eqs') out_ts''' + (mk_single compfuns (mk_tuple out_ts)) end | compile_prems out_ts vs names ps = let @@ -752,59 +866,57 @@ val (compiled_clause, rest) = case p of Prem (us, t) => let - val (in_ts, out_ts''') = get_args js us; - val u = list_comb (compile_expr thy modes (mode, t), in_ts) + val (in_ts, out_ts''') = split_mode js us; + val u = list_comb (compile_expr thy compfuns modes (mode, t), in_ts) val rest = compile_prems out_ts''' vs' names'' ps' in (u, rest) end | Negprem (us, t) => let - val (in_ts, out_ts''') = get_args js us - val u = list_comb (compile_expr thy modes (mode, t), in_ts) + val (in_ts, out_ts''') = split_mode js us + val u = list_comb (compile_expr thy compfuns modes (mode, t), in_ts) val rest = compile_prems out_ts''' vs' names'' ps' in - (mk_not_pred u, rest) + (mk_not compfuns u, rest) end | Sidecond t => let val rest = compile_prems [] vs' names'' ps'; in - (mk_if_pred t, rest) + (mk_if compfuns t, rest) end in - compile_match thy constr_vs' eqs out_ts'' - (mk_bind (compiled_clause, rest)) + compile_match thy compfuns constr_vs' eqs out_ts'' + (mk_bind compfuns (compiled_clause, rest)) end val prem_t = compile_prems in_ts' param_vs all_vs' ps; in - mk_bind (mk_single inp, prem_t) + mk_bind compfuns (mk_single compfuns inp, prem_t) end -fun compile_pred thy all_vs param_vs modes s T cls mode = +fun compile_pred thy compfuns all_vs param_vs modes s T cls mode = let val Ts = binder_types T; val (Ts1, Ts2) = chop (length param_vs) Ts; - val Ts1' = map2 param_funT_of Ts1 (fst mode) - val (Us1, Us2) = get_args (snd mode) Ts2; + val Ts1' = map2 (param_funT_of compfuns) Ts1 (fst mode) + val (Us1, _) = split_mode (snd mode) Ts2; val xnames = Name.variant_list param_vs (map (fn i => "x" ^ string_of_int i) (snd mode)); val xs = map2 (fn s => fn T => Free (s, T)) xnames Us1; val cl_ts = - map (fn cl => compile_clause thy + map (fn cl => compile_clause thy compfuns all_vs param_vs modes mode cl (mk_tuple xs)) cls; - val mode_id = predfun_name_of thy s mode in HOLogic.mk_Trueprop (HOLogic.mk_eq - (list_comb (Const (mode_id, (Ts1' @ Us1) ---> - mk_predT (mk_tupleT Us2)), + (list_comb (mk_predfun_of thy compfuns (s, T) mode, map2 (fn s => fn T => Free (s, T)) param_vs Ts1' @ xs), - foldr1 mk_sup cl_ts)) + foldr1 (mk_sup compfuns) cl_ts)) end; -fun compile_preds thy all_vs param_vs modes preds = +fun compile_preds thy compfuns all_vs param_vs modes preds = map (fn (s, (T, cls)) => - map (compile_pred thy all_vs param_vs modes s T cls) + map (compile_pred thy compfuns all_vs param_vs modes s T cls) ((the o AList.lookup (op =) modes) s)) preds; @@ -812,7 +924,6 @@ val HOL_basic_ss' = HOL_basic_ss setSolver (mk_solver "all_tac_solver" (fn _ => fn _ => all_tac)) - (* Definition of executable functions and their intro and elim rules *) fun print_arities arities = tracing ("Arities:\n" ^ @@ -827,8 +938,8 @@ val argnames = Name.variant_list names (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts))); val args = map Free (argnames ~~ Ts) - val (inargs, outargs) = get_args mode args - val r = mk_Eval (list_comb (x, inargs), mk_tuple outargs) + val (inargs, outargs) = split_mode mode args + val r = PredicateCompFuns.mk_Eval (list_comb (x, inargs), mk_tuple outargs) val t = fold_rev lambda args r in (t, argnames @ names) @@ -841,10 +952,10 @@ val argnames = Name.variant_list [] (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts))); val (Ts1, Ts2) = chop nparams Ts; - val Ts1' = map2 param_funT_of Ts1 (fst mode) + val Ts1' = map2 (param_funT_of (PredicateCompFuns.compfuns)) Ts1 (fst mode) val args = map Free (argnames ~~ (Ts1' @ Ts2)) val (params, io_args) = chop nparams args - val (inargs, outargs) = get_args (snd mode) io_args + val (inargs, outargs) = split_mode (snd mode) io_args val param_names = Name.variant_list argnames (map (fn i => "p" ^ string_of_int i) (1 upto nparams)) val param_vs = map Free (param_names ~~ Ts1) @@ -853,9 +964,9 @@ val predpropE = HOLogic.mk_Trueprop (list_comb (pred, params' @ io_args)) val param_eqs = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (param_vs ~~ params') val funargs = params @ inargs - val funpropE = HOLogic.mk_Trueprop (mk_Eval (list_comb (funtrm, funargs), + val funpropE = HOLogic.mk_Trueprop (PredicateCompFuns.mk_Eval (list_comb (funtrm, funargs), if null outargs then Free("y", HOLogic.unitT) else mk_tuple outargs)) - val funpropI = HOLogic.mk_Trueprop (mk_Eval (list_comb (funtrm, funargs), + val funpropI = HOLogic.mk_Trueprop (PredicateCompFuns.mk_Eval (list_comb (funtrm, funargs), mk_tuple outargs)) val introtrm = Logic.list_implies (predpropI :: param_eqs, funpropI) val _ = tracing (Syntax.string_of_term_global thy introtrm) @@ -870,52 +981,78 @@ (introthm, elimthm) end; +fun create_constname_of_mode thy name mode = + let + fun string_of_mode mode = if null mode then "0" + else space_implode "_" (map string_of_int mode) + fun string_of_HOmode m s = case m of NONE => s | SOME mode => s ^ "__" ^ (string_of_mode mode) + val HOmode = fold string_of_HOmode (fst mode) "" + in + (Sign.full_bname thy (Long_Name.base_name name)) ^ + (if HOmode = "" then "_" else HOmode ^ "___") ^ (string_of_mode (snd mode)) + end; + fun create_definitions preds nparams (name, modes) thy = let val _ = tracing "create definitions" + val compfuns = PredicateCompFuns.compfuns val T = AList.lookup (op =) preds name |> the fun create_definition mode thy = let - fun string_of_mode mode = if null mode then "0" - else space_implode "_" (map string_of_int mode) - val HOmode = let - fun string_of_HOmode m s = case m of NONE => s | SOME mode => s ^ "__" ^ (string_of_mode mode) - in (fold string_of_HOmode (fst mode) "") end; - val mode_id = name ^ (if HOmode = "" then "_" else HOmode ^ "___") - ^ (string_of_mode (snd mode)) + val mode_cname = create_constname_of_mode thy name mode + val mode_cbasename = Long_Name.base_name mode_cname val Ts = binder_types T; val (Ts1, Ts2) = chop nparams Ts; - val Ts1' = map2 param_funT_of Ts1 (fst mode) - val (Us1, Us2) = get_args (snd mode) Ts2; + val Ts1' = map2 (param_funT_of compfuns) Ts1 (fst mode) + val (Us1, Us2) = split_mode (snd mode) Ts2; + val funT = (Ts1' @ Us1) ---> (mk_predT compfuns (mk_tupleT Us2)) val names = Name.variant_list [] (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts))); val xs = map Free (names ~~ (Ts1' @ Ts2)); val (xparams, xargs) = chop nparams xs; val (xparams', names') = fold_map mk_Eval_of ((xparams ~~ Ts1) ~~ (fst mode)) names - val (xins, xouts) = get_args (snd mode) xargs; + val (xins, xouts) = split_mode (snd mode) xargs; fun mk_split_lambda [] t = lambda (Free (Name.variant names' "x", HOLogic.unitT)) t | mk_split_lambda [x] t = lambda x t | mk_split_lambda xs t = let fun mk_split_lambda' (x::y::[]) t = HOLogic.mk_split (lambda x (lambda y t)) | mk_split_lambda' (x::xs) t = HOLogic.mk_split (lambda x (mk_split_lambda' xs t)) in mk_split_lambda' xs t end; - val predterm = mk_Enum (mk_split_lambda xouts (list_comb (Const (name, T), xparams' @ xargs))) - val funT = (Ts1' @ Us1) ---> (mk_predT (mk_tupleT Us2)) - val mode_id = Sign.full_bname thy (Long_Name.base_name mode_id) - val lhs = list_comb (Const (mode_id, funT), xparams @ xins) + val predterm = PredicateCompFuns.mk_Enum (mk_split_lambda xouts (list_comb (Const (name, T), xparams' @ xargs))) + val lhs = list_comb (Const (mode_cname, funT), xparams @ xins) val def = Logic.mk_equals (lhs, predterm) val ([definition], thy') = thy |> - Sign.add_consts_i [(Binding.name (Long_Name.base_name mode_id), funT, NoSyn)] |> - PureThy.add_defs false [((Binding.name (Long_Name.base_name mode_id ^ "_def"), def), [])] - val (intro, elim) = create_intro_elim_rule nparams mode definition mode_id funT (Const (name, T)) thy' - in thy' |> add_predfun name mode (mode_id, definition, intro, elim) - |> PureThy.store_thm (Binding.name (Long_Name.base_name mode_id ^ "I"), intro) |> snd - |> PureThy.store_thm (Binding.name (Long_Name.base_name mode_id ^ "E"), elim) |> snd + Sign.add_consts_i [(Binding.name mode_cbasename, funT, NoSyn)] |> + PureThy.add_defs false [((Binding.name (mode_cbasename ^ "_def"), def), [])] + val (intro, elim) = create_intro_elim_rule nparams mode definition mode_cname funT (Const (name, T)) thy' + in thy' |> add_predfun name mode (mode_cname, definition, intro, elim) + |> PureThy.store_thm (Binding.name (mode_cbasename ^ "I"), intro) |> snd + |> PureThy.store_thm (Binding.name (mode_cbasename ^ "E"), elim) |> snd |> Theory.checkpoint end; in fold create_definition modes thy end; - + +(* TODO: use own theory datastructure for rpred *) +fun rpred_create_definitions preds nparams (name, modes) thy = + let + val T = AList.lookup (op =) preds name |> the + fun create_definition mode thy = + let + val mode_cname = create_constname_of_mode thy name mode + val Ts = binder_types T; + val (Ts1, Ts2) = chop nparams Ts; + val Ts1' = map2 (param_funT_of RPredCompFuns.compfuns) Ts1 (fst mode) + val (Us1, Us2) = split_mode (snd mode) Ts2; + val funT = (Ts1' @ Us1) ---> (RPredCompFuns.mk_rpredT (mk_tupleT Us2)) + in + thy |> Sign.add_consts_i [(Binding.name (Long_Name.base_name mode_cname), funT, NoSyn)] + |> add_predfun name mode (mode_cname, @{thm refl}, @{thm refl}, @{thm refl}) + end; + in + fold create_definition modes thy + end; + (**************************************************************************************) (* Proving equivalence of term *) @@ -973,7 +1110,7 @@ (Const (name, T), args) => if AList.defined op = modes name then (let val introrule = predfun_intro_of thy name mode - (*val (in_args, out_args) = get_args is us + (*val (in_args, out_args) = split_mode is us val (pred, rargs) = strip_comb (HOLogic.dest_Trueprop (hd (Logic.strip_imp_prems (prop_of introrule)))) val nparams = length ms (* get_nparams thy (fst (dest_Const pred)) *) @@ -1052,7 +1189,7 @@ val v = Free (s, fastype_of t) in ((s::names, HOLogic.mk_eq (v, t)::eqs), v) end; - val (in_ts, clause_out_ts) = get_args is ts; + val (in_ts, clause_out_ts) = split_mode is ts; val ((all_vs', eqs), in_ts') = (*FIXME*) Library.foldl_map check_constrt ((all_vs, []), in_ts); fun prove_prems out_ts vs [] = @@ -1069,7 +1206,7 @@ val rps' = filter_out (equal p) rps; val rest_tac = (case p of Prem (us, t) => let - val (in_ts, out_ts''') = get_args js us + val (in_ts, out_ts''') = split_mode js us val rec_tac = prove_prems out_ts''' vs' rps' in print_tac "before clause:" @@ -1081,7 +1218,7 @@ end | Negprem (us, t) => let - val (in_ts, out_ts''') = get_args js us + val (in_ts, out_ts''') = split_mode js us val rec_tac = prove_prems out_ts''' vs' rps' val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE) val (_, params) = strip_comb t @@ -1137,9 +1274,7 @@ THEN print_tac "proved one direction" end; -(*******************************************************************************************************) -(* Proof in the other direction ************************************************************************) -(*******************************************************************************************************) +(** Proof in the other direction **) fun prove_match2 thy out_ts = let fun split_term_tac (Free _) = all_tac @@ -1237,7 +1372,7 @@ |> preprocess_intro thy |> (fn thm => hd (ind_set_codegen_preproc thy [thm])) (* FIXME preprocess |> Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}]) *) - val (in_ts, clause_out_ts) = get_args is ts; + val (in_ts, clause_out_ts) = split_mode is ts; val ((all_vs', eqs), in_ts') = (*FIXME*) Library.foldl_map check_constrt ((all_vs, []), in_ts); fun prove_prems2 out_ts vs [] = @@ -1263,14 +1398,14 @@ val ps' = filter_out (equal p) ps; val rest_tac = (case p of Prem (us, t) => let - val (in_ts, out_ts''') = get_args js us + val (in_ts, out_ts''') = split_mode js us val rec_tac = prove_prems2 out_ts''' vs' ps' in (prove_expr2 thy modes (mode, t)) THEN rec_tac end | Negprem (us, t) => let - val (in_ts, out_ts''') = get_args js us + val (in_ts, out_ts''') = split_mode js us val rec_tac = prove_prems2 out_ts''' vs' ps' val name = (case strip_comb t of (Const (c, _), _) => SOME c | _ => NONE) val (_, params) = strip_comb t @@ -1317,6 +1452,8 @@ THEN (EVERY (map prove_clause (clauses ~~ (1 upto (length clauses))))) end; +(** proof procedure **) + fun prove_pred thy all_vs param_vs modes clauses (((pred, T), mode), t) = let val ctxt = ProofContext.init thy val clauses' = the (AList.lookup (op =) clauses pred) @@ -1335,31 +1472,19 @@ fun prove_preds thy all_vs param_vs modes clauses pmts = map (prove_pred thy all_vs param_vs modes clauses) pmts -(* special case: inductive predicate with no clauses *) -fun noclause (predname, T) thy = let - val Ts = binder_types T - val names = Name.variant_list [] - (map (fn i => "x" ^ (string_of_int i)) (1 upto (length Ts))) - val vs = map2 (curry Free) names Ts - val clausehd = HOLogic.mk_Trueprop (list_comb(Const (predname, T), vs)) - val intro_t = Logic.mk_implies (@{prop False}, clausehd) - val P = HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT)) - val elim_t = Logic.list_implies ([clausehd, Logic.mk_implies (@{prop False}, P)], P) - val intro = Goal.prove (ProofContext.init thy) names [] intro_t - (fn {...} => etac @{thm FalseE} 1) - val elim = Goal.prove (ProofContext.init thy) ("P" :: names) [] elim_t - (fn {...} => etac (the_elim_of thy predname) 1) -in - add_intro intro thy - |> set_elim elim -end +fun rpred_prove_preds thy pmts = + let + fun prove_pred (((pred, T), mode), t) = + let + val _ = Output.tracing ("prove_preds:" ^ Syntax.string_of_term_global thy t) + in SkipProof.make_thm thy t end + in + map prove_pred pmts + end fun prepare_intrs thy prednames = let - (* FIXME: preprocessing moved to fetch_pred_data *) - val intrs = map (preprocess_intro thy) (maps (intros_of thy) prednames) - |> ind_set_codegen_preproc thy (*FIXME preprocessor - |> map (Simplifier.full_simplify (HOL_basic_ss addsimps [@ {thm Predicate.memb_code}]))*) + val intrs = maps (intros_of thy) prednames |> map (Logic.unvarify o prop_of) val nparams = nparams_of thy (hd prednames) val preds = distinct (op =) (map (dest_Const o fst o (strip_intro_concl nparams)) intrs) @@ -1406,33 +1531,53 @@ fun add (key, value) table = AList.update op = (key, these (AList.lookup op = table key) @ [value]) table in fold add kvs [] end; - + +fun add_generators_to_clauses thy all_vs clauses = + let + val _ = Output.tracing ("all_vs:" ^ commas all_vs) + in [clauses] end; + (* main function *) -fun add_equations_of prednames thy = +fun add_equations_of rpred prednames thy = let val _ = tracing ("starting add_equations with " ^ commas prednames ^ "...") - (* null clause handling *) - (* - val thy' = fold (fn pred as (predname, T) => fn thy => - if null (intros_of thy predname) then noclause pred thy else thy) preds thy - *) val (preds, nparams, all_vs, param_vs, extra_modes, clauses, arities) = prepare_intrs thy prednames val _ = tracing "Infering modes..." + (* + val clauses_with_generators = add_generators_to_clauses thy all_vs clauses + val modess = map (infer_modes thy extra_modes arities param_vs) clauses_with_generators + fun print_modess (clauses, modes) = + let + val _ = print_clausess thy clauses + val _ = print_modes modes + in + () + end; + val _ = map print_modess (clauses_with_generators ~~ modess) + *) val modes = infer_modes thy extra_modes arities param_vs clauses val _ = print_modes modes val _ = tracing "Defining executable functions..." - val thy' = fold (create_definitions preds nparams) modes thy |> Theory.checkpoint + val thy' = + (if rpred then + fold (rpred_create_definitions preds nparams) modes thy + else fold (create_definitions preds nparams) modes thy) + |> Theory.checkpoint val clauses' = map (fn (s, cls) => (s, (the (AList.lookup (op =) preds s), cls))) clauses val _ = tracing "Compiling equations..." - val ts = compile_preds thy' all_vs param_vs (extra_modes @ modes) clauses' -(* val _ = map (tracing o (Syntax.string_of_term_global thy')) (flat ts) *) + val compfuns = if rpred then RPredCompFuns.compfuns else PredicateCompFuns.compfuns + val ts = compile_preds thy' compfuns all_vs param_vs (extra_modes @ modes) clauses' + val _ = map (Output.tracing o (Syntax.string_of_term_global thy')) (flat ts) val pred_mode = maps (fn (s, (T, _)) => map (pair (s, T)) ((the o AList.lookup (op =) modes) s)) clauses' val _ = tracing "Proving equations..." val result_thms = - prove_preds thy' all_vs param_vs (extra_modes @ modes) clauses (pred_mode ~~ (flat ts)) + if rpred then + rpred_prove_preds thy' (pred_mode ~~ (flat ts)) + else + prove_preds thy' all_vs param_vs (extra_modes @ modes) clauses (pred_mode ~~ (flat ts)) val thy'' = fold (fn (name, result_thms) => fn thy => snd (PureThy.add_thmss [((Binding.qualify true (Long_Name.base_name name) (Binding.name "equation"), result_thms), [Attrib.attribute_i thy Code.add_default_eqn_attrib])] thy)) @@ -1475,7 +1620,7 @@ val scc = strong_conn_of (PredData.get thy') [name] val thy'' = fold_rev (fn preds => fn thy => - if forall (null o modes_of thy) preds then add_equations_of preds thy else thy) + if forall (null o modes_of thy) preds then add_equations_of true preds thy else thy) scc thy' |> Theory.checkpoint in thy'' end @@ -1563,7 +1708,7 @@ val user_mode = map_filter I (map_index (fn (i, t) => case t of Bound j => if j < length Ts then NONE else SOME (i+1) | _ => SOME (i+1)) args) (*FIXME dangling bounds should not occur*) - val (inargs, _) = get_args user_mode args; + val (inargs, _) = split_mode user_mode args; val modes = filter (fn Mode (_, is, _) => is = user_mode) (modes_of_term (all_modes_of thy) (list_comb (pred, params))); val m = case modes @@ -1572,7 +1717,7 @@ | [m] => m | m :: _ :: _ => (warning ("Multiple modes possible for comprehension " ^ Syntax.string_of_term_global thy t_compr); m); - val t_eval = list_comb (compile_expr thy (all_modes_of thy) (SOME m, list_comb (pred, params)), + val t_eval = list_comb (compile_expr thy PredicateCompFuns.compfuns (all_modes_of thy) (SOME m, list_comb (pred, params)), inargs) in t_eval end;