# HG changeset patch # User haftmann # Date 1168327908 -3600 # Node ID 8e19bad4125fc8406196ca373d6d389dff878e52 # Parent 979671292fbecc8531d837e42303092e7735fc92 moved a lot to codegen_func.ML diff -r 979671292fbe -r 8e19bad4125f src/Pure/Tools/codegen_data.ML --- a/src/Pure/Tools/codegen_data.ML Tue Jan 09 08:31:47 2007 +0100 +++ b/src/Pure/Tools/codegen_data.ML Tue Jan 09 08:31:48 2007 +0100 @@ -29,10 +29,8 @@ val print_thms: theory -> unit - val typ_func: theory -> thm -> typ val typ_funcs: theory -> CodegenConsts.const * thm list -> typ - val rewrite_func: thm list -> thm -> thm - val preprocess_cterm: theory -> cterm -> thm + val preprocess_cterm: cterm -> thm val trace: bool ref end; @@ -107,144 +105,6 @@ (** code theorems **) -(* making rewrite theorems *) - -fun bad_thm msg thm = - error (msg ^ ": " ^ string_of_thm thm); - -fun check_rew thy thm = - let - val (lhs, rhs) = (Logic.dest_equals o Thm.prop_of) thm; - fun vars_of t = fold_aterms - (fn Var (v, _) => insert (op =) v - | Free _ => bad_thm "Illegal free variable in rewrite theorem" thm - | _ => I) t []; - fun tvars_of t = fold_term_types - (fn _ => fold_atyps (fn TVar (v, _) => insert (op =) v - | TFree _ => bad_thm "Illegal free type variable in rewrite theorem" thm)) t []; - val lhs_vs = vars_of lhs; - val rhs_vs = vars_of rhs; - val lhs_tvs = tvars_of lhs; - val rhs_tvs = tvars_of lhs; - val _ = if null (subtract (op =) lhs_vs rhs_vs) - then () - else bad_thm "Free variables on right hand side of rewrite theorems" thm - val _ = if null (subtract (op =) lhs_tvs rhs_tvs) - then () - else bad_thm "Free type variables on right hand side of rewrite theorems" thm - in thm end; - -fun mk_rew thy thm = - let - val thms = (#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) thy thm; - in - map (check_rew thy) thms - end; - - -(* making function theorems *) - -fun typ_func thy = snd o dest_Const o fst o strip_comb - o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of; - -val strict_functyp = ref true; - -fun dest_func thy = apfst dest_Const o strip_comb - o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of - o Drule.fconv_rule Drule.beta_eta_conversion; - -fun mk_head thy thm = - ((CodegenConsts.norm_of_typ thy o fst o dest_func thy) thm, thm); - -fun check_func thy thm = case try (dest_func thy) thm - of SOME (c_ty as (c, ty), args) => - let - val _ = - if has_duplicates (op =) - ((fold o fold_aterms) (fn Var (v, _) => cons v - | _ => I - ) args []) - then bad_thm "Repeated variables on left hand side of function equation" thm - else () - fun no_abs (Abs _) = bad_thm "Abstraction on left hand side of function equation" thm - | no_abs (t1 $ t2) = (no_abs t1; no_abs t2) - | no_abs _ = (); - val _ = map no_abs args; - val is_classop = (is_some o AxClass.class_of_param thy) c; - val const = CodegenConsts.norm_of_typ thy c_ty; - val ty_decl = CodegenConsts.disc_typ_of_const thy - (snd o CodegenConsts.typ_of_inst thy) const; - val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy); - in if Sign.typ_equiv thy (ty_decl, ty) - then SOME (const, thm) - else (if is_classop - then if !strict_functyp - then error - else warning #> K NONE - else if Sign.typ_equiv thy (Type.strip_sorts ty_decl, Type.strip_sorts ty) - then warning #> (K o SOME) (const, thm) - else if !strict_functyp - then error - else warning #> K NONE) - ("Type\n" ^ string_of_typ ty - ^ "\nof function theorem\n" - ^ string_of_thm thm - ^ "\nis strictly less general than declared function type\n" - ^ string_of_typ ty_decl) - end - | NONE => bad_thm "Not a function equation" thm; - -fun check_typ_classop thy thm = - let - val (c_ty as (c, ty), _) = dest_func thy thm; - in case AxClass.class_of_param thy c - of SOME class => let - val const = CodegenConsts.norm_of_typ thy c_ty; - val ty_decl = CodegenConsts.disc_typ_of_const thy - (snd o CodegenConsts.typ_of_inst thy) const; - val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy); - in if Sign.typ_equiv thy (ty_decl, ty) - then thm - else error - ("Type\n" ^ string_of_typ ty - ^ "\nof function theorem\n" - ^ string_of_thm thm - ^ "\nis strictly less general than declared function type\n" - ^ string_of_typ ty_decl) - end - | NONE => thm - end; - -fun mk_func thy raw_thm = - mk_rew thy raw_thm - |> map_filter (check_func thy); - -fun get_prim_def_funcs thy c = - let - fun constrain thm0 thm = case AxClass.class_of_param thy (fst c) - of SOME _ => - let - val ty_decl = CodegenConsts.disc_typ_of_classop thy c; - val max = maxidx_of_typ ty_decl + 1; - val thm = Thm.incr_indexes max thm; - val ty = typ_func thy thm; - val (env, _) = Sign.typ_unify thy (ty_decl, ty) (Vartab.empty, max); - val instT = Vartab.fold (fn (x_i, (sort, ty)) => - cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env []; - in Thm.instantiate (instT, []) thm end - | NONE => thm - in case CodegenConsts.find_def thy c - of SOME ((_, thm), _) => - thm - |> Thm.transfer thy - |> try (map snd o mk_func thy) - |> these - |> map (constrain thm) - |> map (CodegenFunc.expand_eta thy ~1) - | NONE => [] - end; - - (* pairs of (selected, deleted) function theorems *) type sdthms = thm list Susp.T * thm list; @@ -529,31 +389,18 @@ (** theorem transformation and certification **) -fun rewrite_func rewrites thm = - let - val rewrite = MetaSimplifier.rewrite false rewrites; - val (ct_eq, [ct_lhs, ct_rhs]) = (Drule.strip_comb o Thm.cprop_of) thm; - val Const ("==", _) = Thm.term_of ct_eq; - val (ct_f, ct_args) = Drule.strip_comb ct_lhs; - val rhs' = rewrite ct_rhs; - val args' = map rewrite ct_args; - val lhs' = Thm.symmetric (fold (fn th1 => fn th2 => Thm.combination th2 th1) - args' (Thm.reflexive ct_f)); - in - Thm.transitive (Thm.transitive lhs' thm) rhs' - end handle Bind => raise ERROR "rewrite_func" - -fun common_typ_funcs thy [] = [] - | common_typ_funcs thy [thm] = [thm] - | common_typ_funcs thy thms = +fun common_typ_funcs [] = [] + | common_typ_funcs [thm] = [thm] + | common_typ_funcs thms = let + val thy = Thm.theory_of_thm (hd thms) fun incr_thm thm max = let val thm' = incr_indexes max thm; val max' = Thm.maxidx_of thm' + 1; in (thm', max') end; val (thms', maxidx) = fold_map incr_thm thms 0; - val (ty1::tys) = map (typ_func thy) thms'; + val (ty1::tys) = map CodegenFunc.typ_func thms'; fun unify ty env = Sign.typ_unify thy (ty1, ty) env handle Type.TUNIFY => error ("Type unificaton failed, while unifying function equations\n" @@ -568,8 +415,8 @@ fun certify_const thy c c_thms = let fun cert (c', thm) = if CodegenConsts.eq_const (c, c') - then thm else bad_thm ("Wrong head of function equation,\nexpected constant " - ^ CodegenConsts.string_of_const thy c) thm + then thm else error ("Wrong head of function equation,\nexpected constant " + ^ CodegenConsts.string_of_const thy c ^ "\n" ^ string_of_thm thm) in map cert c_thms end; fun mk_cos tyco vs cos = @@ -647,9 +494,9 @@ (** interfaces **) -fun add_func thm thy = +fun gen_add_func mk_func thm thy = let - val thms = mk_func thy thm; + val thms = mk_func thm; val cs = map fst thms; in map_exec_purge (SOME cs) (map_funcs @@ -657,11 +504,12 @@ (c, (Susp.value [], [])) (add_thm thm)) thms)) thy end; -fun add_func_legacy thm = setmp strict_functyp false (add_func thm); +val add_func = gen_add_func CodegenFunc.mk_func; +val add_func_legacy = gen_add_func CodegenFunc.legacy_mk_func; fun del_func thm thy = let - val thms = mk_func thy thm; + val thms = CodegenFunc.mk_func thm; val cs = map fst thms; in map_exec_purge (SOME cs) (map_funcs @@ -672,7 +520,7 @@ fun add_funcl (c, lthms) thy = let val c' = CodegenConsts.norm thy c; - val lthms' = certificate thy (fn thy => certify_const thy c' o maps (mk_func thy)) lthms; + val lthms' = certificate thy (fn thy => certify_const thy c' o maps (CodegenFunc.mk_func)) lthms; in map_exec_purge (SOME [c]) (map_funcs (Consttab.map_default (c', (Susp.value [], [])) (add_lthms lthms'))) thy @@ -699,10 +547,10 @@ in map_exec_purge (SOME consts) del thy end; fun add_inline thm thy = - (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (insert eq_thm) (mk_rew thy thm)) thy; + (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (insert eq_thm) (CodegenFunc.mk_rew thm)) thy; fun del_inline thm thy = - (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (remove eq_thm) (mk_rew thy thm)) thy ; + (map_exec_purge NONE o map_preproc o apfst o apfst) (fold (remove eq_thm) (CodegenFunc.mk_rew thm)) thy ; fun add_inline_proc f = (map_exec_purge NONE o map_preproc o apfst o apsnd) (cons (serial (), f)); @@ -715,12 +563,12 @@ fun gen_apply_inline_proc prep post thy f x = let val cts = prep x; - val rews = map (check_rew thy) (f thy cts); + val rews = map CodegenFunc.check_rew (f thy cts); in post rews x end; val apply_inline_proc = gen_apply_inline_proc (maps ((fn [args, rhs] => rhs :: (snd o Drule.strip_comb) args) o snd o Drule.strip_comb o Thm.cprop_of)) - (fn rews => map (rewrite_func rews)); + (fn rews => map (CodegenFunc.rewrite_func rews)); val apply_inline_proc_cterm = gen_apply_inline_proc single (MetaSimplifier.rewrite false); @@ -728,11 +576,11 @@ | apply_preproc thy f (thms as (thm :: _)) = let val thms' = f thy thms; - val c = (CodegenConsts.norm_of_typ thy o fst o dest_func thy) thm; - in (certify_const thy c o map (mk_head thy)) thms' end; + val c = (CodegenConsts.norm_of_typ thy o fst o CodegenFunc.dest_func) thm; + in (certify_const thy c o map CodegenFunc.mk_head) thms' end; fun cmp_thms thy = - make_ord (fn (thm1, thm2) => not (Sign.typ_instance thy (typ_func thy thm1, typ_func thy thm2))); + make_ord (fn (thm1, thm2) => not (Sign.typ_instance thy (CodegenFunc.typ_func thm1, CodegenFunc.typ_func thm2))); fun rhs_conv conv thm = let @@ -744,19 +592,23 @@ fun preprocess thy thms = thms |> fold (fn (_, f) => apply_preproc thy f) ((#preprocs o the_preproc o get_exec) thy) - |> map (rewrite_func ((#inlines o the_preproc o get_exec) thy)) + |> map (CodegenFunc.rewrite_func ((#inlines o the_preproc o get_exec) thy)) |> fold (fn (_, f) => apply_inline_proc thy f) ((#inline_procs o the_preproc o get_exec) thy) (*FIXME - must check: rewrite rule, function equation, proper constant |> map (snd o check_func false thy) *) |> sort (cmp_thms thy) - |> common_typ_funcs thy; + |> common_typ_funcs; -fun preprocess_cterm thy ct = - ct - |> Thm.reflexive - |> fold (rhs_conv o MetaSimplifier.rewrite false o single) - ((#inlines o the_preproc o get_exec) thy) - |> fold (fn (_, f) => rhs_conv (apply_inline_proc_cterm thy f)) - ((#inline_procs o the_preproc o get_exec) thy) +fun preprocess_cterm ct = + let + val thy = Thm.theory_of_cterm ct + in + ct + |> Thm.reflexive + |> fold (rhs_conv o MetaSimplifier.rewrite false o single) + ((#inlines o the_preproc o get_exec) thy) + |> fold (fn (_, f) => rhs_conv (apply_inline_proc_cterm thy f)) + ((#inline_procs o the_preproc o get_exec) thy) + end; end; (*local*) @@ -768,7 +620,7 @@ |> these |> map (Thm.transfer thy); val funcs_2 = case funcs_1 - of [] => get_prim_def_funcs thy c + of [] => CodegenFunc.get_prim_def_funcs thy c | xs => xs; fun drop_refl thy = filter_out (is_equal o Term.fast_term_ord o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of); @@ -789,9 +641,9 @@ fun typ_funcs thy (c as (name, _), []) = (case AxClass.class_of_param thy name of SOME class => CodegenConsts.disc_typ_of_classop thy c | NONE => (case Option.map (Susp.force o fst) (Consttab.lookup ((the_funcs o get_exec) thy) c) - of SOME [eq] => typ_func thy eq + of SOME [eq] => CodegenFunc.typ_func eq | _ => Sign.the_const_type thy name)) - | typ_funcs thy (_, eq :: _) = typ_func thy eq; + | typ_funcs thy (_, eq :: _) = CodegenFunc.typ_func eq; (** code attributes **) diff -r 979671292fbe -r 8e19bad4125f src/Pure/Tools/codegen_func.ML --- a/src/Pure/Tools/codegen_func.ML Tue Jan 09 08:31:47 2007 +0100 +++ b/src/Pure/Tools/codegen_func.ML Tue Jan 09 08:31:48 2007 +0100 @@ -2,43 +2,152 @@ ID: $Id$ Author: Florian Haftmann, TU Muenchen -Handling defining equations ("func"s) for code generator framework +Handling defining equations ("func"s) for code generator framework. *) -(* FIXME move various stuff here *) - signature CODEGEN_FUNC = sig - val expand_eta: theory -> int -> thm -> thm + val check_rew: thm -> thm + val mk_rew: thm -> thm list + val check_func: thm -> (CodegenConsts.const * thm) option + val mk_func: thm -> (CodegenConsts.const * thm) list + val dest_func: thm -> (string * typ) * term list + val mk_head: thm -> CodegenConsts.const * thm + val typ_func: thm -> typ + val legacy_mk_func: thm -> (CodegenConsts.const * thm) list + val expand_eta: int -> thm -> thm + val rewrite_func: thm list -> thm -> thm + val get_prim_def_funcs: theory -> string * typ list -> thm list end; structure CodegenFunc : CODEGEN_FUNC = struct -(* FIXME get rid of this code duplication *) -val purify_name = +fun lift_thm_thy f thm = f (Thm.theory_of_thm thm) thm; + +fun bad_thm msg thm = + error (msg ^ ": " ^ string_of_thm thm); + + +(* making rewrite theorems *) + +fun check_rew thm = + let + val thy = Thm.theory_of_thm thm; + val (lhs, rhs) = (Logic.dest_equals o Thm.prop_of) thm; + fun vars_of t = fold_aterms + (fn Var (v, _) => insert (op =) v + | Free _ => bad_thm "Illegal free variable in rewrite theorem" thm + | _ => I) t []; + fun tvars_of t = fold_term_types + (fn _ => fold_atyps (fn TVar (v, _) => insert (op =) v + | TFree _ => bad_thm "Illegal free type variable in rewrite theorem" thm)) t []; + val lhs_vs = vars_of lhs; + val rhs_vs = vars_of rhs; + val lhs_tvs = tvars_of lhs; + val rhs_tvs = tvars_of lhs; + val _ = if null (subtract (op =) lhs_vs rhs_vs) + then () + else bad_thm "Free variables on right hand side of rewrite theorems" thm + val _ = if null (subtract (op =) lhs_tvs rhs_tvs) + then () + else bad_thm "Free type variables on right hand side of rewrite theorems" thm + in thm end; + +fun mk_rew thm = let - fun is_valid s = Symbol.is_ascii_letter s orelse Symbol.is_ascii_digit s orelse s = "'"; - val is_junk = not o is_valid andf Symbol.not_eof; - val junk = Scan.many is_junk; - val scan_valids = Symbol.scanner "Malformed input" - ((junk |-- - (Scan.optional (Scan.one Symbol.is_ascii_letter) "x" ^^ (Scan.many is_valid >> implode) - --| junk)) - -- Scan.repeat ((Scan.many1 is_valid >> implode) --| junk) >> op ::); - in explode #> scan_valids #> space_implode "_" end; + val thy = Thm.theory_of_thm thm; + val thms = (#mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of) thy thm; + in + map check_rew thms + end; + + +(* making function theorems *) + +val typ_func = lift_thm_thy (fn thy => snd o dest_Const o fst o strip_comb + o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of); + +val dest_func = lift_thm_thy (fn thy => apfst dest_Const o strip_comb + o fst o Logic.dest_equals o ObjectLogic.drop_judgment thy o Drule.plain_prop_of + o Drule.fconv_rule Drule.beta_eta_conversion); + +val mk_head = lift_thm_thy (fn thy => fn thm => + ((CodegenConsts.norm_of_typ thy o fst o dest_func) thm, thm)); -val purify_lower = - explode - #> (fn cs => (if forall Symbol.is_ascii_upper cs - then map else nth_map 0) Symbol.to_ascii_lower cs) - #> implode; +fun gen_check_func strict_functyp thm = case try dest_func thm + of SOME (c_ty as (c, ty), args) => + let + val thy = Thm.theory_of_thm thm; + val _ = + if has_duplicates (op =) + ((fold o fold_aterms) (fn Var (v, _) => cons v + | _ => I + ) args []) + then bad_thm "Repeated variables on left hand side of function equation" thm + else () + fun no_abs (Abs _) = bad_thm "Abstraction on left hand side of function equation" thm + | no_abs (t1 $ t2) = (no_abs t1; no_abs t2) + | no_abs _ = (); + val _ = map no_abs args; + val is_classop = (is_some o AxClass.class_of_param thy) c; + val const = CodegenConsts.norm_of_typ thy c_ty; + val ty_decl = CodegenConsts.disc_typ_of_const thy + (snd o CodegenConsts.typ_of_inst thy) const; + val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy); + val error_warning = if strict_functyp + then error + else warning #> K NONE + in if Sign.typ_equiv thy (ty_decl, ty) + then SOME (const, thm) + else (if is_classop + then error_warning + else if Sign.typ_equiv thy (Type.strip_sorts ty_decl, Type.strip_sorts ty) + then warning #> (K o SOME) (const, thm) + else error_warning) + ("Type\n" ^ string_of_typ ty + ^ "\nof function theorem\n" + ^ string_of_thm thm + ^ "\nis strictly less general than declared function type\n" + ^ string_of_typ ty_decl) + end + | NONE => bad_thm "Not a function equation" thm; -fun purify_var "" = "x" - | purify_var v = (purify_name #> purify_lower) v; +val check_func = gen_check_func true; +val legacy_check_func = gen_check_func false; -fun expand_eta thy k thm = +fun check_typ_classop thm = let + val thy = Thm.theory_of_thm thm; + val (c_ty as (c, ty), _) = dest_func thm; + in case AxClass.class_of_param thy c + of SOME class => let + val const = CodegenConsts.norm_of_typ thy c_ty; + val ty_decl = CodegenConsts.disc_typ_of_const thy + (snd o CodegenConsts.typ_of_inst thy) const; + val string_of_typ = setmp show_sorts true (Sign.string_of_typ thy); + in if Sign.typ_equiv thy (ty_decl, ty) + then thm + else error + ("Type\n" ^ string_of_typ ty + ^ "\nof function theorem\n" + ^ string_of_thm thm + ^ "\nis strictly less general than declared function type\n" + ^ string_of_typ ty_decl) + end + | NONE => thm + end; + +fun gen_mk_func check_func = map_filter check_func o mk_rew; +val mk_func = gen_mk_func check_func; +val legacy_mk_func = gen_mk_func legacy_check_func; + + +(* utilities *) + +fun expand_eta k thm = + let + val thy = Thm.theory_of_thm thm; val (lhs, rhs) = (Logic.dest_equals o Drule.plain_prop_of) thm; val (head, args) = strip_comb lhs; val l = if k = ~1 @@ -48,7 +157,7 @@ fun get_name _ 0 used = ([], used) | get_name (Abs (v, ty, t)) k used = used - |> Name.variants [purify_var v] + |> Name.variants [v] ||>> get_name t (k - 1) |>> (fn ([v'], vs') => (v', ty) :: vs') | get_name t k used = @@ -68,4 +177,43 @@ fold (fn refl => fn thm => Thm.combination thm refl) vs_refl thm end; +fun get_prim_def_funcs thy c = + let + fun constrain thm0 thm = case AxClass.class_of_param thy (fst c) + of SOME _ => + let + val ty_decl = CodegenConsts.disc_typ_of_classop thy c; + val max = maxidx_of_typ ty_decl + 1; + val thm = Thm.incr_indexes max thm; + val ty = typ_func thm; + val (env, _) = Sign.typ_unify thy (ty_decl, ty) (Vartab.empty, max); + val instT = Vartab.fold (fn (x_i, (sort, ty)) => + cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env []; + in Thm.instantiate (instT, []) thm end + | NONE => thm + in case CodegenConsts.find_def thy c + of SOME ((_, thm), _) => + thm + |> Thm.transfer thy + |> try (map snd o mk_func) + |> these + |> map (constrain thm) + |> map (expand_eta ~1) + | NONE => [] + end; + +fun rewrite_func rewrites thm = + let + val rewrite = MetaSimplifier.rewrite false rewrites; + val (ct_eq, [ct_lhs, ct_rhs]) = (Drule.strip_comb o Thm.cprop_of) thm; + val Const ("==", _) = Thm.term_of ct_eq; + val (ct_f, ct_args) = Drule.strip_comb ct_lhs; + val rhs' = rewrite ct_rhs; + val args' = map rewrite ct_args; + val lhs' = Thm.symmetric (fold (fn th1 => fn th2 => Thm.combination th2 th1) + args' (Thm.reflexive ct_f)); + in + Thm.transitive (Thm.transitive lhs' thm) rhs' + end handle Bind => raise ERROR "rewrite_func" + end; diff -r 979671292fbe -r 8e19bad4125f src/Pure/Tools/nbe.ML --- a/src/Pure/Tools/nbe.ML Tue Jan 09 08:31:47 2007 +0100 +++ b/src/Pure/Tools/nbe.ML Tue Jan 09 08:31:48 2007 +0100 @@ -72,7 +72,7 @@ let val ctxt = ProofContext.init thy; val pres = (map (LocalDefs.meta_rewrite_rule ctxt) o fst) (NBE_Rewrite.get thy) - in map (CodegenData.rewrite_func pres) end + in map (CodegenFunc.rewrite_func pres) end fun apply_posts thy = let