src/Tools/Code/code_preproc.ML
changeset 31957 a9742afd403e
parent 31775 2b04504fcb69
child 31962 baa8dce5bc45
--- a/src/Tools/Code/code_preproc.ML	Tue Jul 07 17:21:26 2009 +0200
+++ b/src/Tools/Code/code_preproc.ML	Tue Jul 07 17:21:27 2009 +0200
@@ -102,6 +102,15 @@
 
 fun rhs_conv conv thm = Thm.transitive thm ((conv o Thm.rhs_of) thm);
 
+fun eqn_conv conv =
+  let
+    fun lhs_conv ct = if can Thm.dest_comb ct
+      then Conv.combination_conv lhs_conv conv ct
+      else Conv.all_conv ct;
+  in Conv.combination_conv (Conv.arg_conv lhs_conv) conv end;
+
+val rewrite_eqn = Conv.fconv_rule o eqn_conv o Simplifier.rewrite;
+
 fun term_of_conv thy f =
   Thm.cterm_of thy
   #> f
@@ -117,7 +126,7 @@
   in
     eqns
     |> apply_functrans thy c functrans
-    |> (map o apfst) (Code.rewrite_eqn pre)
+    |> (map o apfst) (rewrite_eqn pre)
     |> (map o apfst) (AxClass.unoverload thy)
     |> map (Code.assert_eqn thy)
     |> burrow_fst (Code.norm_args thy)
@@ -213,9 +222,19 @@
   (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, Logic.unvarifyT ty)) | _ => I)
     (map (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of o fst) eqns);
 
+fun default_typscheme_of thy c =
+  let
+    val ty = (snd o dest_Const o TermSubst.zero_var_indexes o curry Const c
+      o Type.strip_sorts o Sign.the_const_type thy) c;
+  in case AxClass.class_of_param thy c
+   of SOME class => ([(Name.aT, [class])], ty)
+    | NONE => Code.typscheme thy (c, ty)
+  end;
+
 fun tyscm_rhss_of thy c eqns =
   let
-    val tyscm = case eqns of [] => Code.default_typscheme thy c
+    val tyscm = case eqns
+     of [] => default_typscheme_of thy c
       | ((thm, _) :: _) => Code.typscheme_eqn thy thm;
     val rhss = consts_of thy eqns;
   in (tyscm, rhss) end;
@@ -381,6 +400,17 @@
        handle Sorts.CLASS_ERROR _ => [] (*permissive!*))
   end;
 
+fun inst_thm thy tvars' thm =
+  let
+    val tvars = (Term.add_tvars o Thm.prop_of) thm [];
+    val inter_sort = Sorts.inter_sort (Sign.classes_of thy);
+    fun mk_inst (tvar as (v, sort)) = case Vartab.lookup tvars' v
+     of SOME sort' => SOME (pairself (Thm.ctyp_of thy o TVar)
+          (tvar, (v, inter_sort (sort, sort'))))
+      | NONE => NONE;
+    val insts = map_filter mk_inst tvars;
+  in Thm.instantiate (insts, []) thm end;
+
 fun add_arity thy vardeps (class, tyco) =
   AList.default (op =)
     ((class, tyco), map (fn k => (snd o Vargraph.get_node vardeps) (Inst (class, tyco), k))
@@ -394,7 +424,7 @@
     val inst_tab = Vartab.empty |> fold (fn (v, sort) =>
       Vartab.update ((v, 0), sort)) lhs;
     val eqns = proto_eqns
-      |> (map o apfst) (Code.inst_thm thy inst_tab);
+      |> (map o apfst) (inst_thm thy inst_tab);
     val (tyscm, rhss') = tyscm_rhss_of thy c eqns;
     val eqngr' = Graph.new_node (c, (tyscm, eqns)) eqngr;
   in (map (pair c) rhss' @ rhss, eqngr') end;