diff -r 997aa3a3e4bb -r c9f428269b38 src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML --- a/src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML Tue Feb 23 10:02:14 2010 +0100 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML Tue Feb 23 13:36:15 2010 +0100 @@ -9,6 +9,8 @@ val define_predicates : (string * thm list) list -> theory -> (string * thm list) list * theory val rewrite_intro : theory -> thm -> thm list val pred_of_function : theory -> string -> string option + + val add_function_predicate_translation : (term * term) -> theory -> theory end; structure Predicate_Compile_Fun : PREDICATE_COMPILE_FUN = @@ -16,19 +18,36 @@ open Predicate_Compile_Aux; -(* Table from constant name (string) to term of inductive predicate *) -structure Pred_Compile_Preproc = Theory_Data +(* Table from function to inductive predicate *) +structure Fun_Pred = Theory_Data ( - type T = string Symtab.table; - val empty = Symtab.empty; + type T = (term * term) Item_Net.T; + val empty = Item_Net.init (op aconv o pairself fst) (single o fst); val extend = I; - fun merge data : T = Symtab.merge (op =) data; (* FIXME handle Symtab.DUP ?? *) + val merge = Item_Net.merge; ) -fun pred_of_function thy name = Symtab.lookup (Pred_Compile_Preproc.get thy) name +fun lookup thy net t = + case Item_Net.retrieve net t of + [] => NONE + | [(f, p)] => + let + val subst = Pattern.match thy (f, t) (Vartab.empty, Vartab.empty) + in + SOME (Envir.subst_term subst p) + end + | _ => error ("Multiple matches possible for lookup of " ^ Syntax.string_of_term_global thy t) -fun defined thy = Symtab.defined (Pred_Compile_Preproc.get thy) +fun pred_of_function thy name = + case Item_Net.retrieve (Fun_Pred.get thy) (Const (name, Term.dummyT)) of + [] => NONE + | [(f, p)] => SOME (fst (dest_Const p)) + | _ => error ("Multiple matches possible for lookup of constant " ^ name) +fun defined_const thy name = is_some (pred_of_function thy name) + +fun add_function_predicate_translation (f, p) = + Fun_Pred.map (Item_Net.update (f, p)) fun transform_ho_typ (T as Type ("fun", _)) = let @@ -63,27 +82,6 @@ (Free (Long_Name.base_name name ^ "P", pred_type T)) end -fun mk_param thy lookup_pred (t as Free (v, _)) = lookup_pred t - | mk_param thy lookup_pred t = - if Predicate_Compile_Aux.is_predT (fastype_of t) then - t - else - let - val (vs, body) = strip_abs t - val names = Term.add_free_names body [] - val vs_names = Name.variant_list names (map fst vs) - val vs' = map2 (curry Free) vs_names (map snd vs) - val body' = subst_bounds (rev vs', body) - val (f, args) = strip_comb body' - val resname = Name.variant (vs_names @ names) "res" - val resvar = Free (resname, body_type (fastype_of body')) - (*val P = case try lookup_pred f of SOME P => P | NONE => error "mk_param" - val pred_body = list_comb (P, args @ [resvar]) - *) - val pred_body = HOLogic.mk_eq (body', resvar) - val param = fold_rev lambda (vs' @ [resvar]) pred_body - in param end - (* creates the list of premises for every intro rule *) (* theory -> term -> (string list, term list list) *) @@ -92,22 +90,6 @@ val (func, args) = strip_comb lhs in ((func, args), rhs) end; -fun string_of_typ T = Syntax.string_of_typ_global @{theory} T - -fun string_of_term t = - case t of - Const (c, T) => "Const (" ^ c ^ ", " ^ string_of_typ T ^ ")" - | Free (c, T) => "Free (" ^ c ^ ", " ^ string_of_typ T ^ ")" - | Var ((c, i), T) => "Var ((" ^ c ^ ", " ^ string_of_int i ^ "), " ^ string_of_typ T ^ ")" - | Bound i => "Bound " ^ string_of_int i - | Abs (x, T, t) => "Abs (" ^ x ^ ", " ^ string_of_typ T ^ ", " ^ string_of_term t ^ ")" - | t1 $ t2 => "(" ^ string_of_term t1 ^ ") $ (" ^ string_of_term t2 ^ ")" - -fun ind_package_get_nparams thy name = - case try (Inductive.the_inductive (ProofContext.init thy)) name of - SOME (_, result) => length (Inductive.params_of (#raw_induct result)) - | NONE => error ("No such predicate: " ^ quote name) - (* TODO: does not work with higher order functions yet *) fun mk_rewr_eq (func, pred) = let @@ -122,49 +104,6 @@ (HOLogic.mk_eq (res, list_comb (func, args)), list_comb (pred, args @ [res])) end; -fun has_split_rule_cname @{const_name "nat_case"} = true - | has_split_rule_cname @{const_name "list_case"} = true - | has_split_rule_cname _ = false - -fun has_split_rule_term thy (Const (@{const_name "nat_case"}, _)) = true - | has_split_rule_term thy (Const (@{const_name "list_case"}, _)) = true - | has_split_rule_term thy _ = false - -fun has_split_rule_term' thy (Const (@{const_name "If"}, _)) = true - | has_split_rule_term' thy (Const (@{const_name "Let"}, _)) = true - | has_split_rule_term' thy c = has_split_rule_term thy c - -fun prepare_split_thm ctxt split_thm = - (split_thm RS @{thm iffD2}) - |> LocalDefs.unfold ctxt [@{thm atomize_conjL[symmetric]}, - @{thm atomize_all[symmetric]}, @{thm atomize_imp[symmetric]}] - -fun find_split_thm thy (Const (name, typ)) = - let - fun split_name str = - case first_field "." str - of (SOME (field, rest)) => field :: split_name rest - | NONE => [str] - val splitted_name = split_name name - in - if length splitted_name > 0 andalso - String.isSuffix "_case" (List.last splitted_name) - then - (List.take (splitted_name, length splitted_name - 1)) @ ["split"] - |> space_implode "." - |> PureThy.get_thm thy - |> SOME - handle ERROR msg => NONE - else NONE - end - | find_split_thm _ _ = NONE - -fun find_split_thm' thy (Const (@{const_name "If"}, _)) = SOME @{thm split_if} - | find_split_thm' thy (Const (@{const_name "Let"}, _)) = SOME @{thm refl} (* TODO *) - | find_split_thm' thy c = find_split_thm thy c - -fun strip_all t = (Term.strip_all_vars t, Term.strip_all_body t) - fun folds_map f xs y = let fun folds_map' acc [] y = [(rev acc, y)] @@ -174,23 +113,91 @@ folds_map' [] xs y end; -fun mk_prems thy (lookup_pred, get_nparams) t (names, prems) = +fun keep_functions thy t = + case try dest_Const (fst (strip_comb t)) of + SOME (c, _) => Predicate_Compile_Data.keep_function thy c + | _ => false + +fun mk_prems thy lookup_pred t (names, prems) = let fun mk_prems' (t as Const (name, T)) (names, prems) = - if is_constr thy name orelse (is_none (try lookup_pred t)) then + (if is_constr thy name orelse (is_none (lookup_pred t)) then [(t, (names, prems))] - else [(lookup_pred t, (names, prems))] + else + (*(if is_none (try lookup_pred t) then + [(Abs ("uu", fastype_of t, HOLogic.mk_eq (t, Bound 0)), (names, prems))] + else*) [(the (lookup_pred t), (names, prems))]) | mk_prems' (t as Free (f, T)) (names, prems) = - [(lookup_pred t, (names, prems))] + (case lookup_pred t of + SOME t' => [(t', (names, prems))] + | NONE => [(t, (names, prems))]) | mk_prems' (t as Abs _) (names, prems) = if Predicate_Compile_Aux.is_predT (fastype_of t) then - [(t, (names, prems))] else error "mk_prems': Abs " - (* mk_param *) + ([(Envir.eta_contract t, (names, prems))]) + else + let + val (vars, body) = strip_abs t + val _ = assert (fastype_of body = body_type (fastype_of body)) + val absnames = Name.variant_list names (map fst vars) + val frees = map2 (curry Free) absnames (map snd vars) + val body' = subst_bounds (rev frees, body) + val resname = Name.variant (absnames @ names) "res" + val resvar = Free (resname, fastype_of body) + val t = mk_prems' body' ([], []) + |> map (fn (res, (inner_names, inner_prems)) => + let + fun mk_exists (x, T) t = HOLogic.mk_exists (x, T, t) + val vTs = + fold Term.add_frees inner_prems [] + |> filter (fn (x, T) => member (op =) inner_names x) + val t = + fold mk_exists vTs + (foldr1 HOLogic.mk_conj (HOLogic.mk_eq (resvar, res) :: + map HOLogic.dest_Trueprop inner_prems)) + in + t + end) + |> foldr1 HOLogic.mk_disj + |> fold lambda (resvar :: rev frees) + in + [(t, (names, prems))] + end | mk_prems' t (names, prems) = - if Predicate_Compile_Aux.is_constrt thy t then + if Predicate_Compile_Aux.is_constrt thy t orelse keep_functions thy t then [(t, (names, prems))] else - if has_split_rule_term' thy (fst (strip_comb t)) then + case (fst (strip_comb t)) of + Const (@{const_name "If"}, _) => + (let + val (_, [B, x, y]) = strip_comb t + in + (mk_prems' x (names, prems) + |> map (fn (res, (names, prems)) => (res, (names, (HOLogic.mk_Trueprop B) :: prems)))) + @ (mk_prems' y (names, prems) + |> map (fn (res, (names, prems)) => + (res, (names, (HOLogic.mk_Trueprop (HOLogic.mk_not B)) :: prems)))) + end) + | Const (@{const_name "Let"}, _) => + (let + val (_, [f, g]) = strip_comb t + in + mk_prems' f (names, prems) + |> maps (fn (res, (names, prems)) => + mk_prems' (betapply (g, res)) (names, prems)) + end) + | Const (@{const_name "split"}, _) => + (let + val (_, [g, res]) = strip_comb t + val [res1, res2] = Name.variant_list names ["res1", "res2"] + val (T1, T2) = HOLogic.dest_prodT (fastype_of res) + val (resv1, resv2) = (Free (res1, T1), Free (res2, T2)) + in + mk_prems' (betapplys (g, [resv1, resv2])) + (res1 :: res2 :: names, + HOLogic.mk_Trueprop (HOLogic.mk_eq (res, HOLogic.mk_prod (resv1, resv2))) :: prems) + end) + | _ => + if has_split_thm thy (fst (strip_comb t)) then let val (f, args) = strip_comb t val split_thm = prepare_split_thm (ProofContext.init thy) (the (find_split_thm' thy f)) @@ -208,8 +215,15 @@ val vars = map Free (var_names ~~ (map snd vTs)) val (prems', pre_res) = Logic.strip_horn (subst_bounds (rev vars, assm')) val (_, [inner_t]) = strip_comb (HOLogic.dest_Trueprop pre_res) + val (lhss : term list, rhss) = + split_list (map (HOLogic.dest_eq o HOLogic.dest_Trueprop) prems') in - mk_prems' inner_t (var_names @ names, prems' @ prems) + folds_map mk_prems' lhss (var_names @ names, prems) + |> map (fn (ress, (names, prems)) => + let + val prems' = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (ress ~~ rhss) + in (names, prems' @ prems) end) + |> maps (mk_prems' inner_t) end in maps mk_prems_of_assm assms @@ -219,53 +233,77 @@ val (f, args) = strip_comb t (* TODO: special procedure for higher-order functions: split arguments in simple types and function types *) + val args = map (Pattern.eta_long []) args val resname = Name.variant names "res" val resvar = Free (resname, body_type (fastype_of t)) + val _ = assert (fastype_of t = body_type (fastype_of t)) val names' = resname :: names fun mk_prems'' (t as Const (c, _)) = - if is_constr thy c orelse (is_none (try lookup_pred t)) then + if is_constr thy c orelse (is_none (lookup_pred t)) then + let + val _ = ()(*tracing ("not translating function " ^ Syntax.string_of_term_global thy t)*) + in folds_map mk_prems' args (names', prems) |> map (fn (argvs, (names'', prems')) => let val prem = HOLogic.mk_Trueprop (HOLogic.mk_eq (resvar, list_comb (f, argvs))) in (names'', prem :: prems') end) + end else let - val pred = lookup_pred t - val nparams = get_nparams pred - val (params, args) = chop nparams args - val params' = map (mk_param thy lookup_pred) params + (* lookup_pred is falsch für polymorphe Argumente und bool. *) + val pred = the (lookup_pred t) + val Ts = binder_types (fastype_of pred) in folds_map mk_prems' args (names', prems) |> map (fn (argvs, (names'', prems')) => let - val prem = HOLogic.mk_Trueprop (list_comb (pred, params' @ argvs @ [resvar])) + fun lift_arg T t = + if (fastype_of t) = T then t + else + let + val _ = assert (T = + (binder_types (fastype_of t) @ [@{typ bool}] ---> @{typ bool})) + fun mk_if T (b, t, e) = + Const (@{const_name If}, @{typ bool} --> T --> T --> T) $ b $ t $ e + val Ts = binder_types (fastype_of t) + val t = + list_abs (map (pair "x") Ts @ [("b", @{typ bool})], + mk_if @{typ bool} (list_comb (t, map Bound (length Ts downto 1)), + HOLogic.mk_eq (@{term True}, Bound 0), + HOLogic.mk_eq (@{term False}, Bound 0))) + in + t + end + (*val _ = tracing ("Ts: " ^ commas (map (Syntax.string_of_typ_global thy) Ts)) + val _ = map2 check_arity Ts (map fastype_of (argvs @ [resvar]))*) + val argvs' = map2 lift_arg (fst (split_last Ts)) argvs + val prem = HOLogic.mk_Trueprop (list_comb (pred, argvs' @ [resvar])) in (names'', prem :: prems') end) end | mk_prems'' (t as Free (_, _)) = - let - (* higher order argument call *) - val pred = lookup_pred t - in - folds_map mk_prems' args (resname :: names, prems) - |> map (fn (argvs, (names', prems')) => - let - val prem = HOLogic.mk_Trueprop (list_comb (pred, argvs @ [resvar])) - in (names', prem :: prems') end) - end + folds_map mk_prems' args (names', prems) |> + map + (fn (argvs, (names'', prems')) => + let + val prem = + case lookup_pred t of + NONE => HOLogic.mk_Trueprop (HOLogic.mk_eq (resvar, list_comb (f, argvs))) + | SOME p => HOLogic.mk_Trueprop (list_comb (p, argvs @ [resvar])) + in (names'', prem :: prems') end) | mk_prems'' t = error ("Invalid term: " ^ Syntax.string_of_term_global thy t) in map (pair resvar) (mk_prems'' f) end in - mk_prems' t (names, prems) + mk_prems' (Pattern.eta_long [] t) (names, prems) end; (* assumption: mutual recursive predicates all have the same parameters. *) fun define_predicates specs thy = - if forall (fn (const, _) => member (op =) (Symtab.keys (Pred_Compile_Preproc.get thy)) const) specs then + if forall (fn (const, _) => defined_const thy const) specs then ([], thy) else let @@ -275,36 +313,20 @@ (* create prednames *) val ((funs, argss), rhss) = map_split dest_code_eqn eqns |>> split_list val argss' = map (map transform_ho_arg) argss - val pnames = map dest_Free (distinct (op =) (maps (filter (is_funtype o fastype_of)) argss')) + (* TODO: higher order arguments also occur in tuples! *) + val ho_argss = distinct (op =) (maps (filter (is_funtype o fastype_of)) argss) + val params = distinct (op =) (maps (filter (is_funtype o fastype_of)) argss') + val pnames = map dest_Free params val preds = map pred_of funs val prednames = map (fst o dest_Free) preds val funnames = map (fst o dest_Const) funs val fun_pred_names = (funnames ~~ prednames) (* mapping from term (Free or Const) to term *) - fun lookup_pred (Const (name, T)) = - (case (Symtab.lookup (Pred_Compile_Preproc.get thy) name) of - SOME c => Const (c, pred_type T) - | NONE => - (case AList.lookup op = fun_pred_names name of - SOME f => Free (f, pred_type T) - | NONE => Const (name, T))) - | lookup_pred (Free (name, T)) = - if member op = (map fst pnames) name then - Free (name, transform_ho_typ T) - else - Free (name, T) - | lookup_pred t = - error ("lookup function is not defined for " ^ Syntax.string_of_term_global thy t) - - (* mapping from term (predicate term, not function term!) to int *) - fun get_nparams (Const (name, _)) = - the_default 0 (try (ind_package_get_nparams thy) name) - | get_nparams (Free (name, _)) = - (if member op = prednames name then - length pnames - else 0) - | get_nparams t = error ("No parameters for " ^ (Syntax.string_of_term_global thy t)) - + fun map_Free f = Free o f o dest_Free + val net = fold Item_Net.update + ((funs ~~ preds) @ (ho_argss ~~ params)) + (Fun_Pred.get thy) + fun lookup_pred t = lookup thy net t (* create intro rules *) fun mk_intros ((func, pred), (args, rhs)) = @@ -314,14 +336,15 @@ else let val names = Term.add_free_names rhs [] - in mk_prems thy (lookup_pred, get_nparams) rhs (names, []) + in mk_prems thy lookup_pred rhs (names, []) |> map (fn (resultt, (names', prems)) => Logic.list_implies (prems, HOLogic.mk_Trueprop (list_comb (pred, args @ [resultt])))) end fun mk_rewr_thm (func, pred) = @{thm refl} in - case try (maps mk_intros) ((funs ~~ preds) ~~ (argss' ~~ rhss)) of - NONE => ([], thy) + case (*try *)SOME (maps mk_intros ((funs ~~ preds) ~~ (argss' ~~ rhss))) of + NONE => + let val _ = tracing "error occured!" in ([], thy) end | SOME intr_ts => if is_some (try (map (cterm_of thy)) intr_ts) then let @@ -333,53 +356,59 @@ no_elim = false, no_ind = false, skip_mono = false, fork_mono = false} (map (fn (s, T) => ((Binding.name s, T), NoSyn)) (distinct (op =) (map dest_Free preds))) - pnames + [] (map (fn x => (Attrib.empty_binding, x)) intr_ts) [] ||> Sign.restore_naming thy val prednames = map (fst o dest_Const) (#preds ind_result) (* val rewr_thms = map mk_rewr_eq ((distinct (op =) funs) ~~ (#preds ind_result)) *) (* add constants to my table *) + val specs = map (fn predname => (predname, filter (Predicate_Compile_Aux.is_intro predname) (#intrs ind_result))) prednames + (* val thy'' = Pred_Compile_Preproc.map (fold Symtab.update_new (consts ~~ prednames)) thy' + *) + + val thy'' = Fun_Pred.map + (fold Item_Net.update (map (apfst Logic.varify) + (distinct (op =) funs ~~ (#preds ind_result)))) thy' + (*val _ = print_specs thy'' specs*) in (specs, thy'') end else let - val _ = tracing "Introduction rules of function_predicate are not welltyped" + val _ = Output.tracing ( + "Introduction rules of function_predicate are not welltyped: " ^ + commas (map (Syntax.string_of_term_global thy) intr_ts)) in ([], thy) end end fun rewrite_intro thy intro = let - fun lookup_pred (Const (name, T)) = + (*val _ = tracing ("Rewriting intro with registered mapping for: " ^ + commas (Symtab.keys (Pred_Compile_Preproc.get thy)))*) + (*fun lookup_pred (Const (name, T)) = (case (Symtab.lookup (Pred_Compile_Preproc.get thy) name) of - SOME c => Const (c, pred_type T) - | NONE => error ("Function " ^ name ^ " is not inductified")) - | lookup_pred (Free (name, T)) = Free (name, T) - | lookup_pred _ = error "lookup function is not defined!" - - fun get_nparams (Const (name, _)) = - the_default 0 (try (ind_package_get_nparams thy) name) - | get_nparams (Free _) = 0 - | get_nparams t = error ("No parameters for " ^ (Syntax.string_of_term_global thy t)) - + SOME c => SOME (Const (c, pred_type T)) + | NONE => NONE) + | lookup_pred _ = NONE + *) + fun lookup_pred t = lookup thy (Fun_Pred.get thy) t val intro_t = (Logic.unvarify o prop_of) intro val (prems, concl) = Logic.strip_horn intro_t val frees = map fst (Term.add_frees intro_t []) fun rewrite prem names = let + (*val _ = tracing ("Rewriting premise " ^ Syntax.string_of_term_global thy prem ^ "...")*) val t = (HOLogic.dest_Trueprop prem) val (lit, mk_lit) = case try HOLogic.dest_not t of SOME t => (t, HOLogic.mk_not) | NONE => (t, I) - val (P, args) = (strip_comb lit) + val (P, args) = (strip_comb lit) in - folds_map ( - fn t => if (is_funtype (fastype_of t)) then (fn x => [(t, x)]) - else mk_prems thy (lookup_pred, get_nparams) t) args (names, []) + folds_map (mk_prems thy lookup_pred) args (names, []) |> map (fn (resargs, (names', prems')) => let val prem' = HOLogic.mk_Trueprop (mk_lit (list_comb (P, resargs)))