tuned
authorhaftmann
Mon, 08 Oct 2007 22:03:30 +0200
changeset 24917 8b97a94ab187
parent 24916 dc56dd1b3cda
child 24918 22013215eece
tuned
src/Pure/Isar/code_unit.ML
--- a/src/Pure/Isar/code_unit.ML	Mon Oct 08 22:03:28 2007 +0200
+++ b/src/Pure/Isar/code_unit.ML	Mon Oct 08 22:03:30 2007 +0200
@@ -372,24 +372,20 @@
 
 (* case cerificates *)
 
-fun case_cert thm =
+fun case_certificate thm =
   let
-    (*FIXME rework this code*)
     val thy = Thm.theory_of_thm thm;
-    val (cas, t_pats) = (Logic.dest_implies o Thm.prop_of) thm;
-    val pats = Logic.dest_conjunctions t_pats;
-    val (head, proto_case_expr) = Logic.dest_equals cas;
+    val ((head, raw_case_expr), cases) = (apfst Logic.dest_equals
+      o apsnd Logic.dest_conjunctions o Logic.dest_implies o Thm.prop_of) thm;
     val _ = case head of Free _ => true
       | Var _ => true
       | _ => raise TERM ("case_cert", []);
-    val ([(case_expr_v, _)], case_expr) = Term.strip_abs_eta 1 proto_case_expr;
-    val (Const (c_case, _), raw_params) = strip_comb case_expr;
-    val i = find_index
-      (fn Free (v, _) => v = case_expr_v | _ => false) raw_params;
-    val _ = if i = ~1 then raise TERM ("case_cert", []) else ();
-    val t_params = nth_drop i raw_params;
-    val params = map (fst o dest_Var) t_params;
-    fun dest_pat t =
+    val ([(case_var, _)], case_expr) = Term.strip_abs_eta 1 raw_case_expr;
+    val (Const (case_const, _), raw_params) = strip_comb case_expr;
+    val n = find_index (fn Free (v, _) => v = case_var | _ => false) raw_params;
+    val _ = if n = ~1 then raise TERM ("case_cert", []) else ();
+    val params = map (fst o dest_Var) (nth_drop n raw_params);
+    fun dest_case t =
       let
         val (head' $ t_co, rhs) = Logic.dest_equals t;
         val _ = if head' = head then () else raise TERM ("case_cert", []);
@@ -397,9 +393,9 @@
         val (Var (param, _), args') = strip_comb rhs;
         val _ = if args' = args then () else raise TERM ("case_cert", []);
       in (param, co) end;
-    fun analyze_pats pats =
+    fun analyze_cases cases =
       let
-        val co_list = fold (AList.update (op =) o dest_pat) pats [];
+        val co_list = fold (AList.update (op =) o dest_case) cases [];
       in map (the o AList.lookup (op =) co_list) params end;
     fun analyze_let t =
       let
@@ -408,10 +404,13 @@
         val _ = if arg' = arg then () else raise TERM ("case_cert", []);
         val _ = if [param'] = params then () else raise TERM ("case_cert", []);
       in [] end;
-  in (c_case, (i, case pats
-   of [pat] => (analyze_pats pats handle Bind => analyze_let pat)
-    | _ :: _ => analyze_pats pats))
-  end handle Bind => error "bad case certificate"
-    | TERM _ => error "bad case certificate";
+    fun analyze (cases as [let_case]) =
+          (analyze_cases cases handle Bind => analyze_let let_case)
+      | analyze cases = analyze_cases cases;
+  in (case_const, (n, analyze cases)) end;
+
+fun case_cert thm = case_certificate thm
+  handle Bind => error "bad case certificate"
+      | TERM _ => error "bad case certificate";
 
 end;