src/HOL/Library/case_converter.ML
changeset 69568 de09a7261120
parent 68301 fb5653a7a879
child 69593 3dda49e08b9d
equal deleted inserted replaced
69567:6b4c41037649 69568:de09a7261120
     1 (* Author: Pascal Stoop, ETH Zurich
     1 (* Author: Pascal Stoop, ETH Zurich
     2    Author: Andreas Lochbihler, Digital Asset *)
     2    Author: Andreas Lochbihler, Digital Asset *)
     3 
     3 
     4 signature CASE_CONVERTER =
     4 signature CASE_CONVERTER =
     5 sig
     5 sig
     6   val to_case: Proof.context -> (string * string -> bool) -> (string * typ -> int) ->
     6   type elimination_strategy
       
     7   val to_case: Proof.context -> elimination_strategy -> (string * typ -> int) ->
     7     thm list -> thm list option
     8     thm list -> thm list option
       
     9   val replace_by_type: (Proof.context -> string * string -> bool) -> elimination_strategy
       
    10   val keep_constructor_context: elimination_strategy
     8 end;
    11 end;
     9 
    12 
    10 structure Case_Converter : CASE_CONVERTER =
    13 structure Case_Converter : CASE_CONVERTER =
    11 struct
    14 struct
    12 
    15 
    58           | ((_, (_, x)) :: _, xs') => (s1, (n, term_coordinate_merge x y)) :: (merge_consts xs' ys)
    61           | ((_, (_, x)) :: _, xs') => (s1, (n, term_coordinate_merge x y)) :: (merge_consts xs' ys)
    59   in
    62   in
    60     Coordinate (merge_consts xs ys)
    63     Coordinate (merge_consts xs ys)
    61   end;
    64   end;
    62 
    65 
    63 fun term_to_coordinates P term = 
       
    64   let
       
    65     val (ctr, args) = strip_comb term
       
    66   in
       
    67     case ctr of Const (s, T) =>
       
    68       if P (body_type T |> dest_Type |> fst, s)
       
    69       then SOME (End (body_type T))
       
    70       else
       
    71         let
       
    72           fun f (i, t) = term_to_coordinates P t |> Option.map (pair i)
       
    73           val tcos = map_filter I (map_index f args)
       
    74         in
       
    75           if null tcos then NONE
       
    76           else SOME (Coordinate (map (pair s) tcos))
       
    77         end
       
    78     | _ => NONE
       
    79   end;
       
    80 
       
    81 fun coordinates_to_list (End x) = [(x, [])]
    66 fun coordinates_to_list (End x) = [(x, [])]
    82   | coordinates_to_list (Coordinate xs) = 
    67   | coordinates_to_list (Coordinate xs) = 
    83   let
    68   let
    84     fun f (s, (n, xss)) = map (fn (T, xs) => (T, (s, n) :: xs)) (coordinates_to_list xss)
    69     fun f (s, (n, xss)) = map (fn (T, xs) => (T, (s, n) :: xs)) (coordinates_to_list xss)
    85   in flat (map f xs) end;
    70   in flat (map f xs) end;
       
    71 
       
    72 type elimination_strategy = Proof.context -> term list -> term_coordinate list
       
    73 
       
    74 fun replace_by_type replace_ctr ctxt pats =
       
    75   let
       
    76     fun term_to_coordinates P term = 
       
    77       let
       
    78         val (ctr, args) = strip_comb term
       
    79       in
       
    80         case ctr of Const (s, T) =>
       
    81           if P (body_type T |> dest_Type |> fst, s)
       
    82           then SOME (End (body_type T))
       
    83           else
       
    84             let
       
    85               fun f (i, t) = term_to_coordinates P t |> Option.map (pair i)
       
    86               val tcos = map_filter I (map_index f args)
       
    87             in
       
    88               if null tcos then NONE
       
    89               else SOME (Coordinate (map (pair s) tcos))
       
    90             end
       
    91         | _ => NONE
       
    92       end
       
    93     in
       
    94       map_filter (term_to_coordinates (replace_ctr ctxt)) pats
       
    95     end
       
    96 
       
    97 fun keep_constructor_context ctxt pats =
       
    98   let
       
    99     fun to_coordinates [] = NONE
       
   100       | to_coordinates pats =
       
   101         let
       
   102           val (fs, argss) = map strip_comb pats |> split_list
       
   103           val f = hd fs
       
   104           fun is_single_ctr (Const (name, T)) = 
       
   105               let
       
   106                 val tyco = body_type T |> dest_Type |> fst
       
   107                 val _ = Ctr_Sugar.ctr_sugar_of ctxt tyco |> the |> #ctrs
       
   108               in
       
   109                 case Ctr_Sugar.ctr_sugar_of ctxt tyco of
       
   110                   NONE => error ("Not a free constructor " ^ name ^ " in pattern")
       
   111                 | SOME info =>
       
   112                   case #ctrs info of [Const (name', _)] => name = name'
       
   113                     | _ => false
       
   114               end
       
   115             | is_single_ctr _ = false
       
   116         in 
       
   117           if not (is_single_ctr f) andalso forall (fn x => f = x) fs then
       
   118             let
       
   119               val patss = Ctr_Sugar_Util.transpose argss
       
   120               fun recurse (i, pats) = to_coordinates pats |> Option.map (pair i)
       
   121               val coords = map_filter I (map_index recurse patss)
       
   122             in
       
   123               if null coords then NONE
       
   124               else SOME (Coordinate (map (pair (dest_Const f |> fst)) coords))
       
   125             end
       
   126           else SOME (End (body_type (fastype_of f)))
       
   127           end
       
   128     in
       
   129       the_list (to_coordinates pats)
       
   130     end
    86 
   131 
    87 
   132 
    88 (* AL: TODO: change from term to const_name *)
   133 (* AL: TODO: change from term to const_name *)
    89 fun find_ctr ctr1 xs =
   134 fun find_ctr ctr1 xs =
    90   let
   135   let
   451     (map make_eq_term (lhss2 ~~ rhss2),
   496     (map make_eq_term (lhss2 ~~ rhss2),
   452       get_split_theorems ctxt1 (type_name_fun Symtab.empty),
   497       get_split_theorems ctxt1 (type_name_fun Symtab.empty),
   453       ctxt1)
   498       ctxt1)
   454   end;
   499   end;
   455 
   500 
   456 fun build_case_t replace_ctr ctr_count head lhss rhss ctxt =
   501 
       
   502 fun build_case_t elimination_strategy ctr_count head lhss rhss ctxt =
   457   let
   503   let
   458     val num_eqs = length lhss
   504     val num_eqs = length lhss
   459     val _ = if length rhss = num_eqs andalso num_eqs > 0 then ()
   505     val _ = if length rhss = num_eqs andalso num_eqs > 0 then ()
   460       else raise Fail
   506       else raise Fail
   461         ("expected same number of left-hand sides as right-hand sides\n"
   507         ("expected same number of left-hand sides as right-hand sides\n"
   462           ^ "and at least one equation")
   508           ^ "and at least one equation")
   463     val n = length (hd lhss)
   509     val n = length (hd lhss)
   464     val _ = if forall (fn m => length m = n) lhss then ()
   510     val _ = if forall (fn m => length m = n) lhss then ()
   465       else raise Fail "expected equal number of arguments"
   511       else raise Fail "expected equal number of arguments"
   466 
   512 
   467     fun to_coordinates (n, ts) = case map_filter (term_to_coordinates replace_ctr) ts of
   513     fun to_coordinates (n, ts) = 
   468         [] => NONE
   514       case elimination_strategy ctxt ts of
   469       | (tco :: tcos) => SOME (n, fold term_coordinate_merge tcos tco |> coordinates_to_list)
   515           [] => NONE
       
   516         | (tco :: tcos) => SOME (n, fold term_coordinate_merge tcos tco |> coordinates_to_list)
   470     fun add_T (n, xss) = map (fn (T, xs) => (T, (n, xs))) xss
   517     fun add_T (n, xss) = map (fn (T, xs) => (T, (n, xs))) xss
   471     val (typ_list, poss) = lhss
   518     val (typ_list, poss) = lhss
   472       |> Ctr_Sugar_Util.transpose
   519       |> Ctr_Sugar_Util.transpose
   473       |> map_index to_coordinates
   520       |> map_index to_coordinates
   474       |> map_filter (Option.map add_T)
   521       |> map_filter (Option.map add_T)
   475       |> flat
   522       |> flat
   476       |> split_list 
   523       |> split_list
   477   in
   524   in
   478     if null poss then ([], [], ctxt)
   525     if null poss then ([], [], ctxt)
   479     else terms_to_case ctxt (dest_Const #> ctr_count) head lhss rhss typ_list poss
   526     else terms_to_case ctxt (dest_Const #> ctr_count) head lhss rhss typ_list poss
   480   end;
   527   end;
   481 
   528