diff -r ded5b770ec1c -r 99b9a6290446 src/Tools/Code/code_preproc.ML --- a/src/Tools/Code/code_preproc.ML Tue Jan 12 09:59:45 2010 +0100 +++ b/src/Tools/Code/code_preproc.ML Tue Jan 12 16:27:42 2010 +0100 @@ -18,7 +18,7 @@ type code_algebra type code_graph - val eqns: code_graph -> string -> (thm * bool) list + val cert: code_graph -> string -> Code.cert val sortargs: code_graph -> string -> sort list val all: code_graph -> string list val pretty: theory -> code_graph -> Pretty.T @@ -53,8 +53,8 @@ let val pre = Simplifier.merge_ss (pre1, pre2); val post = Simplifier.merge_ss (post1, post2); - val functrans = AList.merge (op =) (eq_fst (op =)) (functrans1, functrans2); - (* FIXME handle AList.DUP (!?) *) + val functrans = AList.merge (op =) (eq_fst (op =)) (functrans1, functrans2) + handle AList.DUP => error ("Duplicate function transformer"); in make_thmproc ((pre, post), functrans) end; structure Code_Preproc_Data = Theory_Data @@ -102,23 +102,14 @@ (* post- and preprocessing *) -fun apply_functrans thy c _ [] = [] - | apply_functrans thy c [] eqns = eqns - | apply_functrans thy c functrans eqns = eqns - |> perhaps (perhaps_loop (perhaps_apply functrans)) - |> Code.assert_eqns_const thy c - (*FIXME in future, the check here should be more accurate wrt. type schemes - -- perhaps by means of upcoming code certificates with a corresponding - preprocessor protocol*); - fun trans_conv_rule conv thm = Thm.transitive thm ((conv o Thm.rhs_of) thm); -fun eqn_conv conv = +fun eqn_conv conv ct = let fun lhs_conv ct = if can Thm.dest_comb ct then Conv.combination_conv lhs_conv conv ct else Conv.all_conv ct; - in Conv.combination_conv (Conv.arg_conv lhs_conv) conv end; + in Conv.combination_conv (Conv.arg_conv lhs_conv) conv ct end; val rewrite_eqn = Conv.fconv_rule o eqn_conv o Simplifier.rewrite; @@ -129,17 +120,15 @@ #> Logic.dest_equals #> snd; -fun preprocess thy c eqns = +fun preprocess thy eqns = let val pre = (Simplifier.theory_context thy o #pre o the_thmproc) thy; val functrans = (map (fn (_, (_, f)) => f thy) o #functrans o the_thmproc) thy; in eqns - |> apply_functrans thy c functrans + |> perhaps (perhaps_loop (perhaps_apply functrans)) |> (map o apfst) (rewrite_eqn pre) - |> (map o apfst) (AxClass.unoverload thy) - |> map (Code.assert_eqn thy) end; fun preprocess_conv thy ct = @@ -196,20 +185,20 @@ (** sort algebra and code equation graph types **) type code_algebra = (sort -> sort) * Sorts.algebra; -type code_graph = ((string * sort) list * (thm * bool) list) Graph.T; +type code_graph = ((string * sort) list * Code.cert) Graph.T; -fun eqns eqngr = these o Option.map snd o try (Graph.get_node eqngr); -fun sortargs eqngr = map snd o fst o Graph.get_node eqngr +fun cert eqngr = snd o Graph.get_node eqngr; +fun sortargs eqngr = map snd o fst o Graph.get_node eqngr; fun all eqngr = Graph.keys eqngr; fun pretty thy eqngr = AList.make (snd o Graph.get_node eqngr) (Graph.keys eqngr) |> (map o apfst) (Code.string_of_const thy) |> sort (string_ord o pairself fst) - |> map (fn (s, thms) => + |> map (fn (s, cert) => (Pretty.block o Pretty.fbreaks) ( Pretty.str s - :: map (Display.pretty_thm_global thy o fst) thms + :: map (Display.pretty_thm_global thy o AxClass.overload thy o fst) (Code.eqns_of_cert thy cert) )) |> Pretty.chunks; @@ -227,12 +216,12 @@ map (fn (c, _) => AxClass.param_of_inst thy (c, tyco)) o maps (#params o AxClass.get_info thy); -fun typscheme_rhss thy c eqns = +fun typscheme_rhss thy c cert = let - val tyscm = Code.typscheme_eqns thy c (map fst eqns); + val (tyscm, equations) = Code.dest_cert thy cert; val rhss = [] |> (fold o fold o fold_aterms) - (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, Logic.unvarifyT ty)) | _ => I) - (map (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of o fst) eqns); + (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, ty)) | _ => I) + (map (op :: o swap o fst) equations); in (tyscm, rhss) end; @@ -259,7 +248,7 @@ | NONE => Free; type vardeps_data = ((string * styp list) list * class list) Vargraph.T - * (((string * sort) list * (thm * bool) list) Symtab.table + * (((string * sort) list * Code.cert) Symtab.table * (class * string) list); val empty_vardeps_data : vardeps_data = @@ -270,12 +259,11 @@ fun obtain_eqns thy eqngr c = case try (Graph.get_node eqngr) c - of SOME (lhs, eqns) => ((lhs, []), []) + of SOME (lhs, cert) => ((lhs, []), cert) | NONE => let - val eqns = Code.these_eqns thy c - |> preprocess thy c; - val ((lhs, _), rhss) = typscheme_rhss thy c eqns; - in ((lhs, rhss), eqns) end; + val cert = Code.get_cert thy (preprocess thy) c; + val ((lhs, _), rhss) = typscheme_rhss thy c cert; + in ((lhs, rhss), cert) end; fun obtain_instance thy arities (inst as (class, tyco)) = case AList.lookup (op =) arities inst @@ -396,32 +384,27 @@ handle Sorts.CLASS_ERROR _ => [] (*permissive!*)) end; -fun inst_thm thy tvars' thm = +fun inst_cert thy lhs cert = let - val tvars = (Term.add_tvars o Thm.prop_of) thm []; - val inter_sort = Sorts.inter_sort (Sign.classes_of thy); - fun mk_inst (tvar as (v, sort)) = case Vartab.lookup tvars' v - of SOME sort' => SOME (pairself (Thm.ctyp_of thy o TVar) - (tvar, (v, inter_sort (sort, sort')))) - | NONE => NONE; - val insts = map_filter mk_inst tvars; - in Thm.instantiate (insts, []) thm end; + val ((vs, _), _) = Code.dest_cert thy cert; + val sorts = map (fn (v, sort) => case AList.lookup (op =) lhs v + of SOME sort' => Sorts.inter_sort (Sign.classes_of thy) (sort, sort') + | NONE => sort) vs; + in Code.constrain_cert thy sorts cert end; fun add_arity thy vardeps (class, tyco) = AList.default (op =) ((class, tyco), - map_range (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k)) (Sign.arity_number thy tyco)); + map_range (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k)) + (Sign.arity_number thy tyco)); -fun add_eqs thy vardeps (c, (proto_lhs, proto_eqns)) (rhss, eqngr) = +fun add_cert thy vardeps (c, (proto_lhs, proto_cert)) (rhss, eqngr) = if can (Graph.get_node eqngr) c then (rhss, eqngr) else let val lhs = map_index (fn (k, (v, _)) => (v, snd (Vargraph.get_node vardeps (Fun c, k)))) proto_lhs; - val inst_tab = Vartab.empty |> fold (fn (v, sort) => - Vartab.update ((v, 0), sort)) lhs; - val eqns = proto_eqns - |> (map o apfst) (inst_thm thy inst_tab); - val ((vs, _), rhss') = typscheme_rhss thy c eqns; - val eqngr' = Graph.new_node (c, (vs, eqns)) eqngr; + val cert = inst_cert thy lhs proto_cert; + val ((vs, _), rhss') = typscheme_rhss thy c cert; + val eqngr' = Graph.new_node (c, (vs, cert)) eqngr; in (map (pair c) rhss' @ rhss, eqngr') end; fun extend_arities_eqngr thy cs ts (arities, (eqngr : code_graph)) = @@ -435,7 +418,7 @@ val pp = Syntax.pp_global thy; val algebra = Sorts.subalgebra pp (is_proper_class thy) (AList.lookup (op =) arities') (Sign.classes_of thy); - val (rhss, eqngr') = Symtab.fold (add_eqs thy vardeps) eqntab ([], eqngr); + val (rhss, eqngr') = Symtab.fold (add_cert thy vardeps) eqntab ([], eqngr); fun deps_of (c, rhs) = c :: maps (dicts_of thy algebra) (rhs ~~ sortargs eqngr' c); val eqngr'' = fold (fn (c, rhs) => fold