refined algorithm
authorhaftmann
Tue, 31 Oct 2006 09:29:13 +0100
changeset 21120 e333c844b057
parent 21119 5c7edac0c645
child 21121 fae2187d6e2f
refined algorithm
src/Pure/Tools/codegen_funcgr.ML
--- a/src/Pure/Tools/codegen_funcgr.ML	Tue Oct 31 09:29:12 2006 +0100
+++ b/src/Pure/Tools/codegen_funcgr.ML	Tue Oct 31 09:29:13 2006 +0100
@@ -9,11 +9,13 @@
 signature CODEGEN_FUNCGR =
 sig
   type T;
-  val mk_funcgr: theory -> CodegenConsts.const list -> (string * typ) list -> T
-  val all_deps_of: T -> CodegenConsts.const list -> CodegenConsts.const list list
-  val get_funcs: T -> CodegenConsts.const -> thm list
-  val get_func_typs: T -> (CodegenConsts.const * typ) list
-  val normalize: theory -> thm list -> thm list
+  val make: theory -> CodegenConsts.const list -> T
+  val make_term: theory -> cterm -> (cterm * (thm * (thm -> thm))) * T
+  val funcs: T -> CodegenConsts.const -> thm list
+  val typ: T -> CodegenConsts.const -> typ
+  val deps: T -> CodegenConsts.const list -> CodegenConsts.const list list
+  val all: T -> CodegenConsts.const list
+  val norm_vars: theory -> thm list -> thm list
   val print_codethms: theory -> CodegenConsts.const list -> unit
   structure Constgraph : GRAPH
 end;
@@ -60,17 +62,13 @@
       in
         fold (fn ct => fn thm => Thm.combination thm (Thm.reflexive ct)) vs_ct thm
       end;
-    fun beta_norm thm =
-      let
-        val rhs = (snd o Logic.dest_equals o Drule.plain_prop_of) thm;
-        val thm' = Thm.beta_conversion true (cterm_of thy rhs);
-      in Thm.transitive thm thm' end;
   in
     thm
     |> eta_expand
-    |> beta_norm
+    |> Drule.fconv_rule Drule.beta_eta_conversion
   end;
 
+
 fun canonical_tvars thy thm =
   let
     fun mk_inst (v_i as (v, i), (v', sort)) (s as (maxidx, set, acc)) =
@@ -105,7 +103,7 @@
     val (_, _, inst) = fold mk_inst (vars_of thm) (maxidx + 1, [], []);
   in Thm.instantiate ([], inst) thm end;
 
-fun normalize thy thms =
+fun norm_vars thy thms =
   let
     fun burrow_thms f [] = []
       | burrow_thms f thms =
@@ -129,13 +127,13 @@
 
 (** retrieval **)
 
-fun get_funcs funcgr (c_tys as (c, _)) =
-  (these o Option.map snd o try (Constgraph.get_node funcgr)) c_tys;
+fun funcs funcgr =
+  these o Option.map snd o try (Constgraph.get_node funcgr);
 
-fun get_func_typs funcgr =
-  AList.make (fst o Constgraph.get_node funcgr) (Constgraph.keys funcgr);
+fun typ funcgr =
+  fst o Constgraph.get_node funcgr;
 
-fun all_deps_of funcgr cs =
+fun deps funcgr cs =
   let
     val conn = Constgraph.strong_conn funcgr;
     val order = rev conn;
@@ -144,25 +142,24 @@
     |> filter_out null
   end;
 
+fun all funcgr = Constgraph.keys funcgr;
+
 local
 
 fun add_things_of thy f (c, thms) =
   (fold o fold_aterms)
      (fn Const c_ty => let
             val c' = CodegenConsts.norm_of_typ thy c_ty
-          in if CodegenConsts.eq_const (c, c') then I
+          in if is_some c andalso CodegenConsts.eq_const (the c, c') then I
           else f (c', c_ty) end
        | _ => I) (maps (op :: o swap o apfst (snd o strip_comb)
             o Logic.dest_equals o Drule.plain_prop_of) thms)
 
 fun rhs_of thy (c, thms) =
   Consttab.empty
-  |> add_things_of thy (Consttab.update o rpair () o fst) (c, thms)
+  |> add_things_of thy (Consttab.update o rpair () o fst) (SOME c, thms)
   |> Consttab.keys;
 
-fun rhs_of' thy (c, thms) =
-  add_things_of thy (cons o snd) (c, thms) [];
-
 fun insts_of thy funcgr (c, ty) =
   let
     val tys = Sign.const_typargs thy (c, ty);
@@ -170,9 +167,6 @@
     val ty_decl = CodegenConsts.disc_typ_of_const thy
       (fst o Constgraph.get_node funcgr o CodegenConsts.norm thy) (c, tys);
     val tys_decl = Sign.const_typargs thy (c, ty_decl);
-    val pp = Sign.pp thy;
-    val algebra = Sign.classes_of thy;
-    fun classrel (x, _) _ = x;
     fun constructor tyco xs class =
       (tyco, class) :: maps (maps fst) xs;
     fun variable (TVar (_, sort)) = map (pair []) sort
@@ -182,15 +176,20 @@
       | mk_inst (Type (tyco1, tys1)) (Type (tyco2, tys2)) =
           if tyco1 <> tyco2 then error "bad instance"
           else fold2 mk_inst tys1 tys2;
+    val pp = Sign.pp thy;
+    val algebra = Sign.classes_of thy;
+    fun classrel (x, _) _ = x;
+    fun of_sort_deriv (ty, sort) =
+      Sorts.of_sort_derivation pp algebra
+        { classrel = classrel, constructor = constructor, variable = variable }
+        (ty, sort)
   in
-    flat (maps (Sorts.of_sort_derivation pp algebra
-      { classrel = classrel, constructor = constructor, variable = variable })
-      (fold2 mk_inst tys tys_decl []))
+    flat (maps of_sort_deriv (fold2 mk_inst tys tys_decl []))
   end;
 
 fun all_classops thy tyco class =
-  maps (AxClass.params_of thy)
-      (Graph.all_succs ((#classes o Sorts.rep_algebra o Sign.classes_of) thy) [class])
+  AxClass.params_of thy class
+(*   |> tap (fn _ => writeln ("INST " ^ tyco ^ " - " ^ class))  *)
   |> AList.make (fn c => CodegenConsts.disc_typ_of_classop thy (c, [Type (tyco, [])]))
         (*typ_of_classop is very liberal in its type arguments*)
   |> map (CodegenConsts.norm_of_typ thy);
@@ -206,10 +205,10 @@
          (Graph.all_succs thy_classes classes))) tab [])
   end;
 
-fun insts_of_thms thy funcgr c_thms =
+fun insts_of_thms thy funcgr (c, thms) =
   let
     val insts = add_things_of thy (fn (_, c_ty) => fold (insert (op =))
-      (insts_of thy funcgr c_ty)) c_thms [];
+      (insts_of thy funcgr c_ty)) (SOME c, thms) [];
   in instdefs_of thy insts end;
 
 fun ensure_const thy funcgr c auxgr =
@@ -222,7 +221,7 @@
     |> Constgraph.new_node (c, [])
     |> pair (SOME c)
   else let
-    val thms = normalize thy (CodegenData.these_funcs thy c);
+    val thms = norm_vars thy (CodegenData.these_funcs thy c);
     val rhs = rhs_of thy (c, thms);
   in
     auxgr
@@ -240,7 +239,7 @@
     fun typscheme_of (c, ty) =
       try (Constgraph.get_node funcgr) (CodegenConsts.norm_of_typ thy (c, ty))
       |> Option.map fst;
-    fun incr_indices (c, thms) maxidx =
+    fun incr_indices (c:'a, thms) maxidx =
       let
         val thms' = map (Thm.incr_indexes maxidx) thms;
         val maxidx' = Int.max
@@ -263,12 +262,25 @@
     fun apply_unifier unif (c, []) = (c, [])
       | apply_unifier unif (c, thms as thm :: _) =
           let
-            val ty = CodegenData.typ_func thy thm;
-            val ty' = Envir.norm_type unif ty;
-            val env = Type.typ_match (Sign.tsig_of thy) (ty, ty') Vartab.empty;
-            val inst = Thm.instantiate (Vartab.fold (fn (x_i, (sort, ty)) =>
-              cons (Thm.ctyp_of thy (TVar (x_i, sort)), Thm.ctyp_of thy ty)) env [], []);
-          in (c, map (Drule.zero_var_indexes o inst) thms) end;
+            val tvars = Term.add_tvars (Thm.prop_of thm) [];
+            fun mk_inst (v_i_sort as (v, _)) =
+              let
+                val ty = TVar v_i_sort;
+              in
+                pairself (Thm.ctyp_of thy) (ty,
+                  TVar (v, (snd o dest_TVar o Envir.norm_type unif) ty))
+              end;
+            val instmap = map mk_inst tvars;
+            val (thms' as thm' :: _) = map (Drule.zero_var_indexes o Thm.instantiate (instmap, [])) thms
+            val _ = if fst c <> "" andalso not (Sign.typ_equiv thy (Type.strip_sorts (CodegenData.typ_func thy thm), Type.strip_sorts (CodegenData.typ_func thy thm')))
+              then error ("illegal function type instantiation:\n" ^ Sign.string_of_typ thy (CodegenData.typ_func thy thm)
+                ^ "\nto" ^ Sign.string_of_typ thy (CodegenData.typ_func thy thm'))
+              else ();
+          in (c, thms') end;
+    fun rhs_of' thy (("", []), thms as [_]) =
+          add_things_of thy (cons o snd) (NONE, thms) []
+      | rhs_of' thy (c, thms) =
+          add_things_of thy (cons o snd) (SOME c, thms) [];
     val (eqss', maxidx) =
       fold_map incr_indices eqss 0;
     val (unif, _) =
@@ -301,16 +313,67 @@
        (map (AList.make (Constgraph.get_node auxgr))
        (rev (Constgraph.strong_conn auxgr))) funcgr);
 
+fun drop_classes thy tfrees thm =
+  let
+(*     val _ = writeln ("DROP1 " ^ setmp show_types true string_of_thm thm);  *)
+    val (_, thm') = Thm.varifyT' [] thm;
+    val tvars = Term.add_tvars (Thm.prop_of thm') [];
+(*     val _ = writeln ("DROP2 " ^ setmp show_types true string_of_thm thm');  *)
+    val unconstr = map (Thm.ctyp_of thy o TVar) tvars;
+    val instmap = map2 (fn (v_i, _) => fn (v, sort) => pairself (Thm.ctyp_of thy)
+      (TVar (v_i, []), TFree (v, sort))) tvars tfrees;
+  in
+    thm'
+    |> fold Thm.unconstrainT unconstr
+    |> Thm.instantiate (instmap, [])
+    |> Tactic.rule_by_tactic ((REPEAT o CHANGED o ALLGOALS o Tactic.resolve_tac) (AxClass.class_intros thy))
+(*     |> tap (fn thm => writeln ("DROP3 " ^ setmp show_types true string_of_thm thm))  *)
+  end;
+
 in
 
 val ensure_consts = ensure_consts;
 
-fun mk_funcgr thy consts cs =
-  Funcgr.change thy (
-    ensure_consts thy consts
-    #> (fn funcgr => ensure_consts thy
-         (instdefs_of thy (fold (fold (insert (op =)) o insts_of thy funcgr) cs [])) funcgr)
-  );
+fun make thy consts =
+  Funcgr.change thy (ensure_consts thy consts);
+
+fun make_term thy ct =
+  let
+    val _ = Sign.no_vars (Sign.pp thy) (Thm.term_of ct);
+    val _ = Term.fold_types (Type.no_tvars #> K I) (Thm.term_of ct) ();
+    val thm1 = CodegenData.preprocess_cterm thy ct;
+(*     val _ = writeln ("THM1 " ^ setmp show_types true string_of_thm thm1);  *)
+    val ct' = Drule.dest_equals_rhs (Thm.cprop_of thm1);
+    val consts = CodegenConsts.consts_of thy (Thm.term_of ct');
+    val funcgr = make thy consts;
+    val (_, thm2) = Thm.varifyT' [] thm1;
+(*     val _ = writeln ("THM2 " ^ setmp show_types true string_of_thm thm2);  *)
+    val thm3 = Thm.reflexive (Drule.dest_equals_rhs (Thm.cprop_of thm2));
+(*     val _ = writeln ("THM3 " ^ setmp show_types true string_of_thm thm3);  *)
+    val [(_, [thm4])] = specialize_typs thy funcgr [(("", []), [thm3])];
+(*     val _ = writeln ("THM4 " ^ setmp show_types true string_of_thm thm4);  *)
+    val tfrees = Term.add_tfrees (Thm.prop_of thm1) [];
+(*     val _ = writeln "TFREES";  *)
+(*     val _ = (writeln o cat_lines o map (fn (v, sort) => v ^ "::" ^ Sign.string_of_sort thy sort)) tfrees;  *)
+    fun inst thm =
+      let
+        val tvars = Term.add_tvars (Thm.prop_of thm) [];
+(*         val _ = writeln "TVARS";  *)
+(*         val _ = (writeln o cat_lines o map (fn ((v, i), sort) => v ^ "_" ^ string_of_int i ^ "::" ^ Sign.string_of_sort thy sort)) tvars;  *)
+        val instmap = map2 (fn (v_i, sort) => fn (v, _) => pairself (Thm.ctyp_of thy)
+          (TVar (v_i, sort), TFree (v, sort))) tvars tfrees;
+      in Thm.instantiate (instmap, []) thm end;
+    val thm5 = inst thm2;
+    val thm6 = inst thm4;
+(*     val _ = writeln ("THM5 " ^ setmp show_types true string_of_thm thm5);  *)
+(*     val _ = writeln ("THM6 " ^ setmp show_types true string_of_thm thm6);  *)
+    val ct'' = Drule.dest_equals_rhs (Thm.cprop_of thm6);
+    val cs = fold_aterms (fn Const c => cons c | _ => I) (Thm.term_of ct'') [];
+    val drop = drop_classes thy tfrees;
+(*     val _ = writeln "ADD INST";  *)
+    val funcgr' = ensure_consts thy
+      (instdefs_of thy (fold (fold (insert (op =)) o insts_of thy funcgr) cs [])) funcgr
+  in ((ct'', (thm5, drop)), Funcgr.change thy (K funcgr')) end;
 
 end; (*local*)
 
@@ -327,7 +390,7 @@
   |> Pretty.writeln;
 
 fun print_codethms thy consts =
-  mk_funcgr thy consts [] |> print_funcgr thy;
+  make thy consts |> print_funcgr thy;
 
 fun print_codethms_e thy cs =
   print_codethms thy (map (CodegenConsts.read_const thy) cs);