--- a/src/Pure/Tools/codegen_funcgr.ML Tue Oct 31 09:29:12 2006 +0100
+++ b/src/Pure/Tools/codegen_funcgr.ML Tue Oct 31 09:29:13 2006 +0100
@@ -9,11 +9,13 @@
signature CODEGEN_FUNCGR =
sig
type T;
- val mk_funcgr: theory -> CodegenConsts.const list -> (string * typ) list -> T
- val all_deps_of: T -> CodegenConsts.const list -> CodegenConsts.const list list
- val get_funcs: T -> CodegenConsts.const -> thm list
- val get_func_typs: T -> (CodegenConsts.const * typ) list
- val normalize: theory -> thm list -> thm list
+ val make: theory -> CodegenConsts.const list -> T
+ val make_term: theory -> cterm -> (cterm * (thm * (thm -> thm))) * T
+ val funcs: T -> CodegenConsts.const -> thm list
+ val typ: T -> CodegenConsts.const -> typ
+ val deps: T -> CodegenConsts.const list -> CodegenConsts.const list list
+ val all: T -> CodegenConsts.const list
+ val norm_vars: theory -> thm list -> thm list
val print_codethms: theory -> CodegenConsts.const list -> unit
structure Constgraph : GRAPH
end;
@@ -60,17 +62,13 @@
in
fold (fn ct => fn thm => Thm.combination thm (Thm.reflexive ct)) vs_ct thm
end;
- fun beta_norm thm =
- let
- val rhs = (snd o Logic.dest_equals o Drule.plain_prop_of) thm;
- val thm' = Thm.beta_conversion true (cterm_of thy rhs);
- in Thm.transitive thm thm' end;
in
thm
|> eta_expand
- |> beta_norm
+ |> Drule.fconv_rule Drule.beta_eta_conversion
end;
+
fun canonical_tvars thy thm =
let
fun mk_inst (v_i as (v, i), (v', sort)) (s as (maxidx, set, acc)) =
@@ -105,7 +103,7 @@
val (_, _, inst) = fold mk_inst (vars_of thm) (maxidx + 1, [], []);
in Thm.instantiate ([], inst) thm end;
-fun normalize thy thms =
+fun norm_vars thy thms =
let
fun burrow_thms f [] = []
| burrow_thms f thms =
@@ -129,13 +127,13 @@
(** retrieval **)
-fun get_funcs funcgr (c_tys as (c, _)) =
- (these o Option.map snd o try (Constgraph.get_node funcgr)) c_tys;
+fun funcs funcgr =
+ these o Option.map snd o try (Constgraph.get_node funcgr);
-fun get_func_typs funcgr =
- AList.make (fst o Constgraph.get_node funcgr) (Constgraph.keys funcgr);
+fun typ funcgr =
+ fst o Constgraph.get_node funcgr;
-fun all_deps_of funcgr cs =
+fun deps funcgr cs =
let
val conn = Constgraph.strong_conn funcgr;
val order = rev conn;
@@ -144,25 +142,24 @@
|> filter_out null
end;
+fun all funcgr = Constgraph.keys funcgr;
+
local
fun add_things_of thy f (c, thms) =
(fold o fold_aterms)
(fn Const c_ty => let
val c' = CodegenConsts.norm_of_typ thy c_ty
- in if CodegenConsts.eq_const (c, c') then I
+ in if is_some c andalso CodegenConsts.eq_const (the c, c') then I
else f (c', c_ty) end
| _ => I) (maps (op :: o swap o apfst (snd o strip_comb)
o Logic.dest_equals o Drule.plain_prop_of) thms)
fun rhs_of thy (c, thms) =
Consttab.empty
- |> add_things_of thy (Consttab.update o rpair () o fst) (c, thms)
+ |> add_things_of thy (Consttab.update o rpair () o fst) (SOME c, thms)
|> Consttab.keys;
-fun rhs_of' thy (c, thms) =
- add_things_of thy (cons o snd) (c, thms) [];
-
fun insts_of thy funcgr (c, ty) =
let
val tys = Sign.const_typargs thy (c, ty);
@@ -170,9 +167,6 @@
val ty_decl = CodegenConsts.disc_typ_of_const thy
(fst o Constgraph.get_node funcgr o CodegenConsts.norm thy) (c, tys);
val tys_decl = Sign.const_typargs thy (c, ty_decl);
- val pp = Sign.pp thy;
- val algebra = Sign.classes_of thy;
- fun classrel (x, _) _ = x;
fun constructor tyco xs class =
(tyco, class) :: maps (maps fst) xs;
fun variable (TVar (_, sort)) = map (pair []) sort
@@ -182,15 +176,20 @@
| mk_inst (Type (tyco1, tys1)) (Type (tyco2, tys2)) =
if tyco1 <> tyco2 then error "bad instance"
else fold2 mk_inst tys1 tys2;
+ val pp = Sign.pp thy;
+ val algebra = Sign.classes_of thy;
+ fun classrel (x, _) _ = x;
+ fun of_sort_deriv (ty, sort) =
+ Sorts.of_sort_derivation pp algebra
+ { classrel = classrel, constructor = constructor, variable = variable }
+ (ty, sort)
in
- flat (maps (Sorts.of_sort_derivation pp algebra
- { classrel = classrel, constructor = constructor, variable = variable })
- (fold2 mk_inst tys tys_decl []))
+ flat (maps of_sort_deriv (fold2 mk_inst tys tys_decl []))
end;
fun all_classops thy tyco class =
- maps (AxClass.params_of thy)
- (Graph.all_succs ((#classes o Sorts.rep_algebra o Sign.classes_of) thy) [class])
+ AxClass.params_of thy class
+(* |> tap (fn _ => writeln ("INST " ^ tyco ^ " - " ^ class)) *)
|> AList.make (fn c => CodegenConsts.disc_typ_of_classop thy (c, [Type (tyco, [])]))
(*typ_of_classop is very liberal in its type arguments*)
|> map (CodegenConsts.norm_of_typ thy);
@@ -206,10 +205,10 @@
(Graph.all_succs thy_classes classes))) tab [])
end;
-fun insts_of_thms thy funcgr c_thms =
+fun insts_of_thms thy funcgr (c, thms) =
let
val insts = add_things_of thy (fn (_, c_ty) => fold (insert (op =))
- (insts_of thy funcgr c_ty)) c_thms [];
+ (insts_of thy funcgr c_ty)) (SOME c, thms) [];
in instdefs_of thy insts end;
fun ensure_const thy funcgr c auxgr =
@@ -222,7 +221,7 @@
|> Constgraph.new_node (c, [])
|> pair (SOME c)
else let
- val thms = normalize thy (CodegenData.these_funcs thy c);
+ val thms = norm_vars thy (CodegenData.these_funcs thy c);
val rhs = rhs_of thy (c, thms);
in
auxgr
@@ -240,7 +239,7 @@
fun typscheme_of (c, ty) =
try (Constgraph.get_node funcgr) (CodegenConsts.norm_of_typ thy (c, ty))
|> Option.map fst;
- fun incr_indices (c, thms) maxidx =
+ fun incr_indices (c:'a, thms) maxidx =
let
val thms' = map (Thm.incr_indexes maxidx) thms;
val maxidx' = Int.max
@@ -263,12 +262,25 @@
fun apply_unifier unif (c, []) = (c, [])
| apply_unifier unif (c, thms as thm :: _) =
let
- val ty = CodegenData.typ_func thy thm;
- val ty' = Envir.norm_type unif ty;
- val env = Type.typ_match (Sign.tsig_of thy) (ty, ty') Vartab.empty;
- val inst = Thm.instantiate (Vartab.fold (fn (x_i, (sort, ty)) =>
- cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [], []);
- in (c, map (Drule.zero_var_indexes o inst) thms) end;
+ val tvars = Term.add_tvars (Thm.prop_of thm) [];
+ fun mk_inst (v_i_sort as (v, _)) =
+ let
+ val ty = TVar v_i_sort;
+ in
+ pairself (Thm.ctyp_of thy) (ty,
+ TVar (v, (snd o dest_TVar o Envir.norm_type unif) ty))
+ end;
+ val instmap = map mk_inst tvars;
+ val (thms' as thm' :: _) = map (Drule.zero_var_indexes o Thm.instantiate (instmap, [])) thms
+ val _ = if fst c <> "" andalso not (Sign.typ_equiv thy (Type.strip_sorts (CodegenData.typ_func thy thm), Type.strip_sorts (CodegenData.typ_func thy thm')))
+ then error ("illegal function type instantiation:\n" ^ Sign.string_of_typ thy (CodegenData.typ_func thy thm)
+ ^ "\nto" ^ Sign.string_of_typ thy (CodegenData.typ_func thy thm'))
+ else ();
+ in (c, thms') end;
+ fun rhs_of' thy (("", []), thms as [_]) =
+ add_things_of thy (cons o snd) (NONE, thms) []
+ | rhs_of' thy (c, thms) =
+ add_things_of thy (cons o snd) (SOME c, thms) [];
val (eqss', maxidx) =
fold_map incr_indices eqss 0;
val (unif, _) =
@@ -301,16 +313,67 @@
(map (AList.make (Constgraph.get_node auxgr))
(rev (Constgraph.strong_conn auxgr))) funcgr);
+fun drop_classes thy tfrees thm =
+ let
+(* val _ = writeln ("DROP1 " ^ setmp show_types true string_of_thm thm); *)
+ val (_, thm') = Thm.varifyT' [] thm;
+ val tvars = Term.add_tvars (Thm.prop_of thm') [];
+(* val _ = writeln ("DROP2 " ^ setmp show_types true string_of_thm thm'); *)
+ val unconstr = map (Thm.ctyp_of thy o TVar) tvars;
+ val instmap = map2 (fn (v_i, _) => fn (v, sort) => pairself (Thm.ctyp_of thy)
+ (TVar (v_i, []), TFree (v, sort))) tvars tfrees;
+ in
+ thm'
+ |> fold Thm.unconstrainT unconstr
+ |> Thm.instantiate (instmap, [])
+ |> Tactic.rule_by_tactic ((REPEAT o CHANGED o ALLGOALS o Tactic.resolve_tac) (AxClass.class_intros thy))
+(* |> tap (fn thm => writeln ("DROP3 " ^ setmp show_types true string_of_thm thm)) *)
+ end;
+
in
val ensure_consts = ensure_consts;
-fun mk_funcgr thy consts cs =
- Funcgr.change thy (
- ensure_consts thy consts
- #> (fn funcgr => ensure_consts thy
- (instdefs_of thy (fold (fold (insert (op =)) o insts_of thy funcgr) cs [])) funcgr)
- );
+fun make thy consts =
+ Funcgr.change thy (ensure_consts thy consts);
+
+fun make_term thy ct =
+ let
+ val _ = Sign.no_vars (Sign.pp thy) (Thm.term_of ct);
+ val _ = Term.fold_types (Type.no_tvars #> K I) (Thm.term_of ct) ();
+ val thm1 = CodegenData.preprocess_cterm thy ct;
+(* val _ = writeln ("THM1 " ^ setmp show_types true string_of_thm thm1); *)
+ val ct' = Drule.dest_equals_rhs (Thm.cprop_of thm1);
+ val consts = CodegenConsts.consts_of thy (Thm.term_of ct');
+ val funcgr = make thy consts;
+ val (_, thm2) = Thm.varifyT' [] thm1;
+(* val _ = writeln ("THM2 " ^ setmp show_types true string_of_thm thm2); *)
+ val thm3 = Thm.reflexive (Drule.dest_equals_rhs (Thm.cprop_of thm2));
+(* val _ = writeln ("THM3 " ^ setmp show_types true string_of_thm thm3); *)
+ val [(_, [thm4])] = specialize_typs thy funcgr [(("", []), [thm3])];
+(* val _ = writeln ("THM4 " ^ setmp show_types true string_of_thm thm4); *)
+ val tfrees = Term.add_tfrees (Thm.prop_of thm1) [];
+(* val _ = writeln "TFREES"; *)
+(* val _ = (writeln o cat_lines o map (fn (v, sort) => v ^ "::" ^ Sign.string_of_sort thy sort)) tfrees; *)
+ fun inst thm =
+ let
+ val tvars = Term.add_tvars (Thm.prop_of thm) [];
+(* val _ = writeln "TVARS"; *)
+(* val _ = (writeln o cat_lines o map (fn ((v, i), sort) => v ^ "_" ^ string_of_int i ^ "::" ^ Sign.string_of_sort thy sort)) tvars; *)
+ val instmap = map2 (fn (v_i, sort) => fn (v, _) => pairself (Thm.ctyp_of thy)
+ (TVar (v_i, sort), TFree (v, sort))) tvars tfrees;
+ in Thm.instantiate (instmap, []) thm end;
+ val thm5 = inst thm2;
+ val thm6 = inst thm4;
+(* val _ = writeln ("THM5 " ^ setmp show_types true string_of_thm thm5); *)
+(* val _ = writeln ("THM6 " ^ setmp show_types true string_of_thm thm6); *)
+ val ct'' = Drule.dest_equals_rhs (Thm.cprop_of thm6);
+ val cs = fold_aterms (fn Const c => cons c | _ => I) (Thm.term_of ct'') [];
+ val drop = drop_classes thy tfrees;
+(* val _ = writeln "ADD INST"; *)
+ val funcgr' = ensure_consts thy
+ (instdefs_of thy (fold (fold (insert (op =)) o insts_of thy funcgr) cs [])) funcgr
+ in ((ct'', (thm5, drop)), Funcgr.change thy (K funcgr')) end;
end; (*local*)
@@ -327,7 +390,7 @@
|> Pretty.writeln;
fun print_codethms thy consts =
- mk_funcgr thy consts [] |> print_funcgr thy;
+ make thy consts |> print_funcgr thy;
fun print_codethms_e thy cs =
print_codethms thy (map (CodegenConsts.read_const thy) cs);