clarified code for building function equation system; explicit check of type discipline
authorhaftmann
Wed, 15 Nov 2006 17:05:46 +0100
changeset 21387 5d3d340cb783
parent 21386 a80a35d67cf1
child 21388 9d8344cf029f
clarified code for building function equation system; explicit check of type discipline
src/Pure/Tools/codegen_funcgr.ML
--- a/src/Pure/Tools/codegen_funcgr.ML	Wed Nov 15 17:05:45 2006 +0100
+++ b/src/Pure/Tools/codegen_funcgr.ML	Wed Nov 15 17:05:46 2006 +0100
@@ -187,9 +187,7 @@
 
 fun all_classops thy tyco class =
   AxClass.params_of thy class
-(*   |> tap (fn _ => writeln ("INST " ^ tyco ^ " - " ^ class))  *)
   |> AList.make (fn c => CodegenConsts.disc_typ_of_classop thy (c, [Type (tyco, [])]))
-        (*typ_of_classop is very liberal in its type arguments*)
   |> map (CodegenConsts.norm_of_typ thy);
 
 fun instdefs_of thy insts =
@@ -230,31 +228,35 @@
     |> pair (SOME c)
   end;
 
+exception INVALID of CodegenConsts.const list * string;
+
 fun specialize_typs thy funcgr eqss =
   let
-    fun max k [] = k
-      | max k (l::ls) = max (if k < l then l else k) ls;
+    fun max [] = 0
+      | max [l] = l
+      | max (k::l::ls) = max ((if k < l then l else k) :: ls);
     fun typscheme_of (c, ty) =
       try (Constgraph.get_node funcgr) (CodegenConsts.norm_of_typ thy (c, ty))
       |> Option.map fst;
-    fun incr_indices (c:'a, thms) maxidx =
+    fun incr_indices (c, thms) maxidx =
       let
-        val thms' = map (Thm.incr_indexes maxidx) thms;
-        val maxidx' = Int.max
-          (maxidx, max ~1 (map Thm.maxidx_of thms') + 1);
+        val thms' = map (Thm.incr_indexes (maxidx + 1)) thms;
+        val maxidx' = max (maxidx :: map Thm.maxidx_of thms');
       in ((c, thms'), maxidx') end;
-    val tsig = Sign.tsig_of thy;
-    fun unify_const thms (c, ty) (env, maxidx) =
+    val (eqss', maxidx) =
+      fold_map incr_indices eqss 0;
+    fun unify_const (c, ty) (env, maxidx) =
       case typscheme_of (c, ty)
        of SOME ty_decl => let
-            val ty_decl' = Logic.incr_tvar maxidx ty_decl;
-            val maxidx' = Int.max (Term.maxidx_of_typ ty_decl' + 1, maxidx);
-          in Type.unify tsig (ty_decl', ty) (env, maxidx')
-          handle TUNIFY => setmp show_sorts true error ("Failed to instantiate\n"
+            val ty_decl' = Logic.incr_tvar (maxidx + 1) ty_decl;
+            val maxidx' = max [maxidx, Term.maxidx_of_typ ty_decl'];
+          in Type.unify (Sign.tsig_of thy) (ty_decl', ty) (env, maxidx')
+          handle TUNIFY => raise INVALID ([], setmp show_sorts true (setmp show_types true (fn f => f ())) (fn _ => ("Failed to instantiate\n"
             ^ (Sign.string_of_typ thy o Envir.norm_type env) ty_decl' ^ "\nto\n"
-            ^ (Sign.string_of_typ thy o Envir.norm_type env) ty ^ ",\n"
-            ^ "in function theorems\n"
-            ^ cat_lines (map string_of_thm thms))
+            ^ (Sign.string_of_typ thy o Envir.norm_type env) ty
+            ^ ",\nfor constant " ^ quote c
+            ^ "\nin function theorems\n"
+            ^ (cat_lines o maps (map (Sign.string_of_term thy o map_types (Envir.norm_type env) o Thm.prop_of) o snd)) eqss')))
           end
         | NONE => (env, maxidx);
     fun apply_unifier unif (c, []) = (c, [])
@@ -269,20 +271,25 @@
                   TVar (v, (snd o dest_TVar o Envir.norm_type unif) ty))
               end;
             val instmap = map mk_inst tvars;
-            val (thms' as thm' :: _) = map (Drule.zero_var_indexes o Thm.instantiate (instmap, [])) thms
-            val _ = if fst c <> "" andalso not (Sign.typ_equiv thy (Type.strip_sorts (CodegenData.typ_func thy thm), Type.strip_sorts (CodegenData.typ_func thy thm')))
-              then error ("illegal function type instantiation:\n" ^ Sign.string_of_typ thy (CodegenData.typ_func thy thm)
-                ^ "\nto" ^ Sign.string_of_typ thy (CodegenData.typ_func thy thm'))
-              else ();
+            val (thms' as thm' :: _) = map (Drule.zero_var_indexes o Thm.instantiate (instmap, [])) thms;
+            val (ty, ty') = pairself (CodegenData.typ_func thy) (thm, thm');
+            val _ = if fst c = ""
+              orelse (is_none o AxClass.class_of_param thy o fst) c andalso
+                Sign.typ_equiv thy (Type.strip_sorts ty, Type.strip_sorts ty')
+              orelse Sign.typ_equiv thy (ty, ty')
+              then ()
+              else raise INVALID ([], "illegal function type instantiation:\n" ^ Sign.string_of_typ thy (CodegenData.typ_func thy thm)
+                ^ "\nto " ^ Sign.string_of_typ thy (CodegenData.typ_func thy thm')
+                ^ ",\nfor constant " ^ CodegenConsts.string_of_const thy c
+                ^ "\nin function theorems\n"
+                ^ (cat_lines o map string_of_thm) thms)
           in (c, thms') end;
     fun rhs_of' thy (("", []), thms as [_]) =
           add_things_of thy (cons o snd) (NONE, thms) []
       | rhs_of' thy (c, thms) =
           add_things_of thy (cons o snd) (SOME c, thms) [];
-    val (eqss', maxidx) =
-      fold_map incr_indices eqss 0;
     val (unif, _) =
-      fold (fn (c, thms) => fold (unify_const thms) (rhs_of' thy (c, thms)))
+      fold (fn (c, thms) => fold unify_const (rhs_of' thy (c, thms)))
         eqss' (Vartab.empty, maxidx);
     val eqss'' =
       map (apply_unifier unif) eqss';
@@ -309,14 +316,13 @@
   fold (snd oo ensure_const thy funcgr) cs Constgraph.empty
   |> (fn auxgr => fold (merge_new_eqsyss thy)
        (map (AList.make (Constgraph.get_node auxgr))
-       (rev (Constgraph.strong_conn auxgr))) funcgr);
+       (rev (Constgraph.strong_conn auxgr))) funcgr)
+  handle INVALID (cs', msg) => raise INVALID (cs @ cs', msg);
 
 fun drop_classes thy tfrees thm =
   let
-(*     val _ = writeln ("DROP1 " ^ setmp show_types true string_of_thm thm);  *)
     val (_, thm') = Thm.varifyT' [] thm;
     val tvars = Term.add_tvars (Thm.prop_of thm') [];
-(*     val _ = writeln ("DROP2 " ^ setmp show_types true string_of_thm thm');  *)
     val unconstr = map (Thm.ctyp_of thy o TVar) tvars;
     val instmap = map2 (fn (v_i, _) => fn (v, sort) => pairself (Thm.ctyp_of thy)
       (TVar (v_i, []), TFree (v, sort))) tvars tfrees;
@@ -325,12 +331,13 @@
     |> fold Thm.unconstrainT unconstr
     |> Thm.instantiate (instmap, [])
     |> Tactic.rule_by_tactic ((REPEAT o CHANGED o ALLGOALS o Tactic.resolve_tac) (AxClass.class_intros thy))
-(*     |> tap (fn thm => writeln ("DROP3 " ^ setmp show_types true string_of_thm thm))  *)
   end;
 
 in
 
-val ensure_consts = ensure_consts;
+val ensure_consts = (fn thy => fn cs => fn funcgr => ensure_consts thy
+  cs funcgr handle INVALID (cs', msg) => error (msg ^ ",\nwhile preprocessing equations for constant(s) "
+    ^ commas (map (CodegenConsts.string_of_const thy) cs')));
 
 fun make thy consts =
   Funcgr.change thy (ensure_consts thy consts);
@@ -340,35 +347,24 @@
     val _ = Sign.no_vars (Sign.pp thy) (Thm.term_of ct);
     val _ = Term.fold_types (Type.no_tvars #> K I) (Thm.term_of ct) ();
     val thm1 = CodegenData.preprocess_cterm thy ct;
-(*     val _ = writeln ("THM1 " ^ setmp show_types true string_of_thm thm1);  *)
     val ct' = Drule.dest_equals_rhs (Thm.cprop_of thm1);
     val consts = CodegenConsts.consts_of thy (Thm.term_of ct');
     val funcgr = make thy consts;
     val (_, thm2) = Thm.varifyT' [] thm1;
-(*     val _ = writeln ("THM2 " ^ setmp show_types true string_of_thm thm2);  *)
     val thm3 = Thm.reflexive (Drule.dest_equals_rhs (Thm.cprop_of thm2));
-(*     val _ = writeln ("THM3 " ^ setmp show_types true string_of_thm thm3);  *)
     val [(_, [thm4])] = specialize_typs thy funcgr [(("", []), [thm3])];
-(*     val _ = writeln ("THM4 " ^ setmp show_types true string_of_thm thm4);  *)
     val tfrees = Term.add_tfrees (Thm.prop_of thm1) [];
-(*     val _ = writeln "TFREES";  *)
-(*     val _ = (writeln o cat_lines o map (fn (v, sort) => v ^ "::" ^ Sign.string_of_sort thy sort)) tfrees;  *)
     fun inst thm =
       let
         val tvars = Term.add_tvars (Thm.prop_of thm) [];
-(*         val _ = writeln "TVARS";  *)
-(*         val _ = (writeln o cat_lines o map (fn ((v, i), sort) => v ^ "_" ^ string_of_int i ^ "::" ^ Sign.string_of_sort thy sort)) tvars;  *)
         val instmap = map2 (fn (v_i, sort) => fn (v, _) => pairself (Thm.ctyp_of thy)
           (TVar (v_i, sort), TFree (v, sort))) tvars tfrees;
       in Thm.instantiate (instmap, []) thm end;
     val thm5 = inst thm2;
     val thm6 = inst thm4;
-(*     val _ = writeln ("THM5 " ^ setmp show_types true string_of_thm thm5);  *)
-(*     val _ = writeln ("THM6 " ^ setmp show_types true string_of_thm thm6);  *)
     val ct'' = Drule.dest_equals_rhs (Thm.cprop_of thm6);
     val cs = fold_aterms (fn Const c => cons c | _ => I) (Thm.term_of ct'') [];
     val drop = drop_classes thy tfrees;
-(*     val _ = writeln "ADD INST";  *)
     val funcgr' = ensure_consts thy
       (instdefs_of thy (fold (fold (insert (op =)) o insts_of thy funcgr) cs [])) funcgr
   in (f drop ct'' thm5, Funcgr.change thy (K funcgr')) end;