diff -r c3844c4d0c2c -r a9742afd403e src/Tools/Code/code_preproc.ML --- a/src/Tools/Code/code_preproc.ML Tue Jul 07 17:21:26 2009 +0200 +++ b/src/Tools/Code/code_preproc.ML Tue Jul 07 17:21:27 2009 +0200 @@ -102,6 +102,15 @@ fun rhs_conv conv thm = Thm.transitive thm ((conv o Thm.rhs_of) thm); +fun eqn_conv conv = + let + fun lhs_conv ct = if can Thm.dest_comb ct + then Conv.combination_conv lhs_conv conv ct + else Conv.all_conv ct; + in Conv.combination_conv (Conv.arg_conv lhs_conv) conv end; + +val rewrite_eqn = Conv.fconv_rule o eqn_conv o Simplifier.rewrite; + fun term_of_conv thy f = Thm.cterm_of thy #> f @@ -117,7 +126,7 @@ in eqns |> apply_functrans thy c functrans - |> (map o apfst) (Code.rewrite_eqn pre) + |> (map o apfst) (rewrite_eqn pre) |> (map o apfst) (AxClass.unoverload thy) |> map (Code.assert_eqn thy) |> burrow_fst (Code.norm_args thy) @@ -213,9 +222,19 @@ (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, Logic.unvarifyT ty)) | _ => I) (map (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of o fst) eqns); +fun default_typscheme_of thy c = + let + val ty = (snd o dest_Const o TermSubst.zero_var_indexes o curry Const c + o Type.strip_sorts o Sign.the_const_type thy) c; + in case AxClass.class_of_param thy c + of SOME class => ([(Name.aT, [class])], ty) + | NONE => Code.typscheme thy (c, ty) + end; + fun tyscm_rhss_of thy c eqns = let - val tyscm = case eqns of [] => Code.default_typscheme thy c + val tyscm = case eqns + of [] => default_typscheme_of thy c | ((thm, _) :: _) => Code.typscheme_eqn thy thm; val rhss = consts_of thy eqns; in (tyscm, rhss) end; @@ -381,6 +400,17 @@ handle Sorts.CLASS_ERROR _ => [] (*permissive!*)) end; +fun inst_thm thy tvars' thm = + let + val tvars = (Term.add_tvars o Thm.prop_of) thm []; + val inter_sort = Sorts.inter_sort (Sign.classes_of thy); + fun mk_inst (tvar as (v, sort)) = case Vartab.lookup tvars' v + of SOME sort' => SOME (pairself (Thm.ctyp_of thy o TVar) + (tvar, (v, inter_sort (sort, sort')))) + | NONE => NONE; + val insts = map_filter mk_inst tvars; + in Thm.instantiate (insts, []) thm end; + fun add_arity thy vardeps (class, tyco) = AList.default (op =) ((class, tyco), map (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k)) @@ -394,7 +424,7 @@ val inst_tab = Vartab.empty |> fold (fn (v, sort) => Vartab.update ((v, 0), sort)) lhs; val eqns = proto_eqns - |> (map o apfst) (Code.inst_thm thy inst_tab); + |> (map o apfst) (inst_thm thy inst_tab); val (tyscm, rhss') = tyscm_rhss_of thy c eqns; val eqngr' = Graph.new_node (c, (tyscm, eqns)) eqngr; in (map (pair c) rhss' @ rhss, eqngr') end;