--- a/src/Pure/Tools/codegen_theorems.ML Thu Aug 17 09:24:51 2006 +0200
+++ b/src/Pure/Tools/codegen_theorems.ML Thu Aug 17 09:24:56 2006 +0200
@@ -24,6 +24,9 @@
val common_typ: theory -> (thm -> typ) -> thm list -> thm list;
val preprocess: theory -> thm list -> thm list;
+ val prove_freeness: theory -> tactic -> string
+ -> (string * sort) list * (string * typ list) list -> thm list;
+
type thmtab;
val mk_thmtab: theory -> (string * typ) list -> thmtab;
val get_sortalgebra: thmtab -> Sorts.algebra;
@@ -32,6 +35,7 @@
-> ((string * sort) list * (string * typ list) list) option;
val get_fun_thms: thmtab -> string * typ -> thm list;
+ val pretty_funtab: theory -> thm list CodegenConsts.Consttab.table -> Pretty.T;
val print_thms: theory -> unit;
val init_obj: (thm * thm) * (thm * thm) -> theory -> theory;
@@ -81,7 +85,7 @@
fun init_obj ((TrueI, FalseE), (conjI, atomize_eq)) thy =
case CodegenTheoremsSetup.get thy
- of SOME _ => error "code generator already set up for object logic"
+ of SOME _ => error "Code generator already set up for object logic"
| NONE =>
let
fun strip_implies t = (Logic.strip_imp_prems t, Logic.strip_imp_concl t);
@@ -114,7 +118,7 @@
#> apsnd (map Term.dest_Var)
#> apfst Term.dest_Const
)
- |> (fn (v1, ((conj, _), v2)) => if v1 = v2 then conj else error "wrong premise")
+ |> (fn (v1, ((conj, _), v2)) => if v1 = v2 then conj else error "Wrong premise")
fun dest_atomize_eq thm=
Drule.plain_prop_of thm
|> Logic.dest_equals
@@ -130,10 +134,10 @@
#> apsnd Term.dest_Var
)
|> (fn (((eq, _), v2), (v1a as (_, TVar (_, sort)), v1b)) =>
- if [v1a, v1b] = v2 andalso sort = Sign.defaultS thy then eq else error "wrong premise")
+ if [v1a, v1b] = v2 andalso sort = Sign.defaultS thy then eq else error "Wrong premise")
in
((dest_TrueI TrueI, [dest_FalseE FalseE, dest_conjI conjI, dest_atomize_eq atomize_eq])
- handle _ => error "bad code generator setup")
+ handle _ => error "Bad code generator setup")
|> (fn ((tr, b), [fl, con, eq]) => CodegenTheoremsSetup.put
(SOME ((b, atomize_eq), ((tr, fl), (con, eq)))) thy)
end;
@@ -141,7 +145,7 @@
fun get_obj thy =
case CodegenTheoremsSetup.get thy
of SOME ((b, atomize), x) => ((Type (b, []), atomize) ,x)
- | NONE => error "no object logic setup for code theorems";
+ | NONE => error "No object logic setup for code theorems";
fun mk_true thy =
let
@@ -260,14 +264,14 @@
case try (make_eq thy #> Drule.plain_prop_of
#> ObjectLogic.drop_judgment thy #> Logic.dest_equals) thm
of SOME eq => (eq, thm)
- | NONE => err_thm "not an equation" thm;
+ | NONE => err_thm "Not an equation" thm;
fun dest_fun thy thm =
let
fun dest_fun' ((lhs, _), thm) =
case try (dest_Const o fst o strip_comb) lhs
of SOME (c, ty) => (c, (ty, thm))
- | NONE => err_thm "not a function equation" thm;
+ | NONE => err_thm "Not a function equation" thm;
in
thm
|> dest_eq thy
@@ -280,21 +284,23 @@
(* data structures *)
+structure Consttab = CodegenConsts.Consttab;
+
fun merge' eq (xys as (xs, ys)) =
if eq_list eq (xs, ys) then (false, xs) else (true, merge eq xys);
fun alist_merge' eq_key eq (xys as (xs, ys)) =
if eq_list (eq_pair eq_key eq) (xs, ys) then (false, xs) else (true, AList.merge eq_key eq xys);
-fun list_symtab_join' eq (xyt as (xt, yt)) =
+fun list_consttab_join' eq (xyt as (xt, yt)) =
let
- val xc = Symtab.keys xt;
- val yc = Symtab.keys yt;
- val zc = filter (member (op =) yc) xc;
+ val xc = Consttab.keys xt;
+ val yc = Consttab.keys yt;
+ val zc = filter (member CodegenConsts.eq_const yc) xc;
val wc = subtract (op =) zc xc @ subtract (op =) zc yc;
- fun same_thms c = if eq_list eq_thm ((the o Symtab.lookup xt) c, (the o Symtab.lookup yt) c)
+ fun same_thms c = if eq_list eq_thm ((the o Consttab.lookup xt) c, (the o Consttab.lookup yt) c)
then NONE else SOME c;
- in (wc @ map_filter same_thms zc, Symtab.join (K (merge eq)) xyt) end;
+ in (wc @ map_filter same_thms zc, Consttab.join (K (merge eq)) xyt) end;
datatype notify = Notify of (serial * ((string * typ) list option -> theory -> theory)) list;
@@ -337,7 +343,7 @@
datatype funthms = Funthms of {
dirty: string list,
- funs: thm list Symtab.table
+ funs: thm list Consttab.table
};
fun mk_funthms (dirty, funs) =
@@ -347,8 +353,8 @@
fun merge_funthms _ (Funthms { dirty = dirty1, funs = funs1 },
Funthms { dirty = dirty2, funs = funs2 }) =
let
- val (dirty3, funs) = list_symtab_join' eq_thm (funs1, funs2);
- in mk_funthms (merge (op =) (merge (op =) (dirty1, dirty2), dirty3), funs) end;
+ val (dirty3, funs) = list_consttab_join' eq_thm (funs1, funs2);
+ in mk_funthms (merge (op =) (merge (op =) (dirty1, dirty2), map fst dirty3), funs) end;
datatype T = T of {
dirty: bool,
@@ -380,7 +386,7 @@
val name = "Pure/codegen_theorems_data";
type T = T;
val empty = mk_T ((false, mk_notify []), (mk_preproc ([], []),
- (mk_extrs ([], []), mk_funthms ([], Symtab.empty))));
+ (mk_extrs ([], []), mk_funthms ([], Consttab.empty))));
val copy = I;
val extend = I;
val merge = merge_T;
@@ -388,7 +394,7 @@
let
val pretty_thm = ProofContext.pretty_thm (ProofContext.init thy);
val funthms = (fn T { funthms, ... } => funthms) data;
- val funs = (Symtab.dest o (fn Funthms { funs, ... } => funs)) funthms;
+ val funs = (Consttab.dest o (fn Funthms { funs, ... } => funs)) funthms;
val preproc = (fn T { preproc, ... } => preproc) data;
val unfolds = (fn Preproc { unfolds, ... } => unfolds) preproc;
in
@@ -398,7 +404,7 @@
(*Pretty.fbreaks ( *)
map (fn (c, thms) =>
(Pretty.block o Pretty.fbreaks) (
- Pretty.str c :: map pretty_thm (rev thms)
+ (Pretty.str o CodegenConsts.string_of_const thy) c :: map pretty_thm (rev thms)
)
) funs
(*) *) @ [
@@ -437,7 +443,9 @@
(* notifiers *)
fun all_typs thy c =
- map (pair c) (Sign.the_const_type thy c :: (map (#lhs) o Theory.definitions_of thy) c);
+ let
+ val c_tys = (map (pair c o #lhs o snd) o Defs.specifications_of (Theory.defs_of thy)) c;
+ in (c, Sign.the_const_type thy c) :: map (CodegenConsts.typ_of_typinst thy) c_tys end;
fun add_notify f =
map_data (fn ((dirty, notify), x) =>
@@ -489,20 +497,20 @@
fun add_fun thm thy =
case dest_fun thy thm
- of (c, _) =>
+ of (c, (ty, _)) =>
thy
|> map_data (fn (x, (preproc, (extrs, funthms))) =>
(x, (preproc, (extrs, funthms |> map_funthms (fn (dirty, funs) =>
- (dirty, funs |> Symtab.default (c, []) |> Symtab.map_entry c (cons thm)))))))
+ (dirty, funs |> Consttab.map_default (CodegenConsts.norminst_of_typ thy (c, ty), []) (cons thm)))))))
|> notify_all (SOME c);
fun del_fun thm thy =
case dest_fun thy thm
- of (c, _) =>
+ of (c, (ty, _)) =>
thy
|> map_data (fn (x, (preproc, (extrs, funthms))) =>
(x, (preproc, (extrs, funthms |> map_funthms (fn (dirty, funs) =>
- (dirty, funs |> Symtab.map_entry c (remove eq_thm thm)))))))
+ (dirty, funs |> Consttab.map_entry (CodegenConsts.norminst_of_typ thy (c, ty)) (remove eq_thm thm)))))))
|> notify_all (SOME c);
fun add_unfold thm thy =
@@ -523,9 +531,7 @@
thy
|> map_data (fn (x, (preproc, (extrs, funthms))) =>
(x, (preproc, (extrs, funthms |> map_funthms (fn (dirty, funs) =>
- (dirty, funs |> Symtab.map_entry c
- (filter (fn thm => Sign.typ_instance thy
- ((fst o snd o dest_fun thy) thm, ty)))))))))
+ (dirty, funs |> Consttab.update (CodegenConsts.norminst_of_typ thy (c, ty), [])))))))
|> notify_all (SOME c);
@@ -556,7 +562,12 @@
in (thm', max') end;
val (thms', maxidx) = fold_map incr_thm thms 0;
val (ty1::tys) = map extract_typ thms;
- fun unify ty = Sign.typ_unify thy (ty1, ty);
+ fun unify ty env = Sign.typ_unify thy (ty1, ty) env
+ handle Type.TUNIFY =>
+ error ("Type unificaton failed, while unifying function equations\n"
+ ^ (cat_lines o map Display.string_of_thm) thms
+ ^ "\nwith types\n"
+ ^ (cat_lines o map (Sign.string_of_typ thy)) (ty1 :: tys));
val (env, _) = fold unify tys (Vartab.empty, maxidx)
val instT = Vartab.fold (fn (x_i, (sort, ty)) =>
cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
@@ -611,41 +622,47 @@
fun get_funs thy (c, ty) =
let
val _ = debug_msg (fn _ => "[cg_thm] const (1) " ^ c ^ " :: " ^ Sign.string_of_typ thy ty) ()
- val filter_typ = map_filter (fn (_, (ty', thm)) =>
- if Sign.typ_instance thy (ty, ty')
- then SOME thm else debug_msg (fn _ => "[cg_thm] dropping " ^ string_of_thm thm) NONE);
+ val postprocess_typ = case AxClass.class_of_param thy c
+ of NONE => map_filter (fn (_, (ty', thm)) =>
+ if Sign.typ_instance thy (ty, ty')
+ then SOME thm else debug_msg (fn _ => "[cg_thm] dropping " ^ string_of_thm thm) NONE)
+ | SOME _ => let
+ (*FIXME make this more elegant*)
+ val ty' = CodegenConsts.typ_of_classop thy (CodegenConsts.norminst_of_typ thy (c, ty));
+ val ct = Thm.cterm_of thy (Const (c, ty'));
+ val thm' = Thm.reflexive ct;
+ in map (snd o snd) #> cons thm' #> common_typ thy (extr_typ thy) #> tl end;
fun get_funs (c, ty) =
- (these o Symtab.lookup (the_funs thy)) c
+ (these o Consttab.lookup (the_funs thy) o CodegenConsts.norminst_of_typ thy) (c, ty)
|> debug_msg (fn _ => "[cg_thm] trying funs")
|> map (dest_fun thy)
- |> filter_typ;
+ |> postprocess_typ;
fun get_extr (c, ty) =
getf_first_list (map (fn f => f thy) (the_funs_extrs thy)) (c, ty)
|> debug_msg (fn _ => "[cg_thm] trying extr")
|> map (dest_fun thy)
- |> filter_typ;
+ |> postprocess_typ;
fun get_spec (c, ty) =
- Theory.definitions_of thy c
+ (CodegenConsts.find_def thy o CodegenConsts.norminst_of_typ thy) (c, ty)
|> debug_msg (fn _ => "[cg_thm] trying spec")
- (* FIXME avoid dynamic name space lookup!? (via Thm.get_axiom_i etc.??) *)
- |> maps (fn { name, ... } => these (try (PureThy.get_thms thy) (Name name)))
+ |> Option.mapPartial (fn ((_, name), _) => try (Thm.get_axiom_i thy) name)
+ |> the_list
|> map_filter (try (dest_fun thy))
- |> filter_typ;
+ |> postprocess_typ;
in
getf_first_list [get_funs, get_extr, get_spec] (c, ty)
|> debug_msg (fn _ => "[cg_thm] const (2) " ^ c ^ " :: " ^ Sign.string_of_typ thy ty)
|> preprocess thy
end;
-fun get_datatypes thy dtco =
+fun prove_freeness thy tac dtco vs_cos =
let
- val _ = debug_msg (fn _ => "[cg_thm] datatype " ^ dtco) ()
val truh = mk_true thy;
val fals = mk_false thy;
fun mk_lhs vs ((co1, tys1), (co2, tys2)) =
let
val dty = Type (dtco, map TFree vs);
- val (xs1, xs2) = chop (length tys1) (Name.invent_list [] "x" (length tys1 + length tys2));
+ val (xs1, xs2) = chop (length tys1) (Name.invent_list [] "a" (length tys1 + length tys2));
val frees1 = map2 (fn x => fn ty => Free (x, ty)) xs1 tys1;
val frees2 = map2 (fn x => fn ty => Free (x, ty)) xs2 tys2;
fun zip_co co xs tys = list_comb (Const (co,
@@ -667,13 +684,18 @@
fun mk_eqs (vs, cos) =
let val cos' = rev cos
in (op @) (fold (mk_eq vs) (product cos' cos') ([], [])) end;
- fun mk_eq_thms tac vs_cos =
- map (fn t => Goal.prove_global thy [] []
- (ObjectLogic.ensure_propT thy t) (K tac)) (mk_eqs vs_cos);
+ in
+ map (fn t => Goal.prove_global thy [] []
+ (ObjectLogic.ensure_propT thy t) (K tac)) (mk_eqs vs_cos)
+ end;
+
+fun get_datatypes thy dtco =
+ let
+ val _ = debug_msg (fn _ => "[cg_thm] datatype " ^ dtco) ()
in
case getf_first (map (fn f => f thy) (the_datatypes_extrs thy)) dtco
of NONE => NONE
- | SOME (vs_cos, tac) => SOME (vs_cos, mk_eq_thms tac vs_cos)
+ | SOME (vs_cos, tac) => SOME (vs_cos, prove_freeness thy tac dtco vs_cos)
end;
fun get_eq thy (c, ty) =
@@ -691,13 +713,13 @@
fun check_head_lhs thm (lhs, rhs) =
case strip_comb lhs
of (Const (c', _), _) => if c' = c then ()
- else error ("illegal function equation for " ^ quote c
+ else error ("Illegal function equation for " ^ quote c
^ ", actually defining " ^ quote c' ^ ": " ^ Display.string_of_thm thm)
- | _ => error ("illegal function equation: " ^ Display.string_of_thm thm);
+ | _ => error ("Illegal function equation: " ^ Display.string_of_thm thm);
fun check_vars_lhs thm (lhs, rhs) =
if has_duplicates (op =)
(fold_aterms (fn Free (v, _) => cons v | _ => I) lhs [])
- then error ("repeated variables on left hand side of function equation:"
+ then error ("Repeated variables on left hand side of function equation:"
^ Display.string_of_thm thm)
else ();
fun check_vars_rhs thm (lhs, rhs) =
@@ -705,7 +727,7 @@
(fold_aterms (fn Free (v, _) => cons v | _ => I) lhs [])
(fold_aterms (fn Free (v, _) => cons v | _ => I) rhs []))
then ()
- else error ("free variables on right hand side of function equation:"
+ else error ("Free variables on right hand side of function equation:"
^ Display.string_of_thm thm)
val tts = map (Logic.dest_equals o Logic.unvarify o Thm.prop_of) thms;
in