handling nested cases more elegant by requiring less new constants
authorbulwahn
Mon, 27 Sep 2010 12:22:57 +0200
changeset 39723 12cc713036d6
parent 39717 e9bec0b43449
child 39724 ada0cd4900c1
handling nested cases more elegant by requiring less new constants
src/HOL/Tools/Predicate_Compile/predicate_compile_pred.ML
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_pred.ML	Mon Sep 27 11:12:08 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_pred.ML	Mon Sep 27 12:22:57 2010 +0200
@@ -94,6 +94,56 @@
     error "is_compound: Conjunction should not occur; preprocessing is defect"
   | is_compound _ = false
 
+fun try_destruct_case thy names atom =
+  case find_split_thm thy (fst (strip_comb atom)) of
+    NONE => NONE
+  | SOME raw_split_thm =>
+    let
+      val case_name = fst (dest_Const (fst (strip_comb atom)))
+      val split_thm = prepare_split_thm (ProofContext.init_global thy) raw_split_thm
+      (* TODO: contextify things - this line is to unvarify the split_thm *)
+      (*val ((_, [isplit_thm]), _) =
+        Variable.import true [split_thm] (ProofContext.init_global thy)*)
+      val (assms, concl) = Logic.strip_horn (prop_of split_thm)
+      val (P, [split_t]) = strip_comb (HOLogic.dest_Trueprop concl) 
+      val Tcons = datatype_names_of_case_name thy case_name
+      val ths = maps (instantiated_case_rewrites thy) Tcons
+      val atom' = MetaSimplifier.rewrite_term thy
+        (map (fn th => th RS @{thm eq_reflection}) ths) [] atom
+      val subst = Pattern.match thy (split_t, atom') (Vartab.empty, Vartab.empty)
+      val names' = Term.add_free_names atom' names
+      fun mk_subst_rhs assm =
+        let
+          val (vTs, assm') = strip_all (Envir.beta_norm (Envir.subst_term subst assm))
+          val var_names = Name.variant_list names' (map fst vTs)
+          val vars = map Free (var_names ~~ (map snd vTs))
+          val (prems', pre_res) = Logic.strip_horn (subst_bounds (rev vars, assm'))
+          fun partition_prem_subst prem =
+            case HOLogic.dest_eq (HOLogic.dest_Trueprop prem) of
+              (Free (x, T), r) => (NONE, SOME ((x, T), r))
+            | _ => (SOME prem, NONE)
+          fun partition f xs =
+            let
+              fun partition' acc1 acc2 [] = (rev acc1, rev acc2)
+                | partition' acc1 acc2 (x :: xs) =
+                  let
+                    val (y, z) = f x
+                    val acc1' = case y of NONE => acc1 | SOME y' => y' :: acc1
+                    val acc2' = case z of NONE => acc2 | SOME z' => z' :: acc2
+                  in partition' acc1' acc2' xs end
+            in partition' [] [] xs end
+          val (prems'', subst) = partition partition_prem_subst prems'
+          val (_, [inner_t]) = strip_comb (HOLogic.dest_Trueprop pre_res)
+          val pre_rhs =
+            fold (curry HOLogic.mk_conj) (map HOLogic.dest_Trueprop prems'') inner_t
+          val rhs = Envir.expand_term_frees subst pre_rhs
+        in
+          case try_destruct_case thy (var_names @ names') rhs of
+            NONE => [(subst, rhs)]
+          | SOME (_, srs) => map (fn (subst', rhs') => (subst @ subst', rhs')) srs
+        end
+     in SOME (atom', maps mk_subst_rhs assms) end
+     
 fun flatten constname atom (defs, thy) =
   if is_compound atom then
     let
@@ -124,62 +174,20 @@
           flatten constname atom' (defs, thy)
         end
     | _ =>
-      case find_split_thm thy (fst (strip_comb atom)) of
+      case try_destruct_case thy [] atom of
         NONE => (atom, (defs, thy))
-      | SOME raw_split_thm =>
-        let
-          val (f, args) = strip_comb atom
-          val split_thm = prepare_split_thm (ProofContext.init_global thy) raw_split_thm
-          (* TODO: contextify things - this line is to unvarify the split_thm *)
-          (*val ((_, [isplit_thm]), _) =
-            Variable.import true [split_thm] (ProofContext.init_global thy)*)
-          val (assms, concl) = Logic.strip_horn (prop_of split_thm)
-          val (P, [split_t]) = strip_comb (HOLogic.dest_Trueprop concl) 
-          val Tcons = datatype_names_of_case_name thy (fst (dest_Const f))
-          val ths = maps (instantiated_case_rewrites thy) Tcons
-          val atom = MetaSimplifier.rewrite_term thy
-            (map (fn th => th RS @{thm eq_reflection}) ths) [] atom
-          val (f, args) = strip_comb atom
-          val subst = Pattern.match thy (split_t, atom) (Vartab.empty, Vartab.empty)
-          val (_, split_args) = strip_comb split_t
-          val match = split_args ~~ args
-          val names = Term.add_free_names atom []
-          val frees = map Free (Term.add_frees atom [])
+      | SOME (atom', srs) =>
+        let      
+          val frees = map Free (Term.add_frees atom' [])
           val constname = Name.variant (map (Long_Name.base_name o fst) defs)
-            ((Long_Name.base_name constname) ^ "_aux")
+           ((Long_Name.base_name constname) ^ "_aux")
           val full_constname = Sign.full_bname thy constname
           val constT = map fastype_of frees ---> HOLogic.boolT
           val lhs = list_comb (Const (full_constname, constT), frees)
-          fun new_def assm =
-            let
-              val (vTs, assm') = strip_all (Envir.beta_norm (Envir.subst_term subst assm))
-              val var_names = Name.variant_list names (map fst vTs)
-              val vars = map Free (var_names ~~ (map snd vTs))
-              val (prems', pre_res) = Logic.strip_horn (subst_bounds (rev vars, assm'))
-              fun partition_prem_subst prem =
-                case HOLogic.dest_eq (HOLogic.dest_Trueprop prem) of
-                  (Free (x, T), r) => (NONE, SOME ((x, T), r))
-                | _ => (SOME prem, NONE)
-              fun partition f xs =
-                let
-                  fun partition' acc1 acc2 [] = (rev acc1, rev acc2)
-                    | partition' acc1 acc2 (x :: xs) =
-                      let
-                        val (y, z) = f x
-                        val acc1' = case y of NONE => acc1 | SOME y' => y' :: acc1
-                        val acc2' = case z of NONE => acc2 | SOME z' => z' :: acc2
-                      in partition' acc1' acc2' xs end
-                in partition' [] [] xs end
-              val (prems'', subst) = partition partition_prem_subst prems'
-              val (_, [inner_t]) = strip_comb (HOLogic.dest_Trueprop pre_res)
-              val pre_def = Logic.mk_equals (lhs,
-                fold (curry HOLogic.mk_conj) (map HOLogic.dest_Trueprop prems'') inner_t)
-              val def = Envir.expand_term_frees subst pre_def
-            in
-              def
-            end
-         val new_defs = map new_def assms
-         val (definition, thy') = thy
+          fun mk_def (subst, rhs) =
+            Logic.mk_equals (fold Envir.expand_term_frees (map single subst) lhs, rhs)
+          val new_defs = map mk_def srs
+          val (definition, thy') = thy
           |> Sign.add_consts_i [(Binding.name constname, constT, NoSyn)]
           |> fold_map Specification.axiom (map_index
               (fn (i, t) => ((Binding.name (constname ^ "_def" ^ string_of_int i), []), t)) new_defs)