Improved code generator for Collect.
authorberghofe
Wed, 11 Jul 2007 11:38:25 +0200
changeset 23761 9cebbaccf8a2
parent 23760 aca2c7f80e2f
child 23762 24eef53a9ad3
Improved code generator for Collect.
src/HOL/Tools/inductive_codegen.ML
--- a/src/HOL/Tools/inductive_codegen.ML	Wed Jul 11 11:36:06 2007 +0200
+++ b/src/HOL/Tools/inductive_codegen.ML	Wed Jul 11 11:38:25 2007 +0200
@@ -615,18 +615,57 @@
   | SOME _ =>
       (add_edge (name, dep) gr, mk_qual_id module (get_const_id name gr));
 
+(* convert n-tuple to nested pairs *)
+
+fun conv_ntuple fs ts p =
+  let
+    val k = length fs;
+    val xs = map (fn i => Pretty.str ("x" ^ string_of_int i)) (0 upto k);
+    val xs' = map (fn Bound i => nth xs (k - i)) ts;
+    fun conv xs js =
+      if js mem fs then
+        let
+          val (p, xs') = conv xs (1::js);
+          val (q, xs'') = conv xs' (2::js)
+        in (mk_tuple [p, q], xs'') end
+      else (hd xs, tl xs)
+  in
+    if k > 0 then
+      Pretty.block
+        [Pretty.str "Seq.map (fn", Pretty.brk 1,
+         mk_tuple xs', Pretty.str " =>", Pretty.brk 1, fst (conv xs []),
+         Pretty.str ")", Pretty.brk 1, parens p]
+    else p
+  end;
+
 fun inductive_codegen thy defs gr dep module brack t = (case strip_comb t of
-    (Const ("Collect", Type (_, [_, Type (_, [U])])), [u]) => (case strip_comb u of
-        (Const (s, T), ts) => (case (get_clauses thy s, get_assoc_code thy (s, T)) of
-          (SOME (names, thyname, k, intrs), NONE) =>
-            let val (gr', call_p) = mk_ind_call thy defs gr dep module true
-              s T (ts @ [Term.dummy_pattern U]) names thyname k intrs
-            in SOME (gr', (if brack then parens else I) (Pretty.block
-              [Pretty.str "Seq.list_of", Pretty.brk 1, Pretty.str "(",
-               call_p, Pretty.str ")"]))
-            end
-        | _ => NONE)
-      | _ => NONE)
+    (Const ("Collect", _), [u]) =>
+      let val (r, Ts, fs) = HOLogic.strip_split u
+      in case strip_comb r of
+          (Const (s, T), ts) =>
+            (case (get_clauses thy s, get_assoc_code thy (s, T)) of
+              (SOME (names, thyname, k, intrs), NONE) =>
+                let
+                  val (ts1, ts2) = chop k ts;
+                  val ts2' = map
+                    (fn Bound i => Term.dummy_pattern (nth Ts i) | t => t) ts2;
+                  val (ots, its) = List.partition is_Bound ts2;
+                  val no_loose = forall (fn t => not (loose_bvar (t, 0)))
+                in
+                  if null (duplicates op = ots) andalso
+                    no_loose ts1 andalso no_loose its
+                  then
+                    let val (gr', call_p) = mk_ind_call thy defs gr dep module true
+                      s T (ts1 @ ts2') names thyname k intrs
+                    in SOME (gr', (if brack then parens else I) (Pretty.block
+                      [Pretty.str "Seq.list_of", Pretty.brk 1, Pretty.str "(",
+                       conv_ntuple fs ots call_p, Pretty.str ")"]))
+                    end
+                  else NONE
+                end
+            | _ => NONE)
+        | _ => NONE
+      end
   | (Const (s, T), ts) => (case Symtab.lookup (#eqns (CodegenData.get thy)) s of
       NONE => (case (get_clauses thy s, get_assoc_code thy (s, T)) of
         (SOME (names, thyname, k, intrs), NONE) =>