# HG changeset patch # User haftmann # Date 1248954738 -7200 # Node ID c8c17c2e6cebcadab63d9f11071e3b77e0872858 # Parent b4632820e74ce3b871cf6d2d5384650f1de85636 towards proper handling of argument order in comprehensions diff -r b4632820e74c -r c8c17c2e6ceb src/HOL/ex/predicate_compile.ML --- a/src/HOL/ex/predicate_compile.ML Thu Jul 30 13:52:18 2009 +0200 +++ b/src/HOL/ex/predicate_compile.ML Thu Jul 30 13:52:18 2009 +0200 @@ -82,9 +82,9 @@ | dest_tuple (Const (@{const_name Pair}, _) $ t1 $ t2) = t1 :: (dest_tuple t2) | dest_tuple t = [t] -fun mk_pred_enumT T = Type (@{type_name "Predicate.pred"}, [T]) +fun mk_pred_enumT T = Type (@{type_name Predicate.pred}, [T]) -fun dest_pred_enumT (Type (@{type_name "Predicate.pred"}, [T])) = T +fun dest_pred_enumT (Type (@{type_name Predicate.pred}, [T])) = T | dest_pred_enumT T = raise TYPE ("dest_pred_enumT", [T], []); fun mk_Enum f = @@ -119,6 +119,10 @@ fun mk_not_pred t = let val T = mk_pred_enumT HOLogic.unitT in Const (@{const_name Predicate.not_pred}, T --> T) $ t end +fun mk_pred_map T1 T2 tf tp = Const (@{const_name Predicate.map}, + (T1 --> T2) --> mk_pred_enumT T1 --> mk_pred_enumT T2) $ tf $ tp; + + (* destruction of intro rules *) (* FIXME: look for other place where this functionality was used before *) @@ -383,7 +387,7 @@ fun get_args is ts = let fun get_args' _ _ [] = ([], []) - | get_args' is i (t::ts) = (if i mem is then apfst else apsnd) (cons t) + | get_args' is i (t::ts) = (if member (op =) is i then apfst else apsnd) (cons t) (get_args' is (i+1) ts) in get_args' is 1 ts end @@ -1527,18 +1531,17 @@ val eval_ref = ref (NONE : (unit -> term Predicate.pred) option); +(*FIXME turn this into an LCF-guarded preprocessor for comprehensions*) fun analyze_compr thy t_compr = let val split = case t_compr of (Const (@{const_name Collect}, _) $ t) => t | _ => error ("Not a set comprehension: " ^ Syntax.string_of_term_global thy t_compr); val (body, Ts, fp) = HOLogic.strip_splits split; - (*FIXME former order of tuple positions must be restored*) - val (pred as Const (name, T), all_args) = strip_comb body - val (params, args) = chop (nparams_of thy name) all_args + val (pred as Const (name, T), all_args) = strip_comb body; + val (params, args) = chop (nparams_of thy name) all_args; val user_mode = map_filter I (map_index (fn (i, t) => case t of Bound j => if j < length Ts then NONE - else SOME (i+1) | _ => SOME (i+1)) args) (*FIXME dangling bounds should not occur*) - val (inargs, _) = get_args user_mode args; + else SOME (i+1) | _ => SOME (i+1)) args); (*FIXME dangling bounds should not occur*) val modes = filter (fn Mode (_, is, _) => is = user_mode) (modes_of_term (all_modes_of thy) (list_comb (pred, params))); val m = case modes @@ -1547,9 +1550,63 @@ | [m] => m | m :: _ :: _ => (warning ("Multiple modes possible for comprehension " ^ Syntax.string_of_term_global thy t_compr); m); - val t_eval = list_comb (compile_expr thy (all_modes_of thy) (SOME m, list_comb (pred, params)), - inargs) + val (inargs, outargs) = get_args user_mode args; + val t_pred = list_comb (compile_expr thy (all_modes_of thy) (SOME m, list_comb (pred, params)), + inargs); + val t_eval = if null outargs then t_pred else let + val outargs_bounds = map (fn Bound i => i) outargs; + val outargsTs = map (nth Ts) outargs_bounds; + val T_pred = mk_tupleT outargsTs; + val T_compr = HOLogic.mk_tupleT fp Ts; + val arrange_bounds = map_index I outargs_bounds + |> sort (prod_ord (K EQUAL) int_ord) + |> map fst; + val arrange = funpow (length outargs_bounds - 1) HOLogic.mk_split + (Term.list_abs (map (pair "") outargsTs, + HOLogic.mk_tuple fp T_compr (map Bound arrange_bounds))) + in mk_pred_map T_pred T_compr arrange t_pred end in t_eval end; +fun eval thy t_compr = + let + val t = analyze_compr thy t_compr; + val T = dest_pred_enumT (fastype_of t); + val t' = mk_pred_map T HOLogic.termT (HOLogic.term_of_const T) t; + in (T, Code_ML.eval NONE ("Predicate_Compile.eval_ref", eval_ref) Predicate.map thy t' []) end; + +fun values ctxt k t_compr = + let + val thy = ProofContext.theory_of ctxt; + val (T, t) = eval thy t_compr; + val setT = HOLogic.mk_setT T; + val (ts, _) = Predicate.yieldn k t; + val elemsT = HOLogic.mk_set T ts; + in if k = ~1 orelse length ts < k then elemsT + else Const (@{const_name Set.union}, setT --> setT --> setT) $ elemsT $ t_compr + end; + +fun values_cmd modes k raw_t state = + let + val ctxt = Toplevel.context_of state; + val t = Syntax.read_term ctxt raw_t; + val t' = values ctxt k t; + val ty' = Term.type_of t'; + val ctxt' = Variable.auto_fixes t' ctxt; + val p = PrintMode.with_modes modes (fn () => + Pretty.block [Pretty.quote (Syntax.pretty_term ctxt' t'), Pretty.fbrk, + Pretty.str "::", Pretty.brk 1, Pretty.quote (Syntax.pretty_typ ctxt' ty')]) (); + in Pretty.writeln p end; + +local structure P = OuterParse in + +val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) []; + +val _ = OuterSyntax.improper_command "values" "enumerate and print comprehensions" OuterKeyword.diag + (opt_modes -- Scan.optional P.nat ~1 -- P.term + >> (fn ((modes, k), t) => Toplevel.no_timing o Toplevel.keep + (values_cmd modes k t))); + end; +end; +