src/HOL/Tools/Datatype/datatype_case.ML
changeset 32671 fbd224850767
parent 32035 8e77b6a250d5
child 32896 99cd75a18b78
equal deleted inserted replaced
32670:cc0bae788b7e 32671:fbd224850767
     5 Nested case expressions on datatypes.
     5 Nested case expressions on datatypes.
     6 *)
     6 *)
     7 
     7 
     8 signature DATATYPE_CASE =
     8 signature DATATYPE_CASE =
     9 sig
     9 sig
       
    10   datatype config = Error | Warning | Quiet;
    10   val make_case: (string -> DatatypeAux.info option) ->
    11   val make_case: (string -> DatatypeAux.info option) ->
    11     Proof.context -> bool -> string list -> term -> (term * term) list ->
    12     Proof.context -> config -> string list -> term -> (term * term) list ->
    12     term * (term * (int * bool)) list
    13     term * (term * (int * bool)) list
    13   val dest_case: (string -> DatatypeAux.info option) -> bool ->
    14   val dest_case: (string -> DatatypeAux.info option) -> bool ->
    14     string list -> term -> (term * (term * term) list) option
    15     string list -> term -> (term * (term * term) list) option
    15   val strip_case: (string -> DatatypeAux.info option) -> bool ->
    16   val strip_case: (string -> DatatypeAux.info option) -> bool ->
    16     term -> (term * (term * term) list) option
    17     term -> (term * (term * term) list) option
    20     string -> Proof.context -> term list -> term
    21     string -> Proof.context -> term list -> term
    21 end;
    22 end;
    22 
    23 
    23 structure DatatypeCase : DATATYPE_CASE =
    24 structure DatatypeCase : DATATYPE_CASE =
    24 struct
    25 struct
       
    26 
       
    27 datatype config = Error | Warning | Quiet;
    25 
    28 
    26 exception CASE_ERROR of string * int;
    29 exception CASE_ERROR of string * int;
    27 
    30 
    28 fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty;
    31 fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty;
    29 
    32 
   258           case_error (quote s ^ " occurs repeatedly in the pattern " ^
   261           case_error (quote s ^ " occurs repeatedly in the pattern " ^
   259             quote (Syntax.string_of_term ctxt pat))
   262             quote (Syntax.string_of_term ctxt pat))
   260         else x :: xs)
   263         else x :: xs)
   261     | _ => I) pat [];
   264     | _ => I) pat [];
   262 
   265 
   263 fun gen_make_case ty_match ty_inst type_of tab ctxt err used x clauses =
   266 fun gen_make_case ty_match ty_inst type_of tab ctxt config used x clauses =
   264   let
   267   let
   265     fun string_of_clause (pat, rhs) = Syntax.string_of_term ctxt
   268     fun string_of_clause (pat, rhs) = Syntax.string_of_term ctxt
   266       (Syntax.const "_case1" $ pat $ rhs);
   269       (Syntax.const "_case1" $ pat $ rhs);
   267     val _ = map (no_repeat_vars ctxt o fst) clauses;
   270     val _ = map (no_repeat_vars ctxt o fst) clauses;
   268     val rows = map_index (fn (i, (pat, rhs)) =>
   271     val rows = map_index (fn (i, (pat, rhs)) =>
   283     val patts2 = Library.sort (Library.int_ord o Library.pairself row_of_pat) patts1
   286     val patts2 = Library.sort (Library.int_ord o Library.pairself row_of_pat) patts1
   284     val finals = map row_of_pat patts2
   287     val finals = map row_of_pat patts2
   285     val originals = map (row_of_pat o #2) rows
   288     val originals = map (row_of_pat o #2) rows
   286     val _ = case originals \\ finals of
   289     val _ = case originals \\ finals of
   287         [] => ()
   290         [] => ()
   288       | is => (if err then case_error else warning)
   291         | is => (case config of Error => case_error | Warning => warning | Quiet => fn _ => {})
   289           ("The following clauses are redundant (covered by preceding clauses):\n" ^
   292           ("The following clauses are redundant (covered by preceding clauses):\n" ^
   290            cat_lines (map (string_of_clause o nth clauses) is));
   293            cat_lines (map (string_of_clause o nth clauses) is));
   291   in
   294   in
   292     (case_tm, patts2)
   295     (case_tm, patts2)
   293   end;
   296   end;
   336             end
   339             end
   337         | dest_case1 t = case_error "dest_case1";
   340         | dest_case1 t = case_error "dest_case1";
   338       fun dest_case2 (Const ("_case2", _) $ t $ u) = t :: dest_case2 u
   341       fun dest_case2 (Const ("_case2", _) $ t $ u) = t :: dest_case2 u
   339         | dest_case2 t = [t];
   342         | dest_case2 t = [t];
   340       val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
   343       val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
   341       val (case_tm, _) = make_case_untyped (tab_of thy) ctxt err []
   344       val (case_tm, _) = make_case_untyped (tab_of thy) ctxt
       
   345         (if err then Error else Warning) []
   342         (fold (fn tT => fn t => Syntax.const "_constrain" $ t $ tT)
   346         (fold (fn tT => fn t => Syntax.const "_constrain" $ t $ tT)
   343            (flat cnstrts) t) cases;
   347            (flat cnstrts) t) cases;
   344     in case_tm end
   348     in case_tm end
   345   | case_tr _ _ _ ts = case_error "case_tr";
   349   | case_tr _ _ _ ts = case_error "case_tr";
   346 
   350