refined algorithm
authorhaftmann
Fri, 26 Jan 2007 13:59:06 +0100
changeset 22198 226d29db8e0a
parent 22197 461130ccfef4
child 22199 b617ddd200eb
refined algorithm
src/Pure/Tools/codegen_funcgr.ML
--- a/src/Pure/Tools/codegen_funcgr.ML	Fri Jan 26 13:59:04 2007 +0100
+++ b/src/Pure/Tools/codegen_funcgr.ML	Fri Jan 26 13:59:06 2007 +0100
@@ -172,9 +172,7 @@
       | variable (TFree (_, sort)) = map (pair []) sort;
     fun mk_inst ty (TVar (_, sort)) = cons (ty, sort)
       | mk_inst ty (TFree (_, sort)) = cons (ty, sort)
-      | mk_inst (Type (tyco1, tys1)) (Type (tyco2, tys2)) =
-          if tyco1 <> tyco2 then error "bad instance"
-          else fold2 mk_inst tys1 tys2;
+      | mk_inst (Type (_, tys1)) (Type (_, tys2)) = fold2 mk_inst tys1 tys2;
     fun of_sort_deriv (ty, sort) =
       Sorts.of_sort_derivation (Sign.pp thy) algebra
         { classrel = classrel, constructor = constructor, variable = variable }
@@ -192,68 +190,54 @@
 
 exception INVALID of CodegenConsts.const list * string;
 
-fun specialize_typs' thy funcgr funcss =
-  let
-    fun max xs = fold (curry Int.max) xs 0;
-    fun incr_indices (c, thms) maxidx =
+fun resort_thms algebra tap_typ [] = []
+  | resort_thms algebra tap_typ (thms as thm :: _) =
       let
-        val thms' = map (Thm.incr_indexes (maxidx + 1)) thms;
-        val maxidx' = max (maxidx :: map Thm.maxidx_of thms');
-      in ((c, thms'), maxidx') end;
-    val (funcss', maxidx) =
-      fold_map incr_indices funcss 0;
-    fun typ_of_const (c, ty) = case try (CodegenConsts.norm_of_typ thy) (c, ty)
-     of SOME const => Option.map fst (try (Constgraph.get_node funcgr) const)
-      | NONE => NONE;
-    fun unify_const (c, ty) (env, maxidx) =
-      case typ_of_const (c, ty)
-       of SOME ty_decl => let
-            val ty_decl' = Logic.incr_tvar (maxidx + 1) ty_decl;
-            val maxidx' = max [maxidx, Term.maxidx_of_typ ty_decl'];
-          in Type.unify (Sign.tsig_of thy) (ty_decl', ty) (env, maxidx')
-          handle TUNIFY => raise INVALID ([], setmp show_sorts true (setmp show_types true (fn f => f ())) (fn _ => ("Failed to instantiate\n"
-            ^ (Sign.string_of_typ thy o Envir.norm_type env) ty_decl' ^ "\nto\n"
-            ^ (Sign.string_of_typ thy o Envir.norm_type env) ty
-            ^ ",\nfor constant " ^ quote c
-            ^ "\nin function theorems\n"
-            ^ (cat_lines o maps (map (Sign.string_of_term thy o map_types (Envir.norm_type env) o Thm.prop_of) o snd)) funcss')))
-          end
-        | NONE => (env, maxidx);
-    fun apply_unifier unif (c, []) = (c, [])
-      | apply_unifier unif (c, thms as thm :: _) =
+        val thy = Thm.theory_of_thm thm;
+        val cs = fold_consts (insert (op =)) thms [];
+        fun match_const c (ty, ty_decl) =
           let
-            val tvars = Term.add_tvars (Thm.prop_of thm) [];
-            fun mk_inst (v_i_sort as (v, _)) =
-              let
-                val ty = TVar v_i_sort;
-              in
-                pairself (Thm.ctyp_of thy) (ty,
-                  TVar (v, (snd o dest_TVar o Envir.norm_type unif) ty))
-              end;
-            val instmap = map mk_inst tvars;
-            val (thms' as thm' :: _) = map (Drule.zero_var_indexes o Thm.instantiate (instmap, [])) thms;
-            val _ = case c of NONE => ()
-              | SOME c => (case pairself CodegenFunc.typ_func (thm, thm')
-               of (ty, ty') => if (is_none o AxClass.class_of_param thy o fst) c
-                  andalso Sign.typ_equiv thy (Type.strip_sorts ty, Type.strip_sorts ty')
-                  orelse Sign.typ_equiv thy (ty, ty')
-                  then ()
-                  else raise INVALID ([], "illegal function type instantiation:\n" ^ Sign.string_of_typ thy (CodegenFunc.typ_func thm)
-                    ^ "\nto " ^ Sign.string_of_typ thy (CodegenFunc.typ_func thm')
-                    ^ ",\nfor constant " ^ CodegenConsts.string_of_const thy c
-                    ^ "\nin function theorems\n"
-                    ^ (cat_lines o map string_of_thm) thms))
-          in (c, thms') end;
-    val (unif, _) =
-      fold (fn (_, thms) => fold unify_const (fold_consts (insert (op =)) thms []))
-        funcss' (Vartab.empty, maxidx);
-    val funcss'' = map (apply_unifier unif) funcss';
-  in funcss'' end;
+            val tys = CodegenConsts.typargs thy (c, ty);
+            val sorts = map (snd o dest_TVar) (CodegenConsts.typargs thy (c, ty_decl));
+          in fold2 (curry (CodegenConsts.typ_sort_inst algebra)) tys sorts end;
+        fun match (c_ty as (c, ty)) =
+          case tap_typ c_ty
+           of SOME ty_decl => match_const c (ty, ty_decl)
+            | NONE => I;
+        val tvars = fold match cs Vartab.empty;
+      in map (CodegenFunc.inst_thm tvars) thms end;
 
-fun specialize_typs thy funcgr =
-  (map o apfst) SOME
-  #> specialize_typs' thy funcgr
-  #> (map o apfst) the;
+fun resort_funcss thy algebra funcgr =
+  let
+    val typ_funcgr = try (fst o Constgraph.get_node funcgr o CodegenConsts.norm_of_typ thy);
+    fun resort_dep (const, thms) = (const, resort_thms algebra typ_funcgr thms)
+      handle Sorts.CLASS_ERROR e => raise INVALID ([const], Sorts.msg_class_error (Sign.pp thy) e
+                    ^ ",\nfor constant " ^ CodegenConsts.string_of_const thy const
+                    ^ "\nin defining equations\n"
+                    ^ (cat_lines o map string_of_thm) thms)
+    fun resort_rec tap_typ (const, []) = (true, (const, []))
+      | resort_rec tap_typ (const, thms as thm :: _) =
+          let
+            val ty = CodegenFunc.typ_func thm;
+            val thms' as thm' :: _ = resort_thms algebra tap_typ thms
+            val ty' = CodegenFunc.typ_func thm';
+          in (Sign.typ_equiv thy (ty, ty'), (const, thms')) end;
+    fun resort_recs funcss =
+      let
+        fun tap_typ c_ty = case try (CodegenConsts.norm_of_typ thy) c_ty
+         of SOME const => AList.lookup (CodegenConsts.eq_const) funcss const
+              |> these
+              |> try hd
+              |> Option.map CodegenFunc.typ_func
+          | NONE => NONE;
+        val (unchangeds, funcss') = split_list (map (resort_rec tap_typ) funcss);
+        val unchanged = fold (fn x => fn y => x andalso y) unchangeds true;
+      in (unchanged, funcss') end;
+    fun resort_rec_until funcss =
+      let
+        val (unchanged, funcss') = resort_recs funcss;
+      in if unchanged then funcss' else resort_rec_until funcss' end;
+  in map resort_dep #> resort_rec_until end;
 
 fun classop_const thy algebra class classop tyco =
   let
@@ -320,32 +304,44 @@
 
 fun merge_funcss thy algebra raw_funcss funcgr =
   let
-    val funcss = specialize_typs thy funcgr raw_funcss;
+    val funcss = resort_funcss thy algebra funcgr raw_funcss;
+    fun classop_typ (c, [typarg]) class =
+      let
+        val ty = Sign.the_const_type thy c;
+        val inst = case typarg
+         of Type (tyco, _) => classop_const thy algebra class c tyco
+              |> snd
+              |> the_single
+              |> Logic.varifyT
+          | _ => TVar (("'a", 0), [class]);
+      in Term.map_type_tvar (K inst) ty end;
     fun default_typ (const as (c, tys)) = case CodegenData.tap_typ thy const
      of SOME ty => ty
-      | NONE => let
-          val ty = Sign.the_const_type thy c
-        in case AxClass.class_of_param thy c
-         of SOME class => let
-               val inst = case tys
-                of [Type (tyco, _)] => classop_const thy algebra class c tyco
-                      |> snd
-                      |> the_single
-                      |> Logic.varifyT
-                 | _ => TVar (("'a", 0), [class]);
-              in Term.map_type_tvar (K inst) ty end
-          | NONE => ty
-        end;
+      | NONE => (case AxClass.class_of_param thy c
+         of SOME class => classop_typ const class
+          | NONE => Sign.the_const_type thy c)
+    fun typ_func (const as (c, tys)) thm =
+      let
+        val ty = CodegenFunc.typ_func thm;
+      in case AxClass.class_of_param thy c
+       of SOME class => (case tys
+           of [Type _] => let val ty_decl = classop_typ const class
+              in if Sign.typ_equiv thy (ty, ty_decl) then ty
+              else raise raise INVALID ([const], "Illegal instantation for class operation "
+                    ^ CodegenConsts.string_of_const thy const
+                    ^ ":\n" ^ CodegenConsts.string_of_typ thy ty_decl
+                    ^ "\nto " ^ CodegenConsts.string_of_typ thy ty)
+              end
+            | _ => ty)
+        | NONE => ty
+      end;
     fun add_funcs (const, thms as thm :: _) =
-          Constgraph.new_node (const, (CodegenFunc.typ_func thm, thms))
+          Constgraph.new_node (const, (typ_func const thm, thms))
       | add_funcs (const, []) =
           Constgraph.new_node (const, (default_typ const, []));
-    val _ = writeln ("constants " ^ (commas o map (CodegenConsts.string_of_const thy o fst)) funcss);
-    val _ = writeln ("funcs " ^ (cat_lines o maps (map string_of_thm o snd)) funcss);
     fun add_deps (funcs as (const, thms)) funcgr =
       let
         val deps = consts_of funcs;
-        val _ = writeln ("constant " ^ CodegenConsts.string_of_const thy const);
         val insts = instances_of_consts thy algebra funcgr
           (fold_consts (insert (op =)) thms []);
       in
@@ -413,9 +409,11 @@
     val ct' = Drule.dest_equals_rhs (Thm.cprop_of thm1);
     val consts = CodegenConsts.consts_of thy (Thm.term_of ct');
     val funcgr = make thy consts;
+    val algebra = CodegenData.coregular_algebra thy;
     val (_, thm2) = Thm.varifyT' [] thm1;
     val thm3 = Thm.reflexive (Drule.dest_equals_rhs (Thm.cprop_of thm2));
-    val [(_, [thm4])] = specialize_typs' thy funcgr [(NONE, [thm3])];
+    val typ_funcgr = try (fst o Constgraph.get_node funcgr o CodegenConsts.norm_of_typ thy);
+    val [thm4] = resort_thms algebra typ_funcgr [thm3];
     val tfrees = Term.add_tfrees (Thm.prop_of thm1) [];
     fun inst thm =
       let
@@ -428,7 +426,6 @@
     val ct'' = Drule.dest_equals_rhs (Thm.cprop_of thm6);
     val cs = fold_aterms (fn Const c => cons c | _ => I) (Thm.term_of ct'') [];
     val drop = drop_classes thy tfrees;
-    val algebra = CodegenData.coregular_algebra thy;
     val instdefs = instances_of_consts thy algebra funcgr cs;
     val funcgr' = ensure_consts thy instdefs funcgr;
   in (f drop ct'' thm5, Funcgr.change thy (K funcgr')) end;