Added more flexible parse / print translations for case expressions.
authorberghofe
Fri, 21 May 2004 21:48:03 +0200
changeset 14799 a405aadff16c
parent 14798 702cb4859cab
child 14800 50581f2b2c0e
Added more flexible parse / print translations for case expressions.
src/HOL/Tools/datatype_abs_proofs.ML
src/HOL/Tools/datatype_package.ML
src/HOL/Tools/datatype_prop.ML
--- a/src/HOL/Tools/datatype_abs_proofs.ML	Fri May 21 21:47:07 2004 +0200
+++ b/src/HOL/Tools/datatype_abs_proofs.ML	Fri May 21 21:48:03 2004 +0200
@@ -339,8 +339,7 @@
           (DatatypeProp.make_cases new_type_names descr sorts thy2)
 
   in
-    thy2 |> Theory.add_trrules_i
-      (DatatypeProp.make_case_trrules new_type_names descr) |>
+    thy2 |>
     parent_path flat_names |>
     store_thmss "cases" new_type_names case_thms |>
     apsnd (rpair case_names)
--- a/src/HOL/Tools/datatype_package.ML	Fri May 21 21:47:07 2004 +0200
+++ b/src/HOL/Tools/datatype_package.ML	Fri May 21 21:48:03 2004 +0200
@@ -393,6 +393,116 @@
    fn thy => (simpset_ref_of thy := simpset_of thy addsimprocs [distinct_simproc]; thy)];
 
 
+(**** translation rules for case ****)
+
+fun case_tr sg [t, u] =
+    let
+      fun case_error s name ts = raise TERM ("Error in case expression" ^
+        if_none (apsome (curry op ^ " for datatype ") name) "" ^ ":\n" ^ s, ts);
+      fun dest_case1 (Const ("_case1", _) $ t $ u) = (case strip_comb t of
+            (Const (s, _), ts) => (Sign.intern_const sg s, ts)
+          | (Free (s, _), ts) => (Sign.intern_const sg s, ts)
+          | _ => case_error "Head is not a constructor" None [t, u], u)
+        | dest_case1 t = raise TERM ("dest_case1", [t]);
+      fun dest_case2 (Const ("_case2", _) $ t $ u) = t :: dest_case2 u
+        | dest_case2 t = [t];
+      val cases as ((cname, _), _) :: _ = map dest_case1 (dest_case2 u);
+      val tab = Symtab.dest (get_datatypes_sg sg);
+      val (cases', default) = (case split_last cases of
+          (cases', (("dummy_pattern", []), t)) => (cases', Some t)
+        | _ => (cases, None))
+      fun abstr (Free (x, T), body) = Term.absfree (x, T, body)
+        | abstr (Const ("_constrain", _) $ Free (x, T) $ tT, body) =
+            Syntax.const Syntax.constrainAbsC $ Term.absfree (x, T, body) $ tT
+        | abstr (Const ("Pair", _) $ x $ y, body) =
+            Syntax.const "split" $ abstr (x, abstr (y, body))
+        | abstr (t, _) = case_error "Illegal pattern" None [t];
+    in case find_first (fn (_, {descr, index, ...}) =>
+      exists (equal cname o fst) (#3 (snd (nth_elem (index, descr))))) tab of
+        None => case_error ("Not a datatype constructor: " ^ cname) None [u]
+      | Some (tname, {descr, case_name, index, ...}) =>
+        let
+          val _ = if exists (equal "dummy_pattern" o fst o fst) cases' then
+            case_error "Illegal occurrence of '_' dummy pattern" (Some tname) [u] else ();
+          val (_, (_, dts, constrs)) = nth_elem (index, descr);
+          val sorts = map (rpair [] o dest_DtTFree) dts;
+          fun find_case (cases, (s, dt)) =
+            (case find_first (equal s o fst o fst) cases' of
+               None => (case default of
+                   None => case_error ("No clause for constructor " ^ s) (Some tname) [u]
+                 | Some t => (cases, list_abs (map (rpair dummyT) (DatatypeProp.make_tnames
+                     (map (typ_of_dtyp descr sorts) dt)), t)))
+             | Some (c as ((_, vs), t)) =>
+                 if length dt <> length vs then
+                    case_error ("Wrong number of arguments for constructor " ^ s)
+                      (Some tname) vs
+                 else (cases \ c, foldr abstr (vs, t)))
+          val (cases'', fs) = foldl_map find_case (cases', constrs)
+        in case (cases'', length constrs = length cases', default) of
+            ([], true, Some _) =>
+              case_error "Extra '_' dummy pattern" (Some tname) [u]
+          | (_ :: _, _, _) =>
+              let val extra = distinct (map (fst o fst) cases'')
+              in case extra \\ map fst constrs of
+                  [] => case_error ("More than one clause for constructor(s) " ^
+                    commas extra) (Some tname) [u]
+                | extra' => case_error ("Illegal constructor(s): " ^ commas extra')
+                    (Some tname) [u]
+              end
+          | _ => list_comb (Syntax.const case_name, fs) $ t
+        end
+    end
+  | case_tr sg ts = raise TERM ("case_tr", ts);
+
+fun case_tr' constrs sg ts =
+  if length ts <> length constrs + 1 then raise Match else
+  let
+    val (fs, x) = split_last ts;
+    fun strip_abs 0 t = ([], t)
+      | strip_abs i (Abs p) =
+        let val (x, u) = Syntax.atomic_abs_tr' p
+        in apfst (cons x) (strip_abs (i-1) u) end
+      | strip_abs i (Const ("split", _) $ t) = (case strip_abs (i+1) t of
+          (v :: v' :: vs, u) => (Syntax.const "Pair" $ v $ v' :: vs, u));
+    fun is_dependent i t =
+      let val k = length (strip_abs_vars t) - i
+      in k < 0 orelse exists (fn j => j >= k)
+        (loose_bnos (strip_abs_body t))
+      end;
+    val cases = map (fn ((cname, dts), t) =>
+      (Sign.cond_extern sg Sign.constK cname,
+       strip_abs (length dts) t, is_dependent (length dts) t))
+      (constrs ~~ fs);
+    fun count_cases (cs, (_, _, true)) = cs
+      | count_cases (cs, (cname, (_, body), false)) = (case assoc (cs, body) of
+          None => (body, [cname]) :: cs
+        | Some cnames => overwrite (cs, (body, cnames @ [cname])));
+    val cases' = sort (int_ord o Library.swap o pairself (length o snd))
+      (foldl count_cases ([], cases));
+    fun mk_case1 (cname, (vs, body), _) = Syntax.const "_case1" $
+      list_comb (Syntax.const cname, vs) $ body;
+  in
+    Syntax.const "_case_syntax" $ x $
+      foldr1 (fn (t, u) => Syntax.const "_case2" $ t $ u) (map mk_case1
+        (case cases' of
+           [] => cases
+         | (default, cnames) :: _ =>
+           if length cnames = 1 then cases
+           else if length cnames = length constrs then
+             [hd cases, ("dummy_pattern", ([], default), false)]
+           else
+             filter_out (fn (cname, _, _) => cname mem cnames) cases @
+             [("dummy_pattern", ([], default), false)]))
+  end;
+
+fun make_case_tr' case_names descr = flat (map
+  (fn ((_, (_, _, constrs)), case_name) => map (rpair (case_tr' constrs))
+    (NameSpace.accesses' case_name)) (descr ~~ case_names));
+
+val trfun_setup =
+  [Theory.add_advanced_trfuns ([], [("_case_syntax", case_tr)], [], [])];
+
+
 (* prepare types *)
 
 fun read_typ sign ((Ts, sorts), str) =
@@ -529,8 +639,7 @@
 
       Theory.add_consts_i (map (fn ((name, T), Ts) =>
         (name, Ts @ [T] ---> freeT, NoSyn))
-          (case_names ~~ newTs ~~ case_fn_Ts)) |>
-      Theory.add_trrules_i (DatatypeProp.make_case_trrules new_type_names descr);
+          (case_names ~~ newTs ~~ case_fn_Ts));
 
     val reccomb_names' = map (Sign.intern_const (Theory.sign_of thy2')) reccomb_names;
     val case_names' = map (Sign.intern_const (Theory.sign_of thy2')) case_names;
@@ -600,6 +709,7 @@
     val split_thms = split ~~ split_asm;
 
     val thy12 = thy11 |>
+      Theory.add_advanced_trfuns ([], [], make_case_tr' case_names' (hd descr), []) |>
       Theory.add_path (space_implode "_" new_type_names) |>
       add_rules simps case_thms size_thms rec_thms inject distinct
                 weak_case_congs Simplifier.cong_add_global |> 
@@ -657,6 +767,7 @@
     val simps = flat (distinct @ inject @ case_thms) @ size_thms @ rec_thms;
 
     val thy12 = thy11 |>
+      Theory.add_advanced_trfuns ([], [], make_case_tr' case_names (hd descr), []) |>
       Theory.add_path (space_implode "_" new_type_names) |>
       add_rules simps case_thms size_thms rec_thms inject distinct
                 weak_case_congs (Simplifier.change_global_ss (op addcongs)) |> 
@@ -765,6 +876,7 @@
     val simps = flat (distinct @ inject @ case_thms) @ size_thms @ rec_thms;
 
     val thy11 = thy10 |>
+      Theory.add_advanced_trfuns ([], [], make_case_tr' case_names descr, []) |>
       add_rules simps case_thms size_thms rec_thms inject distinct
                 weak_case_congs (Simplifier.change_global_ss (op addcongs)) |> 
       put_datatypes (foldr Symtab.update (dt_infos, dt_info)) |>
@@ -870,7 +982,7 @@
 
 (* setup theory *)
 
-val setup = [DatatypesData.init, Method.add_methods tactic_emulations] @ simproc_setup;
+val setup = [DatatypesData.init, Method.add_methods tactic_emulations] @ simproc_setup @ trfun_setup;
 
 
 (* outer syntax *)
--- a/src/HOL/Tools/datatype_prop.ML	Fri May 21 21:47:07 2004 +0200
+++ b/src/HOL/Tools/datatype_prop.ML	Fri May 21 21:48:03 2004 +0200
@@ -22,8 +22,6 @@
     (string * sort) list -> theory -> term list list
   val make_splits : string list -> DatatypeAux.descr list ->
     (string * sort) list -> theory -> (term * term) list
-  val make_case_trrules : string list -> DatatypeAux.descr list ->
-    ast Syntax.trrule list
   val make_size : DatatypeAux.descr list -> (string * sort) list ->
     theory -> term list
   val make_weak_case_congs : string list -> DatatypeAux.descr list ->
@@ -344,39 +342,6 @@
     (make_case_combs new_type_names descr sorts thy "f"))
   end;
 
-(************************ translation rules for case **************************)
-
-fun make_case_trrules new_type_names descr =
-  let
-    fun mk_asts i j ((cname, cargs)::constrs) =
-      let
-        val k = length cargs;
-        val xs = map (fn i => Variable ("x" ^ string_of_int i)) (i upto i + k - 1);
-        val t = Variable ("t" ^ string_of_int j);
-        val ast = Syntax.mk_appl (Constant "_case1")
-          [Syntax.mk_appl (Constant (Sign.base_name cname)) xs, t];
-        val ast' = foldr (fn (x, y) =>
-          Syntax.mk_appl (Constant "_abs") [x, y]) (xs, t)
-      in
-        (case constrs of
-            [] => (ast, [ast'])
-          | cs => let val (ast'', asts) = mk_asts (i + k) (j + 1) cs
-              in (Syntax.mk_appl (Constant "_case2") [ast, ast''],
-                  ast'::asts)
-              end)
-      end;
-
-    fun mk_trrule ((_, (_, _, constrs)), tname) =
-      let val (ast, asts) = mk_asts 1 1 constrs
-      in Syntax.ParsePrintRule
-        (Syntax.mk_appl (Constant "_case_syntax") [Variable "t", ast],
-         Syntax.mk_appl (Constant (tname ^ "_case"))
-           (asts @ [Variable "t"]))
-      end
-
-  in
-    map mk_trrule (hd descr ~~ new_type_names)
-  end;
 
 (******************************* size functions *******************************)