# HG changeset patch # User bulwahn # Date 1269243012 -3600 # Node ID b0d24a74b06bec750b6d105d359acfbe89789866 # Parent bcfa6b4b21c65cd6d6aa51331293b34865d563f0 restructuring function flattening diff -r bcfa6b4b21c6 -r b0d24a74b06b src/HOL/Tools/Predicate_Compile/predicate_compile.ML --- a/src/HOL/Tools/Predicate_Compile/predicate_compile.ML Mon Mar 22 08:30:12 2010 +0100 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile.ML Mon Mar 22 08:30:12 2010 +0100 @@ -111,7 +111,8 @@ val intross5 = map (fn (s, ths) => (overload_const thy''' s, map (AxClass.overload thy''') ths)) intross4 val intross6 = map_specs (map (expand_tuples thy''')) intross5 - val _ = print_intross options thy''' "introduction rules before registering: " intross6 + val intross7 = map_specs (map (eta_contract_ho_arguments thy''')) intross6 + val _ = print_intross options thy''' "introduction rules before registering: " intross7 val _ = print_step options "Registering introduction rules..." val thy'''' = fold Predicate_Compile_Core.register_intros intross6 thy''' in diff -r bcfa6b4b21c6 -r b0d24a74b06b src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML --- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML Mon Mar 22 08:30:12 2010 +0100 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML Mon Mar 22 08:30:12 2010 +0100 @@ -399,6 +399,18 @@ Logic.list_implies (maps f premises, head) end +fun map_concl f intro = + let + val (premises, head) = Logic.strip_horn intro + in + Logic.list_implies (premises, f head) + end + +(* combinators to apply a function to all basic parts of nested products *) + +fun map_products f (Const ("Pair", T) $ t1 $ t2) = + Const ("Pair", T) $ map_products f t1 $ map_products f t2 + | map_products f t = f t (* split theorems of case expressions *) @@ -619,4 +631,15 @@ intro''''' end +(* eta contract higher-order arguments *) + + +fun eta_contract_ho_arguments thy intro = + let + fun f atom = list_comb (apsnd ((map o map_products) Envir.eta_contract) (strip_comb atom)) + in + map_term thy (map_concl f o map_atoms f) intro + end + + end; diff -r bcfa6b4b21c6 -r b0d24a74b06b src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML --- a/src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML Mon Mar 22 08:30:12 2010 +0100 +++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_fun.ML Mon Mar 22 08:30:12 2010 +0100 @@ -37,7 +37,8 @@ in SOME (Envir.subst_term subst p) end - | _ => error ("Multiple matches possible for lookup of " ^ Syntax.string_of_term_global thy t) + | _ => NONE + (*_ => error ("Multiple matches possible for lookup of " ^ Syntax.string_of_term_global thy t)*) fun pred_of_function thy name = case Item_Net.retrieve (Fun_Pred.get thy) (Const (name, Term.dummyT)) of @@ -119,9 +120,8 @@ SOME (c, _) => Predicate_Compile_Data.keep_function thy c | _ => false -fun flatten thy lookup_pred t (names, prems) = - let - fun flatten' (t as Const (name, T)) (names, prems) = +(* dump: +fun flatten' (t as Const (name, T)) (names, prems) = (if is_constr thy name orelse (is_none (lookup_pred t)) then [(t, (names, prems))] else @@ -163,7 +163,55 @@ in [(t, (names, prems))] end - | flatten' t (names, prems) = +*) + +fun flatten thy lookup_pred t (names, prems) = + let + fun lift t (names, prems) = + case lookup_pred (Envir.eta_contract t) of + SOME pred => [(pred, (names, prems))] + | NONE => + let + val (vars, body) = strip_abs t + val _ = assert (fastype_of body = body_type (fastype_of body)) + val absnames = Name.variant_list names (map fst vars) + val frees = map2 (curry Free) absnames (map snd vars) + val body' = subst_bounds (rev frees, body) + val resname = Name.variant (absnames @ names) "res" + val resvar = Free (resname, fastype_of body) + val t = flatten' body' ([], []) + |> map (fn (res, (inner_names, inner_prems)) => + let + fun mk_exists (x, T) t = HOLogic.mk_exists (x, T, t) + val vTs = + fold Term.add_frees inner_prems [] + |> filter (fn (x, T) => member (op =) inner_names x) + val t = + fold mk_exists vTs + (foldr1 HOLogic.mk_conj (HOLogic.mk_eq (res, resvar) :: + map HOLogic.dest_Trueprop inner_prems)) + in + t + end) + |> foldr1 HOLogic.mk_disj + |> fold lambda (resvar :: rev frees) + in + [(t, (names, prems))] + end + and flatten_or_lift (t, T) (names, prems) = + if fastype_of t = T then + flatten' t (names, prems) + else + (* note pred_type might be to general! *) + if (pred_type (fastype_of t) = T) then + lift t (names, prems) + else + error ("unexpected input for flatten or lift" ^ Syntax.string_of_term_global thy t ^ + ", " ^ Syntax.string_of_typ_global thy T) + and flatten' (t as Const (name, T)) (names, prems) = [(t, (names, prems))] + | flatten' (t as Free (f, T)) (names, prems) = [(t, (names, prems))] + | flatten' (t as Abs _) (names, prems) = [(t, (names, prems))] + | flatten' (t as _ $ _) (names, prems) = if Predicate_Compile_Aux.is_constrt thy t orelse keep_functions thy t then [(t, (names, prems))] else @@ -172,11 +220,14 @@ (let val (_, [B, x, y]) = strip_comb t in - (flatten' x (names, prems) - |> map (fn (res, (names, prems)) => (res, (names, (HOLogic.mk_Trueprop B) :: prems)))) - @ (flatten' y (names, prems) - |> map (fn (res, (names, prems)) => - (res, (names, (HOLogic.mk_Trueprop (HOLogic.mk_not B)) :: prems)))) + flatten' B (names, prems) + |> maps (fn (B', (names, prems)) => + (flatten' x (names, prems) + |> map (fn (res, (names, prems)) => (res, (names, (HOLogic.mk_Trueprop B') :: prems)))) + @ (flatten' y (names, prems) + |> map (fn (res, (names, prems)) => + (* in general unsound! *) + (res, (names, (HOLogic.mk_Trueprop (HOLogic.mk_not B')) :: prems))))) end) | Const (@{const_name "Let"}, _) => (let @@ -232,57 +283,47 @@ else let val (f, args) = strip_comb t - (* TODO: special procedure for higher-order functions: split arguments in - simple types and function types *) val args = map (Pattern.eta_long []) args - val resname = Name.variant names "res" - val resvar = Free (resname, body_type (fastype_of t)) val _ = assert (fastype_of t = body_type (fastype_of t)) - val names' = resname :: names - val t' = lookup_pred f - val Ts = case t' of + val f' = lookup_pred f + val Ts = case f' of SOME pred => (fst (split_last (binder_types (fastype_of pred)))) - | NONE => binder_types (fastype_of t) - val namesprems = - case t' of - NONE => - folds_map flatten' args (names', prems) |> - map - (fn (argvs, (names'', prems')) => - let - val prem = HOLogic.mk_Trueprop (HOLogic.mk_eq (resvar, list_comb (f, argvs))) - in (names'', prem :: prems') end) - | SOME pred => - folds_map flatten' args (names', prems) - |> map (fn (argvs, (names'', prems')) => - let - fun lift_arg T t = - if (fastype_of t) = T then t - else - let - val _ = assert (T = - (binder_types (fastype_of t) @ [@{typ bool}] ---> @{typ bool})) - fun mk_if T (b, t, e) = - Const (@{const_name If}, @{typ bool} --> T --> T --> T) $ b $ t $ e - val Ts = binder_types (fastype_of t) - val t = - list_abs (map (pair "x") Ts @ [("b", @{typ bool})], - mk_if @{typ bool} (list_comb (t, map Bound (length Ts downto 1)), - HOLogic.mk_eq (@{term True}, Bound 0), - HOLogic.mk_eq (@{term False}, Bound 0))) - in - t - end - (*val _ = tracing ("Ts: " ^ commas (map (Syntax.string_of_typ_global thy) Ts)) - val _ = map2 check_arity Ts (map fastype_of (argvs @ [resvar]))*) - val argvs' = map2 lift_arg Ts argvs - val prem = HOLogic.mk_Trueprop (list_comb (pred, argvs' @ [resvar])) - in (names'', prem :: prems') end) + | NONE => binder_types (fastype_of f) in - map (pair resvar) namesprems + folds_map flatten_or_lift (args ~~ Ts) (names, prems) |> + (case f' of + NONE => + map (fn (argvs, (names', prems')) => (list_comb (f, argvs), (names', prems'))) + | SOME pred => + map (fn (argvs, (names', prems')) => + let + fun lift_arg T t = + if (fastype_of t) = T then t + else + let + val _ = assert (T = + (binder_types (fastype_of t) @ [@{typ bool}] ---> @{typ bool})) + fun mk_if T (b, t, e) = + Const (@{const_name If}, @{typ bool} --> T --> T --> T) $ b $ t $ e + val Ts = binder_types (fastype_of t) + val t = + list_abs (map (pair "x") Ts @ [("b", @{typ bool})], + mk_if @{typ bool} (list_comb (t, map Bound (length Ts downto 1)), + HOLogic.mk_eq (@{term True}, Bound 0), + HOLogic.mk_eq (@{term False}, Bound 0))) + in + t + end + (*val _ = tracing ("Ts: " ^ commas (map (Syntax.string_of_typ_global thy) Ts)) + val _ = map2 check_arity Ts (map fastype_of (argvs @ [resvar]))*) + val argvs' = map2 lift_arg Ts argvs + val resname = Name.variant names' "res" + val resvar = Free (resname, body_type (fastype_of t)) + val prem = HOLogic.mk_Trueprop (list_comb (pred, argvs' @ [resvar])) + in (resvar, (resname :: names', prem :: prems')) end)) end in - flatten' (Pattern.eta_long [] t) (names, prems) + map (apfst Envir.eta_contract) (flatten' (Pattern.eta_long [] t) (names, prems)) end; (* assumption: mutual recursive predicates all have the same parameters. *) @@ -373,12 +414,6 @@ let (*val _ = tracing ("Rewriting intro with registered mapping for: " ^ commas (Symtab.keys (Pred_Compile_Preproc.get thy)))*) - (*fun lookup_pred (Const (name, T)) = - (case (Symtab.lookup (Pred_Compile_Preproc.get thy) name) of - SOME c => SOME (Const (c, pred_type T)) - | NONE => NONE) - | lookup_pred _ = NONE - *) fun lookup_pred t = lookup thy (Fun_Pred.get thy) t val intro_t = Logic.unvarify_global (prop_of intro) val (prems, concl) = Logic.strip_horn intro_t