src/HOL/Tools/datatype_package.ML
changeset 22777 2fc921376a86
parent 22709 9ab51bac6287
child 22846 fb79144af9a3
     1.1 --- a/src/HOL/Tools/datatype_package.ML	Tue Apr 24 15:07:27 2007 +0200
     1.2 +++ b/src/HOL/Tools/datatype_package.ML	Tue Apr 24 15:11:07 2007 +0200
     1.3 @@ -65,9 +65,15 @@
     1.4    val get_datatypes : theory -> DatatypeAux.datatype_info Symtab.table
     1.5    val get_datatype : theory -> string -> DatatypeAux.datatype_info option
     1.6    val the_datatype : theory -> string -> DatatypeAux.datatype_info
     1.7 +  val datatype_of_constr : theory -> string -> DatatypeAux.datatype_info option
     1.8 +  val datatype_of_case : theory -> string -> DatatypeAux.datatype_info option
     1.9    val get_datatype_spec : theory -> string -> ((string * sort) list * (string * typ list) list) option
    1.10    val get_datatype_constrs : theory -> string -> (string * typ) list option
    1.11    val print_datatypes : theory -> unit
    1.12 +  val make_case :  Proof.context -> bool -> string list -> term ->
    1.13 +    (term * term) list -> term * (term * (int * bool)) list
    1.14 +  val strip_case: Proof.context -> bool ->
    1.15 +    term -> (term * (term * term) list) option
    1.16    val setup: theory -> theory
    1.17  end;
    1.18  
    1.19 @@ -84,31 +90,53 @@
    1.20  structure DatatypesData = TheoryDataFun
    1.21  (struct
    1.22    val name = "HOL/datatypes";
    1.23 -  type T = datatype_info Symtab.table;
    1.24 +  type T =
    1.25 +    {types: datatype_info Symtab.table,
    1.26 +     constrs: datatype_info Symtab.table,
    1.27 +     cases: datatype_info Symtab.table};
    1.28  
    1.29 -  val empty = Symtab.empty;
    1.30 +  val empty =
    1.31 +    {types = Symtab.empty, constrs = Symtab.empty, cases = Symtab.empty};
    1.32    val copy = I;
    1.33    val extend = I;
    1.34 -  fun merge _ tabs : T = Symtab.merge (K true) tabs;
    1.35 +  fun merge _
    1.36 +    ({types = types1, constrs = constrs1, cases = cases1},
    1.37 +     {types = types2, constrs = constrs2, cases = cases2}) =
    1.38 +    {types = Symtab.merge (K true) (types1, types2),
    1.39 +     constrs = Symtab.merge (K true) (constrs1, constrs2),
    1.40 +     cases = Symtab.merge (K true) (cases1, cases2)};
    1.41  
    1.42 -  fun print sg tab =
    1.43 +  fun print sg ({types, ...} : T) =
    1.44      Pretty.writeln (Pretty.strs ("datatypes:" ::
    1.45 -      map #1 (NameSpace.extern_table (Sign.type_space sg, tab))));
    1.46 +      map #1 (NameSpace.extern_table (Sign.type_space sg, types))));
    1.47  end);
    1.48  
    1.49 -val get_datatypes = DatatypesData.get;
    1.50 -val put_datatypes = DatatypesData.put;
    1.51 +val get_datatypes = #types o DatatypesData.get;
    1.52 +val map_datatypes = DatatypesData.map;
    1.53  val print_datatypes = DatatypesData.print;
    1.54  
    1.55  
    1.56  (** theory information about datatypes **)
    1.57  
    1.58 +fun put_dt_infos (dt_infos : (string * datatype_info) list) =
    1.59 +  map_datatypes (fn {types, constrs, cases} =>
    1.60 +    {types = fold Symtab.update dt_infos types,
    1.61 +     constrs = fold Symtab.update
    1.62 +       (maps (fn (_, info as {descr, index, ...}) => map (rpair info o fst)
    1.63 +          (#3 (the (AList.lookup op = descr index)))) dt_infos) constrs,
    1.64 +     cases = fold Symtab.update
    1.65 +       (map (fn (_, info as {case_name, ...}) => (case_name, info)) dt_infos)
    1.66 +       cases});
    1.67 +
    1.68  val get_datatype = Symtab.lookup o get_datatypes;
    1.69  
    1.70  fun the_datatype thy name = (case get_datatype thy name of
    1.71        SOME info => info
    1.72      | NONE => error ("Unknown datatype " ^ quote name));
    1.73  
    1.74 +val datatype_of_constr = Symtab.lookup o #constrs o DatatypesData.get;
    1.75 +val datatype_of_case = Symtab.lookup o #cases o DatatypesData.get;
    1.76 +
    1.77  fun get_datatype_descr thy dtco =
    1.78    get_datatype thy dtco
    1.79    |> Option.map (fn info as { descr, index, ... } => 
    1.80 @@ -392,127 +420,23 @@
    1.81  
    1.82  (**** translation rules for case ****)
    1.83  
    1.84 -fun case_tr ctxt [t, u] =
    1.85 -    let
    1.86 -      val thy = ProofContext.theory_of ctxt;
    1.87 -      fun case_error s name ts = raise TERM ("Error in case expression" ^
    1.88 -        getOpt (Option.map (curry op ^ " for datatype ") name, "") ^ ":\n" ^ s, ts);
    1.89 -      fun dest_case1 (Const ("_case1", _) $ t $ u) =
    1.90 -          (case strip_comb t of
    1.91 -            (Const (s, _), ts) =>
    1.92 -              (case try (unprefix Syntax.constN) s of
    1.93 -                SOME c => (c, ts)
    1.94 -              | NONE => (Sign.intern_const thy s, ts))
    1.95 -          | (Free (s, _), ts) => (Sign.intern_const thy s, ts)
    1.96 -          | _ => case_error "Head is not a constructor" NONE [t, u], u)
    1.97 -        | dest_case1 t = raise TERM ("dest_case1", [t]);
    1.98 -      fun dest_case2 (Const ("_case2", _) $ t $ u) = t :: dest_case2 u
    1.99 -        | dest_case2 t = [t];
   1.100 -      val cases as ((cname, _), _) :: _ = map dest_case1 (dest_case2 u);
   1.101 -      val tab = Symtab.dest (get_datatypes thy);
   1.102 -      val (cases', default) = (case split_last cases of
   1.103 -          (cases', (("dummy_pattern", []), t)) => (cases', SOME t)
   1.104 -        | _ => (cases, NONE))
   1.105 -      fun abstr (Free (x, T)) body = Term.absfree (x, T, body)
   1.106 -        | abstr (Const ("_constrain", _) $ Free (x, T) $ tT) body =
   1.107 -            Syntax.const Syntax.constrainAbsC $ Term.absfree (x, T, body) $ tT
   1.108 -        | abstr (Const ("Pair", _) $ x $ y) body =
   1.109 -            Syntax.const "split" $ (abstr x o abstr y) body
   1.110 -        | abstr t _ = case_error "Illegal pattern" NONE [t];
   1.111 -    in case find_first (fn (_, {descr, index, ...}) =>
   1.112 -      exists (equal cname o fst) (#3 (snd (nth descr index)))) tab of
   1.113 -        NONE => case_error ("Not a datatype constructor: " ^ cname) NONE [u]
   1.114 -      | SOME (tname, {descr, sorts, case_name, index, ...}) =>
   1.115 -        let
   1.116 -          val _ = if exists (equal "dummy_pattern" o fst o fst) cases' then
   1.117 -            case_error "Illegal occurrence of '_' dummy pattern" (SOME tname) [u] else ();
   1.118 -          val (_, (_, dts, constrs)) = nth descr index;
   1.119 -          fun find_case (s, dt) cases =
   1.120 -            (case find_first (equal s o fst o fst) cases' of
   1.121 -               NONE => (list_abs (map (rpair dummyT)
   1.122 -                 (DatatypeProp.make_tnames (map (typ_of_dtyp descr sorts) dt)),
   1.123 -                 case default of
   1.124 -                   NONE => (warning ("No clause for constructor " ^ s ^
   1.125 -                     " in case expression"); Const ("HOL.undefined", dummyT))
   1.126 -                 | SOME t => t), cases)
   1.127 -             | SOME (c as ((_, vs), t)) =>
   1.128 -                 if length dt <> length vs then
   1.129 -                    case_error ("Wrong number of arguments for constructor " ^ s)
   1.130 -                      (SOME tname) vs
   1.131 -                 else (fold_rev abstr vs t, remove (op =) c cases))
   1.132 -          val (fs, cases'') = fold_map find_case constrs cases'
   1.133 -        in case (cases'', length constrs = length cases', default) of
   1.134 -            ([], true, SOME _) =>
   1.135 -              case_error "Extra '_' dummy pattern" (SOME tname) [u]
   1.136 -          | (_ :: _, _, _) =>
   1.137 -              let val extra = distinct (op =) (map (fst o fst) cases'')
   1.138 -              in case extra \\ map fst constrs of
   1.139 -                  [] => case_error ("More than one clause for constructor(s) " ^
   1.140 -                    commas extra) (SOME tname) [u]
   1.141 -                | extra' => case_error ("Illegal constructor(s): " ^ commas extra')
   1.142 -                    (SOME tname) [u]
   1.143 -              end
   1.144 -          | _ => list_comb (Syntax.const case_name, fs) $ t
   1.145 -        end
   1.146 -    end
   1.147 -  | case_tr _ ts = raise TERM ("case_tr", ts);
   1.148 +fun make_case ctxt = DatatypeCase.make_case
   1.149 +  (datatype_of_constr (ProofContext.theory_of ctxt)) ctxt;
   1.150 +
   1.151 +fun strip_case ctxt = DatatypeCase.strip_case
   1.152 +  (datatype_of_case (ProofContext.theory_of ctxt));
   1.153  
   1.154 -fun case_tr' constrs ctxt ts =
   1.155 -  if length ts <> length constrs + 1 then raise Match else
   1.156 -  let
   1.157 -    val consts = ProofContext.consts_of ctxt;
   1.158 -
   1.159 -    val (fs, x) = split_last ts;
   1.160 -    fun strip_abs 0 t = ([], t)
   1.161 -      | strip_abs i (Abs p) =
   1.162 -        let val (x, u) = Syntax.atomic_abs_tr' p
   1.163 -        in apfst (cons x) (strip_abs (i-1) u) end
   1.164 -      | strip_abs i (Const ("split", _) $ t) = (case strip_abs (i+1) t of
   1.165 -          (v :: v' :: vs, u) => (Syntax.const "Pair" $ v $ v' :: vs, u));
   1.166 -    fun is_dependent i t =
   1.167 -      let val k = length (strip_abs_vars t) - i
   1.168 -      in k < 0 orelse exists (fn j => j >= k)
   1.169 -        (loose_bnos (strip_abs_body t))
   1.170 -      end;
   1.171 -    val cases = map (fn ((cname, dts), t) =>
   1.172 -      (Consts.extern_early consts cname,
   1.173 -       strip_abs (length dts) t, is_dependent (length dts) t))
   1.174 -      (constrs ~~ fs);
   1.175 -    fun count_cases (_, _, true) = I
   1.176 -      | count_cases (cname, (_, body), false) =
   1.177 -          AList.map_default (op = : term * term -> bool)
   1.178 -            (body, []) (cons cname)
   1.179 -    val cases' = sort (int_ord o swap o pairself (length o snd))
   1.180 -      (fold_rev count_cases cases []);
   1.181 -    fun mk_case1 (cname, (vs, body), _) = Syntax.const "_case1" $
   1.182 -      list_comb (Syntax.const cname, vs) $ body;
   1.183 -    fun is_undefined (Const ("HOL.undefined", _)) = true
   1.184 -      | is_undefined _ = false;
   1.185 -  in
   1.186 -    Syntax.const "_case_syntax" $ x $
   1.187 -      foldr1 (fn (t, u) => Syntax.const "_case2" $ t $ u) (map mk_case1
   1.188 -        (case find_first (is_undefined o fst) cases' of
   1.189 -           SOME (_, cnames) =>
   1.190 -           if length cnames = length constrs then [hd cases]
   1.191 -           else filter_out (fn (_, (_, body), _) => is_undefined body) cases
   1.192 -         | NONE => case cases' of
   1.193 -           [] => cases
   1.194 -         | (default, cnames) :: _ =>
   1.195 -           if length cnames = 1 then cases
   1.196 -           else if length cnames = length constrs then
   1.197 -             [hd cases, ("dummy_pattern", ([], default), false)]
   1.198 -           else
   1.199 -             filter_out (fn (cname, _, _) => member (op =) cnames cname) cases @
   1.200 -             [("dummy_pattern", ([], default), false)]))
   1.201 -  end;
   1.202 -
   1.203 -fun make_case_tr' case_names descr = maps
   1.204 -  (fn ((_, (_, _, constrs)), case_name) =>
   1.205 -    map (rpair (case_tr' constrs)) (NameSpace.accesses' case_name))
   1.206 -      (descr ~~ case_names);
   1.207 +fun add_case_tr' case_names thy =
   1.208 +  Theory.add_advanced_trfuns ([], [],
   1.209 +    map (fn case_name => 
   1.210 +      let val case_name' = Sign.const_syntax_name thy case_name
   1.211 +      in (case_name', DatatypeCase.case_tr' datatype_of_case case_name')
   1.212 +      end) case_names, []) thy;
   1.213  
   1.214  val trfun_setup =
   1.215 -  Theory.add_advanced_trfuns ([], [("_case_syntax", case_tr)], [], []);
   1.216 +  Theory.add_advanced_trfuns ([],
   1.217 +    [("_case_syntax", DatatypeCase.case_tr datatype_of_constr)],
   1.218 +    [], []);
   1.219  
   1.220  
   1.221  (* prepare types *)
   1.222 @@ -586,16 +510,19 @@
   1.223      ||> Theory.parent_path
   1.224      |-> (fn [ax] => pair ax)) (tnames ~~ tss);
   1.225  
   1.226 -fun specify_consts args thy =
   1.227 +fun gen_specify_consts add args thy =
   1.228    let
   1.229      val specs = map (fn (c, T, mx) =>
   1.230        Const (Sign.full_name thy (Syntax.const_name c mx), T)) args;
   1.231    in
   1.232      thy
   1.233 -    |> Sign.add_consts_i args
   1.234 +    |> add args
   1.235      |> Theory.add_finals_i false specs
   1.236    end;
   1.237  
   1.238 +val specify_consts = gen_specify_consts Sign.add_consts_i;
   1.239 +val specify_consts_authentic = gen_specify_consts Sign.add_consts_authentic;
   1.240 +
   1.241  fun add_datatype_axm flat_names new_type_names descr sorts types_syntax constr_syntax dt_info
   1.242      case_names_induct case_names_exhausts thy =
   1.243    let
   1.244 @@ -660,7 +587,7 @@
   1.245  
   1.246        (** case combinators **)
   1.247  
   1.248 -      |> specify_consts (map (fn ((name, T), Ts) =>
   1.249 +      |> specify_consts_authentic (map (fn ((name, T), Ts) =>
   1.250             (name, Ts @ [T] ---> freeT, NoSyn)) (case_names ~~ newTs ~~ case_fn_Ts));
   1.251  
   1.252      val reccomb_names' = map (Sign.full_name thy2') reccomb_names;
   1.253 @@ -731,11 +658,11 @@
   1.254  
   1.255      val thy12 =
   1.256        thy11
   1.257 -      |> Theory.add_advanced_trfuns ([], [], make_case_tr' case_names' (hd descr), [])
   1.258 +      |> add_case_tr' case_names'
   1.259        |> Theory.add_path (space_implode "_" new_type_names)
   1.260        |> add_rules simps case_thms size_thms rec_thms inject distinct
   1.261            weak_case_congs Simplifier.cong_add
   1.262 -      |> put_datatypes (fold Symtab.update dt_infos dt_info)
   1.263 +      |> put_dt_infos dt_infos
   1.264        |> add_cases_induct dt_infos induct
   1.265        |> Theory.parent_path
   1.266        |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms)
   1.267 @@ -792,11 +719,11 @@
   1.268  
   1.269      val thy12 =
   1.270        thy11
   1.271 -      |> Theory.add_advanced_trfuns ([], [], make_case_tr' case_names (hd descr), [])
   1.272 +      |> add_case_tr' case_names
   1.273        |> Theory.add_path (space_implode "_" new_type_names)
   1.274        |> add_rules simps case_thms size_thms rec_thms inject distinct
   1.275            weak_case_congs (Simplifier.attrib (op addcongs))
   1.276 -      |> put_datatypes (fold Symtab.update dt_infos dt_info)
   1.277 +      |> put_dt_infos dt_infos
   1.278        |> add_cases_induct dt_infos induct
   1.279        |> Theory.parent_path
   1.280        |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms) |> snd
   1.281 @@ -897,10 +824,10 @@
   1.282  
   1.283      val thy11 =
   1.284        thy10
   1.285 -      |> Theory.add_advanced_trfuns ([], [], make_case_tr' case_names descr, [])
   1.286 +      |> add_case_tr' case_names
   1.287        |> add_rules simps case_thms size_thms rec_thms inject distinct
   1.288             weak_case_congs (Simplifier.attrib (op addcongs))
   1.289 -      |> put_datatypes (fold Symtab.update dt_infos dt_info)
   1.290 +      |> put_dt_infos dt_infos
   1.291        |> add_cases_induct dt_infos induction'
   1.292        |> Theory.parent_path
   1.293        |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms)