src/HOL/Tools/Datatype/datatype_case.ML
changeset 43255 7df9edc6a2d6
parent 43254 2127c138ba3a
child 43256 375809f9afad
equal deleted inserted replaced
43254:2127c138ba3a 43255:7df9edc6a2d6
    30 
    30 
    31 exception CASE_ERROR of string * int;
    31 exception CASE_ERROR of string * int;
    32 
    32 
    33 fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty;
    33 fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty;
    34 
    34 
    35 (*---------------------------------------------------------------------------
    35 (* Get information about datatypes *)
    36  * Get information about datatypes
       
    37  *---------------------------------------------------------------------------*)
       
    38 
    36 
    39 fun ty_info tab sT =
    37 fun ty_info tab sT =
    40   (case tab sT of
    38   (case tab sT of
    41     SOME ({descr, case_name, index, sorts, ...} : info) =>
    39     SOME ({descr, case_name, index, sorts, ...} : info) =>
    42       let
    40       let
    49             Const (cname, Logic.varifyT_global (map mk_ty dts' ---> T))) constrs}
    47             Const (cname, Logic.varifyT_global (map mk_ty dts' ---> T))) constrs}
    50       end
    48       end
    51   | NONE => NONE);
    49   | NONE => NONE);
    52 
    50 
    53 
    51 
    54 (*---------------------------------------------------------------------------
    52 (*Each pattern carries with it a tag i, which denotes the clause it
    55  * Each pattern carries with it a tag i, which denotes
    53 came from. i = ~1 indicates that the clause was added by pattern
    56  * the clause it came from. i = ~1 indicates that
    54 completion.*)
    57  * the clause was added by pattern completion.
       
    58  *---------------------------------------------------------------------------*)
       
    59 
       
    60 fun pattern_subst theta (tm, x) = (subst_free theta tm, x);
       
    61 
    55 
    62 fun add_row_used ((prfx, pats), (tm, tag)) =
    56 fun add_row_used ((prfx, pats), (tm, tag)) =
    63   fold Term.add_free_names (tm :: pats @ map Free prfx);
    57   fold Term.add_free_names (tm :: pats @ map Free prfx);
    64 
    58 
    65 (* try to preserve names given by user *)
    59 (*try to preserve names given by user*)
    66 fun default_names names ts =
    60 fun default_names names ts =
    67   map (fn ("", Free (name', _)) => name' | (name, _) => name) (names ~~ ts);
    61   map (fn ("", Free (name', _)) => name' | (name, _) => name) (names ~~ ts);
    68 
    62 
    69 fun strip_constraints (Const (@{syntax_const "_constrain"}, _) $ t $ tT) =
    63 fun strip_constraints (Const (@{syntax_const "_constrain"}, _) $ t $ tT) =
    70       strip_constraints t ||> cons tT
    64       strip_constraints t ||> cons tT
    73 fun mk_fun_constrain tT t =
    67 fun mk_fun_constrain tT t =
    74   Syntax.const @{syntax_const "_constrain"} $ t $
    68   Syntax.const @{syntax_const "_constrain"} $ t $
    75     (Syntax.const @{type_syntax fun} $ tT $ Syntax.const @{type_syntax dummy});
    69     (Syntax.const @{type_syntax fun} $ tT $ Syntax.const @{type_syntax dummy});
    76 
    70 
    77 
    71 
    78 (*---------------------------------------------------------------------------
    72 (*Produce an instance of a constructor, plus fresh variables for its arguments.*)
    79  * Produce an instance of a constructor, plus genvars for its arguments.
       
    80  *---------------------------------------------------------------------------*)
       
    81 fun fresh_constr ty_match ty_inst colty used c =
    73 fun fresh_constr ty_match ty_inst colty used c =
    82   let
    74   let
    83     val (_, Ty) = dest_Const c
    75     val (_, Ty) = dest_Const c
    84     val Ts = binder_types Ty;
    76     val Ts = binder_types Ty;
    85     val names = Name.variant_list used
    77     val names = Name.variant_list used
    90     val c' = ty_inst ty_theta c
    82     val c' = ty_inst ty_theta c
    91     val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts)
    83     val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts)
    92   in (c', gvars) end;
    84   in (c', gvars) end;
    93 
    85 
    94 
    86 
    95 (*---------------------------------------------------------------------------
    87 (*Goes through a list of rows and picks out the ones beginning with a
    96  * Goes through a list of rows and picks out the ones beginning with a
    88  pattern with constructor = name.*)
    97  * pattern with constructor = name.
       
    98  *---------------------------------------------------------------------------*)
       
    99 fun mk_group (name, T) rows =
    89 fun mk_group (name, T) rows =
   100   let val k = length (binder_types T) in
    90   let val k = length (binder_types T) in
   101     fold (fn (row as ((prfx, p :: rst), rhs as (_, i))) =>
    91     fold (fn (row as ((prfx, p :: rst), rhs as (_, i))) =>
   102       fn ((in_group, not_in_group), (names, cnstrts)) =>
    92       fn ((in_group, not_in_group), (names, cnstrts)) =>
   103         (case strip_comb p of
    93         (case strip_comb p of
   114             else ((in_group, row :: not_in_group), (names, cnstrts))
   104             else ((in_group, row :: not_in_group), (names, cnstrts))
   115         | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
   105         | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
   116     rows (([], []), (replicate k "", replicate k [])) |>> pairself rev
   106     rows (([], []), (replicate k "", replicate k [])) |>> pairself rev
   117   end;
   107   end;
   118 
   108 
   119 (*---------------------------------------------------------------------------
   109 
   120  * Partition the rows. Not efficient: we should use hashing.
   110 (* Partitioning *)
   121  *---------------------------------------------------------------------------*)
   111 
   122 fun partition _ _ _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1)
   112 fun partition _ _ _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1)
   123   | partition ty_match ty_inst type_of used constructors colty res_ty
   113   | partition ty_match ty_inst type_of used constructors colty res_ty
   124         (rows as (((prfx, _ :: rstp), _) :: _)) =
   114         (rows as (((prfx, _ :: rstp), _) :: _)) =
   125       let
   115       let
   126         fun part {constrs = [], rows = [], A} = rev A
   116         fun part {constrs = [], rows = [], A} = rev A
   154                     constraints = cnstrts,
   144                     constraints = cnstrts,
   155                     group = in_group'} :: A}
   145                     group = in_group'} :: A}
   156               end
   146               end
   157       in part {constrs = constructors, rows = rows, A = []} end;
   147       in part {constrs = constructors, rows = rows, A = []} end;
   158 
   148 
   159 (*---------------------------------------------------------------------------
       
   160  * Misc. routines used in mk_case
       
   161  *---------------------------------------------------------------------------*)
       
   162 
       
   163 fun v_to_prfx (prfx, Free v::pats) = (v::prfx,pats)
   149 fun v_to_prfx (prfx, Free v::pats) = (v::prfx,pats)
   164   | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1);
   150   | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1);
   165 
   151 
   166 
   152 
   167 (*----------------------------------------------------------------------------
   153 (* Translation of pattern terms into nested case expressions. *)
   168  * Translation of pattern terms into nested case expressions.
   154  
   169  *
       
   170  * This performs the translation and also builds the full set of patterns.
       
   171  * Thus it supports the construction of induction theorems even when an
       
   172  * incomplete set of patterns is given.
       
   173  *---------------------------------------------------------------------------*)
       
   174 
       
   175 fun mk_case tab ctxt ty_match ty_inst type_of used range_ty =
   155 fun mk_case tab ctxt ty_match ty_inst type_of used range_ty =
   176   let
   156   let
   177     val name = Name.variant used "a";
   157     val name = Name.variant used "a";
   178     fun expand constructors used ty ((_, []), _) =
   158     fun expand constructors used ty ((_, []), _) =
   179           raise CASE_ERROR ("mk_case: expand_var_row", ~1)
   159           raise CASE_ERROR ("mk_case: expand_var_row", ~1)
   180       | expand constructors used ty (row as ((prfx, p :: rst), rhs)) =
   160       | expand constructors used ty (row as ((prfx, p :: rst), (rhs, tag))) =
   181           if is_Free p then
   161           if is_Free p then
   182             let
   162             let
   183               val used' = add_row_used row used;
   163               val used' = add_row_used row used;
   184               fun expnd c =
   164               fun expnd c =
   185                 let val capp =
   165                 let val capp =
   186                   list_comb (fresh_constr ty_match ty_inst ty used' c)
   166                   list_comb (fresh_constr ty_match ty_inst ty used' c)
   187                 in ((prfx, capp :: rst), pattern_subst [(p, capp)] rhs)
   167                 in ((prfx, capp :: rst), (subst_free [(p, capp)] rhs, tag))
   188                 end
   168                 end
   189             in map expnd constructors end
   169             in map expnd constructors end
   190           else [row]
   170           else [row]
   191     fun mk {rows = [], ...} = raise CASE_ERROR ("no rows", ~1)
   171     fun mk {rows = [], ...} = raise CASE_ERROR ("no rows", ~1)
   192       | mk {path = [], rows = ((prfx, []), (tm, tag)) :: _} =  (* Done *)
   172       | mk {path = [], rows = ((prfx, []), (tm, tag)) :: _} =  (* Done *)
   197           let val col0 = map (fn ((_, p :: _), (_, i)) => (p, i)) rows in
   177           let val col0 = map (fn ((_, p :: _), (_, i)) => (p, i)) rows in
   198             (case Option.map (apfst head_of) (find_first (not o is_Free o fst) col0) of
   178             (case Option.map (apfst head_of) (find_first (not o is_Free o fst) col0) of
   199               NONE =>
   179               NONE =>
   200                 let
   180                 let
   201                   val rows' = map (fn ((v, _), row) => row ||>
   181                   val rows' = map (fn ((v, _), row) => row ||>
   202                     pattern_subst [(v, u)] |>> v_to_prfx) (col0 ~~ rows);
   182                     apfst (subst_free [(v, u)]) |>> v_to_prfx) (col0 ~~ rows);
   203                 in mk {path = rstp, rows = rows'} end
   183                 in mk {path = rstp, rows = rows'} end
   204             | SOME (Const (cname, cT), i) =>
   184             | SOME (Const (cname, cT), i) =>
   205                 (case ty_info tab (cname, cT) of
   185                 (case ty_info tab (cname, cT) of
   206                   NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i)
   186                   NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i)
   207                 | SOME {case_name, constructors} =>
   187                 | SOME {case_name, constructors} =>
   232       | mk _ = raise CASE_ERROR ("Malformed row matrix", ~1)
   212       | mk _ = raise CASE_ERROR ("Malformed row matrix", ~1)
   233   in mk end;
   213   in mk end;
   234 
   214 
   235 fun case_error s = error ("Error in case expression:\n" ^ s);
   215 fun case_error s = error ("Error in case expression:\n" ^ s);
   236 
   216 
   237 (* Repeated variable occurrences in a pattern are not allowed. *)
   217 (*Repeated variable occurrences in a pattern are not allowed.*)
   238 fun no_repeat_vars ctxt pat = fold_aterms
   218 fun no_repeat_vars ctxt pat = fold_aterms
   239   (fn x as Free (s, _) => (fn xs =>
   219   (fn x as Free (s, _) => (fn xs =>
   240         if member op aconv xs x then
   220         if member op aconv xs x then
   241           case_error (quote s ^ " occurs repeatedly in the pattern " ^
   221           case_error (quote s ^ " occurs repeatedly in the pattern " ^
   242             quote (Syntax.string_of_term ctxt pat))
   222             quote (Syntax.string_of_term ctxt pat))
   322              (flat cnstrts) t) cases;
   302              (flat cnstrts) t) cases;
   323       in case_tm end
   303       in case_tm end
   324   | case_tr _ _ _ ts = case_error "case_tr";
   304   | case_tr _ _ _ ts = case_error "case_tr";
   325 
   305 
   326 
   306 
   327 (*---------------------------------------------------------------------------
   307 (* Pretty printing of nested case expressions *)
   328  * Pretty printing of nested case expressions
       
   329  *---------------------------------------------------------------------------*)
       
   330 
   308 
   331 (* destruct one level of pattern matching *)
   309 (* destruct one level of pattern matching *)
   332 
   310 
   333 fun gen_dest_case name_of type_of tab d used t =
   311 fun gen_dest_case name_of type_of tab d used t =
   334   (case apfst name_of (strip_comb t) of
   312   (case apfst name_of (strip_comb t) of