first working version
authorhaftmann
Wed, 18 Feb 2009 08:23:12 +0100
changeset 29963 590e0db3a267
parent 29962 bd4dc7fa742d
child 29964 be317a8a50a8
first working version
src/Tools/code/code_funcgr_new.ML
--- a/src/Tools/code/code_funcgr_new.ML	Wed Feb 18 08:23:11 2009 +0100
+++ b/src/Tools/code/code_funcgr_new.ML	Wed Feb 18 08:23:12 2009 +0100
@@ -1,9 +1,8 @@
 (*  Title:      Tools/code/code_funcgr.ML
-    ID:         $Id$
     Author:     Florian Haftmann, TU Muenchen
 
-Retrieving, well-sorting and structuring defining equations in graph
-with explicit dependencies.
+Retrieving, well-sorting and structuring code equations in graph
+with explicit dependencies -- the waisenhaus algorithm.
 *)
 
 signature CODE_FUNCGR =
@@ -28,12 +27,8 @@
 
 type T = (((string * sort) list * typ) * (thm * bool) list) Graph.T;
 
-fun eqns funcgr =
-  these o Option.map snd o try (Graph.get_node funcgr);
-
-fun typ funcgr =
-  fst o Graph.get_node funcgr;
-
+fun eqns funcgr = these o Option.map snd o try (Graph.get_node funcgr);
+fun typ funcgr = fst o Graph.get_node funcgr;
 fun all funcgr = Graph.keys funcgr;
 
 fun pretty thy funcgr =
@@ -48,23 +43,22 @@
   |> Pretty.chunks;
 
 
-(** generic combinators **)
-
-fun fold_consts f thms =
-  thms
-  |> maps (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of)
-  |> (fold o fold_aterms) (fn Const c => f c | _ => I);
-
-fun consts_of (const, []) = []
-  | consts_of (const, thms as _ :: _) = 
-      let
-        fun the_const (c, _) = if c = const then I else insert (op =) c
-      in fold_consts the_const (map fst thms) [] end;
-
-
 (** graph algorithm **)
 
-(* some nonsense -- FIXME *)
+(* generic *)
+
+fun tracing s = (if ! Toplevel.debug then Output.tracing s else ());
+
+fun complete_proper_sort thy =
+  Sign.complete_sort thy #> filter (can (AxClass.get_info thy));
+
+fun inst_params thy tyco class =
+  map (fn (c, _) => AxClass.param_of_inst thy (c, tyco))
+    ((#params o AxClass.get_info thy) class);
+
+fun consts_of thy eqns = [] |> (fold o fold o fold_aterms)
+  (fn Const (c, ty) => insert (op =) (c, Sign.const_typargs thy (c, Logic.unvarifyT ty)) | _ => I)
+    (map (op :: o swap o apfst (snd o strip_comb) o Logic.dest_equals o Thm.plain_prop_of o fst) eqns);
 
 fun lhs_rhss_of thy c =
   let
@@ -73,33 +67,9 @@
       |> burrow_fst (Code_Unit.norm_varnames thy Code_Name.purify_tvar Code_Name.purify_var);
     val (lhs, _) = case eqns of [] => Code.default_typscheme thy c
       | ((thm, _) :: _) => (snd o Code_Unit.head_eqn thy) thm;
-    val rhss = fold_consts (fn (c, ty) =>
-      insert (op =) (c, Sign.const_typargs thy (c, Logic.unvarifyT ty))) (map fst eqns) [];
+    val rhss = consts_of thy eqns;
   in (lhs, rhss) end;
 
-fun inst_params thy tyco class =
-  map (fn (c, _) => AxClass.param_of_inst thy (c, tyco))
-    ((#params o AxClass.get_info thy) class);
-
-fun complete_proper_sort thy sort =
-  Sign.complete_sort thy sort |> filter (can (AxClass.get_info thy));
-
-fun minimal_proper_sort thy sort =
-  complete_proper_sort thy sort |> Sign.minimize_sort thy;
-
-fun dicts_of thy algebra (T, sort) =
-  let
-    fun class_relation (x, _) _ = x;
-    fun type_constructor tyco xs class =
-      inst_params thy tyco class @ (maps o maps) fst xs;
-    fun type_variable (TFree (_, sort)) = map (pair []) sort;
-  in
-    flat (Sorts.of_sort_derivation (Syntax.pp_global thy) algebra
-      { class_relation = class_relation, type_constructor = type_constructor,
-        type_variable = type_variable } (T, minimal_proper_sort thy sort)
-       handle Sorts.CLASS_ERROR _ => [] (*permissive!*))
-  end;
-
 
 (* data structures *)
 
@@ -179,6 +149,7 @@
     val classess = map (complete_proper_sort thy)
       (Sign.arity_sorts thy tyco [class]);
     val inst_params = inst_params thy tyco class;
+    (*FIXME also consider existing things here*)
   in
     vardeps
     |> fold (fn superclass => assert thy (Inst (superclass, tyco))) superclasses
@@ -199,6 +170,7 @@
   let
     val _ = tracing "add_const";
     val (lhs, rhss) = lhs_rhss_of thy c;
+    (*FIXME build lhs_rhss_of such that it points to existing graph if possible*)
     fun styp_of (Type (tyco, tys)) = Tyco (tyco, map styp_of tys)
       | styp_of (TFree (v, _)) = Var (Fun c, find_index (fn (v', _) => v = v') lhs);
     val rhss' = (map o apsnd o map) styp_of rhss;
@@ -220,33 +192,62 @@
 
 (* applying instantiations *)
 
+fun dicts_of thy (proj_sort, algebra) (T, sort) =
+  let
+    fun class_relation (x, _) _ = x;
+    fun type_constructor tyco xs class =
+      inst_params thy tyco class @ (maps o maps) fst xs;
+    fun type_variable (TFree (_, sort)) = map (pair []) (proj_sort sort);
+  in
+    flat (Sorts.of_sort_derivation (Syntax.pp_global thy) algebra
+      { class_relation = class_relation, type_constructor = type_constructor,
+        type_variable = type_variable } (T, proj_sort sort)
+       handle Sorts.CLASS_ERROR _ => [] (*permissive!*))
+  end;
+
+fun instances_of (*FIXME move to sorts.ML*) algebra =
+  let
+    val { classes, arities } = Sorts.rep_algebra algebra;
+    val sort_classes = fn cs => filter (member (op = o apsnd fst) cs)
+      (flat (rev (Graph.strong_conn classes)));
+  in
+    Symtab.fold (fn (a, cs) => append ((map (pair a) o sort_classes) cs))
+      arities []
+  end;
+
 fun algebra_of thy vardeps =
   let
     val pp = Syntax.pp_global thy;
     val thy_algebra = Sign.classes_of thy;
     val is_proper = can (AxClass.get_info thy);
-    val arities = Vargraph.fold (fn ((Fun _, _), _) => I
+    val classrels = Sorts.classrels_of thy_algebra
+      |> filter (is_proper o fst)
+      |> (map o apsnd) (filter is_proper);
+    val instances = instances_of thy_algebra
+      |> filter (is_proper o snd);
+    fun add_class (class, superclasses) algebra =
+      Sorts.add_class pp (class, Sorts.minimize_sort algebra superclasses) algebra;
+    val arity_constraints = Vargraph.fold (fn ((Fun _, _), _) => I
       | ((Inst (class, tyco), k), ((_, classes), _)) =>
           AList.map_default (op =)
             ((tyco, class), replicate (Sign.arity_number thy tyco) [])
               (nth_map k (K classes))) vardeps [];
-    val classrels = Sorts.classrels_of thy_algebra
-      |> filter (is_proper o fst)
-      |> (map o apsnd) (filter is_proper);
-    fun add_arity (tyco, class) = case AList.lookup (op =) arities (tyco, class)
-     of SOME sorts => Sorts.add_arities pp (tyco, [(class, sorts)])
-      | NONE => if Sign.arity_number thy tyco = 0
-          then (tracing (tyco ^ "::" ^ class); Sorts.add_arities pp (tyco, [(class, [])]))
-          else I;
-    val instances = Sorts.instances_of thy_algebra
-      |> filter (is_proper o snd)
+    fun add_arity (tyco, class) algebra =
+      case AList.lookup (op =) arity_constraints (tyco, class)
+       of SOME sorts => (tracing (Pretty.output (Syntax.pretty_arity (ProofContext.init thy)
+              (tyco, sorts, [class])));
+            Sorts.add_arities pp
+              (tyco, [(class, map (Sorts.minimize_sort algebra) sorts)]) algebra)
+        | NONE => if Sign.arity_number thy tyco = 0
+            then Sorts.add_arities pp (tyco, [(class, [])]) algebra
+            else algebra;
   in
     Sorts.empty_algebra
-    |> fold (Sorts.add_class pp) classrels
+    |> fold add_class classrels
     |> fold add_arity instances
   end;
 
-fun add_eqs thy algebra vardeps c gr =
+fun add_eqs thy (proj_sort, algebra) vardeps c gr =
   let
     val eqns = Code.these_eqns thy c
       |> burrow_fst (Code_Unit.norm_args thy)
@@ -260,28 +261,27 @@
     val tyscm = case eqns' of [] => Code.default_typscheme thy c
       | ((thm, _) :: _) => (snd o Code_Unit.head_eqn thy) thm;
     val _ = tracing ("tyscm " ^ makestring (map snd (fst tyscm)));
-    val rhss = fold_consts (fn (c, ty) =>
-      insert (op =) (c, Sign.const_typargs thy (c, Logic.unvarifyT ty))) (map fst eqns') [];
+    val rhss = consts_of thy eqns';
   in
     gr
     |> Graph.new_node (c, (tyscm, eqns'))
-    |> fold (fn (c', Ts) => ensure_eqs_dep thy algebra vardeps c c'
+    |> fold (fn (c', Ts) => ensure_eqs_dep thy (proj_sort, algebra) vardeps c c'
         #-> (fn (vs, _) =>
-          fold2 (ensure_match thy algebra vardeps c) Ts (map snd vs))) rhss
+          fold2 (ensure_match thy (proj_sort, algebra) vardeps c) Ts (map snd vs))) rhss
     |> pair tyscm
   end
-and ensure_match thy algebra vardeps c T sort gr =
+and ensure_match thy (proj_sort, algebra) vardeps c T sort gr =
   gr
-  |> fold (fn c' => ensure_eqs_dep thy algebra vardeps c c' #> snd)
-       (dicts_of thy algebra (T, sort))
-and ensure_eqs_dep thy algebra vardeps c c' gr =
+  |> fold (fn c' => ensure_eqs_dep thy (proj_sort, algebra) vardeps c c' #> snd)
+       (dicts_of thy (proj_sort, algebra) (T, proj_sort sort))
+and ensure_eqs_dep thy (proj_sort, algebra) vardeps c c' gr =
   gr
-  |> ensure_eqs thy algebra vardeps c'
+  |> ensure_eqs thy (proj_sort, algebra) vardeps c'
   ||> Graph.add_edge (c, c')
-and ensure_eqs thy algebra vardeps c gr =
+and ensure_eqs thy (proj_sort, algebra) vardeps c gr =
   case try (Graph.get_node gr) c
    of SOME (tyscm, _) => (tyscm, gr)
-    | NONE => add_eqs thy algebra vardeps c gr;
+    | NONE => add_eqs thy (proj_sort, algebra) vardeps c gr;
 
 fun extend_graph thy cs gr =
   let
@@ -291,13 +291,10 @@
     val _ = tracing "obtaining algebra";
     val algebra = algebra_of thy vardeps;
     val _ = tracing "obtaining equations";
-    val (_, gr) = fold_map (ensure_eqs thy algebra vardeps) cs gr;
+    val proj_sort = complete_proper_sort thy #> Sorts.minimize_sort algebra;
+    val (_, gr') = fold_map (ensure_eqs thy (proj_sort, algebra) vardeps) cs gr;
     val _ = tracing "sort projection";
-    val minimal_proper_sort = fn sort => sort
-      |> Sorts.complete_sort (Sign.classes_of thy)
-      |> filter (can (AxClass.get_info thy))
-      |> Sorts.minimize_sort algebra;
-  in ((minimal_proper_sort, algebra), gr) end;
+  in ((proj_sort, algebra), gr') end;
 
 
 (** retrieval interfaces **)
@@ -320,7 +317,7 @@
       insert (op =) (Sign.const_typargs thy (c, Logic.unvarifyT ty), c)) consts' [];
     val typ_matches = maps (fn (tys, c) => tys ~~ map snd (fst (fst (Graph.get_node funcgr' c))))
       const_matches;
-    val dicts = maps (dicts_of thy (snd algebra')) typ_matches;
+    val dicts = maps (dicts_of thy algebra') typ_matches;
     val (algebra'', funcgr'') = extend_graph thy dicts funcgr';
   in (evaluator_lift (evaluator_funcgr algebra'') thm funcgr'', funcgr'') end;