# HG changeset patch # User haftmann # Date 1139558947 -3600 # Node ID 14c1b2f5dda47eaa5cfe16cf5738e986e1241fae # Parent 0f7b92f75df78be4cfa18b7c5065f9baddc01a6e improved code generator devarification diff -r 0f7b92f75df7 -r 14c1b2f5dda4 src/HOL/Product_Type.thy --- a/src/HOL/Product_Type.thy Fri Feb 10 02:22:59 2006 +0100 +++ b/src/HOL/Product_Type.thy Fri Feb 10 09:09:07 2006 +0100 @@ -780,7 +780,7 @@ "snd" ("snd") code_alias - "*" "Product_Type.*" + "*" "Product_Type.pair" "Pair" "Product_Type.Pair" "fst" "Product_Type.fst" "snd" "Product_Type.snd" diff -r 0f7b92f75df7 -r 14c1b2f5dda4 src/HOL/Sum_Type.thy --- a/src/HOL/Sum_Type.thy Fri Feb 10 02:22:59 2006 +0100 +++ b/src/HOL/Sum_Type.thy Fri Feb 10 09:09:07 2006 +0100 @@ -227,5 +227,11 @@ val basic_monos = thms "basic_monos"; *} +subsection {* Codegenerator setup *} + +code_alias + "+" "Sum_Type.sum" + "Inr" "Sum_Type.Inr" + "Inl" "Sum_Type.Inl" end diff -r 0f7b92f75df7 -r 14c1b2f5dda4 src/HOL/Tools/datatype_codegen.ML --- a/src/HOL/Tools/datatype_codegen.ML Fri Feb 10 02:22:59 2006 +0100 +++ b/src/HOL/Tools/datatype_codegen.ML Fri Feb 10 09:09:07 2006 +0100 @@ -297,13 +297,17 @@ | datatype_tycodegen _ _ _ _ _ _ _ = NONE; -val setup = +val setup = add_codegen "datatype" datatype_codegen #> add_tycodegen "datatype" datatype_tycodegen #> CodegenPackage.set_get_datatype DatatypePackage.get_datatype #> CodegenPackage.set_get_all_datatype_cons DatatypePackage.get_all_datatype_cons #> + (fn thy => thy |> CodegenPackage.add_eqextr_default ("equality", + (CodegenPackage.eqextr_eq + DatatypePackage.get_eq_equations + (Sign.read_term thy "False")))) #> CodegenPackage.ensure_datatype_case_consts DatatypePackage.get_datatype_case_consts DatatypePackage.get_case_const_data; diff -r 0f7b92f75df7 -r 14c1b2f5dda4 src/HOL/Tools/datatype_package.ML --- a/src/HOL/Tools/datatype_package.ML Fri Feb 10 02:22:59 2006 +0100 +++ b/src/HOL/Tools/datatype_package.ML Fri Feb 10 09:09:07 2006 +0100 @@ -70,6 +70,7 @@ val get_datatype_case_consts : theory -> string list val get_case_const_data : theory -> string -> (string * int) list option val get_all_datatype_cons : theory -> (string * string) list + val get_eq_equations: theory -> string -> thm list val constrs_of : theory -> string -> term list option val case_const_of : theory -> string -> term option val weak_case_congs_of : theory -> thm list @@ -159,6 +160,32 @@ (fn (co, _) => cons (co, dtco)) ((snd o the oo get_datatype) thy dtco)) (get_datatypes thy) []; +fun get_eq_equations thy dtco = + case get_datatype thy dtco + of SOME (vars, cos) => + let + fun co_inject thm = + ((fst o dest_Const o fst o strip_comb o fst o HOLogic.dest_eq o fst + o HOLogic.dest_eq o HOLogic.dest_Trueprop o Thm.prop_of) thm, thm RS HOL.eq_reflection); + val inject = (map co_inject o #inject o the o datatype_info thy) dtco; + fun mk_refl co = + let + fun infer t = + (fst o Sign.infer_types (Sign.pp thy) thy (Sign.consts_of thy) (K NONE) (K NONE) [] true) + ([t], Type (dtco, map (fn (v, sort) => TVar ((v, 0), sort)) vars)) + val t = (Thm.cterm_of thy o infer) (Const (co, dummyT)); + in + HOL.refl + |> Drule.instantiate' [(SOME o Thm.ctyp_of_term) t] [SOME t] + |> (fn thm => thm RS Eq_TrueI) + end; + fun get_eq co = + case AList.lookup (op =) inject co + of SOME eq => eq + | NONE => mk_refl co; + in map (get_eq o fst) cos end + | NONE => []; + fun find_tname var Bi = let val frees = map dest_Free (term_frees Bi) val params = rename_wrt_term Bi (Logic.strip_params Bi); diff -r 0f7b92f75df7 -r 14c1b2f5dda4 src/Pure/Tools/codegen_package.ML --- a/src/Pure/Tools/codegen_package.ML Fri Feb 10 02:22:59 2006 +0100 +++ b/src/Pure/Tools/codegen_package.ML Fri Feb 10 09:09:07 2006 +0100 @@ -10,7 +10,9 @@ sig type auxtab; type eqextr = theory -> auxtab - -> (string * typ) -> (thm list * typ) option; + -> string * typ -> (thm list * typ) option; + type eqextr_default = theory -> auxtab + -> string * typ -> ((thm list * term option) * typ) option; type defgen; type appgen = theory -> auxtab -> (string * typ) * term list -> CodegenThingol.transact @@ -19,6 +21,7 @@ val add_appconst: string * ((int * int) * appgen) -> theory -> theory; val add_appconst_i: xstring * ((int * int) * appgen) -> theory -> theory; val add_eqextr: string * eqextr -> theory -> theory; + val add_eqextr_default: string * eqextr_default -> theory -> theory; val add_prim_class: xstring -> (string * string) -> theory -> theory; val add_prim_tyco: xstring -> (string * string) @@ -43,12 +46,14 @@ -> appgen; val appgen_number_of: (term -> term) -> (theory -> term -> IntInf.int) -> string -> string -> appgen; + val eqextr_eq: (theory -> string -> thm list) -> term + -> eqextr_default; val add_case_const: (theory -> string -> (string * int) list option) -> xstring -> theory -> theory; val add_case_const_i: (theory -> string -> (string * int) list option) -> string -> theory -> theory; - val print_codegen_generated: theory -> unit; + val print_code: theory -> unit; val rename_inconsistent: theory -> theory; val ensure_datatype_case_consts: (theory -> string list) -> (theory -> string -> (string * int) list option) @@ -91,12 +96,15 @@ (* code generator basics *) +val alias_ref = ref (fn thy : theory => fn s : string => s); +fun alias_get name = ! alias_ref name; + structure InstNameMangler = NameManglerFun ( type ctxt = theory; type src = string * (class * string); val ord = prod_ord string_ord (prod_ord string_ord string_ord); fun mk thy ((thyname, (cls, tyco)), i) = - NameSpace.base cls ^ "_" ^ NameSpace.base tyco ^ implode (replicate i "'") + (NameSpace.base o alias_get thy) cls ^ "_" ^ (NameSpace.base o alias_get thy) tyco ^ implode (replicate i "'") |> NameSpace.append thyname; fun is_valid _ _ = true; fun maybe_unique _ _ = NONE; @@ -110,7 +118,7 @@ fun mk thy ((c, (ty_decl, ty)), i) = let fun mangle (Type (tyco, tys)) = - NameSpace.base tyco :: Library.flat (List.mapPartial mangle tys) |> SOME + (NameSpace.base o alias_get thy) tyco :: Library.flat (List.mapPartial mangle tys) |> SOME | mangle _ = NONE in @@ -158,7 +166,9 @@ * (InstNameMangler.T * ((typ * typ list) Symtab.table * ConstNameMangler.T) * DatatypeconsNameMangler.T); type eqextr = theory -> auxtab - -> (string * typ) -> (thm list * typ) option; + -> string * typ -> (thm list * typ) option; +type eqextr_default = theory -> auxtab + -> string * typ -> ((thm list * term option) * typ) option; type defgen = theory -> auxtab -> gen_defgen; type appgen = theory -> auxtab -> (string * typ) * term list -> transact -> iexpr * transact; @@ -191,7 +201,7 @@ type gens = { appconst: ((int * int) * (appgen * stamp)) Symtab.table, - eqextrs: (string * (eqextr * stamp)) list + eqextrs: (string * (eqextr_default * stamp)) list }; fun map_gens f { appconst, eqextrs } = @@ -310,11 +320,11 @@ in CodegenData.put { modl = modl, gens = gens, target_data = target_data, logic_data = logic_data } thy end; -fun print_codegen_generated thy = +fun print_code thy = let val module = (#modl o CodegenData.get) thy; in - (writeln o Pretty.output o Pretty.chunks) [pretty_module module, pretty_deps module] + (Pretty.writeln o Pretty.chunks) [pretty_module module, pretty_deps module] end; @@ -329,7 +339,7 @@ (tab |> Symtab.update (src, dst), tab_rev |> Symtab.update (dst, src)))))); -val alias_get = perhaps o Symtab.lookup o fst o #alias o #logic_data o CodegenData.get; +val _ = alias_ref := (perhaps o Symtab.lookup o fst o #alias o #logic_data o CodegenData.get); val alias_rev = perhaps o Symtab.lookup o snd o #alias o #logic_data o CodegenData.get; fun add_nsp shallow name = @@ -347,7 +357,7 @@ val (modl, shallow) = split_last idf''; in if nsp = shallow - then (SOME o NameSpace.pack) (modl @ [idf_base]) + then (SOME o NameSpace.pack) (modl @ [idf_base]) else NONE end; @@ -427,11 +437,22 @@ (fn (appconst, eqextrs) => (appconst, eqextrs |> Output.update_warn (op =) ("overwriting existing equation extractor " ^ name) + (name, ((Option.map o apfst o rpair) NONE ooo eqx , stamp ())))), + target_data, logic_data)); + +fun add_eqextr_default (name, eqx) = + map_codegen_data + (fn (modl, gens, target_data, logic_data) => + (modl, + gens |> map_gens + (fn (appconst, eqextrs) => + (appconst, eqextrs + |> Output.update_warn (op =) ("overwriting existing equation extractor " ^ name) (name, (eqx, stamp ())))), target_data, logic_data)); fun get_eqextrs thy tabs = - (map (fn (_, (eqx, _)) => eqx thy tabs) o #eqextrs o #gens o CodegenData.get) thy; + (map (fn (name, (eqx, _)) => (name, eqx thy tabs)) o #eqextrs o #gens o CodegenData.get) thy; fun set_get_all_datatype_cons f = map_codegen_data @@ -465,7 +486,7 @@ |> Symtab.update ( #ml CodegenSerializer.serializers |> apsnd (fn seri => seri - (nsp_dtcon, nsp_class, fn tyco' => tyco' = idf_of_name thy nsp_tyco tyco ) + (nsp_dtcon, nsp_class, fn tyco' => tyco' = idf_of_name thy nsp_tyco tyco ) [[nsp_module], [nsp_class, nsp_tyco], [nsp_const, nsp_overl, nsp_dtcon, nsp_mem, nsp_inst]] ) ) @@ -474,27 +495,28 @@ (* sophisticated devarification *) -fun assert f msg x = - if f x then x - else error msg; - -val _ : ('a -> bool) -> string -> 'a -> 'a = assert; +fun eq_typ thy (ty1, ty2) = + Sign.typ_instance thy (ty1, ty2) + andalso Sign.typ_instance thy (ty2, ty1); fun devarify_typs tys = let - fun add_rename (var as ((v, _), sort)) used = + fun add_rename (vi as (v, _), sorts) used = let val v' = "'" ^ variant used (unprefix "'" v) - in (((var, TFree (v', sort)), (v', TVar var)), v' :: used) end; + in (map (fn sort => (((vi, sort), TFree (v', sort)), (v', TVar (vi, sort)))) sorts, v' :: used) end; fun typ_names (Type (tyco, tys)) (vars, names) = (vars, names |> insert (op =) (NameSpace.base tyco)) |> fold typ_names tys | typ_names (TFree (v, _)) (vars, names) = (vars, names |> insert (op =) (unprefix "'" v)) | typ_names (TVar (vi, sort)) (vars, names) = - (vars |> AList.update (op =) (vi, sort), names); + (vars + |> AList.default (op =) (vi, []) + |> AList.map_entry (op =) vi (cons sort), + names); val (vars, used) = fold typ_names tys ([], []); - val (renames, reverse) = fold_map add_rename vars used |> fst |> split_list; + val (renames, reverse) = fold_map add_rename vars used |> fst |> Library.flat |> split_list; in (reverse, map (Term.instantiateT renames) tys) end; @@ -513,16 +535,19 @@ fun devarify_terms ts = let - fun add_rename (var as ((v, _), ty)) used = + fun add_rename (vi as (v, _), tys) used = let val v' = variant used v - in (((var, Free (v', ty)), (v', Var var)), v' :: used) end; + in (map (fn ty => (((vi, ty), Free (v', ty)), (v', Var (vi, ty)))) tys, v' :: used) end; fun term_names (Const (c, _)) (vars, names) = (vars, names |> insert (op =) (NameSpace.base c)) | term_names (Free (v, _)) (vars, names) = (vars, names |> insert (op =) v) - | term_names (Var (v, sort)) (vars, names) = - (vars |> AList.update (op =) (v, sort), names) + | term_names (Var (vi, ty)) (vars, names) = + (vars + |> AList.default (op =) (vi, []) + |> AList.map_entry (op =) vi (cons ty), + names) | term_names (Bound _) vars_names = vars_names | term_names (Abs (v, _, _)) (vars, names) = @@ -530,7 +555,7 @@ | term_names (t1 $ t2) vars_names = vars_names |> term_names t1 |> term_names t2 val (vars, used) = fold term_names ts ([], []); - val (renames, reverse) = fold_map add_rename vars used |> fst |> split_list; + val (renames, reverse) = fold_map add_rename vars used |> fst |> Library.flat |> split_list; in (reverse, map (Term.instantiate ([], renames)) ts) end; @@ -576,7 +601,7 @@ fun defgen_datatype thy (tabs as (_, (_, _, dtcontab))) dtco trns = case name_of_idf thy nsp_tyco dtco of SOME dtco => - (case get_datatype thy dtco + (case get_datatype thy dtco of SOME (vars, cos) => let val cos' = map (fn (co, tys) => (DatatypeconsNameMangler.get thy dtcontab (co, dtco) |> @@ -635,8 +660,8 @@ |> fold_map (ensure_def_class thy tabs) clss |-> (fn clss => pair (Lookup (clss, (v |> unprefix "'", i)))) and mk_fun thy tabs (c, ty) trns = - case get_first (fn eqx => eqx (c, ty)) (get_eqextrs thy tabs) - of SOME (eq_thms, ty) => + case get_first (fn (name, eqx) => (eqx (c, ty))) (get_eqextrs thy tabs) + of SOME ((eq_thms, default), ty) => let val sortctxt = ClassPackage.extract_sortctxt thy ty; fun dest_eqthm eq_thm = @@ -649,12 +674,22 @@ ^ ", actually defining " ^ quote c') | _ => error ("illegal function equation for " ^ quote c) end; + fun mk_default t = + let + val (tys, ty') = strip_type ty; + val vs = Term.invent_names (add_term_names (t, [])) "x" (length tys); + in + if (not o eq_typ thy) (type_of t, ty') + then error ("inconsistent type for default rule") + else (map2 (curry Free) vs tys, t) + end; in trns |> (codegen_eqs thy tabs o map dest_eqthm) eq_thms + ||>> (codegen_eqs thy tabs o the_list o Option.map mk_default) default ||>> codegen_type thy tabs [ty] ||>> fold_map (exprgen_tyvar_sort thy tabs) sortctxt - |-> (fn ((eqs, [ty]), sortctxt) => (pair o SOME) (eqs, (sortctxt, ty))) + |-> (fn (((eqs, eq_default), [ty]), sortctxt) => (pair o SOME) (eqs @ eq_default, (sortctxt, ty))) end | NONE => (NONE, trns) and ensure_def_inst thy (tabs as (_, (insttab, _, _))) (cls, tyco) trns = @@ -825,49 +860,14 @@ trns |> appgen_default thy tabs ((f, ty), ts); -(* fun ensure_def_eq thy tabs (dtco, (eqpred, arity)) trns = - let - val name_dtco = (the ooo name_of_idf) thy nsp_tyco dtco; - val idf_eqinst = idf_of_name thy nsp_eq_inst name_dtco; - val idf_eqpred = idf_of_name thy nsp_eq_pred name_dtco; - val inst_sortlookup = map (fn (v, _) => [ClassPackage.Lookup ([], (v, 0))]) arity; - fun mk_eq_pred _ trns = - trns - |> succeed (eqpred) - fun mk_eq_inst _ trns = - trns - |> gen_ensure_def [("eqpred", mk_eq_pred)] ("generating equality predicate for " ^ quote dtco) idf_eqpred - |> succeed (Classinst ((class_eq, (dtco, arity)), ([], [(fun_eq, (idf_eqpred, inst_sortlookup))]))); - in - trns - |> gen_ensure_def [("eqinst", mk_eq_inst)] ("generating equality instance for " ^ quote dtco) idf_eqinst - end; *) - -(* expression generators *) - -(* fun appgen_eq thy tabs (("op =", Type ("fun", [ty, _])), [t1, t2]) trns = - trns - |> invoke_eq (exprgen_type thy tabs) (ensure_def_eq thy tabs) ty - |-> (fn false => error ("could not derive equality for " ^ Sign.string_of_typ thy ty) - | true => fn trns => trns - |> exprgen_term thy tabs t1 - ||>> exprgen_term thy tabs t2 - |-> (fn (e1, e2) => pair (Fun_eq `$ e1 `$ e2))); *) - (* function extractors *) fun eqextr_defs thy ((deftab, _), _) (c, ty) = - let - fun eq_typ (ty1, ty2) = - Sign.typ_instance thy (ty1, ty2) - andalso Sign.typ_instance thy (ty2, ty1) - in - Option.mapPartial (get_first (fn (ty', thm) => if eq_typ (ty, ty') - then SOME ([thm], ty') - else NONE - )) (Symtab.lookup deftab c) - end; + Option.mapPartial (get_first (fn (ty', thm) => if eq_typ thy (ty, ty') + then SOME ([thm], ty') + else NONE + )) (Symtab.lookup deftab c); (* parametrized generators, for instantiation in HOL *) @@ -916,6 +916,17 @@ |> exprgen_term thy tabs (mk_int_to_nat bin) else error ("invalid type constructor for numeral: " ^ quote tyco); +fun eqextr_eq f fals thy tabs ("op =", ty) = + (case ty + of Type ("fun", [Type (dtco, _), _]) => + (case f thy dtco + of [] => NONE + | [eq] => SOME ((Codegen.preprocess thy [eq], NONE), ty) + | eqs => SOME ((Codegen.preprocess thy eqs, SOME fals), ty)) + | _ => NONE) + | eqextr_eq f fals thy tabs _ = + NONE; + fun appgen_datatype_case cos thy tabs ((_, ty), ts) trns = let val (ts', t) = split_last ts; @@ -972,7 +983,7 @@ in if forall is_Var args then SOME ((c, ty), tm) else NONE end handle TERM _ => NONE; fun prep_def def = (case Codegen.preprocess thy [def] of - [def'] => def' | _ => error "mk_auxtab: bad preprocessor"); + [def'] => def' | _ => error "mk_tabs: bad preprocessor"); fun add_def (name, _) = case (dest o prep_def o Thm.get_axiom thy) name of SOME ((c, ty), tm) => @@ -990,6 +1001,22 @@ (fn (tyco, thyname) => InstNameMangler.declare thy (thyname, (cls, tyco))) clsinsts) (ClassPackage.get_classtab thy) |-> (fn _ => I); + fun add_monoeq thy (overltab1, overltab2) = + let + val c = "op ="; + val ty = Sign.the_const_type thy c; + fun inst dtco = + map_atyps (fn _ => Type (dtco, + (map (fn (v, sort) => TVar ((v, 0), sort)) o fst o the o get_datatype thy) dtco)) ty + val dtcos = fold (insert (op =) o snd) (get_all_datatype_cons thy) []; + val tys = map inst dtcos; + in + (overltab1 + |> Symtab.update_new (c, (ty, tys)), + overltab2 + |> fold (fn ty' => ConstNameMangler.declare thy + (idf_of_name thy nsp_overl c, (ty, ty')) #> snd) tys) + end; fun mk_overltabs thy deftab = (Symtab.empty, ConstNameMangler.empty) |> Symtab.fold @@ -998,19 +1025,20 @@ if (is_none o ClassPackage.lookup_const_class thy) c then (fn (overltab1, overltab2) => ( overltab1 - |> Symtab.update_new (c, (Sign.the_const_constraint thy c, map fst tytab)), + |> Symtab.update_new (c, (Sign.the_const_type thy c, map fst tytab)), overltab2 |> fold (fn (ty, _) => ConstNameMangler.declare thy - (idf_of_name thy nsp_overl c, (Sign.the_const_constraint thy c, ty)) #> snd) tytab)) + (idf_of_name thy nsp_overl c, (Sign.the_const_type thy c, ty)) #> snd) tytab)) else I - ) deftab; + ) deftab + |> add_monoeq thy; fun mk_dtcontab thy = DatatypeconsNameMangler.empty |> fold_map (fn (_, co_dtco) => DatatypeconsNameMangler.declare_multi thy co_dtco) (fold (fn (co, dtco) => let - val key = ((NameSpace.drop_base o NameSpace.drop_base) co, NameSpace.base co) + val key = ((NameSpace.drop_base o NameSpace.drop_base) co, NameSpace.base co); in AList.default (op =) (key, []) #> AList.map_entry (op =) key (cons (co, dtco)) end ) (get_all_datatype_cons thy) []) |-> (fn _ => I); @@ -1030,7 +1058,7 @@ fun get_serializer target = case Symtab.lookup (!serializers) target of SOME seri => seri - | NONE => error ("unknown code target language: " ^ quote target); + | NONE => Scan.fail_with (fn _ => "unknown code target language: " ^ quote target) (); fun map_module f = map_codegen_data (fn (modl, gens, target_data, logic_data) => @@ -1094,7 +1122,7 @@ then () else error ("no such constant: " ^ quote c); val ty = case raw_ty - of NONE => Sign.the_const_constraint thy c + of NONE => Sign.the_const_type thy c | SOME raw_ty => read_typ thy raw_ty; in (c, ty) end; @@ -1127,7 +1155,7 @@ (fn thy => fn tabs => idf_of_const thy tabs o read_const thy) CodegenSerializer.parse_targetdef; -val ensure_prim = (map_module oo CodegenThingol.ensure_prim); +val ensure_prim = map_module oo CodegenThingol.ensure_prim; (* syntax *) @@ -1235,7 +1263,7 @@ (** toplevel interface **) local - + fun generate_code (SOME raw_consts) thy = let val consts = map (read_const thy) raw_consts; @@ -1363,7 +1391,7 @@ P.name -- parse_syntax_const raw_const )) ) - >> (Toplevel.theory oo fold o fold) + >> (Toplevel.theory oo fold o fold) (fn (target, modifier) => modifier target) ); @@ -1376,7 +1404,6 @@ val _ = Context.add_setup ( add_eqextr ("defs", eqextr_defs) -(* add_appconst_i ("op =", ((2, 2), appgen_eq)) *) ); end; (* local *) diff -r 0f7b92f75df7 -r 14c1b2f5dda4 src/Pure/Tools/codegen_serializer.ML --- a/src/Pure/Tools/codegen_serializer.ML Fri Feb 10 02:22:59 2006 +0100 +++ b/src/Pure/Tools/codegen_serializer.ML Fri Feb 10 09:09:07 2006 +0100 @@ -98,19 +98,10 @@ end; in mk (const_syntax c) es end; -val _ : (string -> iexpr list -> Pretty.T list) - -> (fixity -> iexpr -> Pretty.T) - -> (string - -> ((int * int) - * (fixity - -> (fixity -> iexpr -> Pretty.T) - -> iexpr list -> Pretty.T)) option) - -> fixity -> string * iexpr list -> Pretty.T = from_app; - fun fillin_mixfix fxy_this ms fxy_ctxt pr args = let fun fillin [] [] = - [] + [] | fillin (Arg fxy :: ms) (a :: args) = pr fxy a :: fillin ms args | fillin (Ignore :: ms) args = @@ -118,7 +109,11 @@ | fillin (Pretty p :: ms) args = p :: fillin ms args | fillin (Quote q :: ms) args = - pr BR q :: fillin ms args; + pr BR q :: fillin ms args + | fillin [] _ = + error ("inconsistent mixfix: too many arguments") + | fillin _ [] = + error ("inconsistent mixfix: too less arguments"); in gen_brackify (eval_fxy fxy_this fxy_ctxt) (fillin ms args) end;