added explicit maintainance of coregular code theorems for overloaded constants
authorhaftmann
Thu, 25 Jan 2007 09:32:50 +0100
changeset 22184 a125f38a559a
parent 22183 0e6c0aeb04ec
child 22185 24bf0e403526
added explicit maintainance of coregular code theorems for overloaded constants
src/Pure/Tools/codegen_data.ML
--- a/src/Pure/Tools/codegen_data.ML	Thu Jan 25 09:32:49 2007 +0100
+++ b/src/Pure/Tools/codegen_data.ML	Thu Jan 25 09:32:50 2007 +0100
@@ -24,15 +24,16 @@
   val del_inline_proc: string -> theory -> theory
   val add_preproc: string * (theory -> thm list -> thm list) -> theory -> theory
   val del_preproc: string -> theory -> theory
-  val class_arity: theory -> class -> string -> sort list
+  val coregular_algebra: theory -> Sorts.algebra
+  val operational_algebra: theory -> (sort -> sort) * Sorts.algebra
   val these_funcs: theory -> CodegenConsts.const -> thm list
+  val tap_typ: theory -> CodegenConsts.const -> typ option
   val get_datatype: theory -> string
     -> ((string * sort) list * (string * typ list) list) option
   val get_datatype_of_constr: theory -> CodegenConsts.const -> string option
 
   val print_thms: theory -> unit
 
-  val typ_funcs: theory -> CodegenConsts.const * thm list -> typ
   val preprocess_cterm: cterm -> thm
 
   val trace: bool ref
@@ -335,9 +336,8 @@
           :: Pretty.str "="
           :: Pretty.separate "|" (map (fn (c, []) => Pretty.str c
                | (c, tys) =>
-                   Pretty.block
-                      (Pretty.str c :: Pretty.brk 1 :: Pretty.str "of" :: Pretty.brk 1
-                      :: Pretty.breaks (map (Pretty.quote o Sign.pretty_typ thy) tys))) cos)
+                   (Pretty.block o Pretty.breaks)
+                      (Pretty.str c :: Pretty.str "of" :: map (Pretty.quote o Sign.pretty_typ thy) tys)) cos)
         )
       val inlines = (#inlines o the_preproc) exec;
       val inline_procs = (map fst o #inline_procs o the_preproc) exec;
@@ -507,7 +507,7 @@
 
 
 
-(** operational sort algebra **)
+(** operational sort algebra and class discipline **)
 
 local
 
@@ -523,73 +523,156 @@
     val inters = curry (Sorts.inter_sort algebra);
   in aggregate (map2 inters) end;
 
-fun get_raw_funcs thy tyco clsop =
-  let
-    val vs = Name.invents Name.context "" (Sign.arity_number thy tyco);
-    val c = CodegenConsts.norm thy (clsop, [Type (tyco, map (TFree o rpair []) vs)])
-  in
-    Consttab.lookup ((the_funcs o get_exec) thy) c
-    |> Option.map (Susp.force o fst)
-    |> these
-    |> map (Thm.transfer thy)
-  end;
-
-fun constraints thy class tyco =
+fun specific_constraints thy (class, tyco) =
   let
     val vs = Name.invents Name.context "" (Sign.arity_number thy tyco);
     val clsops = (these o Option.map snd o try (AxClass.params_of_class thy)) class;
-    val funcs = maps (get_raw_funcs thy tyco o fst) clsops;
+    val funcs = clsops
+      |> map (fn (clsop, _) => CodegenConsts.norm thy (clsop, [Type (tyco, map (TFree o rpair []) vs)]))
+      |> map (Consttab.lookup ((the_funcs o get_exec) thy))
+      |> (map o Option.map) (Susp.force o fst)
+      |> maps these
+      |> map (Thm.transfer thy);
     val sorts = map (map (snd o dest_TVar) o snd o dest_Type o the_single
       o Sign.const_typargs thy o fst o CodegenFunc.dest_func) funcs;
-  in inter_sorts thy sorts end;
+  in sorts end;
+
+fun weakest_constraints thy (class, tyco) =
+  let
+    val all_superclasses = class :: Graph.all_succs ((#classes o Sorts.rep_algebra o Sign.classes_of) thy) [class];
+  in case inter_sorts thy (maps (fn class => specific_constraints thy (class, tyco)) all_superclasses)
+   of SOME sorts => sorts
+    | NONE => Sign.arity_sorts thy tyco [class]
+  end;
 
-fun weakest_constraints thy class tyco =
-  case constraints thy class tyco
-   of sorts as SOME _ => sorts
-    | NONE => let
-        val sorts = map_filter (fn class => weakest_constraints thy class tyco)
-          (Sign.super_classes thy class);
-      in inter_sorts thy sorts end;
+fun strongest_constraints thy (class, tyco) =
+  let
+    val algebra = Sign.classes_of thy;
+    val all_subclasses = class :: Graph.all_preds ((#classes o Sorts.rep_algebra) algebra) [class];
+    val inst_subclasses = filter (can (Sorts.mg_domain algebra tyco) o single) all_subclasses;
+  in case inter_sorts thy (maps (fn class => specific_constraints thy (class, tyco)) inst_subclasses)
+   of SOME sorts => sorts
+    | NONE => replicate
+        (Sign.arity_number thy tyco) (Sign.certify_sort thy (Sign.all_classes thy))
+  end;
+
+fun gen_classop_typ constr thy class (c, tyco) = 
+  let
+    val (var, cs) = try (AxClass.params_of_class thy) class |> the_default ("'a", [])
+    val ty = (the o AList.lookup (op =) cs) c;
+    val sort_args = Name.names (Name.declare var Name.context) "'a"
+      (constr thy (class, tyco));
+    val ty_inst = Type (tyco, map TFree sort_args);
+  in Logic.varifyT (map_type_tfree (K ty_inst) ty) end;
+
+(*FIXME: make distinct step: building algebra from code theorems*)
+fun retrieve_algebra thy operational =
+  Sorts.subalgebra (Sign.pp thy) operational
+    (weakest_constraints thy)
+    (Sign.classes_of thy);
 
 in
 
-fun class_arity thy class tyco =
-  weakest_constraints thy class tyco
-  |> the_default (Sign.arity_sorts thy tyco [class]);
+fun coregular_algebra thy = retrieve_algebra thy (K true) |> snd;
+fun operational_algebra thy =
+  let
+    fun add_iff_operational class classes =
+      if (not o null o these o Option.map #params o try (AxClass.get_definition thy)) class
+        orelse (length o gen_inter (op =))
+          ((Sign.certify_sort thy o Sign.super_classes thy) class, classes) >= 2
+      then class :: classes
+      else classes;
+    val operational_classes = fold add_iff_operational (Sign.all_classes thy) []
+  in retrieve_algebra thy (member (op =) operational_classes) end;
+
+val classop_weakest_typ = gen_classop_typ weakest_constraints;
+val classop_strongest_typ = gen_classop_typ strongest_constraints;
 
-fun upward_compatible_constraints thy sorts class tyco =
-  case constraints thy class tyco
-   of SOME sorts' => forall (Sign.subsort thy) (sorts ~~ sorts')
-    | NONE => forall (fn class => upward_compatible_constraints thy sorts class tyco)
-        (Graph.imm_preds ((#classes o Sorts.rep_algebra o Sign.classes_of) thy) class);
+fun gen_mk_func_typ strict_functyp thm =
+  let
+    val thy = Thm.theory_of_thm thm;
+    val raw_funcs = CodegenFunc.mk_func thm;
+    val error_warning = if strict_functyp then error else warning #> K NONE;
+    val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy);
+    fun check_typ_classop class (const as (c, [Type (tyco, _)]), thm) =
+          let
+            val ((_, ty), _) = CodegenFunc.dest_func thm;
+            val ty_decl = classop_weakest_typ thy class (c, tyco);
+            val ty_strongest = classop_strongest_typ thy class (c, tyco);
+            fun constrain thm = 
+              let
+                val max = Thm.maxidx_of thm + 1;
+                val ty_decl' = Logic.incr_tvar max ty_decl;
+                val ((_, ty'), _) = CodegenFunc.dest_func thm;
+                val (env, _) = Sign.typ_unify thy (ty_decl', ty') (Vartab.empty, max);
+                val instT = Vartab.fold (fn (x_i, (sort, ty)) =>
+                  cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
+              in Thm.instantiate (instT, []) thm end;
+          in if Sign.typ_instance thy (ty_strongest, ty)
+            then if Sign.typ_instance thy (ty, ty_decl)
+            then SOME (const, thm)
+            else (warning ("Constraining type\n" ^ string_of_typ ty
+              ^ "\nof function theorem\n"
+              ^ string_of_thm thm
+              ^ "\nto permitted most general type\n"
+              ^ string_of_typ ty_decl);
+              SOME (const, constrain thm))
+            else error_warning ("Type\n" ^ string_of_typ ty
+              ^ "\nof function theorem\n"
+              ^ string_of_thm thm
+              ^ "\nis incompatible with permitted least general type\n"
+              ^ string_of_typ ty_strongest)
+          end
+      | check_typ_classop class ((c, [_]), thm) =
+          (if strict_functyp then error else warning #> K NONE)
+           ("Illegal type for class operation " ^ quote c
+           ^ "\nin function theorem\n"
+           ^ string_of_thm thm);
+    fun check_typ_fun (const as (c, _), thm) =
+      let
+        val ((_, ty), _) = CodegenFunc.dest_func thm;
+        val ty_decl = Sign.the_const_type thy c;
+      in if Sign.typ_equiv thy (Type.strip_sorts ty_decl, Type.strip_sorts ty)
+        then SOME (const, thm)
+        else error_warning ("Type\n" ^ string_of_typ ty
+           ^ "\nof function theorem\n"
+           ^ string_of_thm thm
+           ^ "\nis incompatible declared function type\n"
+           ^ string_of_typ ty_decl)
+      end;
+    fun check_typ (const as (c, tys), thm) =
+      case AxClass.class_of_param thy c
+       of SOME class => check_typ_classop class (const, thm)
+        | NONE => check_typ_fun (const, thm);
+    val funcs = map_filter check_typ raw_funcs;
+  in funcs end;
 
 end;
 
 
-
 (** interfaces **)
 
-fun gen_add_func mk_func thm thy =
+fun gen_add_func strict_functyp thm thy =
   let
-    val thms = mk_func thm;
-    val cs = map fst thms;
+    val funcs = gen_mk_func_typ strict_functyp thm;
+    val cs = map fst funcs;
   in
     map_exec_purge (SOME cs) (map_funcs 
      (fold (fn (c, thm) => Consttab.map_default
-       (c, (Susp.value [], [])) (add_thm thm)) thms)) thy
+       (c, (Susp.value [], [])) (add_thm thm)) funcs)) thy
   end;
 
-val add_func = gen_add_func CodegenFunc.mk_func;
-val add_func_legacy = gen_add_func CodegenFunc.legacy_mk_func;
+val add_func = gen_add_func true;
+val add_func_legacy = gen_add_func false;
 
 fun del_func thm thy =
   let
-    val thms = CodegenFunc.mk_func thm;
-    val cs = map fst thms;
+    val funcs = gen_mk_func_typ false thm;
+    val cs = map fst funcs;
   in
     map_exec_purge (SOME cs) (map_funcs
      (fold (fn (c, thm) => Consttab.map_entry c
-       (del_thm thm)) thms)) thy
+       (del_thm thm)) funcs)) thy
   end;
 
 fun add_funcl (c, lthms) thy =
@@ -644,7 +727,7 @@
 fun gen_apply_inline_proc prep post thy f x =
   let
     val cts = prep x;
-    val rews = map CodegenFunc.check_rew (f thy cts);
+    val rews = map CodegenFunc.assert_rew (f thy cts);
   in post rews x end;
 
 val apply_inline_proc = gen_apply_inline_proc (maps
@@ -681,36 +764,56 @@
 
 fun preprocess_cterm ct =
   let
-    val thy = Thm.theory_of_cterm ct
+    val thy = Thm.theory_of_cterm ct;
   in
     ct
     |> Thm.reflexive
     |> fold (rhs_conv o MetaSimplifier.rewrite false o single)
-      ((#inlines o the_preproc o get_exec) thy)
+        ((#inlines o the_preproc o get_exec) thy)
     |> fold (fn (_, (_, f)) => rhs_conv (apply_inline_proc_cterm thy f))
-      ((#inline_procs o the_preproc o get_exec) thy)
+        ((#inline_procs o the_preproc o get_exec) thy)
   end;
 
 end; (*local*)
 
-fun these_funcs thy c =
+local
+
+fun get_funcs thy const =
+  Consttab.lookup ((the_funcs o get_exec) thy) const
+  |> Option.map (Susp.force o fst)
+  |> these
+  |> map (Thm.transfer thy);
+
+in
+
+fun these_funcs thy const =
   let
-    val funcs_1 =
-      Consttab.lookup ((the_funcs o get_exec) thy) c
-      |> Option.map (Susp.force o fst)
-      |> these
-      |> map (Thm.transfer thy);
-    val funcs_2 = case funcs_1
-     of [] => CodegenFunc.get_prim_def_funcs thy c
-      | xs => xs;
+    fun get_prim_def_funcs (const as (c, tys)) =
+      case CodegenConsts.find_def thy const
+       of SOME (_, thm) =>
+            thm
+            |> Thm.transfer thy
+            |> gen_mk_func_typ false
+            |> map (CodegenFunc.expand_eta ~1 o snd)
+        | NONE => []
     fun drop_refl thy = filter_out (is_equal o Term.fast_term_ord o Logic.dest_equals
       o ObjectLogic.drop_judgment thy o Drule.plain_prop_of);
+    val funcs = case get_funcs thy const
+     of [] => get_prim_def_funcs const
+    | funcs => funcs
   in
-    funcs_2
+    funcs
     |> preprocess thy
     |> drop_refl thy
   end;
 
+fun tap_typ thy const =
+  case get_funcs thy const
+   of thm :: _ => SOME (CodegenFunc.typ_func thm)
+    | [] => NONE;
+
+end; (*local*)
+
 fun get_datatype thy tyco =
   Symtab.lookup ((the_dtyps o get_exec) thy) tyco
   |> Option.map (fn (spec, thms) => (Susp.force thms; spec));
@@ -719,13 +822,6 @@
   Consttab.lookup ((the_dcontrs o get_exec) thy) c
   |> (Option.map o tap) (fn dtco => get_datatype thy dtco);
 
-fun typ_funcs thy (c as (name, _), []) = (case AxClass.class_of_param thy name
-     of SOME class => CodegenConsts.disc_typ_of_classop thy c
-      | NONE => (case Option.map (Susp.force o fst) (Consttab.lookup ((the_funcs o get_exec) thy) c)
-         of SOME [eq] => CodegenFunc.typ_func eq
-          | _ => Sign.the_const_type thy name))
-  | typ_funcs thy (_, eq :: _) = CodegenFunc.typ_func eq;
-
 
 (** code attributes **)