# HG changeset patch # User haftmann # Date 1263310062 -3600 # Node ID 99b9a6290446b41d7c880ee8f4a61ce30ea59961 # Parent ded5b770ec1cfb6a527572b8a28185dff7da78ab code certificates as integral part of code generation diff -r ded5b770ec1c -r 99b9a6290446 src/HOL/Tools/recfun_codegen.ML --- a/src/HOL/Tools/recfun_codegen.ML Tue Jan 12 09:59:45 2010 +0100 +++ b/src/HOL/Tools/recfun_codegen.ML Tue Jan 12 16:27:42 2010 +0100 @@ -43,8 +43,10 @@ fun get_equations thy defs (raw_c, T) = if raw_c = @{const_name "op ="} then ([], "") else let val c = AxClass.unoverload_const thy (raw_c, T); - val raw_thms = Code.these_eqns thy c + val raw_thms = Code.get_cert thy I c + |> Code.eqns_of_cert thy |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE) + |> map (AxClass.overload thy) |> filter (is_instance T o snd o const_of o prop_of); val module_name = case Symtab.lookup (ModuleData.get thy) c of SOME module_name => module_name diff -r ded5b770ec1c -r 99b9a6290446 src/Pure/Isar/code.ML --- a/src/Pure/Isar/code.ML Tue Jan 12 09:59:45 2010 +0100 +++ b/src/Pure/Isar/code.ML Tue Jan 12 16:27:42 2010 +0100 @@ -28,17 +28,13 @@ val mk_eqn_warning: theory -> thm -> (thm * bool) option val mk_eqn_liberal: theory -> thm -> (thm * bool) option val assert_eqn: theory -> thm * bool -> thm * bool - val assert_eqns_const: theory -> string - -> (thm * bool) list -> (thm * bool) list val const_typ_eqn: theory -> thm -> string * typ - val typscheme_eqn: theory -> thm -> (string * sort) list * typ - val typscheme_eqns: theory -> string -> thm list -> (string * sort) list * typ - val standard_typscheme: theory -> thm list -> thm list type cert = thm * bool list - val cert_of_eqns: theory -> (thm * bool) list -> cert + val empty_cert: theory -> string -> cert + val cert_of_eqns: theory -> string -> (thm * bool) list -> cert val constrain_cert: theory -> sort list -> cert -> cert - val dest_cert: theory -> cert -> (string * ((string * sort) list * typ)) * ((term list * term) * bool) list val eqns_of_cert: theory -> cert -> (thm * bool) list + val dest_cert: theory -> cert -> ((string * sort) list * typ) * ((term list * term) * (thm * bool)) list (*executable code*) val add_type: string -> theory -> theory @@ -61,8 +57,7 @@ val add_undefined: string -> theory -> theory val get_datatype: theory -> string -> ((string * sort) list * (string * typ list) list) val get_datatype_of_constr: theory -> string -> string option - val these_eqns: theory -> string -> (thm * bool) list - val eqn_cert: theory -> string -> cert + val get_cert: theory -> ((thm * bool) list -> (thm * bool) list) -> string -> cert val get_case_scheme: theory -> string -> (int * (int * string list)) option val undefineds: theory -> string list val print_codesetup: theory -> unit @@ -531,20 +526,6 @@ fun typscheme thy (c, ty) = logical_typscheme thy (c, subst_signature thy c ty); -fun typscheme_eqn thy = typscheme thy o apsnd Logic.unvarifyT o const_typ_eqn thy; - -fun typscheme_eqns thy c [] = - let - val raw_ty = const_typ thy c; - val tvars = Term.add_tvar_namesT raw_ty []; - val tvars' = case AxClass.class_of_param thy c - of SOME class => [TFree (Name.aT, [class])] - | NONE => Name.invent_list [] Name.aT (length tvars) - |> map (fn v => TFree (v, [])); - val ty = typ_subst_TVars (tvars ~~ tvars') raw_ty; - in logical_typscheme thy (c, ty) end - | typscheme_eqns thy c (thm :: _) = typscheme_eqn thy thm; - fun assert_eqns_const thy c eqns = let fun cert (eqn as (thm, _)) = if c = const_eqn thy thm @@ -555,93 +536,97 @@ (* code equation certificates *) -fun standard_typscheme thy thms = - let - fun tvars_of T = rev (Term.add_tvarsT T []); - val vss = map (tvars_of o snd o head_eqn) thms; - fun inter_sorts vs = - fold (curry (Sorts.inter_sort (Sign.classes_of thy)) o snd) vs []; - val sorts = map_transpose inter_sorts vss; - val vts = Name.names Name.context Name.aT sorts - |> map (fn (v, sort) => TVar ((v, 0), sort)); - in map2 (fn vs => Thm.certify_instantiate (vs ~~ vts, [])) vss thms end; - type cert = thm * bool list; -fun cert_of_eqns thy [] = (Drule.dummy_thm, []) - | cert_of_eqns thy eqns = +fun mk_head_cterm thy (c, ty) = + Thm.cterm_of thy (Logic.mk_equals (Free ("HEAD", ty), Const (c, ty))); + +fun empty_cert thy c = + let + val raw_ty = const_typ thy c; + val tvars = Term.add_tvar_namesT raw_ty []; + val tvars' = case AxClass.class_of_param thy c + of SOME class => [TFree (Name.aT, [class])] + | NONE => Name.invent_list [] Name.aT (length tvars) + |> map (fn v => TFree (v, [])); + val ty = typ_subst_TVars (tvars ~~ tvars') raw_ty; + val chead = mk_head_cterm thy (c, ty); + in (Thm.weaken chead Drule.dummy_thm, []) end; + +fun cert_of_eqns thy c [] = empty_cert thy c + | cert_of_eqns thy c eqns = let - val propers = map snd eqns; - val thms as thm :: _ = (map Thm.freezeT o standard_typscheme thy o map fst) eqns; (*FIXME*) - val (c, ty) = head_eqn thm; - val head_thm = Thm.assume (Thm.cterm_of thy (Logic.mk_equals - (Free ("HEAD", ty), Const (c, ty)))) |> Thm.symmetric; + val _ = assert_eqns_const thy c eqns; + val (thms, propers) = split_list eqns; + fun tvars_of T = rev (Term.add_tvarsT T []); + val vss = map (tvars_of o snd o head_eqn) thms; + fun inter_sorts vs = + fold (curry (Sorts.inter_sort (Sign.classes_of thy)) o snd) vs []; + val sorts = map_transpose inter_sorts vss; + val vts = Name.names Name.context Name.aT sorts; + val thms as thm :: _ = + map2 (fn vs => Thm.certify_instantiate (vs ~~ map TFree vts, [])) vss thms; + val head_thm = Thm.symmetric (Thm.assume (mk_head_cterm thy (head_eqn (hd thms)))); fun head_conv ct = if can Thm.dest_comb ct then Conv.fun_conv head_conv ct else Conv.rewr_conv head_thm ct; val rewrite_head = Conv.fconv_rule (Conv.arg1_conv head_conv); - val cert = Conjunction.intr_balanced (map rewrite_head thms); - in (cert, propers) end; + val cert_thm = Conjunction.intr_balanced (map rewrite_head thms); + in (cert_thm, propers) end; -fun head_cert thy cert = +fun head_cert thy cert_thm = + let + val [head] = Thm.hyps_of cert_thm; + val (Free (h, _), Const (c, ty)) = Logic.dest_equals head; + in ((c, typscheme thy (c, ty)), (head, h)) end; + +fun constrain_cert thy sorts (cert_thm, propers) = let - val [head] = Thm.hyps_of cert; - val (Free (h, _), Const (c, ty)) = (Logic.dest_equals o the_single o Thm.hyps_of) cert; - in ((c, typscheme thy (AxClass.unoverload_const thy (c, ty), ty)), (head, h)) end; + val ((c, (vs, _)), (head, _)) = head_cert thy cert_thm; + val subst = map2 (fn (v, _) => fn sort => (v, sort)) vs sorts; + val head' = (map_types o map_atyps) + (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v))) head; + val inst = (map2 (fn (v, sort) => fn sort' => + pairself (Thm.ctyp_of thy) (TVar ((v, 0), sort), TFree (v, sort'))) vs sorts, []); + val cert_thm' = cert_thm + |> Thm.implies_intr (Thm.cterm_of thy head) + |> Thm.varifyT + |> Thm.instantiate inst + |> Thm.elim_implies (Thm.assume (Thm.cterm_of thy head')); + in (cert_thm', propers) end; -fun constrain_cert thy sorts (cert, []) = (cert, []) - | constrain_cert thy sorts (cert, propers) = +fun eqns_of_cert thy (cert_thm, []) = [] + | eqns_of_cert thy (cert_thm, propers) = let - val ((c, (vs, _)), (head, _)) = head_cert thy cert; - val subst = map2 (fn (v, _) => fn sort => (v, sort)) vs sorts; - val head' = (map_types o map_atyps) - (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v))) head; - val inst = (map2 (fn (v, sort) => fn sort' => - pairself (Thm.ctyp_of thy) (TVar ((v, 0), sort), TFree (v, sort'))) vs sorts ,[]); - val cert' = cert - |> Thm.implies_intr (Thm.cterm_of thy head) - |> Thm.varifyT - |> Thm.instantiate inst - |> Thm.elim_implies (Thm.assume (Thm.cterm_of thy head')) - in (cert', propers) end; - -fun dest_cert thy (cert, propers) = - let - val (c_vs_ty, (head, h)) = head_cert thy cert; - val equations = cert - |> Thm.prop_of - |> Logic.dest_conjunction_balanced (length propers) - |> map Logic.dest_equals - |> (map o apfst) strip_comb - |> (map o apfst) (fn (Free (h', _), ts) => case h = h' of True => ts) - in (c_vs_ty, equations ~~ propers) end; - -fun eqns_of_cert thy (cert, []) = [] - | eqns_of_cert thy (cert, propers) = - let - val (_, (head, _)) = head_cert thy cert; - val thms = cert + val (_, (head, _)) = head_cert thy cert_thm; + val thms = cert_thm |> LocalDefs.expand [Thm.cterm_of thy head] |> Thm.varifyT |> Conjunction.elim_balanced (length propers) in thms ~~ propers end; +fun dest_cert thy (cert as (cert_thm, propers)) = + let + val eqns = eqns_of_cert thy cert; + val ((_, vs_ty), _) = head_cert thy cert_thm; + val equations = if null propers then [] else cert_thm + |> Thm.prop_of + |> Logic.dest_conjunction_balanced (length propers) + |> map Logic.dest_equals + |> (map o apfst) (snd o strip_comb) + in (vs_ty, equations ~~ eqns) end; + (* code equation access *) -fun these_eqns thy c = +fun get_cert thy f c = Symtab.lookup ((the_eqns o the_exec) thy) c |> Option.map (snd o snd o fst) |> these |> (map o apfst) (Thm.transfer thy) - |> burrow_fst (standard_typscheme thy); - -fun eqn_cert thy c = - Symtab.lookup ((the_eqns o the_exec) thy) c - |> Option.map (snd o snd o fst) - |> these - |> (map o apfst) (Thm.transfer thy) - |> cert_of_eqns thy; + |> f + |> (map o apfst) (AxClass.unoverload thy) + |> cert_of_eqns thy c; (* cases *) diff -r ded5b770ec1c -r 99b9a6290446 src/Tools/Code/code_preproc.ML --- a/src/Tools/Code/code_preproc.ML Tue Jan 12 09:59:45 2010 +0100 +++ b/src/Tools/Code/code_preproc.ML Tue Jan 12 16:27:42 2010 +0100 @@ -18,7 +18,7 @@ type code_algebra type code_graph - val eqns: code_graph -> string -> (thm * bool) list + val cert: code_graph -> string -> Code.cert val sortargs: code_graph -> string -> sort list val all: code_graph -> string list val pretty: theory -> code_graph -> Pretty.T @@ -53,8 +53,8 @@ let val pre = Simplifier.merge_ss (pre1, pre2); val post = Simplifier.merge_ss (post1, post2); - val functrans = AList.merge (op =) (eq_fst (op =)) (functrans1, functrans2); - (* FIXME handle AList.DUP (!?) *) + val functrans = AList.merge (op =) (eq_fst (op =)) (functrans1, functrans2) + handle AList.DUP => error ("Duplicate function transformer"); in make_thmproc ((pre, post), functrans) end; structure Code_Preproc_Data = Theory_Data @@ -102,23 +102,14 @@ (* post- and preprocessing *) -fun apply_functrans thy c _ [] = [] - | apply_functrans thy c [] eqns = eqns - | apply_functrans thy c functrans eqns = eqns - |> perhaps (perhaps_loop (perhaps_apply functrans)) - |> Code.assert_eqns_const thy c - (*FIXME in future, the check here should be more accurate wrt. type schemes - -- perhaps by means of upcoming code certificates with a corresponding - preprocessor protocol*); - fun trans_conv_rule conv thm = Thm.transitive thm ((conv o Thm.rhs_of) thm); -fun eqn_conv conv = +fun eqn_conv conv ct = let fun lhs_conv ct = if can Thm.dest_comb ct then Conv.combination_conv lhs_conv conv ct else Conv.all_conv ct; - in Conv.combination_conv (Conv.arg_conv lhs_conv) conv end; + in Conv.combination_conv (Conv.arg_conv lhs_conv) conv ct end; val rewrite_eqn = Conv.fconv_rule o eqn_conv o Simplifier.rewrite; @@ -129,17 +120,15 @@ #> Logic.dest_equals #> snd; -fun preprocess thy c eqns = +fun preprocess thy eqns = let val pre = (Simplifier.theory_context thy o #pre o the_thmproc) thy; val functrans = (map (fn (_, (_, f)) => f thy) o #functrans o the_thmproc) thy; in eqns - |> apply_functrans thy c functrans + |> perhaps (perhaps_loop (perhaps_apply functrans)) |> (map o apfst) (rewrite_eqn pre) - |> (map o apfst) (AxClass.unoverload thy) - |> map (Code.assert_eqn thy) end; fun preprocess_conv thy ct = @@ -196,20 +185,20 @@ (** sort algebra and code equation graph types **) type code_algebra = (sort -> sort) * Sorts.algebra; -type code_graph = ((string * sort) list * (thm * bool) list) Graph.T; +type code_graph = ((string * sort) list * Code.cert) Graph.T; -fun eqns eqngr = these o Option.map snd o try (Graph.get_node eqngr); -fun sortargs eqngr = map snd o fst o Graph.get_node eqngr +fun cert eqngr = snd o Graph.get_node eqngr; +fun sortargs eqngr = map snd o fst o Graph.get_node eqngr; fun all eqngr = Graph.keys eqngr; fun pretty thy eqngr = AList.make (snd o Graph.get_node eqngr) (Graph.keys eqngr) |> (map o apfst) (Code.string_of_const thy) |> sort (string_ord o pairself fst) - |> map (fn (s, thms) => + |> map (fn (s, cert) => (Pretty.block o Pretty.fbreaks) ( Pretty.str s - :: map (Display.pretty_thm_global thy o fst) thms + :: map (Display.pretty_thm_global thy o AxClass.overload thy o fst) (Code.eqns_of_cert thy cert) )) |> Pretty.chunks; @@ -227,12 +216,12 @@ map (fn (c, _) => AxClass.param_of_inst thy (c, tyco)) o maps (#params o AxClass.get_info thy); -fun typscheme_rhss thy c eqns = +fun typscheme_rhss thy c cert = let - val tyscm = Code.typscheme_eqns thy c (map fst eqns); + val (tyscm, equations) = Code.dest_cert thy cert; val rhss = [] |> (fold o fold o fold_aterms) - (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, Logic.unvarifyT ty)) | _ => I) - (map (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of o fst) eqns); + (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, ty)) | _ => I) + (map (op :: o swap o fst) equations); in (tyscm, rhss) end; @@ -259,7 +248,7 @@ | NONE => Free; type vardeps_data = ((string * styp list) list * class list) Vargraph.T - * (((string * sort) list * (thm * bool) list) Symtab.table + * (((string * sort) list * Code.cert) Symtab.table * (class * string) list); val empty_vardeps_data : vardeps_data = @@ -270,12 +259,11 @@ fun obtain_eqns thy eqngr c = case try (Graph.get_node eqngr) c - of SOME (lhs, eqns) => ((lhs, []), []) + of SOME (lhs, cert) => ((lhs, []), cert) | NONE => let - val eqns = Code.these_eqns thy c - |> preprocess thy c; - val ((lhs, _), rhss) = typscheme_rhss thy c eqns; - in ((lhs, rhss), eqns) end; + val cert = Code.get_cert thy (preprocess thy) c; + val ((lhs, _), rhss) = typscheme_rhss thy c cert; + in ((lhs, rhss), cert) end; fun obtain_instance thy arities (inst as (class, tyco)) = case AList.lookup (op =) arities inst @@ -396,32 +384,27 @@ handle Sorts.CLASS_ERROR _ => [] (*permissive!*)) end; -fun inst_thm thy tvars' thm = +fun inst_cert thy lhs cert = let - val tvars = (Term.add_tvars o Thm.prop_of) thm []; - val inter_sort = Sorts.inter_sort (Sign.classes_of thy); - fun mk_inst (tvar as (v, sort)) = case Vartab.lookup tvars' v - of SOME sort' => SOME (pairself (Thm.ctyp_of thy o TVar) - (tvar, (v, inter_sort (sort, sort')))) - | NONE => NONE; - val insts = map_filter mk_inst tvars; - in Thm.instantiate (insts, []) thm end; + val ((vs, _), _) = Code.dest_cert thy cert; + val sorts = map (fn (v, sort) => case AList.lookup (op =) lhs v + of SOME sort' => Sorts.inter_sort (Sign.classes_of thy) (sort, sort') + | NONE => sort) vs; + in Code.constrain_cert thy sorts cert end; fun add_arity thy vardeps (class, tyco) = AList.default (op =) ((class, tyco), - map_range (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k)) (Sign.arity_number thy tyco)); + map_range (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k)) + (Sign.arity_number thy tyco)); -fun add_eqs thy vardeps (c, (proto_lhs, proto_eqns)) (rhss, eqngr) = +fun add_cert thy vardeps (c, (proto_lhs, proto_cert)) (rhss, eqngr) = if can (Graph.get_node eqngr) c then (rhss, eqngr) else let val lhs = map_index (fn (k, (v, _)) => (v, snd (Vargraph.get_node vardeps (Fun c, k)))) proto_lhs; - val inst_tab = Vartab.empty |> fold (fn (v, sort) => - Vartab.update ((v, 0), sort)) lhs; - val eqns = proto_eqns - |> (map o apfst) (inst_thm thy inst_tab); - val ((vs, _), rhss') = typscheme_rhss thy c eqns; - val eqngr' = Graph.new_node (c, (vs, eqns)) eqngr; + val cert = inst_cert thy lhs proto_cert; + val ((vs, _), rhss') = typscheme_rhss thy c cert; + val eqngr' = Graph.new_node (c, (vs, cert)) eqngr; in (map (pair c) rhss' @ rhss, eqngr') end; fun extend_arities_eqngr thy cs ts (arities, (eqngr : code_graph)) = @@ -435,7 +418,7 @@ val pp = Syntax.pp_global thy; val algebra = Sorts.subalgebra pp (is_proper_class thy) (AList.lookup (op =) arities') (Sign.classes_of thy); - val (rhss, eqngr') = Symtab.fold (add_eqs thy vardeps) eqntab ([], eqngr); + val (rhss, eqngr') = Symtab.fold (add_cert thy vardeps) eqntab ([], eqngr); fun deps_of (c, rhs) = c :: maps (dicts_of thy algebra) (rhs ~~ sortargs eqngr' c); val eqngr'' = fold (fn (c, rhs) => fold diff -r ded5b770ec1c -r 99b9a6290446 src/Tools/Code/code_thingol.ML --- a/src/Tools/Code/code_thingol.ML Tue Jan 12 09:59:45 2010 +0100 +++ b/src/Tools/Code/code_thingol.ML Tue Jan 12 16:27:42 2010 +0100 @@ -447,7 +447,7 @@ in Thm.certify_instantiate ([], var_subst) thm end; fun canonize_thms thy = map (Thm.transfer thy) - #> Code.standard_typscheme thy #> desymbolize_tvars thy + #> desymbolize_tvars thy #> same_arity thy #> map (desymbolize_vars thy); @@ -612,10 +612,10 @@ fun stmt_classparam class = ensure_class thy algbr eqngr class #>> (fn class => Classparam (c, class)); - fun stmt_fun raw_eqns = + fun stmt_fun cert = let - val eqns = burrow_fst (canonize_thms thy) raw_eqns; - val (vs, ty) = Code.typscheme_eqns thy c (map fst eqns); + val ((vs, ty), raw_eqns) = Code.dest_cert thy cert; + val eqns = burrow_fst (canonize_thms thy) (map snd raw_eqns); in fold_map (translate_tyvar_sort thy algbr eqngr) vs ##>> translate_typ thy algbr eqngr ty @@ -626,7 +626,7 @@ of SOME tyco => stmt_datatypecons tyco | NONE => (case AxClass.class_of_param thy c of SOME class => stmt_classparam class - | NONE => stmt_fun (Code_Preproc.eqns eqngr c)) + | NONE => stmt_fun (Code_Preproc.cert eqngr c)) in ensure_stmt lookup_const (declare_const thy) stmt_const c end and ensure_class thy (algbr as (_, algebra)) eqngr class = let @@ -933,11 +933,7 @@ let val (_, eqngr) = Code_Preproc.obtain thy consts []; val all_consts = Graph.all_succs eqngr consts; - in - eqngr - |> Graph.subgraph (member (op =) all_consts) - |> Graph.map_nodes ((apsnd o map o apfst) (AxClass.overload thy)) - end; + in Graph.subgraph (member (op =) all_consts) eqngr end; fun code_thms thy = Pretty.writeln o Code_Preproc.pretty thy o code_depgr thy;