towards proper handling of argument order in comprehensions
authorhaftmann
Thu, 30 Jul 2009 13:52:18 +0200
changeset 32341 c8c17c2e6ceb
parent 32340 b4632820e74c
child 32342 3fabf5b5fc83
towards proper handling of argument order in comprehensions
src/HOL/ex/predicate_compile.ML
--- a/src/HOL/ex/predicate_compile.ML	Thu Jul 30 13:52:18 2009 +0200
+++ b/src/HOL/ex/predicate_compile.ML	Thu Jul 30 13:52:18 2009 +0200
@@ -82,9 +82,9 @@
   | dest_tuple (Const (@{const_name Pair}, _) $ t1 $ t2) = t1 :: (dest_tuple t2)
   | dest_tuple t = [t]
 
-fun mk_pred_enumT T = Type (@{type_name "Predicate.pred"}, [T])
+fun mk_pred_enumT T = Type (@{type_name Predicate.pred}, [T])
 
-fun dest_pred_enumT (Type (@{type_name "Predicate.pred"}, [T])) = T
+fun dest_pred_enumT (Type (@{type_name Predicate.pred}, [T])) = T
   | dest_pred_enumT T = raise TYPE ("dest_pred_enumT", [T], []);
 
 fun mk_Enum f =
@@ -119,6 +119,10 @@
 fun mk_not_pred t = let val T = mk_pred_enumT HOLogic.unitT
   in Const (@{const_name Predicate.not_pred}, T --> T) $ t end
 
+fun mk_pred_map T1 T2 tf tp = Const (@{const_name Predicate.map},
+  (T1 --> T2) --> mk_pred_enumT T1 --> mk_pred_enumT T2) $ tf $ tp;
+
+
 (* destruction of intro rules *)
 
 (* FIXME: look for other place where this functionality was used before *)
@@ -383,7 +387,7 @@
 
 fun get_args is ts = let
   fun get_args' _ _ [] = ([], [])
-    | get_args' is i (t::ts) = (if i mem is then apfst else apsnd) (cons t)
+    | get_args' is i (t::ts) = (if member (op =) is i then apfst else apsnd) (cons t)
         (get_args' is (i+1) ts)
 in get_args' is 1 ts end
 
@@ -1527,18 +1531,17 @@
 
 val eval_ref = ref (NONE : (unit -> term Predicate.pred) option);
 
+(*FIXME turn this into an LCF-guarded preprocessor for comprehensions*)
 fun analyze_compr thy t_compr =
   let
     val split = case t_compr of (Const (@{const_name Collect}, _) $ t) => t
       | _ => error ("Not a set comprehension: " ^ Syntax.string_of_term_global thy t_compr);
     val (body, Ts, fp) = HOLogic.strip_splits split;
-      (*FIXME former order of tuple positions must be restored*)
-    val (pred as Const (name, T), all_args) = strip_comb body
-    val (params, args) = chop (nparams_of thy name) all_args
+    val (pred as Const (name, T), all_args) = strip_comb body;
+    val (params, args) = chop (nparams_of thy name) all_args;
     val user_mode = map_filter I (map_index
       (fn (i, t) => case t of Bound j => if j < length Ts then NONE
-        else SOME (i+1) | _ => SOME (i+1)) args) (*FIXME dangling bounds should not occur*)
-    val (inargs, _) = get_args user_mode args;
+        else SOME (i+1) | _ => SOME (i+1)) args); (*FIXME dangling bounds should not occur*)
     val modes = filter (fn Mode (_, is, _) => is = user_mode)
       (modes_of_term (all_modes_of thy) (list_comb (pred, params)));
     val m = case modes
@@ -1547,9 +1550,63 @@
       | [m] => m
       | m :: _ :: _ => (warning ("Multiple modes possible for comprehension "
                 ^ Syntax.string_of_term_global thy t_compr); m);
-    val t_eval = list_comb (compile_expr thy (all_modes_of thy) (SOME m, list_comb (pred, params)),
-      inargs)
+    val (inargs, outargs) = get_args user_mode args;
+    val t_pred = list_comb (compile_expr thy (all_modes_of thy) (SOME m, list_comb (pred, params)),
+      inargs);
+    val t_eval = if null outargs then t_pred else let
+        val outargs_bounds = map (fn Bound i => i) outargs;
+        val outargsTs = map (nth Ts) outargs_bounds;
+        val T_pred = mk_tupleT outargsTs;
+        val T_compr = HOLogic.mk_tupleT fp Ts;
+        val arrange_bounds = map_index I outargs_bounds
+          |> sort (prod_ord (K EQUAL) int_ord)
+          |> map fst;
+        val arrange = funpow (length outargs_bounds - 1) HOLogic.mk_split
+          (Term.list_abs (map (pair "") outargsTs,
+            HOLogic.mk_tuple fp T_compr (map Bound arrange_bounds)))
+      in mk_pred_map T_pred T_compr arrange t_pred end
   in t_eval end;
 
+fun eval thy t_compr =
+  let
+    val t = analyze_compr thy t_compr;
+    val T = dest_pred_enumT (fastype_of t);
+    val t' = mk_pred_map T HOLogic.termT (HOLogic.term_of_const T) t;
+  in (T, Code_ML.eval NONE ("Predicate_Compile.eval_ref", eval_ref) Predicate.map thy t' []) end;
+
+fun values ctxt k t_compr =
+  let
+    val thy = ProofContext.theory_of ctxt;
+    val (T, t) = eval thy t_compr;
+    val setT = HOLogic.mk_setT T;
+    val (ts, _) = Predicate.yieldn k t;
+    val elemsT = HOLogic.mk_set T ts;
+  in if k = ~1 orelse length ts < k then elemsT
+    else Const (@{const_name Set.union}, setT --> setT --> setT) $ elemsT $ t_compr
+  end;
+
+fun values_cmd modes k raw_t state =
+  let
+    val ctxt = Toplevel.context_of state;
+    val t = Syntax.read_term ctxt raw_t;
+    val t' = values ctxt k t;
+    val ty' = Term.type_of t';
+    val ctxt' = Variable.auto_fixes t' ctxt;
+    val p = PrintMode.with_modes modes (fn () =>
+      Pretty.block [Pretty.quote (Syntax.pretty_term ctxt' t'), Pretty.fbrk,
+        Pretty.str "::", Pretty.brk 1, Pretty.quote (Syntax.pretty_typ ctxt' ty')]) ();
+  in Pretty.writeln p end;
+
+local structure P = OuterParse in
+
+val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) [];
+
+val _ = OuterSyntax.improper_command "values" "enumerate and print comprehensions" OuterKeyword.diag
+  (opt_modes -- Scan.optional P.nat ~1 -- P.term
+    >> (fn ((modes, k), t) => Toplevel.no_timing o Toplevel.keep
+        (values_cmd modes k t)));
+
 end;
 
+end;
+