src/HOL/Tools/Datatype/datatype_case.ML
author wenzelm
Sat, 14 Jan 2012 21:16:15 +0100
changeset 46219 426ed18eba43
parent 46188 8297006abc13
child 49660 de49d9b4d7bc
permissions -rw-r--r--
discontinued old-style Term.list_abs in favour of plain Term.abs;

(*  Title:      HOL/Tools/Datatype/datatype_case.ML
    Author:     Konrad Slind, Cambridge University Computer Laboratory
    Author:     Stefan Berghofer, TU Muenchen

Datatype package: nested case expressions on datatypes.

TODO:
  * Avoid fragile operations on syntax trees (with type constraints
    getting in the way).  Instead work with auxiliary "destructor"
    constants in translations and introduce the actual case
    combinators in a separate term check phase (similar to term
    abbreviations).

  * Avoid hard-wiring with datatype package.  Instead provide generic
    generic declarations of case splits based on an internal data slot.
*)

signature DATATYPE_CASE =
sig
  datatype config = Error | Warning | Quiet
  type info = Datatype_Aux.info
  val make_case :  Proof.context -> config -> string list -> term -> (term * term) list -> term
  val strip_case : Proof.context -> bool -> term -> (term * (term * term) list) option
  val case_tr: bool -> Proof.context -> term list -> term
  val show_cases: bool Config.T
  val case_tr': string -> Proof.context -> term list -> term
  val add_case_tr' : string list -> theory -> theory
  val setup: theory -> theory
end;

structure Datatype_Case : DATATYPE_CASE =
struct

datatype config = Error | Warning | Quiet;
type info = Datatype_Aux.info;

exception CASE_ERROR of string * int;

fun match_type thy pat ob = Sign.typ_match thy (pat, ob) Vartab.empty;

(* Get information about datatypes *)

fun ty_info ({descr, case_name, index, ...} : info) =
  let
    val (_, (tname, dts, constrs)) = nth descr index;
    val mk_ty = Datatype_Aux.typ_of_dtyp descr;
    val T = Type (tname, map mk_ty dts);
  in
   {case_name = case_name,
    constructors = map (fn (cname, dts') =>
      Const (cname, Logic.varifyT_global (map mk_ty dts' ---> T))) constrs}
  end;


(*Each pattern carries with it a tag i, which denotes the clause it
came from. i = ~1 indicates that the clause was added by pattern
completion.*)

fun add_row_used ((prfx, pats), (tm, tag)) =
  fold Term.add_free_names (tm :: pats @ map Free prfx);

fun default_name name (t, cs) =
  let
    val name' = if name = "" then (case t of Free (name', _) => name' | _ => name) else name;
    val cs' = if is_Free t then cs else filter_out Term_Position.is_position cs;
  in (name', cs') end;

fun strip_constraints (Const (@{syntax_const "_constrain"}, _) $ t $ tT) =
      strip_constraints t ||> cons tT
  | strip_constraints t = (t, []);

fun constrain tT t = Syntax.const @{syntax_const "_constrain"} $ t $ tT;
fun constrain_Abs tT t = Syntax.const @{syntax_const "_constrainAbs"} $ t $ tT;


(*Produce an instance of a constructor, plus fresh variables for its arguments.*)
fun fresh_constr ty_match ty_inst colty used c =
  let
    val (_, T) = dest_Const c;
    val Ts = binder_types T;
    val names =
      Name.variant_list used (Datatype_Prop.make_tnames (map Logic.unvarifyT_global Ts));
    val ty = body_type T;
    val ty_theta = ty_match ty colty
      handle Type.TYPE_MATCH => raise CASE_ERROR ("type mismatch", ~1);
    val c' = ty_inst ty_theta c;
    val gvars = map (ty_inst ty_theta o Free) (names ~~ Ts);
  in (c', gvars) end;

fun strip_comb_positions tm =
  let
    fun result t ts = (Term_Position.strip_positions t, ts);
    fun strip (t as Const (@{syntax_const "_constrain"}, _) $ _ $ _) ts = result t ts
      | strip (f $ t) ts = strip f (t :: ts)
      | strip t ts = result t ts;
  in strip tm [] end;

(*Go through a list of rows and pick out the ones beginning with a
  pattern with constructor = name.*)
fun mk_group (name, T) rows =
  let val k = length (binder_types T) in
    fold (fn (row as ((prfx, p :: ps), rhs as (_, i))) =>
      fn ((in_group, not_in_group), (names, cnstrts)) =>
        (case strip_comb_positions p of
          (Const (name', _), args) =>
            if name = name' then
              if length args = k then
                let
                  val constraints' = map strip_constraints args;
                  val (args', cnstrts') = split_list constraints';
                  val (names', cnstrts'') = split_list (map2 default_name names constraints');
                in
                  ((((prfx, args' @ ps), rhs) :: in_group, not_in_group),
                   (names', map2 append cnstrts cnstrts''))
                end
              else raise CASE_ERROR ("Wrong number of arguments for constructor " ^ quote name, i)
            else ((in_group, row :: not_in_group), (names, cnstrts))
        | _ => raise CASE_ERROR ("Not a constructor pattern", i)))
    rows (([], []), (replicate k "", replicate k [])) |>> pairself rev
  end;


(* Partitioning *)

fun partition _ _ _ _ _ _ _ [] = raise CASE_ERROR ("partition: no rows", ~1)
  | partition ty_match ty_inst type_of used constructors colty res_ty
        (rows as (((prfx, _ :: ps), _) :: _)) =
      let
        fun part [] [] = []
          | part [] ((_, (_, i)) :: _) = raise CASE_ERROR ("Not a constructor pattern", i)
          | part (c :: cs) rows =
              let
                val ((in_group, not_in_group), (names, cnstrts)) = mk_group (dest_Const c) rows;
                val used' = fold add_row_used in_group used;
                val (c', gvars) = fresh_constr ty_match ty_inst colty used' c;
                val in_group' =
                  if null in_group  (* Constructor not given *)
                  then
                    let
                      val Ts = map type_of ps;
                      val xs =
                        Name.variant_list
                          (fold Term.add_free_names gvars used')
                          (replicate (length ps) "x");
                    in
                      [((prfx, gvars @ map Free (xs ~~ Ts)),
                        (Const (@{const_syntax undefined}, res_ty), ~1))]
                    end
                  else in_group;
              in
                {constructor = c',
                 new_formals = gvars,
                 names = names,
                 constraints = cnstrts,
                 group = in_group'} :: part cs not_in_group
              end;
      in part constructors rows end;

fun v_to_prfx (prfx, Free v :: pats) = (v :: prfx, pats)
  | v_to_prfx _ = raise CASE_ERROR ("mk_case: v_to_prfx", ~1);


(* Translation of pattern terms into nested case expressions. *)

fun mk_case ctxt ty_match ty_inst type_of used range_ty =
  let
    val get_info = Datatype_Data.info_of_constr_permissive (Proof_Context.theory_of ctxt);

    fun expand constructors used ty ((_, []), _) = raise CASE_ERROR ("mk_case: expand", ~1)
      | expand constructors used ty (row as ((prfx, p :: ps), (rhs, tag))) =
          if is_Free p then
            let
              val used' = add_row_used row used;
              fun expnd c =
                let val capp = list_comb (fresh_constr ty_match ty_inst ty used' c)
                in ((prfx, capp :: ps), (subst_free [(p, capp)] rhs, tag)) end;
            in map expnd constructors end
          else [row];

    val name = singleton (Name.variant_list used) "a";

    fun mk _ [] = raise CASE_ERROR ("no rows", ~1)
      | mk [] (((_, []), (tm, tag)) :: _) = ([tag], tm) (* Done *)
      | mk path (rows as ((row as ((_, [Free _]), _)) :: _ :: _)) = mk path [row]
      | mk (u :: us) (rows as ((_, _ :: _), _) :: _) =
          let val col0 = map (fn ((_, p :: _), (_, i)) => (p, i)) rows in
            (case Option.map (apfst (fst o strip_comb_positions))
                (find_first (not o is_Free o fst) col0) of
              NONE =>
                let
                  val rows' = map (fn ((v, _), row) => row ||>
                    apfst (subst_free [(v, u)]) |>> v_to_prfx) (col0 ~~ rows);
                in mk us rows' end
            | SOME (Const (cname, cT), i) =>
                (case Option.map ty_info (get_info (cname, cT)) of
                  NONE => raise CASE_ERROR ("Not a datatype constructor: " ^ quote cname, i)
                | SOME {case_name, constructors} =>
                    let
                      val pty = body_type cT;
                      val used' = fold Term.add_free_names us used;
                      val nrows = maps (expand constructors used' pty) rows;
                      val subproblems =
                        partition ty_match ty_inst type_of used'
                          constructors pty range_ty nrows;
                      val (pat_rect, dtrees) =
                        split_list (map (fn {new_formals, group, ...} =>
                          mk (new_formals @ us) group) subproblems);
                      val case_functions =
                        map2 (fn {new_formals, names, constraints, ...} =>
                          fold_rev (fn ((x as Free (_, T), s), cnstrts) => fn t =>
                            Abs (if s = "" then name else s, T, abstract_over (x, t))
                            |> fold constrain_Abs cnstrts) (new_formals ~~ names ~~ constraints))
                        subproblems dtrees;
                      val types = map type_of (case_functions @ [u]);
                      val case_const = Const (case_name, types ---> range_ty);
                      val tree = list_comb (case_const, case_functions @ [u]);
                    in (flat pat_rect, tree) end)
            | SOME (t, i) =>
                raise CASE_ERROR ("Not a datatype constructor: " ^ Syntax.string_of_term ctxt t, i))
          end
      | mk _ _ = raise CASE_ERROR ("Malformed row matrix", ~1)
  in mk end;

fun case_error s = error ("Error in case expression:\n" ^ s);

local

(*Repeated variable occurrences in a pattern are not allowed.*)
fun no_repeat_vars ctxt pat = fold_aterms
  (fn x as Free (s, _) =>
      (fn xs =>
        if member op aconv xs x then
          case_error (quote s ^ " occurs repeatedly in the pattern " ^
            quote (Syntax.string_of_term ctxt pat))
        else x :: xs)
    | _ => I) (Term_Position.strip_positions pat) [];

fun gen_make_case ty_match ty_inst type_of ctxt config used x clauses =
  let
    fun string_of_clause (pat, rhs) =
      Syntax.string_of_term ctxt (Syntax.const @{syntax_const "_case1"} $ pat $ rhs);
    val _ = map (no_repeat_vars ctxt o fst) clauses;
    val rows = map_index (fn (i, (pat, rhs)) => (([], [pat]), (rhs, i))) clauses;
    val rangeT =
      (case distinct (op =) (map (type_of o snd) clauses) of
        [] => case_error "no clauses given"
      | [T] => T
      | _ => case_error "all cases must have the same result type");
    val used' = fold add_row_used rows used;
    val (tags, case_tm) =
      mk_case ctxt ty_match ty_inst type_of used' rangeT [x] rows
        handle CASE_ERROR (msg, i) =>
          case_error
            (msg ^ (if i < 0 then "" else "\nIn clause\n" ^ string_of_clause (nth clauses i)));
    val _ =
      (case subtract (op =) tags (map (snd o snd) rows) of
        [] => ()
      | is =>
          (case config of Error => case_error | Warning => warning | Quiet => fn _ => ())
            ("The following clauses are redundant (covered by preceding clauses):\n" ^
              cat_lines (map (string_of_clause o nth clauses) is)));
  in
    case_tm
  end;

in

fun make_case ctxt =
  gen_make_case (match_type (Proof_Context.theory_of ctxt))
    Envir.subst_term_types fastype_of ctxt;

val make_case_untyped =
  gen_make_case (K (K Vartab.empty)) (K (Term.map_types (K dummyT))) (K dummyT);

end;


(* parse translation *)

fun case_tr err ctxt [t, u] =
      let
        val thy = Proof_Context.theory_of ctxt;
        val intern_const_syntax = Consts.intern_syntax (Proof_Context.consts_of ctxt);

        (* replace occurrences of dummy_pattern by distinct variables *)
        (* internalize constant names                                 *)
        (* FIXME proper name context!? *)
        fun prep_pat ((c as Const (@{syntax_const "_constrain"}, _)) $ t $ tT) used =
              let val (t', used') = prep_pat t used
              in (c $ t' $ tT, used') end
          | prep_pat (Const (@{const_syntax dummy_pattern}, T)) used =
              let val x = singleton (Name.variant_list used) "x"
              in (Free (x, T), x :: used) end
          | prep_pat (Const (s, T)) used = (Const (intern_const_syntax s, T), used)
          | prep_pat (v as Free (s, T)) used =
              let val s' = Proof_Context.intern_const ctxt s in
                if Sign.declared_const thy s' then (Const (s', T), used)
                else (v, used)
              end
          | prep_pat (t $ u) used =
              let
                val (t', used') = prep_pat t used;
                val (u', used'') = prep_pat u used';
              in (t' $ u', used'') end
          | prep_pat t used = case_error ("Bad pattern: " ^ Syntax.string_of_term ctxt t);

        fun dest_case1 (t as Const (@{syntax_const "_case1"}, _) $ l $ r) =
              let val (l', cnstrts) = strip_constraints l
              in ((fst (prep_pat l' (Term.add_free_names t [])), r), cnstrts) end
          | dest_case1 t = case_error "dest_case1";

        fun dest_case2 (Const (@{syntax_const "_case2"}, _) $ t $ u) = t :: dest_case2 u
          | dest_case2 t = [t];

        val (cases, cnstrts) = split_list (map dest_case1 (dest_case2 u));
      in
        make_case_untyped ctxt
          (if err then Error else Warning) []
          (fold constrain (filter_out Term_Position.is_position (flat cnstrts)) t)
          cases
      end
  | case_tr _ _ _ = case_error "case_tr";

val trfun_setup =
  Sign.add_advanced_trfuns ([],
    [(@{syntax_const "_case_syntax"}, case_tr true)],
    [], []);


(* Pretty printing of nested case expressions *)

(* destruct one level of pattern matching *)

local

fun gen_dest_case name_of type_of ctxt d used t =
  (case apfst name_of (strip_comb t) of
    (SOME cname, ts as _ :: _) =>
      let
        val (fs, x) = split_last ts;
        fun strip_abs i Us t =
          let
            val zs = strip_abs_vars t;
            val j = length zs;
            val (xs, ys) =
              if j < i then (zs @ map (pair "x") (drop j Us), [])
              else chop i zs;
            val u = fold_rev Term.abs ys (strip_abs_body t);
            val xs' = map Free
              ((fold_map Name.variant (map fst xs)
                  (Term.declare_term_names u used) |> fst) ~~
               map snd xs);
            val (xs1, xs2) = chop j xs'
          in (xs', list_comb (subst_bounds (rev xs1, u), xs2)) end;
        fun is_dependent i t =
          let val k = length (strip_abs_vars t) - i
          in k < 0 orelse exists (fn j => j >= k) (loose_bnos (strip_abs_body t)) end;
        fun count_cases (_, _, true) = I
          | count_cases (c, (_, body), false) = AList.map_default op aconv (body, []) (cons c);
        val is_undefined = name_of #> equal (SOME @{const_name undefined});
        fun mk_case (c, (xs, body), _) = (list_comb (c, xs), body);
        val get_info = Datatype_Data.info_of_case (Proof_Context.theory_of ctxt);
      in
        (case Option.map ty_info (get_info cname) of
          SOME {constructors, ...} =>
            if length fs = length constructors then
              let
                val cases = map (fn (Const (s, U), t) =>
                  let
                    val Us = binder_types U;
                    val k = length Us;
                    val p as (xs, _) = strip_abs k Us t;
                  in
                    (Const (s, map type_of xs ---> type_of x), p, is_dependent k t)
                  end) (constructors ~~ fs);
                val cases' =
                  sort (int_ord o swap o pairself (length o snd))
                    (fold_rev count_cases cases []);
                val R = type_of t;
                val dummy =
                  if d then Term.dummy_pattern R
                  else Free (Name.variant "x" used |> fst, R);
              in
                SOME (x,
                  map mk_case
                    (case find_first (is_undefined o fst) cases' of
                      SOME (_, cs) =>
                        if length cs = length constructors then [hd cases]
                        else filter_out (fn (_, (_, body), _) => is_undefined body) cases
                    | NONE =>
                        (case cases' of
                          [] => cases
                        | (default, cs) :: _ =>
                            if length cs = 1 then cases
                            else if length cs = length constructors then
                              [hd cases, (dummy, ([], default), false)]
                            else
                              filter_out (fn (c, _, _) => member op aconv cs c) cases @
                                [(dummy, ([], default), false)])))
              end
            else NONE
        | _ => NONE)
      end
  | _ => NONE);

in

val dest_case = gen_dest_case (try (dest_Const #> fst)) fastype_of;
val dest_case' = gen_dest_case (try (dest_Const #> fst #> Lexicon.unmark_const)) (K dummyT);

end;


(* destruct nested patterns *)

local

fun strip_case'' dest (pat, rhs) =
  (case dest (Term.declare_term_frees pat Name.context) rhs of
    SOME (exp as Free _, clauses) =>
      if Term.exists_subterm (curry (op aconv) exp) pat andalso
        not (exists (fn (_, rhs') =>
          Term.exists_subterm (curry (op aconv) exp) rhs') clauses)
      then
        maps (strip_case'' dest) (map (fn (pat', rhs') =>
          (subst_free [(exp, pat')] pat, rhs')) clauses)
      else [(pat, rhs)]
  | _ => [(pat, rhs)]);

fun gen_strip_case dest t =
  (case dest Name.context t of
    SOME (x, clauses) => SOME (x, maps (strip_case'' dest) clauses)
  | NONE => NONE);

in

val strip_case = gen_strip_case oo dest_case;
val strip_case' = gen_strip_case oo dest_case';

end;


(* print translation *)

val show_cases = Attrib.setup_config_bool @{binding show_cases} (K true);

fun case_tr' cname ctxt ts =
  if Config.get ctxt show_cases then
    let
      fun mk_clause (pat, rhs) =
        let val xs = Term.add_frees pat [] in
          Syntax.const @{syntax_const "_case1"} $
            map_aterms
              (fn Free p => Syntax_Trans.mark_boundT p
                | Const (s, _) => Syntax.const (Lexicon.mark_const s)
                | t => t) pat $
            map_aterms
              (fn x as Free (s, T) =>
                  if member (op =) xs (s, T) then Syntax_Trans.mark_bound s else x
                | t => t) rhs
        end;
    in
      (case strip_case' ctxt true (list_comb (Syntax.const cname, ts)) of
        SOME (x, clauses) =>
          Syntax.const @{syntax_const "_case_syntax"} $ x $
            foldr1 (fn (t, u) => Syntax.const @{syntax_const "_case2"} $ t $ u)
              (map mk_clause clauses)
      | NONE => raise Match)
    end
  else raise Match;

fun add_case_tr' case_names thy =
  Sign.add_advanced_trfuns ([], [],
    map (fn case_name =>
      let val case_name' = Lexicon.mark_const case_name
      in (case_name', case_tr' case_name') end) case_names, []) thy;


(* theory setup *)

val setup = trfun_setup;

end;