src/Pure/Tools/codegen_data.ML
changeset 22423 c1836b14c63a
parent 22360 26ead7ed4f4b
child 22484 25dfebd7b4c8
--- a/src/Pure/Tools/codegen_data.ML	Fri Mar 09 08:45:50 2007 +0100
+++ b/src/Pure/Tools/codegen_data.ML	Fri Mar 09 08:45:53 2007 +0100
@@ -15,9 +15,8 @@
   val add_func_legacy: thm -> theory -> theory
   val del_func: thm -> theory -> theory
   val add_funcl: CodegenConsts.const * thm list Susp.T -> theory -> theory
-  val add_datatype: string * (((string * sort) list * (string * typ list) list) * thm list Susp.T)
+  val add_datatype: string * ((string * sort) list * (string * typ list) list)
     -> theory -> theory
-  val del_datatype: string -> theory -> theory
   val add_inline: thm -> theory -> theory
   val del_inline: thm -> theory -> theory
   val add_inline_proc: string * (theory -> cterm list -> thm list) -> theory -> theory
@@ -28,8 +27,7 @@
   val operational_algebra: theory -> (sort -> sort) * Sorts.algebra
   val these_funcs: theory -> CodegenConsts.const -> thm list
   val tap_typ: theory -> CodegenConsts.const -> typ option
-  val get_datatype: theory -> string
-    -> ((string * sort) list * (string * typ list) list) option
+  val get_datatype: theory -> string -> ((string * sort) list * (string * typ list) list)
   val get_datatype_of_constr: theory -> CodegenConsts.const -> string option
 
   val preprocess_cterm: cterm -> thm
@@ -193,7 +191,7 @@
     in (SOME consts, thms) end;
 
 val eq_string = op = : string * string -> bool;
-fun eq_dtyp (((vs1, cs1), _), ((vs2, cs2), _)) = 
+fun eq_dtyp ((vs1, cs1), (vs2, cs2)) = 
   gen_eq_set (eq_pair eq_string (gen_eq_set eq_string)) (vs1, vs2)
     andalso gen_eq_set (eq_pair eq_string (eq_list (is_equal o Term.typ_ord))) (cs1, cs2);
 fun merge_dtyps (tabs as (tab1, tab2)) =
@@ -210,7 +208,7 @@
 datatype spec = Spec of {
   funcs: sdthms Consttab.table,
   dconstrs: string Consttab.table,
-  dtyps: (((string * sort) list * (string * typ list) list) * thm list Susp.T) Symtab.table
+  dtyps: ((string * sort) list * (string * typ list) list) Symtab.table
 };
 
 fun mk_spec ((funcs, dconstrs), dtyps) =
@@ -328,15 +326,17 @@
         (Pretty.block o Pretty.fbreaks) (
           Pretty.str s :: pretty_sdthms ctxt lthms
         );
-      fun pretty_dtyp (s, cos) =
-        (Pretty.block o Pretty.breaks) (
-          Pretty.str s
-          :: Pretty.str "="
-          :: Pretty.separate "|" (map (fn (c, []) => Pretty.str c
-               | (c, tys) =>
-                   (Pretty.block o Pretty.breaks)
-                      (Pretty.str c :: Pretty.str "of" :: map (Pretty.quote o Sign.pretty_typ thy) tys)) cos)
-        )
+      fun pretty_dtyp (s, []) =
+            Pretty.str s
+        | pretty_dtyp (s, cos) =
+            (Pretty.block o Pretty.breaks) (
+              Pretty.str s
+              :: Pretty.str "="
+              :: separate (Pretty.str "|") (map (fn (c, []) => Pretty.str c
+                   | (c, tys) =>
+                       (Pretty.block o Pretty.breaks)
+                          (Pretty.str c :: Pretty.str "of" :: map (Pretty.quote o Sign.pretty_typ thy) tys)) cos)
+            );
       val inlines = (#inlines o the_preproc) exec;
       val inline_procs = (map fst o #inline_procs o the_preproc) exec;
       val preprocs = (map fst o #preprocs o the_preproc) exec;
@@ -346,13 +346,14 @@
         |> sort (string_ord o pairself fst);
       val dtyps = the_dtyps exec
         |> Symtab.dest
-        |> map (fn (dtco, ((vs, cos), _)) => (Sign.string_of_typ thy (Type (dtco, map TFree vs)), cos))
+        |> map (fn (dtco, (vs, cos)) => (Sign.string_of_typ thy (Type (dtco, map TFree vs)), cos))
         |> sort (string_ord o pairself fst)
     in
       (Pretty.writeln o Pretty.chunks) [
         Pretty.block (
           Pretty.str "defining equations:"
-          :: map pretty_func funs
+          :: Pretty.fbrk
+          :: (Pretty.fbreaks o map pretty_func) funs
         ),
         Pretty.block (
           Pretty.str "inlining theorems:"
@@ -431,77 +432,6 @@
         ^ CodegenConsts.string_of_const thy c ^ "\n" ^ string_of_thm thm)
   in map cert c_thms end;
 
-fun mk_cos tyco vs cos =
-  let
-    val dty = Type (tyco, map TFree vs);
-    fun mk_co (co, tys) = (Const (co, (tys ---> dty)), map I tys);
-  in map mk_co cos end;
-
-fun mk_co_args (co, tys) ctxt =
-  let
-    val names = Name.invents ctxt "a" (length tys);
-    val ctxt' = fold Name.declare names ctxt;
-    val vs = map2 (fn v => fn ty => Free (fst (v, 0), I ty)) names tys;
-  in (vs, ctxt') end;
-
-fun check_freeness thy cos thms =
-  let
-    val props = AList.make Drule.plain_prop_of thms;
-    fun sym_product [] = []
-      | sym_product (x::xs) = map (pair x) xs @ sym_product xs;
-    val quodlibet =
-      let
-        val judg = ObjectLogic.fixed_judgment (the_context ()) "x";
-        val [free] = fold_aterms (fn v as Free _ => cons v | _ => I) judg [];
-        val judg' = Term.subst_free [(free, Bound 0)] judg;
-        val prop = Type ("prop", []);
-        val prop' = fastype_of judg';
-      in
-        Const ("all", (prop' --> prop) --> prop) $ Abs ("P", prop', judg')
-      end;
-    fun check_inj (co, []) =
-          NONE
-      | check_inj (co, tys) =
-          let
-            val ((xs, ys), _) = Name.context
-              |> mk_co_args (co, tys)
-              ||>> mk_co_args (co, tys);
-            val prem = Logic.mk_equals
-              (list_comb (co, xs), list_comb (co, ys));
-            val concl = Logic.mk_conjunction_list
-              (map2 (curry Logic.mk_equals) xs ys);
-            val t = Logic.mk_implies (prem, concl);
-          in case find_first (curry Term.could_unify t o snd) props
-           of SOME (thm, _) => SOME thm
-            | NONE => error ("Could not prove injectiveness statement\n"
-               ^ Sign.string_of_term thy t
-               ^ "\nfor constructor "
-               ^ CodegenConsts.string_of_const_typ thy (dest_Const co)
-               ^ "\nwith theorems\n" ^ cat_lines (map string_of_thm thms))
-          end;
-    fun check_dist ((co1, tys1), (co2, tys2)) =
-          let
-            val ((xs1, xs2), _) = Name.context
-              |> mk_co_args (co1, tys1)
-              ||>> mk_co_args (co2, tys2);
-            val prem = Logic.mk_equals
-              (list_comb (co1, xs1), list_comb (co2, xs2));
-            val t = Logic.mk_implies (prem, quodlibet);
-          in case find_first (curry Term.could_unify t o snd) props
-           of SOME (thm, _) => thm
-            | NONE => error ("Could not prove distinctness statement\n"
-               ^ Sign.string_of_term thy t
-               ^ "\nfor constructors "
-               ^ CodegenConsts.string_of_const_typ thy (dest_Const co1)
-               ^ " and "
-               ^ CodegenConsts.string_of_const_typ thy (dest_Const co2)
-               ^ "\nwith theorems\n" ^ cat_lines (map string_of_thm thms))
-          end;
-  in (map_filter check_inj cos, map check_dist (sym_product cos)) end;
-
-fun certify_datatype thy dtco cs thms =
-  (op @) (check_freeness thy cs thms);
-
 
 
 (** operational sort algebra and class discipline **)
@@ -684,37 +614,102 @@
       (add_lthms lthms'))) thy
   end;
 
-fun add_datatype (tyco, (vs_cos as (vs, cos), lthms)) thy =
+local
+
+fun consts_of_cos thy tyco vs cos =
+  let
+    val dty = Type (tyco, map TFree vs);
+    fun mk_co (co, tys) = CodegenConsts.norm_of_typ thy (co, tys ---> dty);
+  in map mk_co cos end;
+
+fun co_of_const thy (c, ty) =
   let
-    val cs = mk_cos tyco vs cos;
-    val consts = map (CodegenConsts.norm_of_typ thy o dest_Const o fst) cs;
-    val add =
-      map_dtyps (Symtab.update_new (tyco,
-        (vs_cos, certificate thy (fn thy => certify_datatype thy tyco cs) lthms)))
-      #> map_dconstrs (fold (fn c => Consttab.update (c, tyco)) consts)
-  in map_exec_purge (SOME consts) add thy end;
+    fun err () = error
+     ("Illegal type for datatype constructor: " ^ Sign.string_of_typ thy ty);
+    val (tys, ty') = strip_type ty;
+    val (tyco, vs) = ((apsnd o map) dest_TFree o dest_Type) ty'
+      handle TYPE _ => err ();
+    val sorts = if has_duplicates (eq_fst op =) vs then err ()
+      else map snd vs;
+    val vs_names = Name.invent_list [] "'a" (length vs);
+    val vs_map = map fst vs ~~ vs_names;
+    val vs' = vs_names ~~ sorts;
+    val tys' = (map o map_type_tfree) (fn (v, sort) =>
+      (TFree ((the o AList.lookup (op =) vs_map) v, sort))) tys
+      handle Option => err ();
+  in (tyco, (vs', (c, tys'))) end;
 
 fun del_datatype tyco thy =
+  case Symtab.lookup ((the_dtyps o get_exec) thy) tyco
+   of SOME (vs, cos) => let
+        val consts = consts_of_cos thy tyco vs cos;
+        val del =
+          map_dtyps (Symtab.delete tyco)
+          #> map_dconstrs (fold Consttab.delete consts)
+      in map_exec_purge (SOME consts) del thy end
+    | NONE => thy;
+
+(*FIXME integrate this auxiliary properly*)
+
+in
+
+fun add_datatype (tyco, (vs_cos as (vs, cos))) thy =
   let
-    val SOME ((vs, cos), _) = Symtab.lookup ((the_dtyps o get_exec) thy) tyco;
-    val cs = mk_cos tyco vs cos;
-    val consts = map (CodegenConsts.norm_of_typ thy o dest_Const o fst) cs;
-    val del =
-      map_dtyps (Symtab.delete tyco)
-      #> map_dconstrs (fold Consttab.delete consts)
-  in map_exec_purge (SOME consts) del thy end;
+    val consts = consts_of_cos thy tyco vs cos;
+    val add =
+      map_dtyps (Symtab.update_new (tyco, vs_cos))
+      #> map_dconstrs (fold (fn c => Consttab.update (c, tyco)) consts)
+  in
+    thy
+    |> del_datatype tyco
+    |> map_exec_purge (SOME consts) add
+  end;
+
+fun add_datatype_consts cs thy =
+  let
+    val raw_cos = map (co_of_const thy) cs;
+    val (tyco, (vs_names, sorts_cos)) = if (length o distinct (eq_fst op =)) raw_cos = 1
+      then ((fst o hd) raw_cos, ((map fst o fst o snd o hd) raw_cos,
+        map ((apfst o map) snd o snd) raw_cos))
+      else error ("Term constructors not referring to the same type: "
+        ^ commas (map (CodegenConsts.string_of_const_typ thy) cs));
+    val sorts = foldr1 ((uncurry o map2 o curry o Sorts.inter_sort) (Sign.classes_of thy))
+      (map fst sorts_cos);
+    val cos = map snd sorts_cos;
+    val vs = vs_names ~~ sorts;
+  in
+    thy
+    |> add_datatype (tyco, (vs, cos))
+  end;
+
+fun add_datatype_consts_cmd raw_cs thy =
+  let
+    val cs = map (apsnd Logic.unvarifyT o CodegenConsts.typ_of_inst thy
+      o CodegenConsts.read_const thy) raw_cs
+  in
+    thy
+    |> add_datatype_consts cs
+  end;
+
+end; (*local*)
 
 fun add_inline thm thy =
-  (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (insert Thm.eq_thm) (CodegenFunc.mk_rew thm)) thy;
+  (map_exec_purge NONE o map_preproc o apfst o apfst)
+    (fold (insert Thm.eq_thm) (CodegenFunc.mk_rew thm)) thy;
+        (*fully applied in order to get right context for mk_rew!*)
 
 fun del_inline thm thy =
-  (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (remove Thm.eq_thm) (CodegenFunc.mk_rew thm)) thy ;
+  (map_exec_purge NONE o map_preproc o apfst o apfst)
+    (fold (remove Thm.eq_thm) (CodegenFunc.mk_rew thm)) thy;
+        (*fully applied in order to get right context for mk_rew!*)
 
 fun add_inline_proc (name, f) =
-  (map_exec_purge NONE o map_preproc o apfst o apsnd) (AList.update (op =) (name, (serial (), f)));
+  (map_exec_purge NONE o map_preproc o apfst o apsnd)
+    (AList.update (op =) (name, (serial (), f)));
 
 fun del_inline_proc name =
-  (map_exec_purge NONE o map_preproc o apfst o apsnd) (delete_force "inline procedure" name);
+  (map_exec_purge NONE o map_preproc o apfst o apsnd)
+    (delete_force "inline procedure" name);
 
 fun add_preproc (name, f) =
   (map_exec_purge NONE o map_preproc o apsnd) (AList.update (op =) (name, (serial (), f)));
@@ -774,6 +769,25 @@
 
 end; (*local*)
 
+fun get_datatype thy tyco =
+  case Symtab.lookup ((the_dtyps o get_exec) thy) tyco
+   of SOME spec => spec
+    | NONE => Sign.arity_number thy tyco
+        |> Name.invents Name.context "'a"
+        |> map (rpair [])
+        |> rpair [];
+
+fun get_datatype_of_constr thy =
+  Consttab.lookup ((the_dcontrs o get_exec) thy);
+
+fun get_datatype_constr thy const =
+  case Consttab.lookup ((the_dcontrs o get_exec) thy) const
+   of SOME tyco => let
+        val (vs, cs) = get_datatype thy tyco;
+        (*FIXME continue here*)
+      in NONE end
+    | NONE => NONE;
+
 local
 
 fun get_funcs thy const =
@@ -812,14 +826,6 @@
 
 end; (*local*)
 
-fun get_datatype thy tyco =
-  Symtab.lookup ((the_dtyps o get_exec) thy) tyco
-  |> Option.map (fn (spec, thms) => (Susp.force thms; spec));
-
-fun get_datatype_of_constr thy c =
-  Consttab.lookup ((the_dcontrs o get_exec) thy) c
-  |> (Option.map o tap) (fn dtco => get_datatype thy dtco);
-
 
 (** code attributes **)
 
@@ -846,15 +852,23 @@
 and K = OuterKeyword
 
 val print_codesetupK = "print_codesetup";
+val code_datatypeK = "code_datatype";
 
 in
 
 val print_codesetupP =
-  OuterSyntax.improper_command print_codesetupK "print code generator setup of this theory" OuterKeyword.diag
+  OuterSyntax.improper_command print_codesetupK "print code generator setup of this theory" K.diag
     (Scan.succeed
       (Toplevel.no_timing o Toplevel.unknown_theory o Toplevel.keep (CodeData.print o Toplevel.theory_of)));
 
-val _ = OuterSyntax.add_parsers [print_codesetupP];
+val code_datatypeP =
+  OuterSyntax.command code_datatypeK "define set of code datatype constructors" K.thy_decl (
+    Scan.repeat1 P.term
+    >> (Toplevel.theory o add_datatype_consts_cmd)
+  );
+
+
+val _ = OuterSyntax.add_parsers [print_codesetupP, code_datatypeP];
 
 end; (*local*)