src/HOL/Tools/Datatype/datatype_case.ML
changeset 46139 df2aad3f0ecf
parent 46135 6bff2ebaf7bb
child 46140 463b594e186a
equal deleted inserted replaced
46138:85f8d8a8c711 46139:df2aad3f0ecf
    65 
    65 
    66 fun strip_constraints (Const (@{syntax_const "_constrain"}, _) $ t $ tT) =
    66 fun strip_constraints (Const (@{syntax_const "_constrain"}, _) $ t $ tT) =
    67       strip_constraints t ||> cons tT
    67       strip_constraints t ||> cons tT
    68   | strip_constraints t = (t, []);
    68   | strip_constraints t = (t, []);
    69 
    69 
    70 val recover_constraints =
    70 fun constrain tT t = Syntax.const @{syntax_const "_constrain"} $ t $ tT;
    71   fold (fn tT => fn t => Syntax.const @{syntax_const "_constrain"} $ t $ tT);
       
    72 
       
    73 fun constrain_Abs tT t = Syntax.const @{syntax_const "_constrainAbs"} $ t $ tT;
    71 fun constrain_Abs tT t = Syntax.const @{syntax_const "_constrainAbs"} $ t $ tT;
    74 
    72 
    75 
    73 
    76 (*Produce an instance of a constructor, plus fresh variables for its arguments.*)
    74 (*Produce an instance of a constructor, plus fresh variables for its arguments.*)
    77 fun fresh_constr ty_match ty_inst colty used c =
    75 fun fresh_constr ty_match ty_inst colty used c =
    88   in (c', gvars) end;
    86   in (c', gvars) end;
    89 
    87 
    90 fun strip_comb_positions tm =
    88 fun strip_comb_positions tm =
    91   let
    89   let
    92     fun result t ts = (Term_Position.strip_positions t, ts);
    90     fun result t ts = (Term_Position.strip_positions t, ts);
    93     fun strip (t as (Const (@{syntax_const "_constrain"}, _) $ _ $ _)) ts = result t ts
    91     fun strip (t as Const (@{syntax_const "_constrain"}, _) $ _ $ _) ts = result t ts
    94       | strip (f $ t) ts = strip f (t :: ts)
    92       | strip (f $ t) ts = strip f (t :: ts)
    95       | strip t ts = result t ts;
    93       | strip t ts = result t ts;
    96   in strip tm [] end;
    94   in strip tm [] end;
    97 
    95 
    98 (*Go through a list of rows and pick out the ones beginning with a
    96 (*Go through a list of rows and pick out the ones beginning with a
   107               if length args = k then
   105               if length args = k then
   108                 let val (args', cnstrts') = split_list (map strip_constraints args) in
   106                 let val (args', cnstrts') = split_list (map strip_constraints args) in
   109                   ((((prfx, args' @ ps), rhs) :: in_group, not_in_group),
   107                   ((((prfx, args' @ ps), rhs) :: in_group, not_in_group),
   110                    (default_names names args', map2 append cnstrts cnstrts'))
   108                    (default_names names args', map2 append cnstrts cnstrts'))
   111                 end
   109                 end
   112               else raise CASE_ERROR ("Wrong number of arguments for constructor " ^ name, i)
   110               else raise CASE_ERROR ("Wrong number of arguments for constructor " ^ quote name, i)
   113             else ((in_group, row :: not_in_group), (names, cnstrts))
   111             else ((in_group, row :: not_in_group), (names, cnstrts))
   114         | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
   112         | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
   115     rows (([], []), (replicate k "", replicate k [])) |>> pairself rev
   113     rows (([], []), (replicate k "", replicate k [])) |>> pairself rev
   116   end;
   114   end;
   117 
   115 
   140                           (replicate (length ps) "x");
   138                           (replicate (length ps) "x");
   141                     in
   139                     in
   142                       [((prfx, gvars @ map Free (xs ~~ Ts)),
   140                       [((prfx, gvars @ map Free (xs ~~ Ts)),
   143                         (Const (@{const_syntax undefined}, res_ty), ~1))]
   141                         (Const (@{const_syntax undefined}, res_ty), ~1))]
   144                     end
   142                     end
   145                   else in_group
   143                   else in_group;
   146               in
   144               in
   147                 {constructor = c',
   145                 {constructor = c',
   148                  new_formals = gvars,
   146                  new_formals = gvars,
   149                  names = names,
   147                  names = names,
   150                  constraints = cnstrts,
   148                  constraints = cnstrts,
   160 
   158 
   161 fun mk_case ctxt ty_match ty_inst type_of used range_ty =
   159 fun mk_case ctxt ty_match ty_inst type_of used range_ty =
   162   let
   160   let
   163     val get_info = Datatype_Data.info_of_constr_permissive (Proof_Context.theory_of ctxt);
   161     val get_info = Datatype_Data.info_of_constr_permissive (Proof_Context.theory_of ctxt);
   164 
   162 
   165     fun expand constructors used ty ((_, []), _) = raise CASE_ERROR ("mk_case: expand_var_row", ~1)
   163     fun expand constructors used ty ((_, []), _) = raise CASE_ERROR ("mk_case: expand", ~1)
   166       | expand constructors used ty (row as ((prfx, p :: ps), (rhs, tag))) =
   164       | expand constructors used ty (row as ((prfx, p :: ps), (rhs, tag))) =
   167           if is_Free p then
   165           if is_Free p then
   168             let
   166             let
   169               val used' = add_row_used row used;
   167               val used' = add_row_used row used;
   170               fun expnd c =
   168               fun expnd c =
   187                   val rows' = map (fn ((v, _), row) => row ||>
   185                   val rows' = map (fn ((v, _), row) => row ||>
   188                     apfst (subst_free [(v, u)]) |>> v_to_prfx) (col0 ~~ rows);
   186                     apfst (subst_free [(v, u)]) |>> v_to_prfx) (col0 ~~ rows);
   189                 in mk us rows' end
   187                 in mk us rows' end
   190             | SOME (Const (cname, cT), i) =>
   188             | SOME (Const (cname, cT), i) =>
   191                 (case Option.map ty_info (get_info (cname, cT)) of
   189                 (case Option.map ty_info (get_info (cname, cT)) of
   192                   NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ cname, i)
   190                   NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ quote cname, i)
   193                 | SOME {case_name, constructors} =>
   191                 | SOME {case_name, constructors} =>
   194                     let
   192                     let
   195                       val pty = body_type cT;
   193                       val pty = body_type cT;
   196                       val used' = fold Term.add_free_names us used;
   194                       val used' = fold Term.add_free_names us used;
   197                       val nrows = maps (expand constructors used' pty) rows;
   195                       val nrows = maps (expand constructors used' pty) rows;
   310 
   308 
   311         val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
   309         val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
   312       in
   310       in
   313         make_case_untyped ctxt
   311         make_case_untyped ctxt
   314           (if err then Error else Warning) []
   312           (if err then Error else Warning) []
   315           (recover_constraints (filter_out Term_Position.is_position (flat cnstrts)) t)
   313           (fold constrain (filter_out Term_Position.is_position (flat cnstrts)) t)
   316           cases
   314           cases
   317       end
   315       end
   318   | case_tr _ _ _ = case_error "case_tr";
   316   | case_tr _ _ _ = case_error "case_tr";
   319 
   317 
   320 val trfun_setup =
   318 val trfun_setup =