src/Pure/Tools/codegen_data.ML
changeset 22744 5cbe966d67a2
parent 22705 6199df39688d
child 22806 45ac82e7b887
--- a/src/Pure/Tools/codegen_data.ML	Fri Apr 20 11:21:41 2007 +0200
+++ b/src/Pure/Tools/codegen_data.ML	Fri Apr 20 11:21:42 2007 +0200
@@ -14,6 +14,7 @@
   val add_func: bool -> thm -> theory -> theory
   val del_func: thm -> theory -> theory
   val add_funcl: CodegenConsts.const * thm list Susp.T -> theory -> theory
+  val add_func_attr: bool -> Attrib.src
   val add_inline: thm -> theory -> theory
   val del_inline: thm -> theory -> theory
   val add_inline_proc: string * (theory -> cterm list -> thm list) -> theory -> theory
@@ -23,6 +24,7 @@
   val add_datatype: string * ((string * sort) list * (string * typ list) list)
     -> theory -> theory
   val add_datatype_consts: CodegenConsts.const list -> theory -> theory
+  val add_datatype_consts_cmd: string list -> theory -> theory
 
   val coregular_algebra: theory -> Sorts.algebra
   val operational_algebra: theory -> (sort -> sort) * Sorts.algebra
@@ -33,6 +35,8 @@
 
   val preprocess_cterm: cterm -> thm
 
+  val print_codesetup: theory -> unit
+
   val trace: bool ref
 end;
 
@@ -91,7 +95,7 @@
   then (false, xs)
   else (true, AList.merge eq_key eq xys);
 
-val merge_thms = merge' Thm.eq_thm;
+val merge_thms = merge' Thm.eq_thm_prop;
 
 fun merge_lthms (r1, r2) =
   if Susp.same (r1, r2)
@@ -122,7 +126,7 @@
     fun drop thm' = not (matches args (args_of thm'))
       orelse (warning ("Dropping redundant defining equation\n" ^ string_of_thm thm'); false);
     val (keeps, drops) = List.partition drop sels;
-  in (thm :: keeps, dels |> fold (insert Thm.eq_thm) drops |> remove Thm.eq_thm thm) end;
+  in (thm :: keeps, dels |> fold (insert Thm.eq_thm_prop) drops |> remove Thm.eq_thm_prop thm) end;
 
 fun add_thm thm (sels, dels) =
   apfst Susp.value (add_drop_redundant thm (Susp.force sels, dels));
@@ -135,7 +139,7 @@
       fold add_thm (Susp.force lthms) (sels, dels);
 
 fun del_thm thm (sels, dels) =
-  (Susp.value (remove Thm.eq_thm thm (Susp.force sels)), thm :: dels);
+  (Susp.value (remove Thm.eq_thm_prop thm (Susp.force sels)), thm :: dels);
 
 fun pretty_sdthms ctxt (sels, _) = pretty_lthms ctxt sels;
 
@@ -144,8 +148,8 @@
     val (dels_t, dels) = merge_thms (dels1, dels2);
   in if dels_t
     then let
-      val (_, sels) = merge_thms (Susp.force sels1, subtract Thm.eq_thm dels1 (Susp.force sels2))
-      val (_, dels) = merge_thms (dels1, subtract Thm.eq_thm (Susp.force sels1) dels2)
+      val (_, sels) = merge_thms (Susp.force sels1, subtract Thm.eq_thm_prop dels1 (Susp.force sels2))
+      val (_, dels) = merge_thms (dels1, subtract Thm.eq_thm_prop (Susp.force sels1) dels2)
     in (true, ((lazy_thms o K) sels, dels)) end
     else let
       val (sels_t, sels) = merge_lthms (sels1, sels2)
@@ -377,6 +381,8 @@
     end;
 end);
 
+val print_codesetup = CodeData.print;
+
 fun init k = CodeData.map
   (fn (exec, data) => (exec, ref (Datatab.update (k, invoke_empty k) (! data))));
 
@@ -402,16 +408,16 @@
 
 fun common_typ_funcs [] = []
   | common_typ_funcs [thm] = [thm]
-  | common_typ_funcs thms =
+  | common_typ_funcs (thms as thm :: _) =
       let
-        val thy = Thm.theory_of_thm (hd thms)
+        val thy = Thm.theory_of_thm thm;
         fun incr_thm thm max =
           let
             val thm' = incr_indexes max thm;
             val max' = Thm.maxidx_of thm' + 1;
           in (thm', max') end;
         val (thms', maxidx) = fold_map incr_thm thms 0;
-        val (ty1::tys) = map CodegenFunc.typ_func thms';
+        val ty1 :: tys = map (snd o CodegenFunc.head_func) thms';
         fun unify ty env = Sign.typ_unify thy (ty1, ty) env
           handle Type.TUNIFY =>
             error ("Type unificaton failed, while unifying defining equations\n"
@@ -423,12 +429,12 @@
           cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
       in map (Thm.instantiate (instT, [])) thms' end;
 
-fun certify_const thy c c_thms =
+fun certify_const thy const thms =
   let
-    fun cert (c', thm) = if CodegenConsts.eq_const (c, c')
+    fun cert thm = if CodegenConsts.eq_const (const, fst (CodegenFunc.head_func thm))
       then thm else error ("Wrong head of defining equation,\nexpected constant "
-        ^ CodegenConsts.string_of_const thy c ^ "\n" ^ string_of_thm thm)
-  in map cert c_thms end;
+        ^ CodegenConsts.string_of_const thy const ^ "\n" ^ string_of_thm thm)
+  in map cert thms end;
 
 
 
@@ -459,7 +465,7 @@
       |> maps these
       |> map (Thm.transfer thy);
     val sorts = map (map (snd o dest_TVar) o snd o dest_Type o the_single
-      o Sign.const_typargs thy o fst o CodegenFunc.dest_func) funcs;
+      o Sign.const_typargs thy o (fn ((c, _), ty) => (c, ty)) o CodegenFunc.head_func) funcs;
   in sorts end;
 
 fun weakest_constraints thy (class, tyco) =
@@ -512,51 +518,49 @@
 val classop_weakest_typ = gen_classop_typ weakest_constraints;
 val classop_strongest_typ = gen_classop_typ strongest_constraints;
 
-fun gen_mk_func_typ strict thm =
+fun assert_func_typ thm =
   let
     val thy = Thm.theory_of_thm thm;
-    val raw_funcs = CodegenFunc.mk_func strict thm;
-    val error_warning = if strict then error else warning #> K NONE;
     fun check_typ_classop class (const as (c, SOME tyco), thm) =
           let
-            val ((_, ty), _) = CodegenFunc.dest_func thm;
+            val (_, ty) = CodegenFunc.head_func thm;
             val ty_decl = classop_weakest_typ thy class (c, tyco);
             val ty_strongest = classop_strongest_typ thy class (c, tyco);
             fun constrain thm = 
               let
                 val max = Thm.maxidx_of thm + 1;
                 val ty_decl' = Logic.incr_tvar max ty_decl;
-                val ((_, ty'), _) = CodegenFunc.dest_func thm;
+                val (_, ty') = CodegenFunc.head_func thm;
                 val (env, _) = Sign.typ_unify thy (ty_decl', ty') (Vartab.empty, max);
                 val instT = Vartab.fold (fn (x_i, (sort, ty)) =>
                   cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
               in Thm.instantiate (instT, []) thm end;
           in if Sign.typ_instance thy (ty_strongest, ty)
             then if Sign.typ_instance thy (ty, ty_decl)
-            then SOME (const, thm)
+            then thm
             else (warning ("Constraining type\n" ^ CodegenConsts.string_of_typ thy ty
               ^ "\nof defining equation\n"
               ^ string_of_thm thm
               ^ "\nto permitted most general type\n"
               ^ CodegenConsts.string_of_typ thy ty_decl);
-              SOME (const, constrain thm))
-            else error_warning ("Type\n" ^ CodegenConsts.string_of_typ thy ty
+              constrain thm)
+            else CodegenFunc.bad_thm ("Type\n" ^ CodegenConsts.string_of_typ thy ty
               ^ "\nof defining equation\n"
               ^ string_of_thm thm
               ^ "\nis incompatible with permitted least general type\n"
               ^ CodegenConsts.string_of_typ thy ty_strongest)
           end
       | check_typ_classop class ((c, NONE), thm) =
-          error_warning ("Illegal type for class operation " ^ quote c
+          CodegenFunc.bad_thm ("Illegal type for class operation " ^ quote c
            ^ "\nin defining equation\n"
            ^ string_of_thm thm);
     fun check_typ_fun (const as (c, _), thm) =
       let
-        val ((_, ty), _) = CodegenFunc.dest_func thm;
+        val (_, ty) = CodegenFunc.head_func thm;
         val ty_decl = Sign.the_const_type thy c;
       in if Sign.typ_equiv thy (Type.strip_sorts ty_decl, Type.strip_sorts ty)
-        then SOME (const, thm)
-        else error_warning ("Type\n" ^ CodegenConsts.string_of_typ thy ty
+        then thm
+        else CodegenFunc.bad_thm ("Type\n" ^ CodegenConsts.string_of_typ thy ty
            ^ "\nof defining equation\n"
            ^ string_of_thm thm
            ^ "\nis incompatible declared function type\n"
@@ -566,23 +570,34 @@
       case AxClass.class_of_param thy c
        of SOME class => check_typ_classop class (const, thm)
         | NONE => check_typ_fun (const, thm);
-    val funcs = map_filter check_typ raw_funcs;
-  in funcs end;
+  in check_typ (fst (CodegenFunc.head_func thm), thm) end;
+
+val mk_func = CodegenFunc.error_thm
+  (assert_func_typ o CodegenFunc.mk_func);
+val mk_func_liberal = CodegenFunc.warning_thm
+  (assert_func_typ o CodegenFunc.mk_func);
 
 end;
 
-
 (** interfaces **)
 
-fun add_func strict thm thy =
-  let
-    val funcs = gen_mk_func_typ strict thm;
-    val cs = map fst funcs;
-  in
-    map_exec_purge (SOME cs) (map_funcs 
-     (fold (fn (c, thm) => Consttab.map_default
-       (c, (Susp.value [], [])) (add_thm thm)) funcs)) thy
-  end;
+fun add_func true thm thy =
+      let
+        val func = mk_func thm;
+        val (const, _) = CodegenFunc.head_func func;
+      in map_exec_purge (SOME [const]) (map_funcs
+        (Consttab.map_default
+          (const, (Susp.value [], [])) (add_thm func))) thy
+      end
+  | add_func false thm thy =
+      case mk_func_liberal thm
+       of SOME func => let
+              val (const, _) = CodegenFunc.head_func func
+            in map_exec_purge (SOME [const]) (map_funcs
+              (Consttab.map_default
+                (const, (Susp.value [], [])) (add_thm func))) thy
+            end
+        | NONE => thy;
 
 fun delete_force msg key xs =
   if AList.defined (op =) xs key then AList.delete (op =) key xs
@@ -590,23 +605,26 @@
 
 fun del_func thm thy =
   let
-    val funcs = gen_mk_func_typ false thm;
-    val cs = map fst funcs;
-  in
-    map_exec_purge (SOME cs) (map_funcs
-     (fold (fn (c, thm) => Consttab.map_entry c
-       (del_thm thm)) funcs)) thy
+    val func = mk_func thm;
+    val (const, _) = CodegenFunc.head_func func;
+  in map_exec_purge (SOME [const]) (map_funcs
+    (Consttab.map_entry
+      const (del_thm func))) thy
   end;
 
 fun add_funcl (const, lthms) thy =
   let
-    val lthms' = certificate thy (fn thy => certify_const thy const
-      o maps (CodegenFunc.mk_func true)) lthms;
+    val lthms' = certificate thy (fn thy => certify_const thy const) lthms;
+      (*FIXME must check compatibility with sort algebra;
+        alas, naive checking results in non-termination!*)
   in
     map_exec_purge (SOME [const]) (map_funcs (Consttab.map_default (const, (Susp.value [], []))
       (add_lthms lthms'))) thy
   end;
 
+fun add_func_attr strict = Attrib.internal (fn _ => Thm.declaration_attribute
+  (fn thm => Context.mapping (add_func strict thm) I));
+
 local
 
 fun del_datatype tyco thy =
@@ -637,12 +655,12 @@
 
 fun add_inline thm thy =
   (map_exec_purge NONE o map_preproc o apfst o apfst)
-    (fold (insert Thm.eq_thm) (CodegenFunc.mk_rew thm)) thy;
+    (insert Thm.eq_thm_prop (CodegenFunc.mk_rew thm)) thy;
         (*fully applied in order to get right context for mk_rew!*)
 
 fun del_inline thm thy =
   (map_exec_purge NONE o map_preproc o apfst o apfst)
-    (fold (remove Thm.eq_thm) (CodegenFunc.mk_rew thm)) thy;
+    (remove Thm.eq_thm_prop (CodegenFunc.mk_rew thm)) thy;
         (*fully applied in order to get right context for mk_rew!*)
 
 fun add_inline_proc (name, f) =
@@ -680,12 +698,9 @@
 fun apply_preproc thy f [] = []
   | apply_preproc thy f (thms as (thm :: _)) =
       let
+        val (const, _) = CodegenFunc.head_func thm;
         val thms' = f thy thms;
-        val thms'' as ((const, _) :: _) = map CodegenFunc.mk_head thms'
-      in (certify_const thy const o map CodegenFunc.mk_head) thms' end;
-
-fun cmp_thms thy =
-  make_ord (fn (thm1, thm2) => not (Sign.typ_instance thy (CodegenFunc.typ_func thm1, CodegenFunc.typ_func thm2)));
+      in certify_const thy const thms' end;
 
 fun rhs_conv conv thm =
   let
@@ -700,7 +715,6 @@
   |> map (CodegenFunc.rewrite_func ((#inlines o the_preproc o get_exec) thy))
   |> fold (fn (_, (_, f)) => apply_inline_proc thy f) ((#inline_procs o the_preproc o get_exec) thy)
 (*FIXME - must check: rewrite rule, defining equation, proper constant |> map (snd o check_func false thy) *)
-  |> sort (cmp_thms thy)
   |> common_typ_funcs;
 
 fun preprocess_cterm ct =
@@ -757,38 +771,14 @@
   |> these
   |> map (Thm.transfer thy);
 
-fun find_def thy (const as (c, _)) =
-  let
-    val specs = Defs.specifications_of (Theory.defs_of thy) c;
-    val ty = case try (default_typ_proto thy) const
-     of NONE => NONE
-      | SOME ty => ty;
-    val tys = Sign.const_typargs thy (c, ty |> the_default (Sign.the_const_type thy c));
-    fun get_def (_, { is_def, name, lhs, rhs, thyname }) =
-      if is_def andalso forall (Sign.typ_instance thy) (tys ~~ lhs) then
-        try (Thm.get_axiom_i thy) name
-      else NONE
-  in get_first get_def specs end;
-
 in
 
 fun these_funcs thy const =
   let
-    fun get_prim_def_funcs (const as (c, tys)) =
-      case find_def thy const
-       of SOME thm =>
-            thm
-            |> Thm.transfer thy
-            |> gen_mk_func_typ false
-            |> map (CodegenFunc.expand_eta ~1 o snd)
-        | NONE => []
     fun drop_refl thy = filter_out (is_equal o Term.fast_term_ord o Logic.dest_equals
       o ObjectLogic.drop_judgment thy o Thm.plain_prop_of);
-    val funcs = case get_funcs thy const
-     of [] => get_prim_def_funcs const
-    | funcs => funcs
   in
-    funcs
+    get_funcs thy const
     |> preprocess thy
     |> drop_refl thy
   end;
@@ -796,61 +786,14 @@
 fun default_typ thy (const as (c, _)) = case default_typ_proto thy const
  of SOME ty => ty
   | NONE => (case get_funcs thy const
-     of thm :: _ => CodegenFunc.typ_func thm
+     of thm :: _ => snd (CodegenFunc.head_func thm)
       | [] => Sign.the_const_type thy c);
 
 end; (*local*)
 
-
-(** code attributes **)
-
-local
-  fun add_simple_attribute (name, f) =
-    (Codegen.add_attribute name o (Scan.succeed o Thm.declaration_attribute))
-      (fn th => Context.mapping (f th) I);
-in
-  val _ = map (Context.add_setup o add_simple_attribute) [
-    ("func", add_func true),
-    ("nofunc", del_func),
-    ("unfold", (fn thm => Codegen.add_unfold thm #> add_inline thm)),
-    ("inline", add_inline),
-    ("noinline", del_inline)
-  ]
-end; (*local*)
-
-
-(** Isar setup **)
-
-local
-
-structure P = OuterParse
-and K = OuterKeyword
-
-val print_codesetupK = "print_codesetup";
-val code_datatypeK = "code_datatype";
-
-in
-
-val print_codesetupP =
-  OuterSyntax.improper_command print_codesetupK "print code generator setup of this theory" K.diag
-    (Scan.succeed
-      (Toplevel.no_timing o Toplevel.unknown_theory o Toplevel.keep (CodeData.print o Toplevel.theory_of)));
-
-val code_datatypeP =
-  OuterSyntax.command code_datatypeK "define set of code datatype constructors" K.thy_decl (
-    Scan.repeat1 P.term
-    >> (Toplevel.theory o add_datatype_consts_cmd)
-  );
-
-
-val _ = OuterSyntax.add_parsers [print_codesetupP, code_datatypeP];
-
-end; (*local*)
-
 end; (*struct*)
 
 
-
 (** type-safe interfaces for data depedent on executable content **)
 
 signature CODE_DATA_ARGS =