# HG changeset patch # User haftmann # Date 1263381637 -3600 # Node ID 19fd499cddff39affb8f42abb7c05fdae46dfde0 # Parent fadbdd350dd1d4b41c3c531ae31ea32372b8f830 explicit abstract type of code certificates diff -r fadbdd350dd1 -r 19fd499cddff src/HOL/Tools/recfun_codegen.ML --- a/src/HOL/Tools/recfun_codegen.ML Wed Jan 13 10:18:45 2010 +0100 +++ b/src/HOL/Tools/recfun_codegen.ML Wed Jan 13 12:20:37 2010 +0100 @@ -36,7 +36,7 @@ let val (_, T) = Code.const_typ_eqn thy thm in if null (Term.add_tvarsT T []) orelse (null o fst o strip_type) T then [thm] - else [Code_Thingol.expand_eta thy 1 thm] + else [Code.expand_eta thy 1 thm] end | avoid_value thy thms = thms; @@ -44,8 +44,9 @@ let val c = AxClass.unoverload_const thy (raw_c, T); val raw_thms = Code.get_cert thy (Code_Preproc.preprocess_functrans thy) c - |> Code.eqns_of_cert thy - |> map_filter (fn (thm, linear) => if linear then SOME thm else NONE) + |> Code.equations_thms_cert thy + |> snd + |> map_filter (fn (_, (thm, proper)) => if proper then SOME thm else NONE) |> map (AxClass.overload thy) |> filter (is_instance T o snd o const_of o prop_of); val module_name = case Symtab.lookup (ModuleData.get thy) c @@ -57,7 +58,6 @@ raw_thms |> preprocess thy |> avoid_value thy - |> Code_Thingol.canonize_thms thy |> rpair module_name end; diff -r fadbdd350dd1 -r 19fd499cddff src/Pure/Isar/code.ML --- a/src/Pure/Isar/code.ML Wed Jan 13 10:18:45 2010 +0100 +++ b/src/Pure/Isar/code.ML Wed Jan 13 12:20:37 2010 +0100 @@ -29,12 +29,15 @@ val mk_eqn_liberal: theory -> thm -> (thm * bool) option val assert_eqn: theory -> thm * bool -> thm * bool val const_typ_eqn: theory -> thm -> string * typ - type cert = thm * bool list + val expand_eta: theory -> int -> thm -> thm + type cert val empty_cert: theory -> string -> cert val cert_of_eqns: theory -> string -> (thm * bool) list -> cert val constrain_cert: theory -> sort list -> cert -> cert - val eqns_of_cert: theory -> cert -> (thm * bool) list - val dest_cert: theory -> cert -> ((string * sort) list * typ) * ((term list * term) * (thm * bool)) list + val typscheme_cert: theory -> cert -> (string * sort) list * typ + val equations_cert: theory -> cert -> ((string * sort) list * typ) * (term list * term) list + val equations_thms_cert: theory -> cert -> ((string * sort) list * typ) * ((term list * term) * (thm * bool)) list + val pretty_cert: theory -> cert -> Pretty.T list (*executable code*) val add_type: string -> theory -> theory @@ -511,20 +514,71 @@ fun typscheme thy (c, ty) = logical_typscheme thy (c, subst_signature thy c ty); -fun assert_eqns_const thy c eqns = + +(* technical transformations of code equations *) + +fun expand_eta thy k thm = + let + val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm; + val (_, args) = strip_comb lhs; + val l = if k = ~1 + then (length o fst o strip_abs) rhs + else Int.max (0, k - length args); + val (raw_vars, _) = Term.strip_abs_eta l rhs; + val vars = burrow_fst (Name.variant_list (map (fst o fst) (Term.add_vars lhs []))) + raw_vars; + fun expand (v, ty) thm = Drule.fun_cong_rule thm + (Thm.cterm_of thy (Var ((v, 0), ty))); + in + thm + |> fold expand vars + |> Conv.fconv_rule Drule.beta_eta_conversion + end; + +fun same_arity thy thms = let - fun cert (eqn as (thm, _)) = if c = const_eqn thy thm - then eqn else error ("Wrong head of code equation,\nexpected constant " - ^ string_of_const thy c ^ "\n" ^ Display.string_of_thm_global thy thm) - in map (cert o assert_eqn thy) eqns end; + val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals; + val k = fold (Integer.max o num_args_of o Thm.prop_of) thms 0; + in map (expand_eta thy k) thms end; + +fun mk_desymbolization pre post mk vs = + let + val names = map (pre o fst o fst) vs + |> map (Name.desymbolize false) + |> Name.variant_list [] + |> map post; + in map_filter (fn (((v, i), x), v') => + if v = v' andalso i = 0 then NONE + else SOME (((v, i), x), mk ((v', 0), x))) (vs ~~ names) + end; + +fun desymbolize_tvars thy thms = + let + val tvs = fold (Term.add_tvars o Thm.prop_of) thms []; + val tvar_subst = mk_desymbolization (unprefix "'") (prefix "'") TVar tvs; + in map (Thm.certify_instantiate (tvar_subst, [])) thms end; + +fun desymbolize_vars thy thm = + let + val vs = Term.add_vars (Thm.prop_of thm) []; + val var_subst = mk_desymbolization I I Var vs; + in Thm.certify_instantiate ([], var_subst) thm end; + +fun canonize_thms thy = desymbolize_tvars thy #> same_arity thy #> map (desymbolize_vars thy); (* code equation certificates *) -type cert = thm * bool list; +fun build_head thy (c, ty) = + Thm.cterm_of thy (Logic.mk_equals (Free ("HEAD", ty), Const (c, ty))); -fun mk_head_cterm thy (c, ty) = - Thm.cterm_of thy (Logic.mk_equals (Free ("HEAD", ty), Const (c, ty))); +fun get_head thy cert_thm = + let + val [head] = (#hyps o Thm.crep_thm) cert_thm; + val (_, Const (c, ty)) = (Logic.dest_equals o Thm.term_of) head; + in (typscheme thy (c, ty), head) end; + +abstype cert = Cert of thm * bool list with fun empty_cert thy c = let @@ -535,14 +589,18 @@ | NONE => Name.invent_list [] Name.aT (length tvars) |> map (fn v => TFree (v, [])); val ty = typ_subst_TVars (tvars ~~ tvars') raw_ty; - val chead = mk_head_cterm thy (c, ty); - in (Thm.weaken chead Drule.dummy_thm, []) end; + val chead = build_head thy (c, ty); + in Cert (Thm.weaken chead Drule.dummy_thm, []) end; fun cert_of_eqns thy c [] = empty_cert thy c - | cert_of_eqns thy c eqns = + | cert_of_eqns thy c raw_eqns = let - val _ = assert_eqns_const thy c eqns; + val eqns = burrow_fst (canonize_thms thy) raw_eqns; + val _ = map (assert_eqn thy) eqns; val (thms, propers) = split_list eqns; + val _ = map (fn thm => if c = const_eqn thy thm then () + else error ("Wrong head of code equation,\nexpected constant " + ^ string_of_const thy c ^ "\n" ^ Display.string_of_thm_global thy thm)) thms; fun tvars_of T = rev (Term.add_tvarsT T []); val vss = map (tvars_of o snd o head_eqn) thms; fun inter_sorts vs = @@ -551,55 +609,59 @@ val vts = Name.names Name.context Name.aT sorts; val thms as thm :: _ = map2 (fn vs => Thm.certify_instantiate (vs ~~ map TFree vts, [])) vss thms; - val head_thm = Thm.symmetric (Thm.assume (mk_head_cterm thy (head_eqn (hd thms)))); + val head_thm = Thm.symmetric (Thm.assume (build_head thy (head_eqn (hd thms)))); fun head_conv ct = if can Thm.dest_comb ct then Conv.fun_conv head_conv ct else Conv.rewr_conv head_thm ct; val rewrite_head = Conv.fconv_rule (Conv.arg1_conv head_conv); val cert_thm = Conjunction.intr_balanced (map rewrite_head thms); - in (cert_thm, propers) end; - -fun head_cert thy cert_thm = - let - val [head] = Thm.hyps_of cert_thm; - val (Free (h, _), Const (c, ty)) = Logic.dest_equals head; - in ((c, typscheme thy (c, ty)), (head, h)) end; + in Cert (cert_thm, propers) end; -fun constrain_cert thy sorts (cert_thm, propers) = +fun constrain_cert thy sorts (Cert (cert_thm, propers)) = let - val ((c, (vs, _)), (head, _)) = head_cert thy cert_thm; - val subst = map2 (fn (v, _) => fn sort => (v, sort)) vs sorts; - val head' = (map_types o map_atyps) - (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v))) head; - val inst = (map2 (fn (v, sort) => fn sort' => - pairself (Thm.ctyp_of thy) (TVar ((v, 0), sort), TFree (v, sort'))) vs sorts, []); + val ((vs, _), head) = get_head thy cert_thm; + val subst = map2 (fn (v, sort) => fn sort' => + (v, Sorts.inter_sort (Sign.classes_of thy) (sort, sort'))) vs sorts; + val head' = Thm.term_of head + |> (map_types o map_atyps) + (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v))) + |> Thm.cterm_of thy; + val inst = map2 (fn (v, sort) => fn (_, sort') => + (((v, 0), sort), TFree (v, sort'))) vs subst; val cert_thm' = cert_thm - |> Thm.implies_intr (Thm.cterm_of thy head) + |> Thm.implies_intr head |> Thm.varifyT - |> Thm.instantiate inst - |> Thm.elim_implies (Thm.assume (Thm.cterm_of thy head')); - in (cert_thm', propers) end; + |> Thm.certify_instantiate (inst, []) + |> Thm.elim_implies (Thm.assume head'); + in (Cert (cert_thm', propers)) end; -fun eqns_of_cert thy (cert_thm, []) = [] - | eqns_of_cert thy (cert_thm, propers) = - let - val (_, (head, _)) = head_cert thy cert_thm; - val thms = cert_thm - |> LocalDefs.expand [Thm.cterm_of thy head] - |> Thm.varifyT - |> Conjunction.elim_balanced (length propers) - in thms ~~ propers end; +fun typscheme_cert thy (Cert (cert_thm, _)) = + fst (get_head thy cert_thm); -fun dest_cert thy (cert as (cert_thm, propers)) = +fun equations_cert thy (cert as Cert (cert_thm, propers)) = let - val eqns = eqns_of_cert thy cert; - val ((_, vs_ty), _) = head_cert thy cert_thm; - val equations = if null propers then [] else cert_thm - |> Thm.prop_of + val tyscm = typscheme_cert thy cert; + val equations = if null propers then [] else + Thm.prop_of cert_thm |> Logic.dest_conjunction_balanced (length propers) |> map Logic.dest_equals |> (map o apfst) (snd o strip_comb) - in (vs_ty, equations ~~ eqns) end; + in (tyscm, equations) end; + +fun equations_thms_cert thy (cert as Cert (cert_thm, propers)) = + let + val (tyscm, equations) = equations_cert thy cert; + val thms = if null propers then [] else + cert_thm + |> LocalDefs.expand [snd (get_head thy cert_thm)] + |> Thm.varifyT + |> Conjunction.elim_balanced (length propers) + in (tyscm, equations ~~ (thms ~~ propers)) end; + +fun pretty_cert thy = map (Display.pretty_thm_global thy o AxClass.overload thy o fst o snd) + o snd o equations_thms_cert thy; + +end; (* code equation access *) diff -r fadbdd350dd1 -r 19fd499cddff src/Tools/Code/code_preproc.ML --- a/src/Tools/Code/code_preproc.ML Wed Jan 13 10:18:45 2010 +0100 +++ b/src/Tools/Code/code_preproc.ML Wed Jan 13 12:20:37 2010 +0100 @@ -199,11 +199,7 @@ AList.make (snd o Graph.get_node eqngr) (Graph.keys eqngr) |> (map o apfst) (Code.string_of_const thy) |> sort (string_ord o pairself fst) - |> map (fn (s, cert) => - (Pretty.block o Pretty.fbreaks) ( - Pretty.str s - :: map (Display.pretty_thm_global thy o AxClass.overload thy o fst) (Code.eqns_of_cert thy cert) - )) + |> map (fn (s, cert) => (Pretty.block o Pretty.fbreaks) (Pretty.str s :: Code.pretty_cert thy cert)) |> Pretty.chunks; @@ -220,13 +216,13 @@ map (fn (c, _) => AxClass.param_of_inst thy (c, tyco)) o maps (#params o AxClass.get_info thy); -fun typscheme_rhss thy c cert = +fun typargs_rhss thy c cert = let - val (tyscm, equations) = Code.dest_cert thy cert; + val ((vs, _), equations) = Code.equations_cert thy cert; val rhss = [] |> (fold o fold o fold_aterms) (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, ty)) | _ => I) - (map (op :: o swap o fst) equations); - in (tyscm, rhss) end; + (map (op :: o swap) equations); + in (vs, rhss) end; (* data structures *) @@ -266,7 +262,7 @@ of SOME (lhs, cert) => ((lhs, []), cert) | NONE => let val cert = Code.get_cert thy (preprocess thy) c; - val ((lhs, _), rhss) = typscheme_rhss thy c cert; + val (lhs, rhss) = typargs_rhss thy c cert; in ((lhs, rhss), cert) end; fun obtain_instance thy arities (inst as (class, tyco)) = @@ -388,14 +384,6 @@ handle Sorts.CLASS_ERROR _ => [] (*permissive!*)) end; -fun inst_cert thy lhs cert = - let - val ((vs, _), _) = Code.dest_cert thy cert; - val sorts = map (fn (v, sort) => case AList.lookup (op =) lhs v - of SOME sort' => Sorts.inter_sort (Sign.classes_of thy) (sort, sort') - | NONE => sort) vs; - in Code.constrain_cert thy sorts cert end; - fun add_arity thy vardeps (class, tyco) = AList.default (op =) ((class, tyco), map_range (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k)) @@ -406,8 +394,8 @@ else let val lhs = map_index (fn (k, (v, _)) => (v, snd (Vargraph.get_node vardeps (Fun c, k)))) proto_lhs; - val cert = inst_cert thy lhs proto_cert; - val ((vs, _), rhss') = typscheme_rhss thy c cert; + val cert = Code.constrain_cert thy (map snd lhs) proto_cert; + val (vs, rhss') = typargs_rhss thy c cert; val eqngr' = Graph.new_node (c, (vs, cert)) eqngr; in (map (pair c) rhss' @ rhss, eqngr') end; diff -r fadbdd350dd1 -r 19fd499cddff src/Tools/Code/code_thingol.ML --- a/src/Tools/Code/code_thingol.ML Wed Jan 13 10:18:45 2010 +0100 +++ b/src/Tools/Code/code_thingol.ML Wed Jan 13 12:20:37 2010 +0100 @@ -86,8 +86,6 @@ -> ((string * stmt) list * (string * stmt) list * ((string * stmt) list * (string * stmt) list)) list - val expand_eta: theory -> int -> thm -> thm - val canonize_thms: theory -> thm list -> thm list val read_const_exprs: theory -> string list -> string list * string list val consts_program: theory -> string list -> string list * (naming * program) val eval_conv: theory @@ -397,60 +395,6 @@ end; (* local *) -(** technical transformations of code equations **) - -fun expand_eta thy k thm = - let - val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm; - val (_, args) = strip_comb lhs; - val l = if k = ~1 - then (length o fst o strip_abs) rhs - else Int.max (0, k - length args); - val (raw_vars, _) = Term.strip_abs_eta l rhs; - val vars = burrow_fst (Name.variant_list (map (fst o fst) (Term.add_vars lhs []))) - raw_vars; - fun expand (v, ty) thm = Drule.fun_cong_rule thm - (Thm.cterm_of thy (Var ((v, 0), ty))); - in - thm - |> fold expand vars - |> Conv.fconv_rule Drule.beta_eta_conversion - end; - -fun same_arity thy thms = - let - val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals; - val k = fold (Integer.max o num_args_of o Thm.prop_of) thms 0; - in map (expand_eta thy k) thms end; - -fun mk_desymbolization pre post mk vs = - let - val names = map (pre o fst o fst) vs - |> map (Name.desymbolize false) - |> Name.variant_list [] - |> map post; - in map_filter (fn (((v, i), x), v') => - if v = v' andalso i = 0 then NONE - else SOME (((v, i), x), mk ((v', 0), x))) (vs ~~ names) - end; - -fun desymbolize_tvars thy thms = - let - val tvs = fold (Term.add_tvars o Thm.prop_of) thms []; - val tvar_subst = mk_desymbolization (unprefix "'") (prefix "'") TVar tvs; - in map (Thm.certify_instantiate (tvar_subst, [])) thms end; - -fun desymbolize_vars thy thm = - let - val vs = Term.add_vars (Thm.prop_of thm) []; - val var_subst = mk_desymbolization I I Var vs; - in Thm.certify_instantiate ([], var_subst) thm end; - -fun canonize_thms thy = map (Thm.transfer thy) - #> desymbolize_tvars thy - #> same_arity thy #> map (desymbolize_vars thy); - - (** statements, abstract programs **) type typscheme = (vname * sort) list * itype; @@ -614,8 +558,8 @@ #>> (fn class => Classparam (c, class)); fun stmt_fun cert = let - val ((vs, ty), raw_eqns) = Code.dest_cert thy cert; - val eqns = burrow_fst (canonize_thms thy) (map snd raw_eqns); + val ((vs, ty), raw_eqns) = Code.equations_thms_cert thy cert; + val eqns = map snd raw_eqns; in fold_map (translate_tyvar_sort thy algbr eqngr) vs ##>> translate_typ thy algbr eqngr ty