src/HOL/Library/case_converter.ML
changeset 69568 de09a7261120
parent 68301 fb5653a7a879
child 69593 3dda49e08b9d
--- a/src/HOL/Library/case_converter.ML	Sun Dec 30 10:30:41 2018 +0100
+++ b/src/HOL/Library/case_converter.ML	Tue Jan 01 17:04:53 2019 +0100
@@ -3,8 +3,11 @@
 
 signature CASE_CONVERTER =
 sig
-  val to_case: Proof.context -> (string * string -> bool) -> (string * typ -> int) ->
+  type elimination_strategy
+  val to_case: Proof.context -> elimination_strategy -> (string * typ -> int) ->
     thm list -> thm list option
+  val replace_by_type: (Proof.context -> string * string -> bool) -> elimination_strategy
+  val keep_constructor_context: elimination_strategy
 end;
 
 structure Case_Converter : CASE_CONVERTER =
@@ -60,30 +63,72 @@
     Coordinate (merge_consts xs ys)
   end;
 
-fun term_to_coordinates P term = 
-  let
-    val (ctr, args) = strip_comb term
-  in
-    case ctr of Const (s, T) =>
-      if P (body_type T |> dest_Type |> fst, s)
-      then SOME (End (body_type T))
-      else
-        let
-          fun f (i, t) = term_to_coordinates P t |> Option.map (pair i)
-          val tcos = map_filter I (map_index f args)
-        in
-          if null tcos then NONE
-          else SOME (Coordinate (map (pair s) tcos))
-        end
-    | _ => NONE
-  end;
-
 fun coordinates_to_list (End x) = [(x, [])]
   | coordinates_to_list (Coordinate xs) = 
   let
     fun f (s, (n, xss)) = map (fn (T, xs) => (T, (s, n) :: xs)) (coordinates_to_list xss)
   in flat (map f xs) end;
 
+type elimination_strategy = Proof.context -> term list -> term_coordinate list
+
+fun replace_by_type replace_ctr ctxt pats =
+  let
+    fun term_to_coordinates P term = 
+      let
+        val (ctr, args) = strip_comb term
+      in
+        case ctr of Const (s, T) =>
+          if P (body_type T |> dest_Type |> fst, s)
+          then SOME (End (body_type T))
+          else
+            let
+              fun f (i, t) = term_to_coordinates P t |> Option.map (pair i)
+              val tcos = map_filter I (map_index f args)
+            in
+              if null tcos then NONE
+              else SOME (Coordinate (map (pair s) tcos))
+            end
+        | _ => NONE
+      end
+    in
+      map_filter (term_to_coordinates (replace_ctr ctxt)) pats
+    end
+
+fun keep_constructor_context ctxt pats =
+  let
+    fun to_coordinates [] = NONE
+      | to_coordinates pats =
+        let
+          val (fs, argss) = map strip_comb pats |> split_list
+          val f = hd fs
+          fun is_single_ctr (Const (name, T)) = 
+              let
+                val tyco = body_type T |> dest_Type |> fst
+                val _ = Ctr_Sugar.ctr_sugar_of ctxt tyco |> the |> #ctrs
+              in
+                case Ctr_Sugar.ctr_sugar_of ctxt tyco of
+                  NONE => error ("Not a free constructor " ^ name ^ " in pattern")
+                | SOME info =>
+                  case #ctrs info of [Const (name', _)] => name = name'
+                    | _ => false
+              end
+            | is_single_ctr _ = false
+        in 
+          if not (is_single_ctr f) andalso forall (fn x => f = x) fs then
+            let
+              val patss = Ctr_Sugar_Util.transpose argss
+              fun recurse (i, pats) = to_coordinates pats |> Option.map (pair i)
+              val coords = map_filter I (map_index recurse patss)
+            in
+              if null coords then NONE
+              else SOME (Coordinate (map (pair (dest_Const f |> fst)) coords))
+            end
+          else SOME (End (body_type (fastype_of f)))
+          end
+    in
+      the_list (to_coordinates pats)
+    end
+
 
 (* AL: TODO: change from term to const_name *)
 fun find_ctr ctr1 xs =
@@ -453,7 +498,8 @@
       ctxt1)
   end;
 
-fun build_case_t replace_ctr ctr_count head lhss rhss ctxt =
+
+fun build_case_t elimination_strategy ctr_count head lhss rhss ctxt =
   let
     val num_eqs = length lhss
     val _ = if length rhss = num_eqs andalso num_eqs > 0 then ()
@@ -464,16 +510,17 @@
     val _ = if forall (fn m => length m = n) lhss then ()
       else raise Fail "expected equal number of arguments"
 
-    fun to_coordinates (n, ts) = case map_filter (term_to_coordinates replace_ctr) ts of
-        [] => NONE
-      | (tco :: tcos) => SOME (n, fold term_coordinate_merge tcos tco |> coordinates_to_list)
+    fun to_coordinates (n, ts) = 
+      case elimination_strategy ctxt ts of
+          [] => NONE
+        | (tco :: tcos) => SOME (n, fold term_coordinate_merge tcos tco |> coordinates_to_list)
     fun add_T (n, xss) = map (fn (T, xs) => (T, (n, xs))) xss
     val (typ_list, poss) = lhss
       |> Ctr_Sugar_Util.transpose
       |> map_index to_coordinates
       |> map_filter (Option.map add_T)
       |> flat
-      |> split_list 
+      |> split_list
   in
     if null poss then ([], [], ctxt)
     else terms_to_case ctxt (dest_Const #> ctr_count) head lhss rhss typ_list poss