--- a/src/HOL/Tools/recfun_codegen.ML Tue Jan 12 09:59:45 2010 +0100
+++ b/src/HOL/Tools/recfun_codegen.ML Tue Jan 12 16:27:42 2010 +0100
@@ -43,8 +43,10 @@
fun get_equations thy defs (raw_c, T) = if raw_c = @{const_name "op ="} then ([], "") else
let
val c = AxClass.unoverload_const thy (raw_c, T);
- val raw_thms = Code.these_eqns thy c
+ val raw_thms = Code.get_cert thy I c
+ |> Code.eqns_of_cert thy
|> map_filter (fn (thm, linear) => if linear then SOME thm else NONE)
+ |> map (AxClass.overload thy)
|> filter (is_instance T o snd o const_of o prop_of);
val module_name = case Symtab.lookup (ModuleData.get thy) c
of SOME module_name => module_name
--- a/src/Pure/Isar/code.ML Tue Jan 12 09:59:45 2010 +0100
+++ b/src/Pure/Isar/code.ML Tue Jan 12 16:27:42 2010 +0100
@@ -28,17 +28,13 @@
val mk_eqn_warning: theory -> thm -> (thm * bool) option
val mk_eqn_liberal: theory -> thm -> (thm * bool) option
val assert_eqn: theory -> thm * bool -> thm * bool
- val assert_eqns_const: theory -> string
- -> (thm * bool) list -> (thm * bool) list
val const_typ_eqn: theory -> thm -> string * typ
- val typscheme_eqn: theory -> thm -> (string * sort) list * typ
- val typscheme_eqns: theory -> string -> thm list -> (string * sort) list * typ
- val standard_typscheme: theory -> thm list -> thm list
type cert = thm * bool list
- val cert_of_eqns: theory -> (thm * bool) list -> cert
+ val empty_cert: theory -> string -> cert
+ val cert_of_eqns: theory -> string -> (thm * bool) list -> cert
val constrain_cert: theory -> sort list -> cert -> cert
- val dest_cert: theory -> cert -> (string * ((string * sort) list * typ)) * ((term list * term) * bool) list
val eqns_of_cert: theory -> cert -> (thm * bool) list
+ val dest_cert: theory -> cert -> ((string * sort) list * typ) * ((term list * term) * (thm * bool)) list
(*executable code*)
val add_type: string -> theory -> theory
@@ -61,8 +57,7 @@
val add_undefined: string -> theory -> theory
val get_datatype: theory -> string -> ((string * sort) list * (string * typ list) list)
val get_datatype_of_constr: theory -> string -> string option
- val these_eqns: theory -> string -> (thm * bool) list
- val eqn_cert: theory -> string -> cert
+ val get_cert: theory -> ((thm * bool) list -> (thm * bool) list) -> string -> cert
val get_case_scheme: theory -> string -> (int * (int * string list)) option
val undefineds: theory -> string list
val print_codesetup: theory -> unit
@@ -531,20 +526,6 @@
fun typscheme thy (c, ty) = logical_typscheme thy (c, subst_signature thy c ty);
-fun typscheme_eqn thy = typscheme thy o apsnd Logic.unvarifyT o const_typ_eqn thy;
-
-fun typscheme_eqns thy c [] =
- let
- val raw_ty = const_typ thy c;
- val tvars = Term.add_tvar_namesT raw_ty [];
- val tvars' = case AxClass.class_of_param thy c
- of SOME class => [TFree (Name.aT, [class])]
- | NONE => Name.invent_list [] Name.aT (length tvars)
- |> map (fn v => TFree (v, []));
- val ty = typ_subst_TVars (tvars ~~ tvars') raw_ty;
- in logical_typscheme thy (c, ty) end
- | typscheme_eqns thy c (thm :: _) = typscheme_eqn thy thm;
-
fun assert_eqns_const thy c eqns =
let
fun cert (eqn as (thm, _)) = if c = const_eqn thy thm
@@ -555,93 +536,97 @@
(* code equation certificates *)
-fun standard_typscheme thy thms =
- let
- fun tvars_of T = rev (Term.add_tvarsT T []);
- val vss = map (tvars_of o snd o head_eqn) thms;
- fun inter_sorts vs =
- fold (curry (Sorts.inter_sort (Sign.classes_of thy)) o snd) vs [];
- val sorts = map_transpose inter_sorts vss;
- val vts = Name.names Name.context Name.aT sorts
- |> map (fn (v, sort) => TVar ((v, 0), sort));
- in map2 (fn vs => Thm.certify_instantiate (vs ~~ vts, [])) vss thms end;
-
type cert = thm * bool list;
-fun cert_of_eqns thy [] = (Drule.dummy_thm, [])
- | cert_of_eqns thy eqns =
+fun mk_head_cterm thy (c, ty) =
+ Thm.cterm_of thy (Logic.mk_equals (Free ("HEAD", ty), Const (c, ty)));
+
+fun empty_cert thy c =
+ let
+ val raw_ty = const_typ thy c;
+ val tvars = Term.add_tvar_namesT raw_ty [];
+ val tvars' = case AxClass.class_of_param thy c
+ of SOME class => [TFree (Name.aT, [class])]
+ | NONE => Name.invent_list [] Name.aT (length tvars)
+ |> map (fn v => TFree (v, []));
+ val ty = typ_subst_TVars (tvars ~~ tvars') raw_ty;
+ val chead = mk_head_cterm thy (c, ty);
+ in (Thm.weaken chead Drule.dummy_thm, []) end;
+
+fun cert_of_eqns thy c [] = empty_cert thy c
+ | cert_of_eqns thy c eqns =
let
- val propers = map snd eqns;
- val thms as thm :: _ = (map Thm.freezeT o standard_typscheme thy o map fst) eqns; (*FIXME*)
- val (c, ty) = head_eqn thm;
- val head_thm = Thm.assume (Thm.cterm_of thy (Logic.mk_equals
- (Free ("HEAD", ty), Const (c, ty)))) |> Thm.symmetric;
+ val _ = assert_eqns_const thy c eqns;
+ val (thms, propers) = split_list eqns;
+ fun tvars_of T = rev (Term.add_tvarsT T []);
+ val vss = map (tvars_of o snd o head_eqn) thms;
+ fun inter_sorts vs =
+ fold (curry (Sorts.inter_sort (Sign.classes_of thy)) o snd) vs [];
+ val sorts = map_transpose inter_sorts vss;
+ val vts = Name.names Name.context Name.aT sorts;
+ val thms as thm :: _ =
+ map2 (fn vs => Thm.certify_instantiate (vs ~~ map TFree vts, [])) vss thms;
+ val head_thm = Thm.symmetric (Thm.assume (mk_head_cterm thy (head_eqn (hd thms))));
fun head_conv ct = if can Thm.dest_comb ct
then Conv.fun_conv head_conv ct
else Conv.rewr_conv head_thm ct;
val rewrite_head = Conv.fconv_rule (Conv.arg1_conv head_conv);
- val cert = Conjunction.intr_balanced (map rewrite_head thms);
- in (cert, propers) end;
+ val cert_thm = Conjunction.intr_balanced (map rewrite_head thms);
+ in (cert_thm, propers) end;
-fun head_cert thy cert =
+fun head_cert thy cert_thm =
+ let
+ val [head] = Thm.hyps_of cert_thm;
+ val (Free (h, _), Const (c, ty)) = Logic.dest_equals head;
+ in ((c, typscheme thy (c, ty)), (head, h)) end;
+
+fun constrain_cert thy sorts (cert_thm, propers) =
let
- val [head] = Thm.hyps_of cert;
- val (Free (h, _), Const (c, ty)) = (Logic.dest_equals o the_single o Thm.hyps_of) cert;
- in ((c, typscheme thy (AxClass.unoverload_const thy (c, ty), ty)), (head, h)) end;
+ val ((c, (vs, _)), (head, _)) = head_cert thy cert_thm;
+ val subst = map2 (fn (v, _) => fn sort => (v, sort)) vs sorts;
+ val head' = (map_types o map_atyps)
+ (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v))) head;
+ val inst = (map2 (fn (v, sort) => fn sort' =>
+ pairself (Thm.ctyp_of thy) (TVar ((v, 0), sort), TFree (v, sort'))) vs sorts, []);
+ val cert_thm' = cert_thm
+ |> Thm.implies_intr (Thm.cterm_of thy head)
+ |> Thm.varifyT
+ |> Thm.instantiate inst
+ |> Thm.elim_implies (Thm.assume (Thm.cterm_of thy head'));
+ in (cert_thm', propers) end;
-fun constrain_cert thy sorts (cert, []) = (cert, [])
- | constrain_cert thy sorts (cert, propers) =
+fun eqns_of_cert thy (cert_thm, []) = []
+ | eqns_of_cert thy (cert_thm, propers) =
let
- val ((c, (vs, _)), (head, _)) = head_cert thy cert;
- val subst = map2 (fn (v, _) => fn sort => (v, sort)) vs sorts;
- val head' = (map_types o map_atyps)
- (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v))) head;
- val inst = (map2 (fn (v, sort) => fn sort' =>
- pairself (Thm.ctyp_of thy) (TVar ((v, 0), sort), TFree (v, sort'))) vs sorts ,[]);
- val cert' = cert
- |> Thm.implies_intr (Thm.cterm_of thy head)
- |> Thm.varifyT
- |> Thm.instantiate inst
- |> Thm.elim_implies (Thm.assume (Thm.cterm_of thy head'))
- in (cert', propers) end;
-
-fun dest_cert thy (cert, propers) =
- let
- val (c_vs_ty, (head, h)) = head_cert thy cert;
- val equations = cert
- |> Thm.prop_of
- |> Logic.dest_conjunction_balanced (length propers)
- |> map Logic.dest_equals
- |> (map o apfst) strip_comb
- |> (map o apfst) (fn (Free (h', _), ts) => case h = h' of True => ts)
- in (c_vs_ty, equations ~~ propers) end;
-
-fun eqns_of_cert thy (cert, []) = []
- | eqns_of_cert thy (cert, propers) =
- let
- val (_, (head, _)) = head_cert thy cert;
- val thms = cert
+ val (_, (head, _)) = head_cert thy cert_thm;
+ val thms = cert_thm
|> LocalDefs.expand [Thm.cterm_of thy head]
|> Thm.varifyT
|> Conjunction.elim_balanced (length propers)
in thms ~~ propers end;
+fun dest_cert thy (cert as (cert_thm, propers)) =
+ let
+ val eqns = eqns_of_cert thy cert;
+ val ((_, vs_ty), _) = head_cert thy cert_thm;
+ val equations = if null propers then [] else cert_thm
+ |> Thm.prop_of
+ |> Logic.dest_conjunction_balanced (length propers)
+ |> map Logic.dest_equals
+ |> (map o apfst) (snd o strip_comb)
+ in (vs_ty, equations ~~ eqns) end;
+
(* code equation access *)
-fun these_eqns thy c =
+fun get_cert thy f c =
Symtab.lookup ((the_eqns o the_exec) thy) c
|> Option.map (snd o snd o fst)
|> these
|> (map o apfst) (Thm.transfer thy)
- |> burrow_fst (standard_typscheme thy);
-
-fun eqn_cert thy c =
- Symtab.lookup ((the_eqns o the_exec) thy) c
- |> Option.map (snd o snd o fst)
- |> these
- |> (map o apfst) (Thm.transfer thy)
- |> cert_of_eqns thy;
+ |> f
+ |> (map o apfst) (AxClass.unoverload thy)
+ |> cert_of_eqns thy c;
(* cases *)
--- a/src/Tools/Code/code_preproc.ML Tue Jan 12 09:59:45 2010 +0100
+++ b/src/Tools/Code/code_preproc.ML Tue Jan 12 16:27:42 2010 +0100
@@ -18,7 +18,7 @@
type code_algebra
type code_graph
- val eqns: code_graph -> string -> (thm * bool) list
+ val cert: code_graph -> string -> Code.cert
val sortargs: code_graph -> string -> sort list
val all: code_graph -> string list
val pretty: theory -> code_graph -> Pretty.T
@@ -53,8 +53,8 @@
let
val pre = Simplifier.merge_ss (pre1, pre2);
val post = Simplifier.merge_ss (post1, post2);
- val functrans = AList.merge (op =) (eq_fst (op =)) (functrans1, functrans2);
- (* FIXME handle AList.DUP (!?) *)
+ val functrans = AList.merge (op =) (eq_fst (op =)) (functrans1, functrans2)
+ handle AList.DUP => error ("Duplicate function transformer");
in make_thmproc ((pre, post), functrans) end;
structure Code_Preproc_Data = Theory_Data
@@ -102,23 +102,14 @@
(* post- and preprocessing *)
-fun apply_functrans thy c _ [] = []
- | apply_functrans thy c [] eqns = eqns
- | apply_functrans thy c functrans eqns = eqns
- |> perhaps (perhaps_loop (perhaps_apply functrans))
- |> Code.assert_eqns_const thy c
- (*FIXME in future, the check here should be more accurate wrt. type schemes
- -- perhaps by means of upcoming code certificates with a corresponding
- preprocessor protocol*);
-
fun trans_conv_rule conv thm = Thm.transitive thm ((conv o Thm.rhs_of) thm);
-fun eqn_conv conv =
+fun eqn_conv conv ct =
let
fun lhs_conv ct = if can Thm.dest_comb ct
then Conv.combination_conv lhs_conv conv ct
else Conv.all_conv ct;
- in Conv.combination_conv (Conv.arg_conv lhs_conv) conv end;
+ in Conv.combination_conv (Conv.arg_conv lhs_conv) conv ct end;
val rewrite_eqn = Conv.fconv_rule o eqn_conv o Simplifier.rewrite;
@@ -129,17 +120,15 @@
#> Logic.dest_equals
#> snd;
-fun preprocess thy c eqns =
+fun preprocess thy eqns =
let
val pre = (Simplifier.theory_context thy o #pre o the_thmproc) thy;
val functrans = (map (fn (_, (_, f)) => f thy) o #functrans
o the_thmproc) thy;
in
eqns
- |> apply_functrans thy c functrans
+ |> perhaps (perhaps_loop (perhaps_apply functrans))
|> (map o apfst) (rewrite_eqn pre)
- |> (map o apfst) (AxClass.unoverload thy)
- |> map (Code.assert_eqn thy)
end;
fun preprocess_conv thy ct =
@@ -196,20 +185,20 @@
(** sort algebra and code equation graph types **)
type code_algebra = (sort -> sort) * Sorts.algebra;
-type code_graph = ((string * sort) list * (thm * bool) list) Graph.T;
+type code_graph = ((string * sort) list * Code.cert) Graph.T;
-fun eqns eqngr = these o Option.map snd o try (Graph.get_node eqngr);
-fun sortargs eqngr = map snd o fst o Graph.get_node eqngr
+fun cert eqngr = snd o Graph.get_node eqngr;
+fun sortargs eqngr = map snd o fst o Graph.get_node eqngr;
fun all eqngr = Graph.keys eqngr;
fun pretty thy eqngr =
AList.make (snd o Graph.get_node eqngr) (Graph.keys eqngr)
|> (map o apfst) (Code.string_of_const thy)
|> sort (string_ord o pairself fst)
- |> map (fn (s, thms) =>
+ |> map (fn (s, cert) =>
(Pretty.block o Pretty.fbreaks) (
Pretty.str s
- :: map (Display.pretty_thm_global thy o fst) thms
+ :: map (Display.pretty_thm_global thy o AxClass.overload thy o fst) (Code.eqns_of_cert thy cert)
))
|> Pretty.chunks;
@@ -227,12 +216,12 @@
map (fn (c, _) => AxClass.param_of_inst thy (c, tyco))
o maps (#params o AxClass.get_info thy);
-fun typscheme_rhss thy c eqns =
+fun typscheme_rhss thy c cert =
let
- val tyscm = Code.typscheme_eqns thy c (map fst eqns);
+ val (tyscm, equations) = Code.dest_cert thy cert;
val rhss = [] |> (fold o fold o fold_aterms)
- (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, Logic.unvarifyT ty)) | _ => I)
- (map (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of o fst) eqns);
+ (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, ty)) | _ => I)
+ (map (op :: o swap o fst) equations);
in (tyscm, rhss) end;
@@ -259,7 +248,7 @@
| NONE => Free;
type vardeps_data = ((string * styp list) list * class list) Vargraph.T
- * (((string * sort) list * (thm * bool) list) Symtab.table
+ * (((string * sort) list * Code.cert) Symtab.table
* (class * string) list);
val empty_vardeps_data : vardeps_data =
@@ -270,12 +259,11 @@
fun obtain_eqns thy eqngr c =
case try (Graph.get_node eqngr) c
- of SOME (lhs, eqns) => ((lhs, []), [])
+ of SOME (lhs, cert) => ((lhs, []), cert)
| NONE => let
- val eqns = Code.these_eqns thy c
- |> preprocess thy c;
- val ((lhs, _), rhss) = typscheme_rhss thy c eqns;
- in ((lhs, rhss), eqns) end;
+ val cert = Code.get_cert thy (preprocess thy) c;
+ val ((lhs, _), rhss) = typscheme_rhss thy c cert;
+ in ((lhs, rhss), cert) end;
fun obtain_instance thy arities (inst as (class, tyco)) =
case AList.lookup (op =) arities inst
@@ -396,32 +384,27 @@
handle Sorts.CLASS_ERROR _ => [] (*permissive!*))
end;
-fun inst_thm thy tvars' thm =
+fun inst_cert thy lhs cert =
let
- val tvars = (Term.add_tvars o Thm.prop_of) thm [];
- val inter_sort = Sorts.inter_sort (Sign.classes_of thy);
- fun mk_inst (tvar as (v, sort)) = case Vartab.lookup tvars' v
- of SOME sort' => SOME (pairself (Thm.ctyp_of thy o TVar)
- (tvar, (v, inter_sort (sort, sort'))))
- | NONE => NONE;
- val insts = map_filter mk_inst tvars;
- in Thm.instantiate (insts, []) thm end;
+ val ((vs, _), _) = Code.dest_cert thy cert;
+ val sorts = map (fn (v, sort) => case AList.lookup (op =) lhs v
+ of SOME sort' => Sorts.inter_sort (Sign.classes_of thy) (sort, sort')
+ | NONE => sort) vs;
+ in Code.constrain_cert thy sorts cert end;
fun add_arity thy vardeps (class, tyco) =
AList.default (op =) ((class, tyco),
- map_range (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k)) (Sign.arity_number thy tyco));
+ map_range (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k))
+ (Sign.arity_number thy tyco));
-fun add_eqs thy vardeps (c, (proto_lhs, proto_eqns)) (rhss, eqngr) =
+fun add_cert thy vardeps (c, (proto_lhs, proto_cert)) (rhss, eqngr) =
if can (Graph.get_node eqngr) c then (rhss, eqngr)
else let
val lhs = map_index (fn (k, (v, _)) =>
(v, snd (Vargraph.get_node vardeps (Fun c, k)))) proto_lhs;
- val inst_tab = Vartab.empty |> fold (fn (v, sort) =>
- Vartab.update ((v, 0), sort)) lhs;
- val eqns = proto_eqns
- |> (map o apfst) (inst_thm thy inst_tab);
- val ((vs, _), rhss') = typscheme_rhss thy c eqns;
- val eqngr' = Graph.new_node (c, (vs, eqns)) eqngr;
+ val cert = inst_cert thy lhs proto_cert;
+ val ((vs, _), rhss') = typscheme_rhss thy c cert;
+ val eqngr' = Graph.new_node (c, (vs, cert)) eqngr;
in (map (pair c) rhss' @ rhss, eqngr') end;
fun extend_arities_eqngr thy cs ts (arities, (eqngr : code_graph)) =
@@ -435,7 +418,7 @@
val pp = Syntax.pp_global thy;
val algebra = Sorts.subalgebra pp (is_proper_class thy)
(AList.lookup (op =) arities') (Sign.classes_of thy);
- val (rhss, eqngr') = Symtab.fold (add_eqs thy vardeps) eqntab ([], eqngr);
+ val (rhss, eqngr') = Symtab.fold (add_cert thy vardeps) eqntab ([], eqngr);
fun deps_of (c, rhs) = c :: maps (dicts_of thy algebra)
(rhs ~~ sortargs eqngr' c);
val eqngr'' = fold (fn (c, rhs) => fold
--- a/src/Tools/Code/code_thingol.ML Tue Jan 12 09:59:45 2010 +0100
+++ b/src/Tools/Code/code_thingol.ML Tue Jan 12 16:27:42 2010 +0100
@@ -447,7 +447,7 @@
in Thm.certify_instantiate ([], var_subst) thm end;
fun canonize_thms thy = map (Thm.transfer thy)
- #> Code.standard_typscheme thy #> desymbolize_tvars thy
+ #> desymbolize_tvars thy
#> same_arity thy #> map (desymbolize_vars thy);
@@ -612,10 +612,10 @@
fun stmt_classparam class =
ensure_class thy algbr eqngr class
#>> (fn class => Classparam (c, class));
- fun stmt_fun raw_eqns =
+ fun stmt_fun cert =
let
- val eqns = burrow_fst (canonize_thms thy) raw_eqns;
- val (vs, ty) = Code.typscheme_eqns thy c (map fst eqns);
+ val ((vs, ty), raw_eqns) = Code.dest_cert thy cert;
+ val eqns = burrow_fst (canonize_thms thy) (map snd raw_eqns);
in
fold_map (translate_tyvar_sort thy algbr eqngr) vs
##>> translate_typ thy algbr eqngr ty
@@ -626,7 +626,7 @@
of SOME tyco => stmt_datatypecons tyco
| NONE => (case AxClass.class_of_param thy c
of SOME class => stmt_classparam class
- | NONE => stmt_fun (Code_Preproc.eqns eqngr c))
+ | NONE => stmt_fun (Code_Preproc.cert eqngr c))
in ensure_stmt lookup_const (declare_const thy) stmt_const c end
and ensure_class thy (algbr as (_, algebra)) eqngr class =
let
@@ -933,11 +933,7 @@
let
val (_, eqngr) = Code_Preproc.obtain thy consts [];
val all_consts = Graph.all_succs eqngr consts;
- in
- eqngr
- |> Graph.subgraph (member (op =) all_consts)
- |> Graph.map_nodes ((apsnd o map o apfst) (AxClass.overload thy))
- end;
+ in Graph.subgraph (member (op =) all_consts) eqngr end;
fun code_thms thy = Pretty.writeln o Code_Preproc.pretty thy o code_depgr thy;