--- a/src/Pure/Tools/codegen_data.ML Mon Oct 02 23:01:04 2006 +0200
+++ b/src/Pure/Tools/codegen_data.ML Mon Oct 02 23:01:05 2006 +0200
@@ -5,8 +5,6 @@
Basic code generator data structures; abstract executable content of theory.
*)
-(* val _ = PolyML.Compiler.maxInlineSize := 0; *)
-
signature CODEGEN_DATA =
sig
type lthms = thm list Susp.T;
@@ -21,6 +19,8 @@
val del_datatype: string -> theory -> theory
val add_inline: thm -> theory -> theory
val del_inline: thm -> theory -> theory
+ val add_inline_proc: (theory -> cterm list -> thm list) -> theory -> theory
+ val add_constrains: (theory -> term list -> (indexname * sort) list) -> theory -> theory
val add_preproc: (theory -> thm list -> thm list) -> theory -> theory
val these_funcs: theory -> CodegenConsts.const -> thm list
val get_datatype: theory -> string
@@ -31,10 +31,9 @@
val typ_func: theory -> thm -> typ
val rewrite_func: thm list -> thm -> thm
- val preprocess_cterm: theory -> cterm -> thm
- val preprocess: theory -> thm list -> thm list
+ val preprocess_cterm: theory -> (string * typ -> typ) -> cterm -> thm * cterm
- val debug: bool ref
+ val trace: bool ref
val strict_functyp: bool ref
end;
@@ -55,8 +54,8 @@
(** diagnostics **)
-val debug = ref false;
-fun debug_msg f x = (if !debug then Output.tracing (f x) else (); x);
+val trace = ref false;
+fun tracing f x = (if !trace then Output.tracing (f x) else (); x);
@@ -64,7 +63,6 @@
type lthms = thm list Susp.T;
val eval_always = ref false;
-val _ = eval_always := true;
fun lazy f = if !eval_always
then Susp.value (f ())
@@ -78,10 +76,12 @@
of SOME thms => (map (ProofContext.pretty_thm ctxt) o rev) thms
| NONE => [Pretty.str "[...]"];
-fun certificate f r =
+fun certificate thy f r =
case Susp.peek r
- of SOME thms => (Susp.value o f) thms
- | NONE => lazy (fn () => (f o Susp.force) r);
+ of SOME thms => (Susp.value o f thy) thms
+ | NONE => let
+ val thy_ref = Theory.self_ref thy;
+ in lazy (fn () => (f (Theory.deref thy_ref) o Susp.force) r) end;
fun merge' _ ([], []) = (false, [])
| merge' _ ([], ys) = (true, ys)
@@ -107,45 +107,104 @@
(** code theorems **)
-(* making function theorems *)
+(* making rewrite theorems *)
fun bad_thm msg thm =
error (msg ^ ": " ^ string_of_thm thm);
+fun check_rew thy thm =
+ let
+ val (lhs, rhs) = (Logic.dest_equals o Thm.prop_of) thm;
+ fun vars_of t = fold_aterms
+ (fn Var (v, _) => insert (op =) v
+ | Free _ => bad_thm "Illegal free variable in rewrite theorem" thm
+ | _ => I) t [];
+ fun tvars_of t = fold_term_types
+ (fn _ => fold_atyps (fn TVar (v, _) => insert (op =) v
+ | TFree _ => bad_thm "Illegal free type variable in rewrite theorem" thm)) t [];
+ val lhs_vs = vars_of lhs;
+ val rhs_vs = vars_of rhs;
+ val lhs_tvs = tvars_of lhs;
+ val rhs_tvs = tvars_of lhs;
+ val _ = if null (subtract (op =) lhs_vs rhs_vs)
+ then ()
+ else bad_thm "Free variables on right hand side of rewrite theorems" thm
+ val _ = if null (subtract (op =) lhs_tvs rhs_tvs)
+ then ()
+ else bad_thm "Free type variables on right hand side of rewrite theorems" thm
+ in thm end;
+
+fun mk_rew thy thm =
+ let
+ val thms = (#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) thy thm;
+ in
+ map (check_rew thy) thms
+ end;
+
+
+(* making function theorems *)
+
fun typ_func thy = snd o dest_Const o fst o strip_comb o fst
o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of;
-val mk_rew =
- #mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of;
+val strict_functyp = ref true;
+
+fun dest_func thy = apfst dest_Const o strip_comb o Envir.beta_eta_contract
+ o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of;
+
+fun mk_head thy thm =
+ ((CodegenConsts.norm_of_typ thy o fst o dest_func thy) thm, thm);
-val strict_functyp = ref true;
+fun check_func verbose thy thm = case try (dest_func thy) thm
+ of SOME (c_ty as (c, ty), args) =>
+ let
+ val _ =
+ if has_duplicates (op =)
+ ((fold o fold_aterms) (fn Var (v, _) => cons v | _ => I) args [])
+ then bad_thm "Repeated variables on left hand side of function equation" thm
+ else ()
+ val is_classop = (is_some o AxClass.class_of_param thy) c;
+ val const = CodegenConsts.norm_of_typ thy c_ty;
+ val ty_decl = CodegenConsts.disc_typ_of_const thy
+ (snd o CodegenConsts.typ_of_inst thy) const;
+ val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy);
+ in if Sign.typ_equiv thy (ty_decl, ty)
+ then (const, thm)
+ else (if is_classop orelse (!strict_functyp andalso not
+ (Sign.typ_equiv thy (Type.strip_sorts ty_decl, Type.strip_sorts ty)))
+ then error else (if verbose then warning else K ()) #> K (const, thm))
+ ("Type\n" ^ string_of_typ ty
+ ^ "\nof function theorem\n"
+ ^ string_of_thm thm
+ ^ "\nis strictly less general than declared function type\n"
+ ^ string_of_typ ty_decl)
+ end
+ | NONE => bad_thm "Not a function equation" thm;
+
+fun check_typ_classop thy thm =
+ let
+ val (c_ty as (c, ty), _) = dest_func thy thm;
+ in case AxClass.class_of_param thy c
+ of SOME class => let
+ val const = CodegenConsts.norm_of_typ thy c_ty;
+ val ty_decl = CodegenConsts.disc_typ_of_const thy
+ (snd o CodegenConsts.typ_of_inst thy) const;
+ val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy);
+ in if Sign.typ_equiv thy (ty_decl, ty)
+ then thm
+ else error
+ ("Type\n" ^ string_of_typ ty
+ ^ "\nof function theorem\n"
+ ^ string_of_thm thm
+ ^ "\nis strictly less general than declared function type\n"
+ ^ string_of_typ ty_decl)
+ end
+ | NONE => thm
+ end;
fun mk_func thy raw_thm =
- let
- fun dest_func thy = dest_Const o fst o strip_comb o Envir.beta_eta_contract
- o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of;
- fun mk_head thm = case try (dest_func thy) thm
- of SOME (c_ty as (c, ty)) =>
- let
- val is_classop = (is_some o AxClass.class_of_param thy) c;
- val const = CodegenConsts.norm_of_typ thy c_ty;
- val ty_decl = CodegenConsts.disc_typ_of_const thy
- (snd o CodegenConsts.typ_of_inst thy) const;
- val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy);
- in if Sign.typ_equiv thy (ty_decl, ty)
- then (const, thm)
- else ((if is_classop orelse !strict_functyp then error else warning)
- ("Type\n" ^ string_of_typ ty
- ^ "\nof function theorem\n"
- ^ string_of_thm thm
- ^ "\nis strictly less general than declared function type\n"
- ^ string_of_typ ty_decl); (const, thm))
- end
- | NONE => bad_thm "Not a function equation" thm;
- in
- mk_rew thy raw_thm
- |> map mk_head
- end;
+ mk_rew thy raw_thm
+ |> map (check_func true thy);
fun get_prim_def_funcs thy c =
let
@@ -178,9 +237,7 @@
fun add_drop_redundant thm thms =
let
-(* val _ = writeln "add_drop 01"; *)
val thy = Context.check_thy (Thm.theory_of_thm thm);
-(* val _ = writeln "add_drop 02"; *)
val pattern = (fst o Logic.dest_equals o Drule.plain_prop_of) thm;
fun matches thm' = if (curry (Pattern.matches thy) pattern o
fst o Logic.dest_equals o Drule.plain_prop_of) thm'
@@ -222,19 +279,24 @@
datatype preproc = Preproc of {
inlines: thm list,
+ inline_procs: (serial * (theory -> cterm list -> thm list)) list,
+ constrains: (serial * (theory -> term list -> (indexname * sort) list)) list,
preprocs: (serial * (theory -> thm list -> thm list)) list
};
-fun mk_preproc (inlines, preprocs) =
- Preproc { inlines = inlines, preprocs = preprocs };
-fun map_preproc f (Preproc { inlines, preprocs }) =
- mk_preproc (f (inlines, preprocs));
-fun merge_preproc (Preproc { inlines = inlines1, preprocs = preprocs1 },
- Preproc { inlines = inlines2, preprocs = preprocs2 }) =
+fun mk_preproc ((inlines, inline_procs), (constrains, preprocs)) =
+ Preproc { inlines = inlines, inline_procs = inline_procs, constrains = constrains, preprocs = preprocs };
+fun map_preproc f (Preproc { inlines, inline_procs, constrains, preprocs }) =
+ mk_preproc (f ((inlines, inline_procs), (constrains, preprocs)));
+fun merge_preproc (Preproc { inlines = inlines1, inline_procs = inline_procs1, constrains = constrains1 , preprocs = preprocs1 },
+ Preproc { inlines = inlines2, inline_procs = inline_procs2, constrains = constrains2 , preprocs = preprocs2 }) =
let
val (touched1, inlines) = merge_thms (inlines1, inlines2);
- val (touched2, preprocs) = merge_alist (op =) (K true) (preprocs1, preprocs2);
- in (touched1 orelse touched2, mk_preproc (inlines, preprocs)) end;
+ val (touched2, inline_procs) = merge_alist (op =) (K true) (inline_procs1, inline_procs2);
+ val (touched3, constrains) = merge_alist (op =) (K true) (constrains1, constrains2);
+ val (touched4, preprocs) = merge_alist (op =) (K true) (preprocs1, preprocs2);
+ in (touched1 orelse touched2 orelse touched3 orelse touched4,
+ mk_preproc ((inlines, inline_procs), (constrains, preprocs))) end;
fun join_func_thms (tabs as (tab1, tab2)) =
let
@@ -257,13 +319,13 @@
andalso gen_eq_set (eq_pair eq_string (eq_list (is_equal o Term.typ_ord))) (cs1, cs2);
fun merge_dtyps (tabs as (tab1, tab2)) =
let
- (*EXTEND: could be more clever with respect to constructors*)
val tycos1 = Symtab.keys tab1;
val tycos2 = Symtab.keys tab2;
val tycos' = filter (member eq_string tycos2) tycos1;
- val touched = gen_eq_set (eq_pair (op =) (eq_dtyp))
+ val touched = not (gen_eq_set (op =) (tycos1, tycos2) andalso
+ gen_eq_set (eq_pair (op =) (eq_dtyp))
(AList.make (the o Symtab.lookup tab1) tycos',
- AList.make (the o Symtab.lookup tab2) tycos');
+ AList.make (the o Symtab.lookup tab2) tycos'));
in (touched, Symtab.merge (K true) tabs) end;
datatype spec = Spec of {
@@ -301,7 +363,7 @@
val (touched_cs, spec) = merge_spec (spec1, spec2);
val touched = if touched' then NONE else touched_cs;
in (touched, mk_exec (preproc, spec)) end;
-val empty_exec = mk_exec (mk_preproc ([], []),
+val empty_exec = mk_exec (mk_preproc (([], []), ([], [])),
mk_spec ((Consttab.empty, Consttab.empty), Symtab.empty));
fun the_preproc (Exec { preproc = Preproc x, ...}) = x;
@@ -450,9 +512,9 @@
fun rewrite_func rewrites thm =
let
- val rewrite = Tactic.rewrite true rewrites;
- val (ct_eq, [ct_lhs, ct_rhs]) = (Drule.strip_comb o cprop_of) thm;
- val Const ("==", _) = term_of ct_eq;
+ val rewrite = Tactic.rewrite false rewrites;
+ val (ct_eq, [ct_lhs, ct_rhs]) = (Drule.strip_comb o Thm.cprop_of) thm;
+ val Const ("==", _) = Thm.term_of ct_eq;
val (ct_f, ct_args) = Drule.strip_comb ct_lhs;
val rhs' = rewrite ct_rhs;
val args' = map rewrite ct_args;
@@ -484,12 +546,12 @@
cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [];
in map (Thm.instantiate (instT, [])) thms end;
-fun certify_const thy c thms =
+fun certify_const thy c c_thms =
let
fun cert (c', thm) = if CodegenConsts.eq_const (c, c')
then thm else bad_thm ("Wrong head of function equation,\nexpected constant "
^ CodegenConsts.string_of_const thy c) thm
- in (map cert o maps (mk_func thy)) thms end;
+ in map cert c_thms end;
fun mk_cos tyco vs cos =
let
@@ -589,7 +651,7 @@
fun add_funcl (c, lthms) thy =
let
val c' = CodegenConsts.norm thy c;
- val lthms' = certificate (certify_const thy c') lthms;
+ val lthms' = certificate thy (fn thy => certify_const thy c' o maps (mk_func thy)) lthms;
in
map_exec_purge (SOME [c]) (map_funcs (Consttab.map_default (c', (Susp.value [], []))
(add_lthms lthms'))) thy
@@ -601,7 +663,7 @@
val consts = map (CodegenConsts.norm_of_typ thy o dest_Const o fst) cs;
val add =
map_dtyps (Symtab.update_new (tyco,
- (vs_cos, certificate (certify_datatype thy tyco cs) lthms)))
+ (vs_cos, certificate thy (fn thy => certify_datatype thy tyco cs) lthms)))
#> map_dconstrs (fold (fn c => Consttab.update (c, tyco)) consts)
in map_exec_purge (SOME consts) add thy end;
@@ -616,52 +678,145 @@
in map_exec_purge (SOME consts) del thy end;
fun add_inline thm thy =
- map_exec_purge NONE (map_preproc (apfst (fold (insert eq_thm) (mk_rew thy thm)))) thy;
+ (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (insert eq_thm) (mk_rew thy thm)) thy;
fun del_inline thm thy =
- map_exec_purge NONE (map_preproc (apfst (fold (remove eq_thm) (mk_rew thy thm)))) thy ;
+ (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (remove eq_thm) (mk_rew thy thm)) thy ;
+
+fun add_inline_proc f =
+ (map_exec_purge NONE o map_preproc o apfst o apsnd) (cons (serial (), f));
+
+fun add_constrains f =
+ (map_exec_purge NONE o map_preproc o apsnd o apfst) (cons (serial (), f));
fun add_preproc f =
- map_exec_purge NONE (map_preproc (apsnd (cons (serial (), f))));
+ (map_exec_purge NONE o map_preproc o apsnd o apsnd) (cons (serial (), f));
+
+local
+
+fun gen_apply_constrain prep post const_typ thy fs x =
+ let
+ val ts = prep x;
+ val tvars = (fold o fold_aterms) Term.add_tvars ts [];
+ val consts = (fold o fold_aterms) (fn Const c => cons c | _ => I) ts [];
+ fun insts_of const_typ (c, ty) =
+ let
+ val ty_decl = const_typ (c, ty);
+ val env = Vartab.dest (Type.raw_match (ty_decl, ty) Vartab.empty);
+ val insts = map_filter
+ (fn (v, (sort, TVar (_, sort'))) =>
+ if Sorts.sort_le (Sign.classes_of thy) (sort, sort')
+ then NONE else SOME (v, sort)
+ | _ => NONE) env
+ in
+ insts
+ end
+ val const_insts = case const_typ
+ of NONE => []
+ | SOME const_typ => maps (insts_of const_typ) consts;
+ fun add_inst (v, sort') =
+ let
+ val sort = (the o AList.lookup (op =) tvars) v
+ in
+ AList.map_default (op =) (v, (sort, sort))
+ (apsnd (fn sort => Sorts.inter_sort (Sign.classes_of thy) (sort, sort')))
+ end;
+ val inst =
+ []
+ |> fold (fn f => fold add_inst (f thy ts)) fs
+ |> fold add_inst const_insts;
+ in
+ post thy inst x
+ end;
-fun getf_first [] _ = NONE
- | getf_first (f::fs) x = case f x
- of NONE => getf_first fs x
- | y as SOME x => y;
+val apply_constrain = gen_apply_constrain (maps
+ ((fn (args, rhs) => rhs :: (snd o strip_comb) args) o Logic.dest_equals o Thm.prop_of))
+ (fn thy => fn inst => map (check_typ_classop thy o Thm.instantiate (map (fn (v, (sort, sort')) =>
+ (Thm.ctyp_of thy (TVar (v, sort)), Thm.ctyp_of thy (TVar (v, sort')))
+ ) inst, []))) NONE;
+fun apply_constrain_cterm thy const_typ = gen_apply_constrain (single o Thm.term_of)
+ (fn thy => fn inst => pair inst o Thm.cterm_of thy o map_types
+ (TermSubst.instantiateT (map (fn (v, (sort, sort')) => ((v, sort), TVar (v, sort'))) inst)) o Thm.term_of) (SOME const_typ) thy;
+
+fun gen_apply_inline_proc prep post thy f x =
+ let
+ val cts = prep x;
+ val rews = map (check_rew thy) (f thy cts);
+ in post rews x end;
+
+val apply_inline_proc = gen_apply_inline_proc (maps
+ ((fn [args, rhs] => rhs :: (snd o Drule.strip_comb) args) o snd o Drule.strip_comb o Thm.cprop_of))
+ (fn rews => map (rewrite_func rews));
+val apply_inline_proc_cterm = gen_apply_inline_proc single
+ (Tactic.rewrite false);
-fun getf_first_list [] x = []
- | getf_first_list (f::fs) x = case f x
- of [] => getf_first_list fs x
- | xs => xs;
+fun apply_preproc thy f [] = []
+ | apply_preproc thy f (thms as (thm :: _)) =
+ let
+ val thms' = f thy thms;
+ val c = (CodegenConsts.norm_of_typ thy o fst o dest_func thy) thm;
+ in (certify_const thy c o map (mk_head thy)) thms' end;
+
+fun cmp_thms thy =
+ make_ord (fn (thm1, thm2) => not (Sign.typ_instance thy (typ_func thy thm1, typ_func thy thm2)));
+
+fun rhs_conv conv thm =
+ let
+ val thm' = (conv o snd o Drule.dest_equals o Thm.cprop_of) thm;
+ in Thm.transitive thm thm' end
+
+fun drop_classes thy inst thm =
+ let
+ val unconstr = map (fn (v, (_, sort')) =>
+ (Thm.ctyp_of thy o TVar) (v, sort')) inst;
+ val instmap = map (fn (v, (sort, _)) =>
+ pairself (Thm.ctyp_of thy o TVar) ((v, []), (v, sort))) inst;
+ in
+ thm
+ |> fold Thm.unconstrainT unconstr
+ |> Thm.instantiate (instmap, [])
+ |> Tactic.rule_by_tactic ((REPEAT o CHANGED o ALLGOALS o Tactic.resolve_tac) (AxClass.class_intros thy))
+ end;
+
+in
fun preprocess thy thms =
- let
- fun cmp_thms (thm1, thm2) =
- not (Sign.typ_instance thy (typ_func thy thm1, typ_func thy thm2));
- in
- thms
- |> map (rewrite_func ((#inlines o the_preproc o get_exec) thy))
- |> fold (fn (_, f) => f thy) ((#preprocs o the_preproc o get_exec) thy)
- |> map (rewrite_func ((#inlines o the_preproc o get_exec) thy))
- |> sort (make_ord cmp_thms)
- |> common_typ_funcs thy
- end;
+ thms
+ |> fold (fn (_, f) => apply_preproc thy f) ((#preprocs o the_preproc o get_exec) thy)
+ |> map (rewrite_func ((#inlines o the_preproc o get_exec) thy))
+ |> apply_constrain thy ((map snd o #constrains o the_preproc o get_exec) thy)
+ |> map (rewrite_func ((#inlines o the_preproc o get_exec) thy))
+ |> fold (fn (_, f) => apply_inline_proc thy f) ((#inline_procs o the_preproc o get_exec) thy)
+ |> map (snd o check_func false thy)
+ |> sort (cmp_thms thy)
+ |> common_typ_funcs thy;
-fun preprocess_cterm thy =
- Tactic.rewrite false ((#inlines o the_preproc o get_exec) thy);
+fun preprocess_cterm thy const_typ ct =
+ ct
+ |> apply_constrain_cterm thy const_typ ((map snd o #constrains o the_preproc o get_exec) thy)
+ |-> (fn inst =>
+ Thm.reflexive
+ #> fold (rhs_conv o Tactic.rewrite false o single) ((#inlines o the_preproc o get_exec) thy)
+ #> fold (fn (_, f) => rhs_conv (apply_inline_proc_cterm thy f)) ((#inline_procs o the_preproc o get_exec) thy)
+ #> (fn thm => (drop_classes thy inst thm, ((fn xs => nth xs 1) o snd o Drule.strip_comb o Thm.cprop_of) thm))
+ );
+
+end; (*local*)
fun these_funcs thy c =
let
- fun test_funcs c =
+ val funcs_1 =
Consttab.lookup ((the_funcs o get_exec) thy) c
|> Option.map (Susp.force o fst)
|> these
|> map (Thm.transfer thy);
- val test_defs = get_prim_def_funcs thy;
+ val funcs_2 = case funcs_1
+ of [] => get_prim_def_funcs thy c
+ | xs => xs;
fun drop_refl thy = filter_out (is_equal o Term.fast_term_ord o Logic.dest_equals
o ObjectLogic.drop_judgment thy o Drule.plain_prop_of);
in
- getf_first_list [test_funcs, test_defs] c
+ funcs_2
|> preprocess thy
|> drop_refl thy
end;