unified variable names in case expressions; no exponential fork in translation of case expressions
authorhaftmann
Tue, 17 Feb 2009 18:45:41 +0100
changeset 29952 9aed85067721
parent 29947 0a51765d2084
child 29953 7a2eb84343f9
unified variable names in case expressions; no exponential fork in translation of case expressions
src/Tools/code/code_haskell.ML
src/Tools/code/code_ml.ML
src/Tools/code/code_thingol.ML
--- a/src/Tools/code/code_haskell.ML	Mon Feb 16 19:35:52 2009 -0800
+++ b/src/Tools/code/code_haskell.ML	Tue Feb 17 18:45:41 2009 +0100
@@ -101,7 +101,7 @@
     and pr_bind tyvars = pr_haskell_bind (pr_term tyvars)
     and pr_case tyvars thm vars fxy (cases as ((_, [_]), _)) =
           let
-            val (binds, t) = Code_Thingol.unfold_let (ICase cases);
+            val (binds, body) = Code_Thingol.unfold_let (ICase cases);
             fun pr ((pat, ty), t) vars =
               vars
               |> pr_bind tyvars thm BR ((NONE, SOME pat), ty)
@@ -110,20 +110,20 @@
           in
             Pretty.block_enclose (
               str "let {",
-              concat [str "}", str "in", pr_term tyvars thm vars' NOBR t]
+              concat [str "}", str "in", pr_term tyvars thm vars' NOBR body]
             ) ps
           end
-      | pr_case tyvars thm vars fxy (((td, ty), bs as _ :: _), _) =
+      | pr_case tyvars thm vars fxy (((t, ty), clauses as _ :: _), _) =
           let
-            fun pr (pat, t) =
+            fun pr (pat, body) =
               let
                 val (p, vars') = pr_bind tyvars thm NOBR ((NONE, SOME pat), ty) vars;
-              in semicolon [p, str "->", pr_term tyvars thm vars' NOBR t] end;
+              in semicolon [p, str "->", pr_term tyvars thm vars' NOBR body] end;
           in
             Pretty.block_enclose (
-              concat [str "(case", pr_term tyvars thm vars NOBR td, str "of", str "{"],
+              concat [str "(case", pr_term tyvars thm vars NOBR t, str "of", str "{"],
               str "})"
-            ) (map pr bs)
+            ) (map pr clauses)
           end
       | pr_case tyvars thm vars fxy ((_, []), _) = str "error \"empty case\"";
     fun pr_stmt (name, Code_Thingol.Fun (_, ((vs, ty), []))) =
--- a/src/Tools/code/code_ml.ML	Mon Feb 16 19:35:52 2009 -0800
+++ b/src/Tools/code/code_ml.ML	Tue Feb 17 18:45:41 2009 +0100
@@ -130,7 +130,7 @@
     and pr_bind is_closure = gen_pr_bind pr_bind' (pr_term is_closure)
     and pr_case is_closure thm vars fxy (cases as ((_, [_]), _)) =
           let
-            val (binds, t') = Code_Thingol.unfold_let (ICase cases);
+            val (binds, body) = Code_Thingol.unfold_let (ICase cases);
             fun pr ((pat, ty), t) vars =
               vars
               |> pr_bind is_closure thm NOBR ((NONE, SOME pat), ty)
@@ -139,24 +139,24 @@
           in
             Pretty.chunks [
               [str ("let"), Pretty.fbrk, Pretty.chunks ps] |> Pretty.block,
-              [str ("in"), Pretty.fbrk, pr_term is_closure thm vars' NOBR t'] |> Pretty.block,
+              [str ("in"), Pretty.fbrk, pr_term is_closure thm vars' NOBR body] |> Pretty.block,
               str ("end")
             ]
           end
-      | pr_case is_closure thm vars fxy (((td, ty), b::bs), _) =
+      | pr_case is_closure thm vars fxy (((t, ty), clause :: clauses), _) =
           let
-            fun pr delim (pat, t) =
+            fun pr delim (pat, body) =
               let
                 val (p, vars') = pr_bind is_closure thm NOBR ((NONE, SOME pat), ty) vars;
               in
-                concat [str delim, p, str "=>", pr_term is_closure thm vars' NOBR t]
+                concat [str delim, p, str "=>", pr_term is_closure thm vars' NOBR body]
               end;
           in
             (Pretty.enclose "(" ")" o single o brackify fxy) (
               str "case"
-              :: pr_term is_closure thm vars NOBR td
-              :: pr "of" b
-              :: map (pr "|") bs
+              :: pr_term is_closure thm vars NOBR t
+              :: pr "of" clause
+              :: map (pr "|") clauses
             )
           end
       | pr_case is_closure thm vars fxy ((_, []), _) = str "raise Fail \"empty case\"";
@@ -434,26 +434,26 @@
     and pr_bind is_closure = gen_pr_bind pr_bind' (pr_term is_closure)
     and pr_case is_closure thm vars fxy (cases as ((_, [_]), _)) =
           let
-            val (binds, t') = Code_Thingol.unfold_let (ICase cases);
+            val (binds, body) = Code_Thingol.unfold_let (ICase cases);
             fun pr ((pat, ty), t) vars =
               vars
               |> pr_bind is_closure thm NOBR ((NONE, SOME pat), ty)
               |>> (fn p => concat
                   [str "let", p, str "=", pr_term is_closure thm vars NOBR t, str "in"])
             val (ps, vars') = fold_map pr binds vars;
-          in Pretty.chunks (ps @| pr_term is_closure thm vars' NOBR t') end
-      | pr_case is_closure thm vars fxy (((td, ty), b::bs), _) =
+          in Pretty.chunks (ps @| pr_term is_closure thm vars' NOBR body) end
+      | pr_case is_closure thm vars fxy (((t, ty), clause :: clauses), _) =
           let
-            fun pr delim (pat, t) =
+            fun pr delim (pat, body) =
               let
                 val (p, vars') = pr_bind is_closure thm NOBR ((NONE, SOME pat), ty) vars;
-              in concat [str delim, p, str "->", pr_term is_closure thm vars' NOBR t] end;
+              in concat [str delim, p, str "->", pr_term is_closure thm vars' NOBR body] end;
           in
             (Pretty.enclose "(" ")" o single o brackify fxy) (
               str "match"
-              :: pr_term is_closure thm vars NOBR td
-              :: pr "with" b
-              :: map (pr "|") bs
+              :: pr_term is_closure thm vars NOBR t
+              :: pr "with" clause
+              :: map (pr "|") clauses
             )
           end
       | pr_case is_closure thm vars fxy ((_, []), _) = str "failwith \"empty case\"";
--- a/src/Tools/code/code_thingol.ML	Mon Feb 16 19:35:52 2009 -0800
+++ b/src/Tools/code/code_thingol.ML	Tue Feb 17 18:45:41 2009 +0100
@@ -486,6 +486,12 @@
 
 (* translation *)
 
+(*FIXME move to code(_unit).ML*)
+fun get_case_scheme thy c = case Code.get_case_data thy c
+ of SOME (proto_case_scheme as (_, case_pats)) => 
+      SOME (1 + (if null case_pats then 1 else length case_pats), proto_case_scheme)
+  | NONE => NONE
+
 fun ensure_class thy (algbr as (_, algebra)) funcgr class =
   let
     val superclasses = (Sorts.minimize_sort algebra o Sorts.super_classes algebra) class;
@@ -669,58 +675,72 @@
   translate_const thy algbr funcgr thm c_ty
   ##>> fold_map (translate_term thy algbr funcgr thm) ts
   #>> (fn (t, ts) => t `$$ ts)
-and translate_case thy algbr funcgr thm n cases (app as ((c, ty), ts)) =
+and translate_case thy algbr funcgr thm (num_args, (t_pos, case_pats)) (c_ty, ts) =
   let
-    val (tys, _) =
-      (chop (1 + (if null cases then 1 else length cases)) o fst o strip_type) ty;
-    val dt = nth ts n;
-    val dty = nth tys n;
-    fun is_undefined (Const (c, _)) = Code.is_undefined thy c
-      | is_undefined _ = false;
-    fun mk_case (co, n) t =
+    val (tys, _) = (chop num_args o fst o strip_type o snd) c_ty;
+    val t = nth ts t_pos;
+    val ty = nth tys t_pos;
+    val ts_clause = nth_drop t_pos ts;
+    fun mk_clause (co, num_co_args) t =
       let
         val _ = if (is_some o Code.get_datatype_of_constr thy) co then ()
           else error ("Non-constructor " ^ quote co
             ^ " encountered in case pattern"
             ^ (case thm of NONE => ""
               | SOME thm => ", in equation\n" ^ Display.string_of_thm thm))
-        val (vs, body) = Term.strip_abs_eta n t;
-        val selector = list_comb (Const (co, map snd vs ---> dty), map Free vs);
-      in if is_undefined body then NONE else SOME (selector, body) end;
-    fun mk_ds [] =
+        val (vs, body) = Term.strip_abs_eta num_co_args t;
+        val not_undefined = case body
+         of (Const (c, _)) => not (Code.is_undefined thy c)
+          | _ => true;
+        val pat = list_comb (Const (co, map snd vs ---> ty), map Free vs);
+      in (not_undefined, (pat, body)) end;
+    val clauses = if null case_pats then let val ([v_ty], body) =
+        Term.strip_abs_eta 1 (the_single ts_clause)
+      in [(true, (Free v_ty, body))] end
+      else map (uncurry mk_clause)
+        (AList.make (Code_Unit.no_args thy) case_pats ~~ ts_clause);
+    fun retermify ty (_, (IVar x, body)) =
+          (x, ty) `|-> body
+      | retermify _ (_, (pat, body)) =
           let
-            val ([v_ty], body) = Term.strip_abs_eta 1 (the_single (nth_drop n ts))
-          in [(Free v_ty, body)] end
-      | mk_ds cases = map_filter (uncurry mk_case)
-          (AList.make (Code_Unit.no_args thy) cases ~~ nth_drop n ts);
+            val (IConst (_, (_, tys)), ts) = unfold_app pat;
+            val vs = map2 (fn IVar x => fn ty => (x, ty)) ts tys;
+          in vs `|--> body end;
+    fun mk_icase const t ty clauses =
+      let
+        val (ts1, ts2) = chop t_pos (map (retermify ty) clauses);
+      in
+        ICase (((t, ty), map_filter (fn (b, d) => if b then SOME d else NONE) clauses),
+          const `$$ (ts1 @ t :: ts2))
+      end;
   in
-    translate_term thy algbr funcgr thm dt
-    ##>> translate_typ thy algbr funcgr dty
-    ##>> fold_map (fn (pat, body) => translate_term thy algbr funcgr thm pat
-          ##>> translate_term thy algbr funcgr thm body) (mk_ds cases)
-    ##>> translate_app_default thy algbr funcgr thm app
-    #>> (fn (((dt, dty), ds), t0) => ICase (((dt, dty), ds), t0))
+    translate_const thy algbr funcgr thm c_ty
+    ##>> translate_term thy algbr funcgr thm t
+    ##>> translate_typ thy algbr funcgr ty
+    ##>> fold_map (fn (b, (pat, body)) => translate_term thy algbr funcgr thm pat
+      ##>> translate_term thy algbr funcgr thm body
+      #>> pair b) clauses
+    #>> (fn (((const, t), ty), ds) => mk_icase const t ty ds)
   end
-and translate_app thy algbr funcgr thm ((c, ty), ts) = case Code.get_case_data thy c
- of SOME (n, cases) => let val i = 1 + (if null cases then 1 else length cases) in
-      if length ts < i then
+and translate_app thy algbr funcgr thm ((c, ty), ts) = case get_case_scheme thy c
+ of SOME (case_scheme as (num_args, _)) =>
+      if length ts < num_args then
         let
           val k = length ts;
-          val tys = (curry Library.take (i - k) o curry Library.drop k o fst o strip_type) ty;
+          val tys = (curry Library.take (num_args - k) o curry Library.drop k o fst o strip_type) ty;
           val ctxt = (fold o fold_aterms) Term.declare_term_frees ts Name.context;
           val vs = Name.names ctxt "a" tys;
         in
           fold_map (translate_typ thy algbr funcgr) tys
-          ##>> translate_case thy algbr funcgr thm n cases ((c, ty), ts @ map Free vs)
+          ##>> translate_case thy algbr funcgr thm case_scheme ((c, ty), ts @ map Free vs)
           #>> (fn (tys, t) => map2 (fn (v, _) => pair v) vs tys `|--> t)
         end
-      else if length ts > i then
-        translate_case thy algbr funcgr thm n cases ((c, ty), Library.take (i, ts))
-        ##>> fold_map (translate_term thy algbr funcgr thm) (Library.drop (i, ts))
+      else if length ts > num_args then
+        translate_case thy algbr funcgr thm case_scheme ((c, ty), Library.take (num_args, ts))
+        ##>> fold_map (translate_term thy algbr funcgr thm) (Library.drop (num_args, ts))
         #>> (fn (t, ts) => t `$$ ts)
       else
-        translate_case thy algbr funcgr thm n cases ((c, ty), ts)
-      end
+        translate_case thy algbr funcgr thm case_scheme ((c, ty), ts)
   | NONE => translate_app_default thy algbr funcgr thm ((c, ty), ts);