src/HOL/Tools/datatype_package.ML
changeset 22777 2fc921376a86
parent 22709 9ab51bac6287
child 22846 fb79144af9a3
--- a/src/HOL/Tools/datatype_package.ML	Tue Apr 24 15:07:27 2007 +0200
+++ b/src/HOL/Tools/datatype_package.ML	Tue Apr 24 15:11:07 2007 +0200
@@ -65,9 +65,15 @@
   val get_datatypes : theory -> DatatypeAux.datatype_info Symtab.table
   val get_datatype : theory -> string -> DatatypeAux.datatype_info option
   val the_datatype : theory -> string -> DatatypeAux.datatype_info
+  val datatype_of_constr : theory -> string -> DatatypeAux.datatype_info option
+  val datatype_of_case : theory -> string -> DatatypeAux.datatype_info option
   val get_datatype_spec : theory -> string -> ((string * sort) list * (string * typ list) list) option
   val get_datatype_constrs : theory -> string -> (string * typ) list option
   val print_datatypes : theory -> unit
+  val make_case :  Proof.context -> bool -> string list -> term ->
+    (term * term) list -> term * (term * (int * bool)) list
+  val strip_case: Proof.context -> bool ->
+    term -> (term * (term * term) list) option
   val setup: theory -> theory
 end;
 
@@ -84,31 +90,53 @@
 structure DatatypesData = TheoryDataFun
 (struct
   val name = "HOL/datatypes";
-  type T = datatype_info Symtab.table;
+  type T =
+    {types: datatype_info Symtab.table,
+     constrs: datatype_info Symtab.table,
+     cases: datatype_info Symtab.table};
 
-  val empty = Symtab.empty;
+  val empty =
+    {types = Symtab.empty, constrs = Symtab.empty, cases = Symtab.empty};
   val copy = I;
   val extend = I;
-  fun merge _ tabs : T = Symtab.merge (K true) tabs;
+  fun merge _
+    ({types = types1, constrs = constrs1, cases = cases1},
+     {types = types2, constrs = constrs2, cases = cases2}) =
+    {types = Symtab.merge (K true) (types1, types2),
+     constrs = Symtab.merge (K true) (constrs1, constrs2),
+     cases = Symtab.merge (K true) (cases1, cases2)};
 
-  fun print sg tab =
+  fun print sg ({types, ...} : T) =
     Pretty.writeln (Pretty.strs ("datatypes:" ::
-      map #1 (NameSpace.extern_table (Sign.type_space sg, tab))));
+      map #1 (NameSpace.extern_table (Sign.type_space sg, types))));
 end);
 
-val get_datatypes = DatatypesData.get;
-val put_datatypes = DatatypesData.put;
+val get_datatypes = #types o DatatypesData.get;
+val map_datatypes = DatatypesData.map;
 val print_datatypes = DatatypesData.print;
 
 
 (** theory information about datatypes **)
 
+fun put_dt_infos (dt_infos : (string * datatype_info) list) =
+  map_datatypes (fn {types, constrs, cases} =>
+    {types = fold Symtab.update dt_infos types,
+     constrs = fold Symtab.update
+       (maps (fn (_, info as {descr, index, ...}) => map (rpair info o fst)
+          (#3 (the (AList.lookup op = descr index)))) dt_infos) constrs,
+     cases = fold Symtab.update
+       (map (fn (_, info as {case_name, ...}) => (case_name, info)) dt_infos)
+       cases});
+
 val get_datatype = Symtab.lookup o get_datatypes;
 
 fun the_datatype thy name = (case get_datatype thy name of
       SOME info => info
     | NONE => error ("Unknown datatype " ^ quote name));
 
+val datatype_of_constr = Symtab.lookup o #constrs o DatatypesData.get;
+val datatype_of_case = Symtab.lookup o #cases o DatatypesData.get;
+
 fun get_datatype_descr thy dtco =
   get_datatype thy dtco
   |> Option.map (fn info as { descr, index, ... } => 
@@ -392,127 +420,23 @@
 
 (**** translation rules for case ****)
 
-fun case_tr ctxt [t, u] =
-    let
-      val thy = ProofContext.theory_of ctxt;
-      fun case_error s name ts = raise TERM ("Error in case expression" ^
-        getOpt (Option.map (curry op ^ " for datatype ") name, "") ^ ":\n" ^ s, ts);
-      fun dest_case1 (Const ("_case1", _) $ t $ u) =
-          (case strip_comb t of
-            (Const (s, _), ts) =>
-              (case try (unprefix Syntax.constN) s of
-                SOME c => (c, ts)
-              | NONE => (Sign.intern_const thy s, ts))
-          | (Free (s, _), ts) => (Sign.intern_const thy s, ts)
-          | _ => case_error "Head is not a constructor" NONE [t, u], u)
-        | dest_case1 t = raise TERM ("dest_case1", [t]);
-      fun dest_case2 (Const ("_case2", _) $ t $ u) = t :: dest_case2 u
-        | dest_case2 t = [t];
-      val cases as ((cname, _), _) :: _ = map dest_case1 (dest_case2 u);
-      val tab = Symtab.dest (get_datatypes thy);
-      val (cases', default) = (case split_last cases of
-          (cases', (("dummy_pattern", []), t)) => (cases', SOME t)
-        | _ => (cases, NONE))
-      fun abstr (Free (x, T)) body = Term.absfree (x, T, body)
-        | abstr (Const ("_constrain", _) $ Free (x, T) $ tT) body =
-            Syntax.const Syntax.constrainAbsC $ Term.absfree (x, T, body) $ tT
-        | abstr (Const ("Pair", _) $ x $ y) body =
-            Syntax.const "split" $ (abstr x o abstr y) body
-        | abstr t _ = case_error "Illegal pattern" NONE [t];
-    in case find_first (fn (_, {descr, index, ...}) =>
-      exists (equal cname o fst) (#3 (snd (nth descr index)))) tab of
-        NONE => case_error ("Not a datatype constructor: " ^ cname) NONE [u]
-      | SOME (tname, {descr, sorts, case_name, index, ...}) =>
-        let
-          val _ = if exists (equal "dummy_pattern" o fst o fst) cases' then
-            case_error "Illegal occurrence of '_' dummy pattern" (SOME tname) [u] else ();
-          val (_, (_, dts, constrs)) = nth descr index;
-          fun find_case (s, dt) cases =
-            (case find_first (equal s o fst o fst) cases' of
-               NONE => (list_abs (map (rpair dummyT)
-                 (DatatypeProp.make_tnames (map (typ_of_dtyp descr sorts) dt)),
-                 case default of
-                   NONE => (warning ("No clause for constructor " ^ s ^
-                     " in case expression"); Const ("HOL.undefined", dummyT))
-                 | SOME t => t), cases)
-             | SOME (c as ((_, vs), t)) =>
-                 if length dt <> length vs then
-                    case_error ("Wrong number of arguments for constructor " ^ s)
-                      (SOME tname) vs
-                 else (fold_rev abstr vs t, remove (op =) c cases))
-          val (fs, cases'') = fold_map find_case constrs cases'
-        in case (cases'', length constrs = length cases', default) of
-            ([], true, SOME _) =>
-              case_error "Extra '_' dummy pattern" (SOME tname) [u]
-          | (_ :: _, _, _) =>
-              let val extra = distinct (op =) (map (fst o fst) cases'')
-              in case extra \\ map fst constrs of
-                  [] => case_error ("More than one clause for constructor(s) " ^
-                    commas extra) (SOME tname) [u]
-                | extra' => case_error ("Illegal constructor(s): " ^ commas extra')
-                    (SOME tname) [u]
-              end
-          | _ => list_comb (Syntax.const case_name, fs) $ t
-        end
-    end
-  | case_tr _ ts = raise TERM ("case_tr", ts);
+fun make_case ctxt = DatatypeCase.make_case
+  (datatype_of_constr (ProofContext.theory_of ctxt)) ctxt;
+
+fun strip_case ctxt = DatatypeCase.strip_case
+  (datatype_of_case (ProofContext.theory_of ctxt));
 
-fun case_tr' constrs ctxt ts =
-  if length ts <> length constrs + 1 then raise Match else
-  let
-    val consts = ProofContext.consts_of ctxt;
-
-    val (fs, x) = split_last ts;
-    fun strip_abs 0 t = ([], t)
-      | strip_abs i (Abs p) =
-        let val (x, u) = Syntax.atomic_abs_tr' p
-        in apfst (cons x) (strip_abs (i-1) u) end
-      | strip_abs i (Const ("split", _) $ t) = (case strip_abs (i+1) t of
-          (v :: v' :: vs, u) => (Syntax.const "Pair" $ v $ v' :: vs, u));
-    fun is_dependent i t =
-      let val k = length (strip_abs_vars t) - i
-      in k < 0 orelse exists (fn j => j >= k)
-        (loose_bnos (strip_abs_body t))
-      end;
-    val cases = map (fn ((cname, dts), t) =>
-      (Consts.extern_early consts cname,
-       strip_abs (length dts) t, is_dependent (length dts) t))
-      (constrs ~~ fs);
-    fun count_cases (_, _, true) = I
-      | count_cases (cname, (_, body), false) =
-          AList.map_default (op = : term * term -> bool)
-            (body, []) (cons cname)
-    val cases' = sort (int_ord o swap o pairself (length o snd))
-      (fold_rev count_cases cases []);
-    fun mk_case1 (cname, (vs, body), _) = Syntax.const "_case1" $
-      list_comb (Syntax.const cname, vs) $ body;
-    fun is_undefined (Const ("HOL.undefined", _)) = true
-      | is_undefined _ = false;
-  in
-    Syntax.const "_case_syntax" $ x $
-      foldr1 (fn (t, u) => Syntax.const "_case2" $ t $ u) (map mk_case1
-        (case find_first (is_undefined o fst) cases' of
-           SOME (_, cnames) =>
-           if length cnames = length constrs then [hd cases]
-           else filter_out (fn (_, (_, body), _) => is_undefined body) cases
-         | NONE => case cases' of
-           [] => cases
-         | (default, cnames) :: _ =>
-           if length cnames = 1 then cases
-           else if length cnames = length constrs then
-             [hd cases, ("dummy_pattern", ([], default), false)]
-           else
-             filter_out (fn (cname, _, _) => member (op =) cnames cname) cases @
-             [("dummy_pattern", ([], default), false)]))
-  end;
-
-fun make_case_tr' case_names descr = maps
-  (fn ((_, (_, _, constrs)), case_name) =>
-    map (rpair (case_tr' constrs)) (NameSpace.accesses' case_name))
-      (descr ~~ case_names);
+fun add_case_tr' case_names thy =
+  Theory.add_advanced_trfuns ([], [],
+    map (fn case_name => 
+      let val case_name' = Sign.const_syntax_name thy case_name
+      in (case_name', DatatypeCase.case_tr' datatype_of_case case_name')
+      end) case_names, []) thy;
 
 val trfun_setup =
-  Theory.add_advanced_trfuns ([], [("_case_syntax", case_tr)], [], []);
+  Theory.add_advanced_trfuns ([],
+    [("_case_syntax", DatatypeCase.case_tr datatype_of_constr)],
+    [], []);
 
 
 (* prepare types *)
@@ -586,16 +510,19 @@
     ||> Theory.parent_path
     |-> (fn [ax] => pair ax)) (tnames ~~ tss);
 
-fun specify_consts args thy =
+fun gen_specify_consts add args thy =
   let
     val specs = map (fn (c, T, mx) =>
       Const (Sign.full_name thy (Syntax.const_name c mx), T)) args;
   in
     thy
-    |> Sign.add_consts_i args
+    |> add args
     |> Theory.add_finals_i false specs
   end;
 
+val specify_consts = gen_specify_consts Sign.add_consts_i;
+val specify_consts_authentic = gen_specify_consts Sign.add_consts_authentic;
+
 fun add_datatype_axm flat_names new_type_names descr sorts types_syntax constr_syntax dt_info
     case_names_induct case_names_exhausts thy =
   let
@@ -660,7 +587,7 @@
 
       (** case combinators **)
 
-      |> specify_consts (map (fn ((name, T), Ts) =>
+      |> specify_consts_authentic (map (fn ((name, T), Ts) =>
            (name, Ts @ [T] ---> freeT, NoSyn)) (case_names ~~ newTs ~~ case_fn_Ts));
 
     val reccomb_names' = map (Sign.full_name thy2') reccomb_names;
@@ -731,11 +658,11 @@
 
     val thy12 =
       thy11
-      |> Theory.add_advanced_trfuns ([], [], make_case_tr' case_names' (hd descr), [])
+      |> add_case_tr' case_names'
       |> Theory.add_path (space_implode "_" new_type_names)
       |> add_rules simps case_thms size_thms rec_thms inject distinct
           weak_case_congs Simplifier.cong_add
-      |> put_datatypes (fold Symtab.update dt_infos dt_info)
+      |> put_dt_infos dt_infos
       |> add_cases_induct dt_infos induct
       |> Theory.parent_path
       |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms)
@@ -792,11 +719,11 @@
 
     val thy12 =
       thy11
-      |> Theory.add_advanced_trfuns ([], [], make_case_tr' case_names (hd descr), [])
+      |> add_case_tr' case_names
       |> Theory.add_path (space_implode "_" new_type_names)
       |> add_rules simps case_thms size_thms rec_thms inject distinct
           weak_case_congs (Simplifier.attrib (op addcongs))
-      |> put_datatypes (fold Symtab.update dt_infos dt_info)
+      |> put_dt_infos dt_infos
       |> add_cases_induct dt_infos induct
       |> Theory.parent_path
       |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms) |> snd
@@ -897,10 +824,10 @@
 
     val thy11 =
       thy10
-      |> Theory.add_advanced_trfuns ([], [], make_case_tr' case_names descr, [])
+      |> add_case_tr' case_names
       |> add_rules simps case_thms size_thms rec_thms inject distinct
            weak_case_congs (Simplifier.attrib (op addcongs))
-      |> put_datatypes (fold Symtab.update dt_infos dt_info)
+      |> put_dt_infos dt_infos
       |> add_cases_induct dt_infos induction'
       |> Theory.parent_path
       |> store_thmss "splits" new_type_names (map (fn (x, y) => [x, y]) split_thms)