--- a/src/Tools/Code/code_preproc.ML Tue Jul 07 17:21:26 2009 +0200
+++ b/src/Tools/Code/code_preproc.ML Tue Jul 07 17:21:27 2009 +0200
@@ -102,6 +102,15 @@
fun rhs_conv conv thm = Thm.transitive thm ((conv o Thm.rhs_of) thm);
+fun eqn_conv conv =
+ let
+ fun lhs_conv ct = if can Thm.dest_comb ct
+ then Conv.combination_conv lhs_conv conv ct
+ else Conv.all_conv ct;
+ in Conv.combination_conv (Conv.arg_conv lhs_conv) conv end;
+
+val rewrite_eqn = Conv.fconv_rule o eqn_conv o Simplifier.rewrite;
+
fun term_of_conv thy f =
Thm.cterm_of thy
#> f
@@ -117,7 +126,7 @@
in
eqns
|> apply_functrans thy c functrans
- |> (map o apfst) (Code.rewrite_eqn pre)
+ |> (map o apfst) (rewrite_eqn pre)
|> (map o apfst) (AxClass.unoverload thy)
|> map (Code.assert_eqn thy)
|> burrow_fst (Code.norm_args thy)
@@ -213,9 +222,19 @@
(fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, Logic.unvarifyT ty)) | _ => I)
(map (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of o fst) eqns);
+fun default_typscheme_of thy c =
+ let
+ val ty = (snd o dest_Const o TermSubst.zero_var_indexes o curry Const c
+ o Type.strip_sorts o Sign.the_const_type thy) c;
+ in case AxClass.class_of_param thy c
+ of SOME class => ([(Name.aT, [class])], ty)
+ | NONE => Code.typscheme thy (c, ty)
+ end;
+
fun tyscm_rhss_of thy c eqns =
let
- val tyscm = case eqns of [] => Code.default_typscheme thy c
+ val tyscm = case eqns
+ of [] => default_typscheme_of thy c
| ((thm, _) :: _) => Code.typscheme_eqn thy thm;
val rhss = consts_of thy eqns;
in (tyscm, rhss) end;
@@ -381,6 +400,17 @@
handle Sorts.CLASS_ERROR _ => [] (*permissive!*))
end;
+fun inst_thm thy tvars' thm =
+ let
+ val tvars = (Term.add_tvars o Thm.prop_of) thm [];
+ val inter_sort = Sorts.inter_sort (Sign.classes_of thy);
+ fun mk_inst (tvar as (v, sort)) = case Vartab.lookup tvars' v
+ of SOME sort' => SOME (pairself (Thm.ctyp_of thy o TVar)
+ (tvar, (v, inter_sort (sort, sort'))))
+ | NONE => NONE;
+ val insts = map_filter mk_inst tvars;
+ in Thm.instantiate (insts, []) thm end;
+
fun add_arity thy vardeps (class, tyco) =
AList.default (op =)
((class, tyco), map (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k))
@@ -394,7 +424,7 @@
val inst_tab = Vartab.empty |> fold (fn (v, sort) =>
Vartab.update ((v, 0), sort)) lhs;
val eqns = proto_eqns
- |> (map o apfst) (Code.inst_thm thy inst_tab);
+ |> (map o apfst) (inst_thm thy inst_tab);
val (tyscm, rhss') = tyscm_rhss_of thy c eqns;
val eqngr' = Graph.new_node (c, (tyscm, eqns)) eqngr;
in (map (pair c) rhss' @ rhss, eqngr') end;