# HG changeset patch # User haftmann # Date 1159822865 -7200 # Node ID 6792583aa463db4280afd309395e42d66b2fcd68 # Parent a5343075bdc5d6c8e3f51c0253a525351168596e changed preprocessing framework diff -r a5343075bdc5 -r 6792583aa463 src/Pure/Tools/codegen_data.ML --- a/src/Pure/Tools/codegen_data.ML Mon Oct 02 23:01:04 2006 +0200 +++ b/src/Pure/Tools/codegen_data.ML Mon Oct 02 23:01:05 2006 +0200 @@ -5,8 +5,6 @@ Basic code generator data structures; abstract executable content of theory. *) -(* val _ = PolyML.Compiler.maxInlineSize := 0; *) - signature CODEGEN_DATA = sig type lthms = thm list Susp.T; @@ -21,6 +19,8 @@ val del_datatype: string -> theory -> theory val add_inline: thm -> theory -> theory val del_inline: thm -> theory -> theory + val add_inline_proc: (theory -> cterm list -> thm list) -> theory -> theory + val add_constrains: (theory -> term list -> (indexname * sort) list) -> theory -> theory val add_preproc: (theory -> thm list -> thm list) -> theory -> theory val these_funcs: theory -> CodegenConsts.const -> thm list val get_datatype: theory -> string @@ -31,10 +31,9 @@ val typ_func: theory -> thm -> typ val rewrite_func: thm list -> thm -> thm - val preprocess_cterm: theory -> cterm -> thm - val preprocess: theory -> thm list -> thm list + val preprocess_cterm: theory -> (string * typ -> typ) -> cterm -> thm * cterm - val debug: bool ref + val trace: bool ref val strict_functyp: bool ref end; @@ -55,8 +54,8 @@ (** diagnostics **) -val debug = ref false; -fun debug_msg f x = (if !debug then Output.tracing (f x) else (); x); +val trace = ref false; +fun tracing f x = (if !trace then Output.tracing (f x) else (); x); @@ -64,7 +63,6 @@ type lthms = thm list Susp.T; val eval_always = ref false; -val _ = eval_always := true; fun lazy f = if !eval_always then Susp.value (f ()) @@ -78,10 +76,12 @@ of SOME thms => (map (ProofContext.pretty_thm ctxt) o rev) thms | NONE => [Pretty.str "[...]"]; -fun certificate f r = +fun certificate thy f r = case Susp.peek r - of SOME thms => (Susp.value o f) thms - | NONE => lazy (fn () => (f o Susp.force) r); + of SOME thms => (Susp.value o f thy) thms + | NONE => let + val thy_ref = Theory.self_ref thy; + in lazy (fn () => (f (Theory.deref thy_ref) o Susp.force) r) end; fun merge' _ ([], []) = (false, []) | merge' _ ([], ys) = (true, ys) @@ -107,45 +107,104 @@ (** code theorems **) -(* making function theorems *) +(* making rewrite theorems *) fun bad_thm msg thm = error (msg ^ ": " ^ string_of_thm thm); +fun check_rew thy thm = + let + val (lhs, rhs) = (Logic.dest_equals o Thm.prop_of) thm; + fun vars_of t = fold_aterms + (fn Var (v, _) => insert (op =) v + | Free _ => bad_thm "Illegal free variable in rewrite theorem" thm + | _ => I) t []; + fun tvars_of t = fold_term_types + (fn _ => fold_atyps (fn TVar (v, _) => insert (op =) v + | TFree _ => bad_thm "Illegal free type variable in rewrite theorem" thm)) t []; + val lhs_vs = vars_of lhs; + val rhs_vs = vars_of rhs; + val lhs_tvs = tvars_of lhs; + val rhs_tvs = tvars_of lhs; + val _ = if null (subtract (op =) lhs_vs rhs_vs) + then () + else bad_thm "Free variables on right hand side of rewrite theorems" thm + val _ = if null (subtract (op =) lhs_tvs rhs_tvs) + then () + else bad_thm "Free type variables on right hand side of rewrite theorems" thm + in thm end; + +fun mk_rew thy thm = + let + val thms = (#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) thy thm; + in + map (check_rew thy) thms + end; + + +(* making function theorems *) + fun typ_func thy = snd o dest_Const o fst o strip_comb o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of; -val mk_rew = - #mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of; +val strict_functyp = ref true; + +fun dest_func thy = apfst dest_Const o strip_comb o Envir.beta_eta_contract + o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of; + +fun mk_head thy thm = + ((CodegenConsts.norm_of_typ thy o fst o dest_func thy) thm, thm); -val strict_functyp = ref true; +fun check_func verbose thy thm = case try (dest_func thy) thm + of SOME (c_ty as (c, ty), args) => + let + val _ = + if has_duplicates (op =) + ((fold o fold_aterms) (fn Var (v, _) => cons v | _ => I) args []) + then bad_thm "Repeated variables on left hand side of function equation" thm + else () + val is_classop = (is_some o AxClass.class_of_param thy) c; + val const = CodegenConsts.norm_of_typ thy c_ty; + val ty_decl = CodegenConsts.disc_typ_of_const thy + (snd o CodegenConsts.typ_of_inst thy) const; + val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy); + in if Sign.typ_equiv thy (ty_decl, ty) + then (const, thm) + else (if is_classop orelse (!strict_functyp andalso not + (Sign.typ_equiv thy (Type.strip_sorts ty_decl, Type.strip_sorts ty))) + then error else (if verbose then warning else K ()) #> K (const, thm)) + ("Type\n" ^ string_of_typ ty + ^ "\nof function theorem\n" + ^ string_of_thm thm + ^ "\nis strictly less general than declared function type\n" + ^ string_of_typ ty_decl) + end + | NONE => bad_thm "Not a function equation" thm; + +fun check_typ_classop thy thm = + let + val (c_ty as (c, ty), _) = dest_func thy thm; + in case AxClass.class_of_param thy c + of SOME class => let + val const = CodegenConsts.norm_of_typ thy c_ty; + val ty_decl = CodegenConsts.disc_typ_of_const thy + (snd o CodegenConsts.typ_of_inst thy) const; + val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy); + in if Sign.typ_equiv thy (ty_decl, ty) + then thm + else error + ("Type\n" ^ string_of_typ ty + ^ "\nof function theorem\n" + ^ string_of_thm thm + ^ "\nis strictly less general than declared function type\n" + ^ string_of_typ ty_decl) + end + | NONE => thm + end; fun mk_func thy raw_thm = - let - fun dest_func thy = dest_Const o fst o strip_comb o Envir.beta_eta_contract - o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of; - fun mk_head thm = case try (dest_func thy) thm - of SOME (c_ty as (c, ty)) => - let - val is_classop = (is_some o AxClass.class_of_param thy) c; - val const = CodegenConsts.norm_of_typ thy c_ty; - val ty_decl = CodegenConsts.disc_typ_of_const thy - (snd o CodegenConsts.typ_of_inst thy) const; - val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy); - in if Sign.typ_equiv thy (ty_decl, ty) - then (const, thm) - else ((if is_classop orelse !strict_functyp then error else warning) - ("Type\n" ^ string_of_typ ty - ^ "\nof function theorem\n" - ^ string_of_thm thm - ^ "\nis strictly less general than declared function type\n" - ^ string_of_typ ty_decl); (const, thm)) - end - | NONE => bad_thm "Not a function equation" thm; - in - mk_rew thy raw_thm - |> map mk_head - end; + mk_rew thy raw_thm + |> map (check_func true thy); fun get_prim_def_funcs thy c = let @@ -178,9 +237,7 @@ fun add_drop_redundant thm thms = let -(* val _ = writeln "add_drop 01"; *) val thy = Context.check_thy (Thm.theory_of_thm thm); -(* val _ = writeln "add_drop 02"; *) val pattern = (fst o Logic.dest_equals o Drule.plain_prop_of) thm; fun matches thm' = if (curry (Pattern.matches thy) pattern o fst o Logic.dest_equals o Drule.plain_prop_of) thm' @@ -222,19 +279,24 @@ datatype preproc = Preproc of { inlines: thm list, + inline_procs: (serial * (theory -> cterm list -> thm list)) list, + constrains: (serial * (theory -> term list -> (indexname * sort) list)) list, preprocs: (serial * (theory -> thm list -> thm list)) list }; -fun mk_preproc (inlines, preprocs) = - Preproc { inlines = inlines, preprocs = preprocs }; -fun map_preproc f (Preproc { inlines, preprocs }) = - mk_preproc (f (inlines, preprocs)); -fun merge_preproc (Preproc { inlines = inlines1, preprocs = preprocs1 }, - Preproc { inlines = inlines2, preprocs = preprocs2 }) = +fun mk_preproc ((inlines, inline_procs), (constrains, preprocs)) = + Preproc { inlines = inlines, inline_procs = inline_procs, constrains = constrains, preprocs = preprocs }; +fun map_preproc f (Preproc { inlines, inline_procs, constrains, preprocs }) = + mk_preproc (f ((inlines, inline_procs), (constrains, preprocs))); +fun merge_preproc (Preproc { inlines = inlines1, inline_procs = inline_procs1, constrains = constrains1 , preprocs = preprocs1 }, + Preproc { inlines = inlines2, inline_procs = inline_procs2, constrains = constrains2 , preprocs = preprocs2 }) = let val (touched1, inlines) = merge_thms (inlines1, inlines2); - val (touched2, preprocs) = merge_alist (op =) (K true) (preprocs1, preprocs2); - in (touched1 orelse touched2, mk_preproc (inlines, preprocs)) end; + val (touched2, inline_procs) = merge_alist (op =) (K true) (inline_procs1, inline_procs2); + val (touched3, constrains) = merge_alist (op =) (K true) (constrains1, constrains2); + val (touched4, preprocs) = merge_alist (op =) (K true) (preprocs1, preprocs2); + in (touched1 orelse touched2 orelse touched3 orelse touched4, + mk_preproc ((inlines, inline_procs), (constrains, preprocs))) end; fun join_func_thms (tabs as (tab1, tab2)) = let @@ -257,13 +319,13 @@ andalso gen_eq_set (eq_pair eq_string (eq_list (is_equal o Term.typ_ord))) (cs1, cs2); fun merge_dtyps (tabs as (tab1, tab2)) = let - (*EXTEND: could be more clever with respect to constructors*) val tycos1 = Symtab.keys tab1; val tycos2 = Symtab.keys tab2; val tycos' = filter (member eq_string tycos2) tycos1; - val touched = gen_eq_set (eq_pair (op =) (eq_dtyp)) + val touched = not (gen_eq_set (op =) (tycos1, tycos2) andalso + gen_eq_set (eq_pair (op =) (eq_dtyp)) (AList.make (the o Symtab.lookup tab1) tycos', - AList.make (the o Symtab.lookup tab2) tycos'); + AList.make (the o Symtab.lookup tab2) tycos')); in (touched, Symtab.merge (K true) tabs) end; datatype spec = Spec of { @@ -301,7 +363,7 @@ val (touched_cs, spec) = merge_spec (spec1, spec2); val touched = if touched' then NONE else touched_cs; in (touched, mk_exec (preproc, spec)) end; -val empty_exec = mk_exec (mk_preproc ([], []), +val empty_exec = mk_exec (mk_preproc (([], []), ([], [])), mk_spec ((Consttab.empty, Consttab.empty), Symtab.empty)); fun the_preproc (Exec { preproc = Preproc x, ...}) = x; @@ -450,9 +512,9 @@ fun rewrite_func rewrites thm = let - val rewrite = Tactic.rewrite true rewrites; - val (ct_eq, [ct_lhs, ct_rhs]) = (Drule.strip_comb o cprop_of) thm; - val Const ("==", _) = term_of ct_eq; + val rewrite = Tactic.rewrite false rewrites; + val (ct_eq, [ct_lhs, ct_rhs]) = (Drule.strip_comb o Thm.cprop_of) thm; + val Const ("==", _) = Thm.term_of ct_eq; val (ct_f, ct_args) = Drule.strip_comb ct_lhs; val rhs' = rewrite ct_rhs; val args' = map rewrite ct_args; @@ -484,12 +546,12 @@ cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env []; in map (Thm.instantiate (instT, [])) thms end; -fun certify_const thy c thms = +fun certify_const thy c c_thms = let fun cert (c', thm) = if CodegenConsts.eq_const (c, c') then thm else bad_thm ("Wrong head of function equation,\nexpected constant " ^ CodegenConsts.string_of_const thy c) thm - in (map cert o maps (mk_func thy)) thms end; + in map cert c_thms end; fun mk_cos tyco vs cos = let @@ -589,7 +651,7 @@ fun add_funcl (c, lthms) thy = let val c' = CodegenConsts.norm thy c; - val lthms' = certificate (certify_const thy c') lthms; + val lthms' = certificate thy (fn thy => certify_const thy c' o maps (mk_func thy)) lthms; in map_exec_purge (SOME [c]) (map_funcs (Consttab.map_default (c', (Susp.value [], [])) (add_lthms lthms'))) thy @@ -601,7 +663,7 @@ val consts = map (CodegenConsts.norm_of_typ thy o dest_Const o fst) cs; val add = map_dtyps (Symtab.update_new (tyco, - (vs_cos, certificate (certify_datatype thy tyco cs) lthms))) + (vs_cos, certificate thy (fn thy => certify_datatype thy tyco cs) lthms))) #> map_dconstrs (fold (fn c => Consttab.update (c, tyco)) consts) in map_exec_purge (SOME consts) add thy end; @@ -616,52 +678,145 @@ in map_exec_purge (SOME consts) del thy end; fun add_inline thm thy = - map_exec_purge NONE (map_preproc (apfst (fold (insert eq_thm) (mk_rew thy thm)))) thy; + (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (insert eq_thm) (mk_rew thy thm)) thy; fun del_inline thm thy = - map_exec_purge NONE (map_preproc (apfst (fold (remove eq_thm) (mk_rew thy thm)))) thy ; + (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (remove eq_thm) (mk_rew thy thm)) thy ; + +fun add_inline_proc f = + (map_exec_purge NONE o map_preproc o apfst o apsnd) (cons (serial (), f)); + +fun add_constrains f = + (map_exec_purge NONE o map_preproc o apsnd o apfst) (cons (serial (), f)); fun add_preproc f = - map_exec_purge NONE (map_preproc (apsnd (cons (serial (), f)))); + (map_exec_purge NONE o map_preproc o apsnd o apsnd) (cons (serial (), f)); + +local + +fun gen_apply_constrain prep post const_typ thy fs x = + let + val ts = prep x; + val tvars = (fold o fold_aterms) Term.add_tvars ts []; + val consts = (fold o fold_aterms) (fn Const c => cons c | _ => I) ts []; + fun insts_of const_typ (c, ty) = + let + val ty_decl = const_typ (c, ty); + val env = Vartab.dest (Type.raw_match (ty_decl, ty) Vartab.empty); + val insts = map_filter + (fn (v, (sort, TVar (_, sort'))) => + if Sorts.sort_le (Sign.classes_of thy) (sort, sort') + then NONE else SOME (v, sort) + | _ => NONE) env + in + insts + end + val const_insts = case const_typ + of NONE => [] + | SOME const_typ => maps (insts_of const_typ) consts; + fun add_inst (v, sort') = + let + val sort = (the o AList.lookup (op =) tvars) v + in + AList.map_default (op =) (v, (sort, sort)) + (apsnd (fn sort => Sorts.inter_sort (Sign.classes_of thy) (sort, sort'))) + end; + val inst = + [] + |> fold (fn f => fold add_inst (f thy ts)) fs + |> fold add_inst const_insts; + in + post thy inst x + end; -fun getf_first [] _ = NONE - | getf_first (f::fs) x = case f x - of NONE => getf_first fs x - | y as SOME x => y; +val apply_constrain = gen_apply_constrain (maps + ((fn (args, rhs) => rhs :: (snd o strip_comb) args) o Logic.dest_equals o Thm.prop_of)) + (fn thy => fn inst => map (check_typ_classop thy o Thm.instantiate (map (fn (v, (sort, sort')) => + (Thm.ctyp_of thy (TVar (v, sort)), Thm.ctyp_of thy (TVar (v, sort'))) + ) inst, []))) NONE; +fun apply_constrain_cterm thy const_typ = gen_apply_constrain (single o Thm.term_of) + (fn thy => fn inst => pair inst o Thm.cterm_of thy o map_types + (TermSubst.instantiateT (map (fn (v, (sort, sort')) => ((v, sort), TVar (v, sort'))) inst)) o Thm.term_of) (SOME const_typ) thy; + +fun gen_apply_inline_proc prep post thy f x = + let + val cts = prep x; + val rews = map (check_rew thy) (f thy cts); + in post rews x end; + +val apply_inline_proc = gen_apply_inline_proc (maps + ((fn [args, rhs] => rhs :: (snd o Drule.strip_comb) args) o snd o Drule.strip_comb o Thm.cprop_of)) + (fn rews => map (rewrite_func rews)); +val apply_inline_proc_cterm = gen_apply_inline_proc single + (Tactic.rewrite false); -fun getf_first_list [] x = [] - | getf_first_list (f::fs) x = case f x - of [] => getf_first_list fs x - | xs => xs; +fun apply_preproc thy f [] = [] + | apply_preproc thy f (thms as (thm :: _)) = + let + val thms' = f thy thms; + val c = (CodegenConsts.norm_of_typ thy o fst o dest_func thy) thm; + in (certify_const thy c o map (mk_head thy)) thms' end; + +fun cmp_thms thy = + make_ord (fn (thm1, thm2) => not (Sign.typ_instance thy (typ_func thy thm1, typ_func thy thm2))); + +fun rhs_conv conv thm = + let + val thm' = (conv o snd o Drule.dest_equals o Thm.cprop_of) thm; + in Thm.transitive thm thm' end + +fun drop_classes thy inst thm = + let + val unconstr = map (fn (v, (_, sort')) => + (Thm.ctyp_of thy o TVar) (v, sort')) inst; + val instmap = map (fn (v, (sort, _)) => + pairself (Thm.ctyp_of thy o TVar) ((v, []), (v, sort))) inst; + in + thm + |> fold Thm.unconstrainT unconstr + |> Thm.instantiate (instmap, []) + |> Tactic.rule_by_tactic ((REPEAT o CHANGED o ALLGOALS o Tactic.resolve_tac) (AxClass.class_intros thy)) + end; + +in fun preprocess thy thms = - let - fun cmp_thms (thm1, thm2) = - not (Sign.typ_instance thy (typ_func thy thm1, typ_func thy thm2)); - in - thms - |> map (rewrite_func ((#inlines o the_preproc o get_exec) thy)) - |> fold (fn (_, f) => f thy) ((#preprocs o the_preproc o get_exec) thy) - |> map (rewrite_func ((#inlines o the_preproc o get_exec) thy)) - |> sort (make_ord cmp_thms) - |> common_typ_funcs thy - end; + thms + |> fold (fn (_, f) => apply_preproc thy f) ((#preprocs o the_preproc o get_exec) thy) + |> map (rewrite_func ((#inlines o the_preproc o get_exec) thy)) + |> apply_constrain thy ((map snd o #constrains o the_preproc o get_exec) thy) + |> map (rewrite_func ((#inlines o the_preproc o get_exec) thy)) + |> fold (fn (_, f) => apply_inline_proc thy f) ((#inline_procs o the_preproc o get_exec) thy) + |> map (snd o check_func false thy) + |> sort (cmp_thms thy) + |> common_typ_funcs thy; -fun preprocess_cterm thy = - Tactic.rewrite false ((#inlines o the_preproc o get_exec) thy); +fun preprocess_cterm thy const_typ ct = + ct + |> apply_constrain_cterm thy const_typ ((map snd o #constrains o the_preproc o get_exec) thy) + |-> (fn inst => + Thm.reflexive + #> fold (rhs_conv o Tactic.rewrite false o single) ((#inlines o the_preproc o get_exec) thy) + #> fold (fn (_, f) => rhs_conv (apply_inline_proc_cterm thy f)) ((#inline_procs o the_preproc o get_exec) thy) + #> (fn thm => (drop_classes thy inst thm, ((fn xs => nth xs 1) o snd o Drule.strip_comb o Thm.cprop_of) thm)) + ); + +end; (*local*) fun these_funcs thy c = let - fun test_funcs c = + val funcs_1 = Consttab.lookup ((the_funcs o get_exec) thy) c |> Option.map (Susp.force o fst) |> these |> map (Thm.transfer thy); - val test_defs = get_prim_def_funcs thy; + val funcs_2 = case funcs_1 + of [] => get_prim_def_funcs thy c + | xs => xs; fun drop_refl thy = filter_out (is_equal o Term.fast_term_ord o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of); in - getf_first_list [test_funcs, test_defs] c + funcs_2 |> preprocess thy |> drop_refl thy end;