src/HOL/Library/case_converter.ML
changeset 69568 de09a7261120
parent 68301 fb5653a7a879
child 69593 3dda49e08b9d
     1.1 --- a/src/HOL/Library/case_converter.ML	Sun Dec 30 10:30:41 2018 +0100
     1.2 +++ b/src/HOL/Library/case_converter.ML	Tue Jan 01 17:04:53 2019 +0100
     1.3 @@ -3,8 +3,11 @@
     1.4  
     1.5  signature CASE_CONVERTER =
     1.6  sig
     1.7 -  val to_case: Proof.context -> (string * string -> bool) -> (string * typ -> int) ->
     1.8 +  type elimination_strategy
     1.9 +  val to_case: Proof.context -> elimination_strategy -> (string * typ -> int) ->
    1.10      thm list -> thm list option
    1.11 +  val replace_by_type: (Proof.context -> string * string -> bool) -> elimination_strategy
    1.12 +  val keep_constructor_context: elimination_strategy
    1.13  end;
    1.14  
    1.15  structure Case_Converter : CASE_CONVERTER =
    1.16 @@ -60,30 +63,72 @@
    1.17      Coordinate (merge_consts xs ys)
    1.18    end;
    1.19  
    1.20 -fun term_to_coordinates P term = 
    1.21 -  let
    1.22 -    val (ctr, args) = strip_comb term
    1.23 -  in
    1.24 -    case ctr of Const (s, T) =>
    1.25 -      if P (body_type T |> dest_Type |> fst, s)
    1.26 -      then SOME (End (body_type T))
    1.27 -      else
    1.28 -        let
    1.29 -          fun f (i, t) = term_to_coordinates P t |> Option.map (pair i)
    1.30 -          val tcos = map_filter I (map_index f args)
    1.31 -        in
    1.32 -          if null tcos then NONE
    1.33 -          else SOME (Coordinate (map (pair s) tcos))
    1.34 -        end
    1.35 -    | _ => NONE
    1.36 -  end;
    1.37 -
    1.38  fun coordinates_to_list (End x) = [(x, [])]
    1.39    | coordinates_to_list (Coordinate xs) = 
    1.40    let
    1.41      fun f (s, (n, xss)) = map (fn (T, xs) => (T, (s, n) :: xs)) (coordinates_to_list xss)
    1.42    in flat (map f xs) end;
    1.43  
    1.44 +type elimination_strategy = Proof.context -> term list -> term_coordinate list
    1.45 +
    1.46 +fun replace_by_type replace_ctr ctxt pats =
    1.47 +  let
    1.48 +    fun term_to_coordinates P term = 
    1.49 +      let
    1.50 +        val (ctr, args) = strip_comb term
    1.51 +      in
    1.52 +        case ctr of Const (s, T) =>
    1.53 +          if P (body_type T |> dest_Type |> fst, s)
    1.54 +          then SOME (End (body_type T))
    1.55 +          else
    1.56 +            let
    1.57 +              fun f (i, t) = term_to_coordinates P t |> Option.map (pair i)
    1.58 +              val tcos = map_filter I (map_index f args)
    1.59 +            in
    1.60 +              if null tcos then NONE
    1.61 +              else SOME (Coordinate (map (pair s) tcos))
    1.62 +            end
    1.63 +        | _ => NONE
    1.64 +      end
    1.65 +    in
    1.66 +      map_filter (term_to_coordinates (replace_ctr ctxt)) pats
    1.67 +    end
    1.68 +
    1.69 +fun keep_constructor_context ctxt pats =
    1.70 +  let
    1.71 +    fun to_coordinates [] = NONE
    1.72 +      | to_coordinates pats =
    1.73 +        let
    1.74 +          val (fs, argss) = map strip_comb pats |> split_list
    1.75 +          val f = hd fs
    1.76 +          fun is_single_ctr (Const (name, T)) = 
    1.77 +              let
    1.78 +                val tyco = body_type T |> dest_Type |> fst
    1.79 +                val _ = Ctr_Sugar.ctr_sugar_of ctxt tyco |> the |> #ctrs
    1.80 +              in
    1.81 +                case Ctr_Sugar.ctr_sugar_of ctxt tyco of
    1.82 +                  NONE => error ("Not a free constructor " ^ name ^ " in pattern")
    1.83 +                | SOME info =>
    1.84 +                  case #ctrs info of [Const (name', _)] => name = name'
    1.85 +                    | _ => false
    1.86 +              end
    1.87 +            | is_single_ctr _ = false
    1.88 +        in 
    1.89 +          if not (is_single_ctr f) andalso forall (fn x => f = x) fs then
    1.90 +            let
    1.91 +              val patss = Ctr_Sugar_Util.transpose argss
    1.92 +              fun recurse (i, pats) = to_coordinates pats |> Option.map (pair i)
    1.93 +              val coords = map_filter I (map_index recurse patss)
    1.94 +            in
    1.95 +              if null coords then NONE
    1.96 +              else SOME (Coordinate (map (pair (dest_Const f |> fst)) coords))
    1.97 +            end
    1.98 +          else SOME (End (body_type (fastype_of f)))
    1.99 +          end
   1.100 +    in
   1.101 +      the_list (to_coordinates pats)
   1.102 +    end
   1.103 +
   1.104  
   1.105  (* AL: TODO: change from term to const_name *)
   1.106  fun find_ctr ctr1 xs =
   1.107 @@ -453,7 +498,8 @@
   1.108        ctxt1)
   1.109    end;
   1.110  
   1.111 -fun build_case_t replace_ctr ctr_count head lhss rhss ctxt =
   1.112 +
   1.113 +fun build_case_t elimination_strategy ctr_count head lhss rhss ctxt =
   1.114    let
   1.115      val num_eqs = length lhss
   1.116      val _ = if length rhss = num_eqs andalso num_eqs > 0 then ()
   1.117 @@ -464,16 +510,17 @@
   1.118      val _ = if forall (fn m => length m = n) lhss then ()
   1.119        else raise Fail "expected equal number of arguments"
   1.120  
   1.121 -    fun to_coordinates (n, ts) = case map_filter (term_to_coordinates replace_ctr) ts of
   1.122 -        [] => NONE
   1.123 -      | (tco :: tcos) => SOME (n, fold term_coordinate_merge tcos tco |> coordinates_to_list)
   1.124 +    fun to_coordinates (n, ts) = 
   1.125 +      case elimination_strategy ctxt ts of
   1.126 +          [] => NONE
   1.127 +        | (tco :: tcos) => SOME (n, fold term_coordinate_merge tcos tco |> coordinates_to_list)
   1.128      fun add_T (n, xss) = map (fn (T, xs) => (T, (n, xs))) xss
   1.129      val (typ_list, poss) = lhss
   1.130        |> Ctr_Sugar_Util.transpose
   1.131        |> map_index to_coordinates
   1.132        |> map_filter (Option.map add_T)
   1.133        |> flat
   1.134 -      |> split_list 
   1.135 +      |> split_list
   1.136    in
   1.137      if null poss then ([], [], ctxt)
   1.138      else terms_to_case ctxt (dest_Const #> ctr_count) head lhss rhss typ_list poss