--- a/src/Pure/Isar/code.ML Wed Jan 13 10:18:45 2010 +0100
+++ b/src/Pure/Isar/code.ML Wed Jan 13 12:20:37 2010 +0100
@@ -29,12 +29,15 @@
val mk_eqn_liberal: theory -> thm -> (thm * bool) option
val assert_eqn: theory -> thm * bool -> thm * bool
val const_typ_eqn: theory -> thm -> string * typ
- type cert = thm * bool list
+ val expand_eta: theory -> int -> thm -> thm
+ type cert
val empty_cert: theory -> string -> cert
val cert_of_eqns: theory -> string -> (thm * bool) list -> cert
val constrain_cert: theory -> sort list -> cert -> cert
- val eqns_of_cert: theory -> cert -> (thm * bool) list
- val dest_cert: theory -> cert -> ((string * sort) list * typ) * ((term list * term) * (thm * bool)) list
+ val typscheme_cert: theory -> cert -> (string * sort) list * typ
+ val equations_cert: theory -> cert -> ((string * sort) list * typ) * (term list * term) list
+ val equations_thms_cert: theory -> cert -> ((string * sort) list * typ) * ((term list * term) * (thm * bool)) list
+ val pretty_cert: theory -> cert -> Pretty.T list
(*executable code*)
val add_type: string -> theory -> theory
@@ -511,20 +514,71 @@
fun typscheme thy (c, ty) = logical_typscheme thy (c, subst_signature thy c ty);
-fun assert_eqns_const thy c eqns =
+
+(* technical transformations of code equations *)
+
+fun expand_eta thy k thm =
+ let
+ val (lhs, rhs) = (Logic.dest_equals o Thm.plain_prop_of) thm;
+ val (_, args) = strip_comb lhs;
+ val l = if k = ~1
+ then (length o fst o strip_abs) rhs
+ else Int.max (0, k - length args);
+ val (raw_vars, _) = Term.strip_abs_eta l rhs;
+ val vars = burrow_fst (Name.variant_list (map (fst o fst) (Term.add_vars lhs [])))
+ raw_vars;
+ fun expand (v, ty) thm = Drule.fun_cong_rule thm
+ (Thm.cterm_of thy (Var ((v, 0), ty)));
+ in
+ thm
+ |> fold expand vars
+ |> Conv.fconv_rule Drule.beta_eta_conversion
+ end;
+
+fun same_arity thy thms =
let
- fun cert (eqn as (thm, _)) = if c = const_eqn thy thm
- then eqn else error ("Wrong head of code equation,\nexpected constant "
- ^ string_of_const thy c ^ "\n" ^ Display.string_of_thm_global thy thm)
- in map (cert o assert_eqn thy) eqns end;
+ val num_args_of = length o snd o strip_comb o fst o Logic.dest_equals;
+ val k = fold (Integer.max o num_args_of o Thm.prop_of) thms 0;
+ in map (expand_eta thy k) thms end;
+
+fun mk_desymbolization pre post mk vs =
+ let
+ val names = map (pre o fst o fst) vs
+ |> map (Name.desymbolize false)
+ |> Name.variant_list []
+ |> map post;
+ in map_filter (fn (((v, i), x), v') =>
+ if v = v' andalso i = 0 then NONE
+ else SOME (((v, i), x), mk ((v', 0), x))) (vs ~~ names)
+ end;
+
+fun desymbolize_tvars thy thms =
+ let
+ val tvs = fold (Term.add_tvars o Thm.prop_of) thms [];
+ val tvar_subst = mk_desymbolization (unprefix "'") (prefix "'") TVar tvs;
+ in map (Thm.certify_instantiate (tvar_subst, [])) thms end;
+
+fun desymbolize_vars thy thm =
+ let
+ val vs = Term.add_vars (Thm.prop_of thm) [];
+ val var_subst = mk_desymbolization I I Var vs;
+ in Thm.certify_instantiate ([], var_subst) thm end;
+
+fun canonize_thms thy = desymbolize_tvars thy #> same_arity thy #> map (desymbolize_vars thy);
(* code equation certificates *)
-type cert = thm * bool list;
+fun build_head thy (c, ty) =
+ Thm.cterm_of thy (Logic.mk_equals (Free ("HEAD", ty), Const (c, ty)));
-fun mk_head_cterm thy (c, ty) =
- Thm.cterm_of thy (Logic.mk_equals (Free ("HEAD", ty), Const (c, ty)));
+fun get_head thy cert_thm =
+ let
+ val [head] = (#hyps o Thm.crep_thm) cert_thm;
+ val (_, Const (c, ty)) = (Logic.dest_equals o Thm.term_of) head;
+ in (typscheme thy (c, ty), head) end;
+
+abstype cert = Cert of thm * bool list with
fun empty_cert thy c =
let
@@ -535,14 +589,18 @@
| NONE => Name.invent_list [] Name.aT (length tvars)
|> map (fn v => TFree (v, []));
val ty = typ_subst_TVars (tvars ~~ tvars') raw_ty;
- val chead = mk_head_cterm thy (c, ty);
- in (Thm.weaken chead Drule.dummy_thm, []) end;
+ val chead = build_head thy (c, ty);
+ in Cert (Thm.weaken chead Drule.dummy_thm, []) end;
fun cert_of_eqns thy c [] = empty_cert thy c
- | cert_of_eqns thy c eqns =
+ | cert_of_eqns thy c raw_eqns =
let
- val _ = assert_eqns_const thy c eqns;
+ val eqns = burrow_fst (canonize_thms thy) raw_eqns;
+ val _ = map (assert_eqn thy) eqns;
val (thms, propers) = split_list eqns;
+ val _ = map (fn thm => if c = const_eqn thy thm then ()
+ else error ("Wrong head of code equation,\nexpected constant "
+ ^ string_of_const thy c ^ "\n" ^ Display.string_of_thm_global thy thm)) thms;
fun tvars_of T = rev (Term.add_tvarsT T []);
val vss = map (tvars_of o snd o head_eqn) thms;
fun inter_sorts vs =
@@ -551,55 +609,59 @@
val vts = Name.names Name.context Name.aT sorts;
val thms as thm :: _ =
map2 (fn vs => Thm.certify_instantiate (vs ~~ map TFree vts, [])) vss thms;
- val head_thm = Thm.symmetric (Thm.assume (mk_head_cterm thy (head_eqn (hd thms))));
+ val head_thm = Thm.symmetric (Thm.assume (build_head thy (head_eqn (hd thms))));
fun head_conv ct = if can Thm.dest_comb ct
then Conv.fun_conv head_conv ct
else Conv.rewr_conv head_thm ct;
val rewrite_head = Conv.fconv_rule (Conv.arg1_conv head_conv);
val cert_thm = Conjunction.intr_balanced (map rewrite_head thms);
- in (cert_thm, propers) end;
-
-fun head_cert thy cert_thm =
- let
- val [head] = Thm.hyps_of cert_thm;
- val (Free (h, _), Const (c, ty)) = Logic.dest_equals head;
- in ((c, typscheme thy (c, ty)), (head, h)) end;
+ in Cert (cert_thm, propers) end;
-fun constrain_cert thy sorts (cert_thm, propers) =
+fun constrain_cert thy sorts (Cert (cert_thm, propers)) =
let
- val ((c, (vs, _)), (head, _)) = head_cert thy cert_thm;
- val subst = map2 (fn (v, _) => fn sort => (v, sort)) vs sorts;
- val head' = (map_types o map_atyps)
- (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v))) head;
- val inst = (map2 (fn (v, sort) => fn sort' =>
- pairself (Thm.ctyp_of thy) (TVar ((v, 0), sort), TFree (v, sort'))) vs sorts, []);
+ val ((vs, _), head) = get_head thy cert_thm;
+ val subst = map2 (fn (v, sort) => fn sort' =>
+ (v, Sorts.inter_sort (Sign.classes_of thy) (sort, sort'))) vs sorts;
+ val head' = Thm.term_of head
+ |> (map_types o map_atyps)
+ (fn TFree (v, _) => TFree (v, the (AList.lookup (op =) subst v)))
+ |> Thm.cterm_of thy;
+ val inst = map2 (fn (v, sort) => fn (_, sort') =>
+ (((v, 0), sort), TFree (v, sort'))) vs subst;
val cert_thm' = cert_thm
- |> Thm.implies_intr (Thm.cterm_of thy head)
+ |> Thm.implies_intr head
|> Thm.varifyT
- |> Thm.instantiate inst
- |> Thm.elim_implies (Thm.assume (Thm.cterm_of thy head'));
- in (cert_thm', propers) end;
+ |> Thm.certify_instantiate (inst, [])
+ |> Thm.elim_implies (Thm.assume head');
+ in (Cert (cert_thm', propers)) end;
-fun eqns_of_cert thy (cert_thm, []) = []
- | eqns_of_cert thy (cert_thm, propers) =
- let
- val (_, (head, _)) = head_cert thy cert_thm;
- val thms = cert_thm
- |> LocalDefs.expand [Thm.cterm_of thy head]
- |> Thm.varifyT
- |> Conjunction.elim_balanced (length propers)
- in thms ~~ propers end;
+fun typscheme_cert thy (Cert (cert_thm, _)) =
+ fst (get_head thy cert_thm);
-fun dest_cert thy (cert as (cert_thm, propers)) =
+fun equations_cert thy (cert as Cert (cert_thm, propers)) =
let
- val eqns = eqns_of_cert thy cert;
- val ((_, vs_ty), _) = head_cert thy cert_thm;
- val equations = if null propers then [] else cert_thm
- |> Thm.prop_of
+ val tyscm = typscheme_cert thy cert;
+ val equations = if null propers then [] else
+ Thm.prop_of cert_thm
|> Logic.dest_conjunction_balanced (length propers)
|> map Logic.dest_equals
|> (map o apfst) (snd o strip_comb)
- in (vs_ty, equations ~~ eqns) end;
+ in (tyscm, equations) end;
+
+fun equations_thms_cert thy (cert as Cert (cert_thm, propers)) =
+ let
+ val (tyscm, equations) = equations_cert thy cert;
+ val thms = if null propers then [] else
+ cert_thm
+ |> LocalDefs.expand [snd (get_head thy cert_thm)]
+ |> Thm.varifyT
+ |> Conjunction.elim_balanced (length propers)
+ in (tyscm, equations ~~ (thms ~~ propers)) end;
+
+fun pretty_cert thy = map (Display.pretty_thm_global thy o AxClass.overload thy o fst o snd)
+ o snd o equations_thms_cert thy;
+
+end;
(* code equation access *)