# HG changeset patch # User haftmann # Date 1162283353 -3600 # Node ID e333c844b05704a143cb084f878b9f201b537a40 # Parent 5c7edac0c645282d81179b91164619e74419d0a4 refined algorithm diff -r 5c7edac0c645 -r e333c844b057 src/Pure/Tools/codegen_funcgr.ML --- a/src/Pure/Tools/codegen_funcgr.ML Tue Oct 31 09:29:12 2006 +0100 +++ b/src/Pure/Tools/codegen_funcgr.ML Tue Oct 31 09:29:13 2006 +0100 @@ -9,11 +9,13 @@ signature CODEGEN_FUNCGR = sig type T; - val mk_funcgr: theory -> CodegenConsts.const list -> (string * typ) list -> T - val all_deps_of: T -> CodegenConsts.const list -> CodegenConsts.const list list - val get_funcs: T -> CodegenConsts.const -> thm list - val get_func_typs: T -> (CodegenConsts.const * typ) list - val normalize: theory -> thm list -> thm list + val make: theory -> CodegenConsts.const list -> T + val make_term: theory -> cterm -> (cterm * (thm * (thm -> thm))) * T + val funcs: T -> CodegenConsts.const -> thm list + val typ: T -> CodegenConsts.const -> typ + val deps: T -> CodegenConsts.const list -> CodegenConsts.const list list + val all: T -> CodegenConsts.const list + val norm_vars: theory -> thm list -> thm list val print_codethms: theory -> CodegenConsts.const list -> unit structure Constgraph : GRAPH end; @@ -60,17 +62,13 @@ in fold (fn ct => fn thm => Thm.combination thm (Thm.reflexive ct)) vs_ct thm end; - fun beta_norm thm = - let - val rhs = (snd o Logic.dest_equals o Drule.plain_prop_of) thm; - val thm' = Thm.beta_conversion true (cterm_of thy rhs); - in Thm.transitive thm thm' end; in thm |> eta_expand - |> beta_norm + |> Drule.fconv_rule Drule.beta_eta_conversion end; + fun canonical_tvars thy thm = let fun mk_inst (v_i as (v, i), (v', sort)) (s as (maxidx, set, acc)) = @@ -105,7 +103,7 @@ val (_, _, inst) = fold mk_inst (vars_of thm) (maxidx + 1, [], []); in Thm.instantiate ([], inst) thm end; -fun normalize thy thms = +fun norm_vars thy thms = let fun burrow_thms f [] = [] | burrow_thms f thms = @@ -129,13 +127,13 @@ (** retrieval **) -fun get_funcs funcgr (c_tys as (c, _)) = - (these o Option.map snd o try (Constgraph.get_node funcgr)) c_tys; +fun funcs funcgr = + these o Option.map snd o try (Constgraph.get_node funcgr); -fun get_func_typs funcgr = - AList.make (fst o Constgraph.get_node funcgr) (Constgraph.keys funcgr); +fun typ funcgr = + fst o Constgraph.get_node funcgr; -fun all_deps_of funcgr cs = +fun deps funcgr cs = let val conn = Constgraph.strong_conn funcgr; val order = rev conn; @@ -144,25 +142,24 @@ |> filter_out null end; +fun all funcgr = Constgraph.keys funcgr; + local fun add_things_of thy f (c, thms) = (fold o fold_aterms) (fn Const c_ty => let val c' = CodegenConsts.norm_of_typ thy c_ty - in if CodegenConsts.eq_const (c, c') then I + in if is_some c andalso CodegenConsts.eq_const (the c, c') then I else f (c', c_ty) end | _ => I) (maps (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Drule.plain_prop_of) thms) fun rhs_of thy (c, thms) = Consttab.empty - |> add_things_of thy (Consttab.update o rpair () o fst) (c, thms) + |> add_things_of thy (Consttab.update o rpair () o fst) (SOME c, thms) |> Consttab.keys; -fun rhs_of' thy (c, thms) = - add_things_of thy (cons o snd) (c, thms) []; - fun insts_of thy funcgr (c, ty) = let val tys = Sign.const_typargs thy (c, ty); @@ -170,9 +167,6 @@ val ty_decl = CodegenConsts.disc_typ_of_const thy (fst o Constgraph.get_node funcgr o CodegenConsts.norm thy) (c, tys); val tys_decl = Sign.const_typargs thy (c, ty_decl); - val pp = Sign.pp thy; - val algebra = Sign.classes_of thy; - fun classrel (x, _) _ = x; fun constructor tyco xs class = (tyco, class) :: maps (maps fst) xs; fun variable (TVar (_, sort)) = map (pair []) sort @@ -182,15 +176,20 @@ | mk_inst (Type (tyco1, tys1)) (Type (tyco2, tys2)) = if tyco1 <> tyco2 then error "bad instance" else fold2 mk_inst tys1 tys2; + val pp = Sign.pp thy; + val algebra = Sign.classes_of thy; + fun classrel (x, _) _ = x; + fun of_sort_deriv (ty, sort) = + Sorts.of_sort_derivation pp algebra + { classrel = classrel, constructor = constructor, variable = variable } + (ty, sort) in - flat (maps (Sorts.of_sort_derivation pp algebra - { classrel = classrel, constructor = constructor, variable = variable }) - (fold2 mk_inst tys tys_decl [])) + flat (maps of_sort_deriv (fold2 mk_inst tys tys_decl [])) end; fun all_classops thy tyco class = - maps (AxClass.params_of thy) - (Graph.all_succs ((#classes o Sorts.rep_algebra o Sign.classes_of) thy) [class]) + AxClass.params_of thy class +(* |> tap (fn _ => writeln ("INST " ^ tyco ^ " - " ^ class)) *) |> AList.make (fn c => CodegenConsts.disc_typ_of_classop thy (c, [Type (tyco, [])])) (*typ_of_classop is very liberal in its type arguments*) |> map (CodegenConsts.norm_of_typ thy); @@ -206,10 +205,10 @@ (Graph.all_succs thy_classes classes))) tab []) end; -fun insts_of_thms thy funcgr c_thms = +fun insts_of_thms thy funcgr (c, thms) = let val insts = add_things_of thy (fn (_, c_ty) => fold (insert (op =)) - (insts_of thy funcgr c_ty)) c_thms []; + (insts_of thy funcgr c_ty)) (SOME c, thms) []; in instdefs_of thy insts end; fun ensure_const thy funcgr c auxgr = @@ -222,7 +221,7 @@ |> Constgraph.new_node (c, []) |> pair (SOME c) else let - val thms = normalize thy (CodegenData.these_funcs thy c); + val thms = norm_vars thy (CodegenData.these_funcs thy c); val rhs = rhs_of thy (c, thms); in auxgr @@ -240,7 +239,7 @@ fun typscheme_of (c, ty) = try (Constgraph.get_node funcgr) (CodegenConsts.norm_of_typ thy (c, ty)) |> Option.map fst; - fun incr_indices (c, thms) maxidx = + fun incr_indices (c:'a, thms) maxidx = let val thms' = map (Thm.incr_indexes maxidx) thms; val maxidx' = Int.max @@ -263,12 +262,25 @@ fun apply_unifier unif (c, []) = (c, []) | apply_unifier unif (c, thms as thm :: _) = let - val ty = CodegenData.typ_func thy thm; - val ty' = Envir.norm_type unif ty; - val env = Type.typ_match (Sign.tsig_of thy) (ty, ty') Vartab.empty; - val inst = Thm.instantiate (Vartab.fold (fn (x_i, (sort, ty)) => - cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [], []); - in (c, map (Drule.zero_var_indexes o inst) thms) end; + val tvars = Term.add_tvars (Thm.prop_of thm) []; + fun mk_inst (v_i_sort as (v, _)) = + let + val ty = TVar v_i_sort; + in + pairself (Thm.ctyp_of thy) (ty, + TVar (v, (snd o dest_TVar o Envir.norm_type unif) ty)) + end; + val instmap = map mk_inst tvars; + val (thms' as thm' :: _) = map (Drule.zero_var_indexes o Thm.instantiate (instmap, [])) thms + val _ = if fst c <> "" andalso not (Sign.typ_equiv thy (Type.strip_sorts (CodegenData.typ_func thy thm), Type.strip_sorts (CodegenData.typ_func thy thm'))) + then error ("illegal function type instantiation:\n" ^ Sign.string_of_typ thy (CodegenData.typ_func thy thm) + ^ "\nto" ^ Sign.string_of_typ thy (CodegenData.typ_func thy thm')) + else (); + in (c, thms') end; + fun rhs_of' thy (("", []), thms as [_]) = + add_things_of thy (cons o snd) (NONE, thms) [] + | rhs_of' thy (c, thms) = + add_things_of thy (cons o snd) (SOME c, thms) []; val (eqss', maxidx) = fold_map incr_indices eqss 0; val (unif, _) = @@ -301,16 +313,67 @@ (map (AList.make (Constgraph.get_node auxgr)) (rev (Constgraph.strong_conn auxgr))) funcgr); +fun drop_classes thy tfrees thm = + let +(* val _ = writeln ("DROP1 " ^ setmp show_types true string_of_thm thm); *) + val (_, thm') = Thm.varifyT' [] thm; + val tvars = Term.add_tvars (Thm.prop_of thm') []; +(* val _ = writeln ("DROP2 " ^ setmp show_types true string_of_thm thm'); *) + val unconstr = map (Thm.ctyp_of thy o TVar) tvars; + val instmap = map2 (fn (v_i, _) => fn (v, sort) => pairself (Thm.ctyp_of thy) + (TVar (v_i, []), TFree (v, sort))) tvars tfrees; + in + thm' + |> fold Thm.unconstrainT unconstr + |> Thm.instantiate (instmap, []) + |> Tactic.rule_by_tactic ((REPEAT o CHANGED o ALLGOALS o Tactic.resolve_tac) (AxClass.class_intros thy)) +(* |> tap (fn thm => writeln ("DROP3 " ^ setmp show_types true string_of_thm thm)) *) + end; + in val ensure_consts = ensure_consts; -fun mk_funcgr thy consts cs = - Funcgr.change thy ( - ensure_consts thy consts - #> (fn funcgr => ensure_consts thy - (instdefs_of thy (fold (fold (insert (op =)) o insts_of thy funcgr) cs [])) funcgr) - ); +fun make thy consts = + Funcgr.change thy (ensure_consts thy consts); + +fun make_term thy ct = + let + val _ = Sign.no_vars (Sign.pp thy) (Thm.term_of ct); + val _ = Term.fold_types (Type.no_tvars #> K I) (Thm.term_of ct) (); + val thm1 = CodegenData.preprocess_cterm thy ct; +(* val _ = writeln ("THM1 " ^ setmp show_types true string_of_thm thm1); *) + val ct' = Drule.dest_equals_rhs (Thm.cprop_of thm1); + val consts = CodegenConsts.consts_of thy (Thm.term_of ct'); + val funcgr = make thy consts; + val (_, thm2) = Thm.varifyT' [] thm1; +(* val _ = writeln ("THM2 " ^ setmp show_types true string_of_thm thm2); *) + val thm3 = Thm.reflexive (Drule.dest_equals_rhs (Thm.cprop_of thm2)); +(* val _ = writeln ("THM3 " ^ setmp show_types true string_of_thm thm3); *) + val [(_, [thm4])] = specialize_typs thy funcgr [(("", []), [thm3])]; +(* val _ = writeln ("THM4 " ^ setmp show_types true string_of_thm thm4); *) + val tfrees = Term.add_tfrees (Thm.prop_of thm1) []; +(* val _ = writeln "TFREES"; *) +(* val _ = (writeln o cat_lines o map (fn (v, sort) => v ^ "::" ^ Sign.string_of_sort thy sort)) tfrees; *) + fun inst thm = + let + val tvars = Term.add_tvars (Thm.prop_of thm) []; +(* val _ = writeln "TVARS"; *) +(* val _ = (writeln o cat_lines o map (fn ((v, i), sort) => v ^ "_" ^ string_of_int i ^ "::" ^ Sign.string_of_sort thy sort)) tvars; *) + val instmap = map2 (fn (v_i, sort) => fn (v, _) => pairself (Thm.ctyp_of thy) + (TVar (v_i, sort), TFree (v, sort))) tvars tfrees; + in Thm.instantiate (instmap, []) thm end; + val thm5 = inst thm2; + val thm6 = inst thm4; +(* val _ = writeln ("THM5 " ^ setmp show_types true string_of_thm thm5); *) +(* val _ = writeln ("THM6 " ^ setmp show_types true string_of_thm thm6); *) + val ct'' = Drule.dest_equals_rhs (Thm.cprop_of thm6); + val cs = fold_aterms (fn Const c => cons c | _ => I) (Thm.term_of ct'') []; + val drop = drop_classes thy tfrees; +(* val _ = writeln "ADD INST"; *) + val funcgr' = ensure_consts thy + (instdefs_of thy (fold (fold (insert (op =)) o insts_of thy funcgr) cs [])) funcgr + in ((ct'', (thm5, drop)), Funcgr.change thy (K funcgr')) end; end; (*local*) @@ -327,7 +390,7 @@ |> Pretty.writeln; fun print_codethms thy consts = - mk_funcgr thy consts [] |> print_funcgr thy; + make thy consts |> print_funcgr thy; fun print_codethms_e thy cs = print_codethms thy (map (CodegenConsts.read_const thy) cs);