src/Tools/Code/code_preproc.ML
changeset 34891 99b9a6290446
parent 34251 cd642bb91f64
child 34893 ecdc526af73a
--- a/src/Tools/Code/code_preproc.ML	Tue Jan 12 09:59:45 2010 +0100
+++ b/src/Tools/Code/code_preproc.ML	Tue Jan 12 16:27:42 2010 +0100
@@ -18,7 +18,7 @@
 
   type code_algebra
   type code_graph
-  val eqns: code_graph -> string -> (thm * bool) list
+  val cert: code_graph -> string -> Code.cert
   val sortargs: code_graph -> string -> sort list
   val all: code_graph -> string list
   val pretty: theory -> code_graph -> Pretty.T
@@ -53,8 +53,8 @@
     let
       val pre = Simplifier.merge_ss (pre1, pre2);
       val post = Simplifier.merge_ss (post1, post2);
-      val functrans = AList.merge (op =) (eq_fst (op =)) (functrans1, functrans2);
-        (* FIXME handle AList.DUP (!?) *)
+      val functrans = AList.merge (op =) (eq_fst (op =)) (functrans1, functrans2)
+        handle AList.DUP => error ("Duplicate function transformer");
     in make_thmproc ((pre, post), functrans) end;
 
 structure Code_Preproc_Data = Theory_Data
@@ -102,23 +102,14 @@
 
 (* post- and preprocessing *)
 
-fun apply_functrans thy c _ [] = []
-  | apply_functrans thy c [] eqns = eqns
-  | apply_functrans thy c functrans eqns = eqns
-      |> perhaps (perhaps_loop (perhaps_apply functrans))
-      |> Code.assert_eqns_const thy c
-      (*FIXME in future, the check here should be more accurate wrt. type schemes
-      -- perhaps by means of upcoming code certificates with a corresponding
-         preprocessor protocol*);
-
 fun trans_conv_rule conv thm = Thm.transitive thm ((conv o Thm.rhs_of) thm);
 
-fun eqn_conv conv =
+fun eqn_conv conv ct =
   let
     fun lhs_conv ct = if can Thm.dest_comb ct
       then Conv.combination_conv lhs_conv conv ct
       else Conv.all_conv ct;
-  in Conv.combination_conv (Conv.arg_conv lhs_conv) conv end;
+  in Conv.combination_conv (Conv.arg_conv lhs_conv) conv ct end;
 
 val rewrite_eqn = Conv.fconv_rule o eqn_conv o Simplifier.rewrite;
 
@@ -129,17 +120,15 @@
   #> Logic.dest_equals
   #> snd;
 
-fun preprocess thy c eqns =
+fun preprocess thy eqns =
   let
     val pre = (Simplifier.theory_context thy o #pre o the_thmproc) thy;
     val functrans = (map (fn (_, (_, f)) => f thy) o #functrans
       o the_thmproc) thy;
   in
     eqns
-    |> apply_functrans thy c functrans
+    |> perhaps (perhaps_loop (perhaps_apply functrans))
     |> (map o apfst) (rewrite_eqn pre)
-    |> (map o apfst) (AxClass.unoverload thy)
-    |> map (Code.assert_eqn thy)
   end;
 
 fun preprocess_conv thy ct =
@@ -196,20 +185,20 @@
 (** sort algebra and code equation graph types **)
 
 type code_algebra = (sort -> sort) * Sorts.algebra;
-type code_graph = ((string * sort) list * (thm * bool) list) Graph.T;
+type code_graph = ((string * sort) list * Code.cert) Graph.T;
 
-fun eqns eqngr = these o Option.map snd o try (Graph.get_node eqngr);
-fun sortargs eqngr = map snd o fst o Graph.get_node eqngr
+fun cert eqngr = snd o Graph.get_node eqngr;
+fun sortargs eqngr = map snd o fst o Graph.get_node eqngr;
 fun all eqngr = Graph.keys eqngr;
 
 fun pretty thy eqngr =
   AList.make (snd o Graph.get_node eqngr) (Graph.keys eqngr)
   |> (map o apfst) (Code.string_of_const thy)
   |> sort (string_ord o pairself fst)
-  |> map (fn (s, thms) =>
+  |> map (fn (s, cert) =>
        (Pretty.block o Pretty.fbreaks) (
          Pretty.str s
-         :: map (Display.pretty_thm_global thy o fst) thms
+         :: map (Display.pretty_thm_global thy o AxClass.overload thy o fst) (Code.eqns_of_cert thy cert)
        ))
   |> Pretty.chunks;
 
@@ -227,12 +216,12 @@
   map (fn (c, _) => AxClass.param_of_inst thy (c, tyco))
     o maps (#params o AxClass.get_info thy);
 
-fun typscheme_rhss thy c eqns =
+fun typscheme_rhss thy c cert =
   let
-    val tyscm = Code.typscheme_eqns thy c (map fst eqns);
+    val (tyscm, equations) = Code.dest_cert thy cert;
     val rhss = [] |> (fold o fold o fold_aterms)
-      (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, Logic.unvarifyT ty)) | _ => I)
-        (map (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of o fst) eqns);
+      (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, ty)) | _ => I)
+        (map (op :: o swap o fst) equations);
   in (tyscm, rhss) end;
 
 
@@ -259,7 +248,7 @@
       | NONE => Free;
 
 type vardeps_data = ((string * styp list) list * class list) Vargraph.T
-  * (((string * sort) list * (thm * bool) list) Symtab.table
+  * (((string * sort) list * Code.cert) Symtab.table
     * (class * string) list);
 
 val empty_vardeps_data : vardeps_data =
@@ -270,12 +259,11 @@
 
 fun obtain_eqns thy eqngr c =
   case try (Graph.get_node eqngr) c
-   of SOME (lhs, eqns) => ((lhs, []), [])
+   of SOME (lhs, cert) => ((lhs, []), cert)
     | NONE => let
-        val eqns = Code.these_eqns thy c
-          |> preprocess thy c;
-        val ((lhs, _), rhss) = typscheme_rhss thy c eqns;
-      in ((lhs, rhss), eqns) end;
+        val cert = Code.get_cert thy (preprocess thy) c;
+        val ((lhs, _), rhss) = typscheme_rhss thy c cert;
+      in ((lhs, rhss), cert) end;
 
 fun obtain_instance thy arities (inst as (class, tyco)) =
   case AList.lookup (op =) arities inst
@@ -396,32 +384,27 @@
        handle Sorts.CLASS_ERROR _ => [] (*permissive!*))
   end;
 
-fun inst_thm thy tvars' thm =
+fun inst_cert thy lhs cert =
   let
-    val tvars = (Term.add_tvars o Thm.prop_of) thm [];
-    val inter_sort = Sorts.inter_sort (Sign.classes_of thy);
-    fun mk_inst (tvar as (v, sort)) = case Vartab.lookup tvars' v
-     of SOME sort' => SOME (pairself (Thm.ctyp_of thy o TVar)
-          (tvar, (v, inter_sort (sort, sort'))))
-      | NONE => NONE;
-    val insts = map_filter mk_inst tvars;
-  in Thm.instantiate (insts, []) thm end;
+    val ((vs, _), _) = Code.dest_cert thy cert;
+    val sorts = map (fn (v, sort) => case AList.lookup (op =) lhs v
+     of SOME sort' => Sorts.inter_sort (Sign.classes_of thy) (sort, sort')
+      | NONE => sort) vs;
+  in Code.constrain_cert thy sorts cert end;
 
 fun add_arity thy vardeps (class, tyco) =
   AList.default (op =) ((class, tyco),
-    map_range (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k)) (Sign.arity_number thy tyco));
+    map_range (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k))
+      (Sign.arity_number thy tyco));
 
-fun add_eqs thy vardeps (c, (proto_lhs, proto_eqns)) (rhss, eqngr) =
+fun add_cert thy vardeps (c, (proto_lhs, proto_cert)) (rhss, eqngr) =
   if can (Graph.get_node eqngr) c then (rhss, eqngr)
   else let
     val lhs = map_index (fn (k, (v, _)) =>
       (v, snd (Vargraph.get_node vardeps (Fun c, k)))) proto_lhs;
-    val inst_tab = Vartab.empty |> fold (fn (v, sort) =>
-      Vartab.update ((v, 0), sort)) lhs;
-    val eqns = proto_eqns
-      |> (map o apfst) (inst_thm thy inst_tab);
-    val ((vs, _), rhss') = typscheme_rhss thy c eqns;
-    val eqngr' = Graph.new_node (c, (vs, eqns)) eqngr;
+    val cert = inst_cert thy lhs proto_cert;
+    val ((vs, _), rhss') = typscheme_rhss thy c cert;
+    val eqngr' = Graph.new_node (c, (vs, cert)) eqngr;
   in (map (pair c) rhss' @ rhss, eqngr') end;
 
 fun extend_arities_eqngr thy cs ts (arities, (eqngr : code_graph)) =
@@ -435,7 +418,7 @@
     val pp = Syntax.pp_global thy;
     val algebra = Sorts.subalgebra pp (is_proper_class thy)
       (AList.lookup (op =) arities') (Sign.classes_of thy);
-    val (rhss, eqngr') = Symtab.fold (add_eqs thy vardeps) eqntab ([], eqngr);
+    val (rhss, eqngr') = Symtab.fold (add_cert thy vardeps) eqntab ([], eqngr);
     fun deps_of (c, rhs) = c :: maps (dicts_of thy algebra)
       (rhs ~~ sortargs eqngr' c);
     val eqngr'' = fold (fn (c, rhs) => fold