--- a/src/Tools/Code/code_preproc.ML Tue Jan 12 09:59:45 2010 +0100
+++ b/src/Tools/Code/code_preproc.ML Tue Jan 12 16:27:42 2010 +0100
@@ -18,7 +18,7 @@
type code_algebra
type code_graph
- val eqns: code_graph -> string -> (thm * bool) list
+ val cert: code_graph -> string -> Code.cert
val sortargs: code_graph -> string -> sort list
val all: code_graph -> string list
val pretty: theory -> code_graph -> Pretty.T
@@ -53,8 +53,8 @@
let
val pre = Simplifier.merge_ss (pre1, pre2);
val post = Simplifier.merge_ss (post1, post2);
- val functrans = AList.merge (op =) (eq_fst (op =)) (functrans1, functrans2);
- (* FIXME handle AList.DUP (!?) *)
+ val functrans = AList.merge (op =) (eq_fst (op =)) (functrans1, functrans2)
+ handle AList.DUP => error ("Duplicate function transformer");
in make_thmproc ((pre, post), functrans) end;
structure Code_Preproc_Data = Theory_Data
@@ -102,23 +102,14 @@
(* post- and preprocessing *)
-fun apply_functrans thy c _ [] = []
- | apply_functrans thy c [] eqns = eqns
- | apply_functrans thy c functrans eqns = eqns
- |> perhaps (perhaps_loop (perhaps_apply functrans))
- |> Code.assert_eqns_const thy c
- (*FIXME in future, the check here should be more accurate wrt. type schemes
- -- perhaps by means of upcoming code certificates with a corresponding
- preprocessor protocol*);
-
fun trans_conv_rule conv thm = Thm.transitive thm ((conv o Thm.rhs_of) thm);
-fun eqn_conv conv =
+fun eqn_conv conv ct =
let
fun lhs_conv ct = if can Thm.dest_comb ct
then Conv.combination_conv lhs_conv conv ct
else Conv.all_conv ct;
- in Conv.combination_conv (Conv.arg_conv lhs_conv) conv end;
+ in Conv.combination_conv (Conv.arg_conv lhs_conv) conv ct end;
val rewrite_eqn = Conv.fconv_rule o eqn_conv o Simplifier.rewrite;
@@ -129,17 +120,15 @@
#> Logic.dest_equals
#> snd;
-fun preprocess thy c eqns =
+fun preprocess thy eqns =
let
val pre = (Simplifier.theory_context thy o #pre o the_thmproc) thy;
val functrans = (map (fn (_, (_, f)) => f thy) o #functrans
o the_thmproc) thy;
in
eqns
- |> apply_functrans thy c functrans
+ |> perhaps (perhaps_loop (perhaps_apply functrans))
|> (map o apfst) (rewrite_eqn pre)
- |> (map o apfst) (AxClass.unoverload thy)
- |> map (Code.assert_eqn thy)
end;
fun preprocess_conv thy ct =
@@ -196,20 +185,20 @@
(** sort algebra and code equation graph types **)
type code_algebra = (sort -> sort) * Sorts.algebra;
-type code_graph = ((string * sort) list * (thm * bool) list) Graph.T;
+type code_graph = ((string * sort) list * Code.cert) Graph.T;
-fun eqns eqngr = these o Option.map snd o try (Graph.get_node eqngr);
-fun sortargs eqngr = map snd o fst o Graph.get_node eqngr
+fun cert eqngr = snd o Graph.get_node eqngr;
+fun sortargs eqngr = map snd o fst o Graph.get_node eqngr;
fun all eqngr = Graph.keys eqngr;
fun pretty thy eqngr =
AList.make (snd o Graph.get_node eqngr) (Graph.keys eqngr)
|> (map o apfst) (Code.string_of_const thy)
|> sort (string_ord o pairself fst)
- |> map (fn (s, thms) =>
+ |> map (fn (s, cert) =>
(Pretty.block o Pretty.fbreaks) (
Pretty.str s
- :: map (Display.pretty_thm_global thy o fst) thms
+ :: map (Display.pretty_thm_global thy o AxClass.overload thy o fst) (Code.eqns_of_cert thy cert)
))
|> Pretty.chunks;
@@ -227,12 +216,12 @@
map (fn (c, _) => AxClass.param_of_inst thy (c, tyco))
o maps (#params o AxClass.get_info thy);
-fun typscheme_rhss thy c eqns =
+fun typscheme_rhss thy c cert =
let
- val tyscm = Code.typscheme_eqns thy c (map fst eqns);
+ val (tyscm, equations) = Code.dest_cert thy cert;
val rhss = [] |> (fold o fold o fold_aterms)
- (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, Logic.unvarifyT ty)) | _ => I)
- (map (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of o fst) eqns);
+ (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, ty)) | _ => I)
+ (map (op :: o swap o fst) equations);
in (tyscm, rhss) end;
@@ -259,7 +248,7 @@
| NONE => Free;
type vardeps_data = ((string * styp list) list * class list) Vargraph.T
- * (((string * sort) list * (thm * bool) list) Symtab.table
+ * (((string * sort) list * Code.cert) Symtab.table
* (class * string) list);
val empty_vardeps_data : vardeps_data =
@@ -270,12 +259,11 @@
fun obtain_eqns thy eqngr c =
case try (Graph.get_node eqngr) c
- of SOME (lhs, eqns) => ((lhs, []), [])
+ of SOME (lhs, cert) => ((lhs, []), cert)
| NONE => let
- val eqns = Code.these_eqns thy c
- |> preprocess thy c;
- val ((lhs, _), rhss) = typscheme_rhss thy c eqns;
- in ((lhs, rhss), eqns) end;
+ val cert = Code.get_cert thy (preprocess thy) c;
+ val ((lhs, _), rhss) = typscheme_rhss thy c cert;
+ in ((lhs, rhss), cert) end;
fun obtain_instance thy arities (inst as (class, tyco)) =
case AList.lookup (op =) arities inst
@@ -396,32 +384,27 @@
handle Sorts.CLASS_ERROR _ => [] (*permissive!*))
end;
-fun inst_thm thy tvars' thm =
+fun inst_cert thy lhs cert =
let
- val tvars = (Term.add_tvars o Thm.prop_of) thm [];
- val inter_sort = Sorts.inter_sort (Sign.classes_of thy);
- fun mk_inst (tvar as (v, sort)) = case Vartab.lookup tvars' v
- of SOME sort' => SOME (pairself (Thm.ctyp_of thy o TVar)
- (tvar, (v, inter_sort (sort, sort'))))
- | NONE => NONE;
- val insts = map_filter mk_inst tvars;
- in Thm.instantiate (insts, []) thm end;
+ val ((vs, _), _) = Code.dest_cert thy cert;
+ val sorts = map (fn (v, sort) => case AList.lookup (op =) lhs v
+ of SOME sort' => Sorts.inter_sort (Sign.classes_of thy) (sort, sort')
+ | NONE => sort) vs;
+ in Code.constrain_cert thy sorts cert end;
fun add_arity thy vardeps (class, tyco) =
AList.default (op =) ((class, tyco),
- map_range (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k)) (Sign.arity_number thy tyco));
+ map_range (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k))
+ (Sign.arity_number thy tyco));
-fun add_eqs thy vardeps (c, (proto_lhs, proto_eqns)) (rhss, eqngr) =
+fun add_cert thy vardeps (c, (proto_lhs, proto_cert)) (rhss, eqngr) =
if can (Graph.get_node eqngr) c then (rhss, eqngr)
else let
val lhs = map_index (fn (k, (v, _)) =>
(v, snd (Vargraph.get_node vardeps (Fun c, k)))) proto_lhs;
- val inst_tab = Vartab.empty |> fold (fn (v, sort) =>
- Vartab.update ((v, 0), sort)) lhs;
- val eqns = proto_eqns
- |> (map o apfst) (inst_thm thy inst_tab);
- val ((vs, _), rhss') = typscheme_rhss thy c eqns;
- val eqngr' = Graph.new_node (c, (vs, eqns)) eqngr;
+ val cert = inst_cert thy lhs proto_cert;
+ val ((vs, _), rhss') = typscheme_rhss thy c cert;
+ val eqngr' = Graph.new_node (c, (vs, cert)) eqngr;
in (map (pair c) rhss' @ rhss, eqngr') end;
fun extend_arities_eqngr thy cs ts (arities, (eqngr : code_graph)) =
@@ -435,7 +418,7 @@
val pp = Syntax.pp_global thy;
val algebra = Sorts.subalgebra pp (is_proper_class thy)
(AList.lookup (op =) arities') (Sign.classes_of thy);
- val (rhss, eqngr') = Symtab.fold (add_eqs thy vardeps) eqntab ([], eqngr);
+ val (rhss, eqngr') = Symtab.fold (add_cert thy vardeps) eqntab ([], eqngr);
fun deps_of (c, rhs) = c :: maps (dicts_of thy algebra)
(rhs ~~ sortargs eqngr' c);
val eqngr'' = fold (fn (c, rhs) => fold