added type inference at the end of normalization
authornipkow
Tue, 06 Jun 2006 19:24:05 +0200
changeset 19795 746274ca400b
parent 19794 100ba10eee64
child 19796 d86e7b1fc472
added type inference at the end of normalization
src/Pure/Tools/nbe.ML
src/Pure/Tools/nbe_codegen.ML
src/Pure/Tools/nbe_eval.ML
--- a/src/Pure/Tools/nbe.ML	Tue Jun 06 19:16:42 2006 +0200
+++ b/src/Pure/Tools/nbe.ML	Tue Jun 06 19:24:05 2006 +0200
@@ -2,13 +2,15 @@
     Author:     Tobias Nipkow, Florian Haftmann, TU Muenchen
 
 Toplevel theory interface for "normalization by evaluation"
+Preconditions: no Vars
 *)
 
 signature NBE =
 sig
   val norm_by_eval_i: term -> theory -> term * theory;
   val lookup: string -> NBE_Eval.Univ;
-  val update: string * NBE_Eval.Univ -> unit;
+  val update: string * NBE_Eval.Univ -> unit
+  val trace_nbe: bool ref
 end;
 
 structure NBE: NBE =
@@ -25,6 +27,10 @@
   fun print _ _ = ();
 end);
 
+val trace_nbe = ref false;
+
+fun nbe_trace fs = if !trace_nbe then tracing(fs()) else ();
+
 val _ = Context.add_setup NBE_Data.init;
 
 val tab : NBE_Eval.Univ Symtab.table ref = ref Symtab.empty;
@@ -32,12 +38,30 @@
 fun update sx = (tab := Symtab.update sx (!tab));
 fun defined s = Symtab.defined (!tab) s;
 
-fun use_show s = (writeln ("\n---generated code:\n"^ s);
-     use_text(writeln o enclose "\n---compiler echo:\n" "\n---\n",
-              writeln o enclose "\n--- compiler echo (with error!):\n" 
-                                "\n---\n")
-      true s);
+fun use_show "" = ()
+  | use_show s =
+ (if !trace_nbe then tracing ("\n---generated code:\n"^ s) else ();
+  use_text(tracing o enclose "\n---compiler echo:\n" "\n---\n",
+           tracing o enclose "\n--- compiler echo (with error!):\n" "\n---\n")
+      (!trace_nbe) s);
 
+(* FIXME move to term.ML *)
+fun subst_Frees [] tm = tm
+  | subst_Frees inst tm =
+      let
+        fun subst (t as Free(s, _)) = the_default t (AList.lookup (op =) inst s)
+          | subst (Abs (a, T, t)) = Abs (a, T, subst t)
+          | subst (t $ u) = subst t $ subst u
+          | subst t = t;
+      in subst tm end;
+
+fun var_tab t = (Term.add_frees t [], Term.add_vars t []);
+
+fun anno_vars (Ftab,Vtab) =
+  subst_Vars  (map (fn (ixn,T) => (ixn, Var(ixn,T))) Vtab) o
+  subst_Frees (map (fn (s,T) =>   (s,   Free(s,T)))  Ftab)
+
+(* FIXME try to use isar_cmd/print_term to take care of context *)
 fun norm_by_eval_i t thy =
   let
     val nbe_tab = NBE_Data.get thy;
@@ -46,23 +70,26 @@
     val (t', thy') = CodegenPackage.codegen_term t thy;
     val modl_new = CodegenPackage.get_root_module thy';
     val diff = CodegenThingol.diff_module (modl_new, modl_old);
-    val _ = writeln ("new definitions: " ^ (commas o map fst) diff);
+    val _ = nbe_trace (fn() => "new definitions: " ^ (commas o map fst) diff);
     val _ = (tab := nbe_tab;
              Library.seq (use_show o NBE_Codegen.generate defined) diff);
     val thy'' = NBE_Data.put (!tab) thy';
     val nt' = NBE_Eval.nbe (!tab) t';
-    val _ = print nt';
+    val _ = nbe_trace (fn() => "Input:\n" ^ Display.raw_string_of_term t)
+    val _ =  nbe_trace (fn()=> "Normalized:\n" ^ NBE_Eval.string_of_nterm nt');
     val t' = NBE_Codegen.nterm_to_term thy'' nt';
-(*
-    val _ = print t';
-    val (t'', _) =
+    val _ = nbe_trace (fn()=>"Converted back:\n" ^ Display.raw_string_of_term t');
+    val t'' = anno_vars (var_tab t) t';
+    val _ = nbe_trace (fn()=>"Vars typed:\n" ^ Display.raw_string_of_term t'');
+    val ty = type_of t;
+    val (t''', _) =
       Sign.infer_types (Sign.pp thy) thy (Sign.consts_of thy) (K NONE) (K NONE)
-        [] true ([t'], type_of t);
-*)
-    val _ = (Pretty.writeln o Sign.pretty_term thy'') t';
-  in
-    (t', thy'')
-  end;
+        [] false ([t''], ty);
+    val s = Pretty.string_of
+      (Pretty.block [Pretty.quote (Sign.pretty_term thy'' t'''), Pretty.fbrk,
+        Pretty.str "::", Pretty.brk 1, Pretty.quote (Sign.pretty_typ thy'' ty)])
+    val _ = writeln s
+  in  (t''', thy'')  end;
 
 fun norm_by_eval raw_t thy = norm_by_eval_i (Sign.read_term thy raw_t) thy;
 
--- a/src/Pure/Tools/nbe_codegen.ML	Tue Jun 06 19:16:42 2006 +0200
+++ b/src/Pure/Tools/nbe_codegen.ML	Tue Jun 06 19:24:05 2006 +0200
@@ -141,6 +141,15 @@
 
 open NBE_Eval;
 
+val tcount = ref 0;
+
+(* FIXME get rid of TVar case!!! *)
+fun varifyT ty =
+  let val ty' = map_type_tvar (fn ((s,i),S) => TypeInfer.param (!tcount + i) (s,S)) ty;
+      val _ = (tcount := !tcount + maxidx_of_typ ty + 1);
+      val ty'' = map_type_tfree (TypeInfer.param (!tcount)) ty'
+  in tcount := !tcount+1; ty'' end;
+
 fun nterm_to_term thy t =
   let
     fun consts_of (C s) = insert (op =) s
@@ -149,14 +158,14 @@
       | consts_of (A (t1, t2)) = consts_of t1 #> consts_of t2
       | consts_of (AbsN (_, t)) = consts_of t;
     val consts = consts_of t [];
-    val the_const = the o AList.lookup (op =)
-      (consts ~~ CodegenPackage.consts_of_idfs thy consts);
+    val ctab = consts ~~ CodegenPackage.consts_of_idfs thy consts;
+    val the_const = apsnd varifyT o the o AList.lookup (op =) ctab;
     fun to_term bounds (C s) = Const (the_const s)
       | to_term bounds (V s) = Free (s, dummyT)
       | to_term bounds (B i) = Bound (find_index (fn j => i = j) bounds)
       | to_term bounds (A (t1, t2)) = to_term bounds t1 $ to_term bounds t2
       | to_term bounds (AbsN (i, t)) =
           Abs("u", dummyT, to_term (i::bounds) t);
-  in to_term [] t end;
+  in tcount := 0; to_term [] t end;
 
 end;
--- a/src/Pure/Tools/nbe_eval.ML	Tue Jun 06 19:16:42 2006 +0200
+++ b/src/Pure/Tools/nbe_eval.ML	Tue Jun 06 19:24:05 2006 +0200
@@ -25,17 +25,15 @@
     | Fun of (Univ list -> Univ) * (Univ list) * int * (unit -> nterm)
                                          (*functions*);
 
-  val nbe: Univ Symtab.table -> CodegenThingol.iexpr -> nterm;
-  val apply: Univ -> Univ -> Univ;
+  val nbe: Univ Symtab.table -> CodegenThingol.iexpr -> nterm
+  val apply: Univ -> Univ -> Univ
 
-  val to_term: Univ -> nterm;
+  val to_term: Univ -> nterm
 
-  val mk_Fun: string * (Univ list -> Univ) * int -> string * Univ;
-  val new_name: unit -> int;
+  val mk_Fun: string * (Univ list -> Univ) * int -> string * Univ
+  val new_name: unit -> int
 
-  (* For testing
-  val eval: (Univ list) -> term -> Univ
-  *)
+  val string_of_nterm: nterm -> string
 end;
 
 structure NBE_Eval: NBE_EVAL =
@@ -48,6 +46,14 @@
   | A of nterm * nterm
   | AbsN of int * nterm;
 
+fun string_of_nterm(C s) = "(C \"" ^ s ^ "\")"
+  | string_of_nterm(V s) = "(V \"" ^ s ^ "\")"
+  | string_of_nterm(B n) = "(B " ^ string_of_int n ^ ")"
+  | string_of_nterm(A(s,t)) =
+       "(A " ^ string_of_nterm s ^ string_of_nterm t ^ ")"
+  | string_of_nterm(AbsN(n,t)) =
+      "(Abs " ^ string_of_int n ^ " " ^ string_of_nterm t ^ ")";
+
 fun apps t args = foldr (fn (y,x) => A(x,y)) t args;
 
 (* ------------------------------ The semantical universe --------------------- *)