# HG changeset patch # User haftmann # Date 1155799496 -7200 # Node ID 21227c43ba26f820c5fc06680f0ebbca9a153a67 # Parent df3252bbc0e6780f5118acd20e7a261ace14f1dd improved thmtab diff -r df3252bbc0e6 -r 21227c43ba26 src/Pure/Tools/codegen_theorems.ML --- a/src/Pure/Tools/codegen_theorems.ML Thu Aug 17 09:24:51 2006 +0200 +++ b/src/Pure/Tools/codegen_theorems.ML Thu Aug 17 09:24:56 2006 +0200 @@ -24,6 +24,9 @@ val common_typ: theory -> (thm -> typ) -> thm list -> thm list; val preprocess: theory -> thm list -> thm list; + val prove_freeness: theory -> tactic -> string + -> (string * sort) list * (string * typ list) list -> thm list; + type thmtab; val mk_thmtab: theory -> (string * typ) list -> thmtab; val get_sortalgebra: thmtab -> Sorts.algebra; @@ -32,6 +35,7 @@ -> ((string * sort) list * (string * typ list) list) option; val get_fun_thms: thmtab -> string * typ -> thm list; + val pretty_funtab: theory -> thm list CodegenConsts.Consttab.table -> Pretty.T; val print_thms: theory -> unit; val init_obj: (thm * thm) * (thm * thm) -> theory -> theory; @@ -81,7 +85,7 @@ fun init_obj ((TrueI, FalseE), (conjI, atomize_eq)) thy = case CodegenTheoremsSetup.get thy - of SOME _ => error "code generator already set up for object logic" + of SOME _ => error "Code generator already set up for object logic" | NONE => let fun strip_implies t = (Logic.strip_imp_prems t, Logic.strip_imp_concl t); @@ -114,7 +118,7 @@ #> apsnd (map Term.dest_Var) #> apfst Term.dest_Const ) - |> (fn (v1, ((conj, _), v2)) => if v1 = v2 then conj else error "wrong premise") + |> (fn (v1, ((conj, _), v2)) => if v1 = v2 then conj else error "Wrong premise") fun dest_atomize_eq thm= Drule.plain_prop_of thm |> Logic.dest_equals @@ -130,10 +134,10 @@ #> apsnd Term.dest_Var ) |> (fn (((eq, _), v2), (v1a as (_, TVar (_, sort)), v1b)) => - if [v1a, v1b] = v2 andalso sort = Sign.defaultS thy then eq else error "wrong premise") + if [v1a, v1b] = v2 andalso sort = Sign.defaultS thy then eq else error "Wrong premise") in ((dest_TrueI TrueI, [dest_FalseE FalseE, dest_conjI conjI, dest_atomize_eq atomize_eq]) - handle _ => error "bad code generator setup") + handle _ => error "Bad code generator setup") |> (fn ((tr, b), [fl, con, eq]) => CodegenTheoremsSetup.put (SOME ((b, atomize_eq), ((tr, fl), (con, eq)))) thy) end; @@ -141,7 +145,7 @@ fun get_obj thy = case CodegenTheoremsSetup.get thy of SOME ((b, atomize), x) => ((Type (b, []), atomize) ,x) - | NONE => error "no object logic setup for code theorems"; + | NONE => error "No object logic setup for code theorems"; fun mk_true thy = let @@ -260,14 +264,14 @@ case try (make_eq thy #> Drule.plain_prop_of #> ObjectLogic.drop_judgment thy #> Logic.dest_equals) thm of SOME eq => (eq, thm) - | NONE => err_thm "not an equation" thm; + | NONE => err_thm "Not an equation" thm; fun dest_fun thy thm = let fun dest_fun' ((lhs, _), thm) = case try (dest_Const o fst o strip_comb) lhs of SOME (c, ty) => (c, (ty, thm)) - | NONE => err_thm "not a function equation" thm; + | NONE => err_thm "Not a function equation" thm; in thm |> dest_eq thy @@ -280,21 +284,23 @@ (* data structures *) +structure Consttab = CodegenConsts.Consttab; + fun merge' eq (xys as (xs, ys)) = if eq_list eq (xs, ys) then (false, xs) else (true, merge eq xys); fun alist_merge' eq_key eq (xys as (xs, ys)) = if eq_list (eq_pair eq_key eq) (xs, ys) then (false, xs) else (true, AList.merge eq_key eq xys); -fun list_symtab_join' eq (xyt as (xt, yt)) = +fun list_consttab_join' eq (xyt as (xt, yt)) = let - val xc = Symtab.keys xt; - val yc = Symtab.keys yt; - val zc = filter (member (op =) yc) xc; + val xc = Consttab.keys xt; + val yc = Consttab.keys yt; + val zc = filter (member CodegenConsts.eq_const yc) xc; val wc = subtract (op =) zc xc @ subtract (op =) zc yc; - fun same_thms c = if eq_list eq_thm ((the o Symtab.lookup xt) c, (the o Symtab.lookup yt) c) + fun same_thms c = if eq_list eq_thm ((the o Consttab.lookup xt) c, (the o Consttab.lookup yt) c) then NONE else SOME c; - in (wc @ map_filter same_thms zc, Symtab.join (K (merge eq)) xyt) end; + in (wc @ map_filter same_thms zc, Consttab.join (K (merge eq)) xyt) end; datatype notify = Notify of (serial * ((string * typ) list option -> theory -> theory)) list; @@ -337,7 +343,7 @@ datatype funthms = Funthms of { dirty: string list, - funs: thm list Symtab.table + funs: thm list Consttab.table }; fun mk_funthms (dirty, funs) = @@ -347,8 +353,8 @@ fun merge_funthms _ (Funthms { dirty = dirty1, funs = funs1 }, Funthms { dirty = dirty2, funs = funs2 }) = let - val (dirty3, funs) = list_symtab_join' eq_thm (funs1, funs2); - in mk_funthms (merge (op =) (merge (op =) (dirty1, dirty2), dirty3), funs) end; + val (dirty3, funs) = list_consttab_join' eq_thm (funs1, funs2); + in mk_funthms (merge (op =) (merge (op =) (dirty1, dirty2), map fst dirty3), funs) end; datatype T = T of { dirty: bool, @@ -380,7 +386,7 @@ val name = "Pure/codegen_theorems_data"; type T = T; val empty = mk_T ((false, mk_notify []), (mk_preproc ([], []), - (mk_extrs ([], []), mk_funthms ([], Symtab.empty)))); + (mk_extrs ([], []), mk_funthms ([], Consttab.empty)))); val copy = I; val extend = I; val merge = merge_T; @@ -388,7 +394,7 @@ let val pretty_thm = ProofContext.pretty_thm (ProofContext.init thy); val funthms = (fn T { funthms, ... } => funthms) data; - val funs = (Symtab.dest o (fn Funthms { funs, ... } => funs)) funthms; + val funs = (Consttab.dest o (fn Funthms { funs, ... } => funs)) funthms; val preproc = (fn T { preproc, ... } => preproc) data; val unfolds = (fn Preproc { unfolds, ... } => unfolds) preproc; in @@ -398,7 +404,7 @@ (*Pretty.fbreaks ( *) map (fn (c, thms) => (Pretty.block o Pretty.fbreaks) ( - Pretty.str c :: map pretty_thm (rev thms) + (Pretty.str o CodegenConsts.string_of_const thy) c :: map pretty_thm (rev thms) ) ) funs (*) *) @ [ @@ -437,7 +443,9 @@ (* notifiers *) fun all_typs thy c = - map (pair c) (Sign.the_const_type thy c :: (map (#lhs) o Theory.definitions_of thy) c); + let + val c_tys = (map (pair c o #lhs o snd) o Defs.specifications_of (Theory.defs_of thy)) c; + in (c, Sign.the_const_type thy c) :: map (CodegenConsts.typ_of_typinst thy) c_tys end; fun add_notify f = map_data (fn ((dirty, notify), x) => @@ -489,20 +497,20 @@ fun add_fun thm thy = case dest_fun thy thm - of (c, _) => + of (c, (ty, _)) => thy |> map_data (fn (x, (preproc, (extrs, funthms))) => (x, (preproc, (extrs, funthms |> map_funthms (fn (dirty, funs) => - (dirty, funs |> Symtab.default (c, []) |> Symtab.map_entry c (cons thm))))))) + (dirty, funs |> Consttab.map_default (CodegenConsts.norminst_of_typ thy (c, ty), []) (cons thm))))))) |> notify_all (SOME c); fun del_fun thm thy = case dest_fun thy thm - of (c, _) => + of (c, (ty, _)) => thy |> map_data (fn (x, (preproc, (extrs, funthms))) => (x, (preproc, (extrs, funthms |> map_funthms (fn (dirty, funs) => - (dirty, funs |> Symtab.map_entry c (remove eq_thm thm))))))) + (dirty, funs |> Consttab.map_entry (CodegenConsts.norminst_of_typ thy (c, ty)) (remove eq_thm thm))))))) |> notify_all (SOME c); fun add_unfold thm thy = @@ -523,9 +531,7 @@ thy |> map_data (fn (x, (preproc, (extrs, funthms))) => (x, (preproc, (extrs, funthms |> map_funthms (fn (dirty, funs) => - (dirty, funs |> Symtab.map_entry c - (filter (fn thm => Sign.typ_instance thy - ((fst o snd o dest_fun thy) thm, ty))))))))) + (dirty, funs |> Consttab.update (CodegenConsts.norminst_of_typ thy (c, ty), []))))))) |> notify_all (SOME c); @@ -556,7 +562,12 @@ in (thm', max') end; val (thms', maxidx) = fold_map incr_thm thms 0; val (ty1::tys) = map extract_typ thms; - fun unify ty = Sign.typ_unify thy (ty1, ty); + fun unify ty env = Sign.typ_unify thy (ty1, ty) env + handle Type.TUNIFY => + error ("Type unificaton failed, while unifying function equations\n" + ^ (cat_lines o map Display.string_of_thm) thms + ^ "\nwith types\n" + ^ (cat_lines o map (Sign.string_of_typ thy)) (ty1 :: tys)); val (env, _) = fold unify tys (Vartab.empty, maxidx) val instT = Vartab.fold (fn (x_i, (sort, ty)) => cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env []; @@ -611,41 +622,47 @@ fun get_funs thy (c, ty) = let val _ = debug_msg (fn _ => "[cg_thm] const (1) " ^ c ^ " :: " ^ Sign.string_of_typ thy ty) () - val filter_typ = map_filter (fn (_, (ty', thm)) => - if Sign.typ_instance thy (ty, ty') - then SOME thm else debug_msg (fn _ => "[cg_thm] dropping " ^ string_of_thm thm) NONE); + val postprocess_typ = case AxClass.class_of_param thy c + of NONE => map_filter (fn (_, (ty', thm)) => + if Sign.typ_instance thy (ty, ty') + then SOME thm else debug_msg (fn _ => "[cg_thm] dropping " ^ string_of_thm thm) NONE) + | SOME _ => let + (*FIXME make this more elegant*) + val ty' = CodegenConsts.typ_of_classop thy (CodegenConsts.norminst_of_typ thy (c, ty)); + val ct = Thm.cterm_of thy (Const (c, ty')); + val thm' = Thm.reflexive ct; + in map (snd o snd) #> cons thm' #> common_typ thy (extr_typ thy) #> tl end; fun get_funs (c, ty) = - (these o Symtab.lookup (the_funs thy)) c + (these o Consttab.lookup (the_funs thy) o CodegenConsts.norminst_of_typ thy) (c, ty) |> debug_msg (fn _ => "[cg_thm] trying funs") |> map (dest_fun thy) - |> filter_typ; + |> postprocess_typ; fun get_extr (c, ty) = getf_first_list (map (fn f => f thy) (the_funs_extrs thy)) (c, ty) |> debug_msg (fn _ => "[cg_thm] trying extr") |> map (dest_fun thy) - |> filter_typ; + |> postprocess_typ; fun get_spec (c, ty) = - Theory.definitions_of thy c + (CodegenConsts.find_def thy o CodegenConsts.norminst_of_typ thy) (c, ty) |> debug_msg (fn _ => "[cg_thm] trying spec") - (* FIXME avoid dynamic name space lookup!? (via Thm.get_axiom_i etc.??) *) - |> maps (fn { name, ... } => these (try (PureThy.get_thms thy) (Name name))) + |> Option.mapPartial (fn ((_, name), _) => try (Thm.get_axiom_i thy) name) + |> the_list |> map_filter (try (dest_fun thy)) - |> filter_typ; + |> postprocess_typ; in getf_first_list [get_funs, get_extr, get_spec] (c, ty) |> debug_msg (fn _ => "[cg_thm] const (2) " ^ c ^ " :: " ^ Sign.string_of_typ thy ty) |> preprocess thy end; -fun get_datatypes thy dtco = +fun prove_freeness thy tac dtco vs_cos = let - val _ = debug_msg (fn _ => "[cg_thm] datatype " ^ dtco) () val truh = mk_true thy; val fals = mk_false thy; fun mk_lhs vs ((co1, tys1), (co2, tys2)) = let val dty = Type (dtco, map TFree vs); - val (xs1, xs2) = chop (length tys1) (Name.invent_list [] "x" (length tys1 + length tys2)); + val (xs1, xs2) = chop (length tys1) (Name.invent_list [] "a" (length tys1 + length tys2)); val frees1 = map2 (fn x => fn ty => Free (x, ty)) xs1 tys1; val frees2 = map2 (fn x => fn ty => Free (x, ty)) xs2 tys2; fun zip_co co xs tys = list_comb (Const (co, @@ -667,13 +684,18 @@ fun mk_eqs (vs, cos) = let val cos' = rev cos in (op @) (fold (mk_eq vs) (product cos' cos') ([], [])) end; - fun mk_eq_thms tac vs_cos = - map (fn t => Goal.prove_global thy [] [] - (ObjectLogic.ensure_propT thy t) (K tac)) (mk_eqs vs_cos); + in + map (fn t => Goal.prove_global thy [] [] + (ObjectLogic.ensure_propT thy t) (K tac)) (mk_eqs vs_cos) + end; + +fun get_datatypes thy dtco = + let + val _ = debug_msg (fn _ => "[cg_thm] datatype " ^ dtco) () in case getf_first (map (fn f => f thy) (the_datatypes_extrs thy)) dtco of NONE => NONE - | SOME (vs_cos, tac) => SOME (vs_cos, mk_eq_thms tac vs_cos) + | SOME (vs_cos, tac) => SOME (vs_cos, prove_freeness thy tac dtco vs_cos) end; fun get_eq thy (c, ty) = @@ -691,13 +713,13 @@ fun check_head_lhs thm (lhs, rhs) = case strip_comb lhs of (Const (c', _), _) => if c' = c then () - else error ("illegal function equation for " ^ quote c + else error ("Illegal function equation for " ^ quote c ^ ", actually defining " ^ quote c' ^ ": " ^ Display.string_of_thm thm) - | _ => error ("illegal function equation: " ^ Display.string_of_thm thm); + | _ => error ("Illegal function equation: " ^ Display.string_of_thm thm); fun check_vars_lhs thm (lhs, rhs) = if has_duplicates (op =) (fold_aterms (fn Free (v, _) => cons v | _ => I) lhs []) - then error ("repeated variables on left hand side of function equation:" + then error ("Repeated variables on left hand side of function equation:" ^ Display.string_of_thm thm) else (); fun check_vars_rhs thm (lhs, rhs) = @@ -705,7 +727,7 @@ (fold_aterms (fn Free (v, _) => cons v | _ => I) lhs []) (fold_aterms (fn Free (v, _) => cons v | _ => I) rhs [])) then () - else error ("free variables on right hand side of function equation:" + else error ("Free variables on right hand side of function equation:" ^ Display.string_of_thm thm) val tts = map (Logic.dest_equals o Logic.unvarify o Thm.prop_of) thms; in