improved evaluation interface
authorhaftmann
Tue, 21 Aug 2007 13:30:38 +0200
changeset 24381 560e8ecdf633
parent 24380 c215e256beca
child 24382 54da7d61372d
improved evaluation interface
src/HOL/Library/Eval.thy
src/Tools/code/code_package.ML
src/Tools/code/code_target.ML
src/Tools/code/code_thingol.ML
src/Tools/nbe.ML
--- a/src/HOL/Library/Eval.thy	Tue Aug 21 13:30:36 2007 +0200
+++ b/src/HOL/Library/Eval.thy	Tue Aug 21 13:30:38 2007 +0200
@@ -167,7 +167,8 @@
 end;
 *}
 
-oracle eval_oracle ("term * CodeThingol.code * CodeThingol.iterm * CodeThingol.itype * cterm") = {* fn thy => fn (t0, code, t, ty, ct) => 
+oracle eval_oracle ("term * CodeThingol.code * (CodeThingol.typscheme * CodeThingol.iterm) * cterm") =
+{* fn thy => fn (t0, code, ((vs, ty), t), ct) => 
 let
   val _ = (Term.map_types o Term.map_atyps) (fn _ =>
     error ("Term " ^ Sign.string_of_term thy t0 ^ " contains polymorphic type"))
@@ -181,7 +182,7 @@
 
 open Eval;
 
-fun eval_invoke thy t0 code (t, ty) _ ct = eval_oracle thy (t0, code, t, ty, ct);
+fun eval_invoke thy t0 code vs_ty_t _ ct = eval_oracle thy (t0, code, vs_ty_t, ct);
 
 fun eval_conv ct =
   let
--- a/src/Tools/code/code_package.ML	Tue Aug 21 13:30:36 2007 +0200
+++ b/src/Tools/code/code_package.ML	Tue Aug 21 13:30:38 2007 +0200
@@ -9,10 +9,12 @@
 sig
   (* interfaces *)
   val eval_conv: theory
-    -> (CodeThingol.code -> CodeThingol.iterm * CodeThingol.itype -> string list -> cterm -> thm)
+    -> (CodeThingol.code -> CodeThingol.typscheme * CodeThingol.iterm
+       -> string list -> cterm -> thm)
     -> cterm -> thm;
   val eval_term: theory
-    -> (CodeThingol.code -> CodeThingol.iterm * CodeThingol.itype -> string list -> cterm -> 'a)
+    -> (CodeThingol.code -> CodeThingol.typscheme * CodeThingol.iterm
+       -> string list -> cterm -> 'a)
     -> cterm -> 'a;
   val satisfies_ref: bool option ref;
   val satisfies: theory -> cterm -> string list -> bool;
@@ -285,10 +287,10 @@
           ##>> exprgen_term thy algbr funcgr rhs;
       in
         trns
-        |> timeap (fold_map (exprgen_eq o dest_eqthm) thms)
-        ||>> fold_map (exprgen_tyvar_sort thy algbr funcgr) vs
+        |> fold_map (exprgen_tyvar_sort thy algbr funcgr) vs
         ||>> exprgen_typ thy algbr funcgr ty
-        |>> (fn ((eqs, vs), ty) => CodeThingol.Fun (eqs, (vs, ty)))
+        ||>> timeap (fold_map (exprgen_eq o dest_eqthm) thms)
+        |>> (fn ((vs, ty), eqs) => CodeThingol.Fun ((vs, ty), eqs))
       end;
     val defgen = if (is_some o Code.get_datatype_of_constr thy) const
       then defgen_datatypecons
@@ -536,25 +538,28 @@
     val value_name = "Isabelle_Eval.EVAL.EVAL";
     fun ensure_eval thy algbr funcgr t = 
       let
+        val ty = fastype_of t;
+        val vs = typ_tfrees ty;
         val defgen_eval =
-          exprgen_term' thy algbr funcgr t
-          ##>> exprgen_typ thy algbr funcgr (fastype_of t)
-          #>> (fn (t, ty) => CodeThingol.Fun ([([], t)], ([], ty)));
+          fold_map (exprgen_tyvar_sort thy algbr funcgr) vs
+          ##>> exprgen_typ thy algbr funcgr ty
+          ##>> exprgen_term' thy algbr funcgr t
+          #>> (fn ((vs, ty), t) => CodeThingol.Fun ((vs, ty), [([], t)]));
         fun result (dep, code) =
           let
-            val CodeThingol.Fun ([([], t)], ([], ty)) = Graph.get_node code value_name;
+            val CodeThingol.Fun ((vs, ty), [([], t)]) = Graph.get_node code value_name;
             val deps = Graph.imm_succs code value_name;
             val code' = Graph.del_nodes [value_name] code;
             val code'' = CodeThingol.project_code false [] (SOME deps) code';
-          in ((code'', (t, ty), deps), (dep, code')) end;
+          in ((code'', ((vs, ty), t), deps), (dep, code')) end;
       in
         ensure_def thy defgen_eval "evaluation" value_name
         #> result
       end;
     fun h funcgr ct =
       let
-        val (code, (t, ty), deps) = generate thy funcgr ensure_eval (Thm.term_of ct);
-      in g code (t, ty) deps ct end;
+        val (code, vs_ty_t, deps) = generate thy funcgr ensure_eval (Thm.term_of ct);
+      in g code vs_ty_t deps ct end;
   in f thy h end;
 
 fun eval_conv thy = raw_eval CodeFuncgr.eval_conv thy;
@@ -564,7 +569,7 @@
 
 fun satisfies thy ct witnesses =
   let
-    fun evl code (t, ty) deps ct =
+    fun evl code ((vs, ty), t) deps ct =
       let
         val t0 = Thm.term_of ct
         val _ = (Term.map_types o Term.map_atyps) (fn _ =>
--- a/src/Tools/code/code_target.ML	Tue Aug 21 13:30:36 2007 +0200
+++ b/src/Tools/code/code_target.ML	Tue Aug 21 13:30:38 2007 +0200
@@ -293,7 +293,7 @@
 (** SML/OCaml serializer **)
 
 datatype ml_def =
-    MLFuns of (string * ((iterm list * iterm) list * typscheme)) list
+    MLFuns of (string * (typscheme * (iterm list * iterm) list)) list
   | MLDatas of (string * ((vname * sort) list * (string * itype list) list)) list
   | MLClass of string * ((class * string) list * (vname * (string * itype) list))
   | MLClassinst of string * ((class * (string * (vname * sort) list))
@@ -423,13 +423,13 @@
                 fun mk [] [] = "val"
                   | mk (_::_) _ = "fun"
                   | mk [] vs = if (null o filter_out (null o snd)) vs then "val" else "fun";
-                fun chk (_, ((ts, _) :: _, (vs, _))) NONE = SOME (mk ts vs)
-                  | chk (_, ((ts, _) :: _, (vs, _))) (SOME defi) =
+                fun chk (_, ((vs, _), (ts, _) :: _)) NONE = SOME (mk ts vs)
+                  | chk (_, ((vs, _), (ts, _) :: _)) (SOME defi) =
                       if defi = mk ts vs then SOME defi
                       else error ("Mixing simultaneous vals and funs not implemented: "
                         ^ commas (map (labelled_name o fst) funns));
               in the (fold chk funns NONE) end;
-            fun pr_funn definer (name, (eqs as eq::eqs', (raw_vs, ty))) =
+            fun pr_funn definer (name, ((raw_vs, ty), eqs as eq :: eqs')) =
               let
                 val vs = filter_out (null o snd) raw_vs;
                 val shift = if null eqs' then I else
@@ -758,7 +758,7 @@
                       :: maps (append [Pretty.fbrk, str "|", Pretty.brk 1] o single o pr_eq) eqs'
                     )
                   end;
-            fun pr_funn definer (name, (eqs, (vs, ty))) =
+            fun pr_funn definer (name, ((vs, ty), eqs)) =
               concat (
                 str definer
                 :: (str o deresolv) name
@@ -1173,7 +1173,7 @@
             ) (map pr bs)
           end
       | pr_case vars fxy ((_, []), _) = str "error \"empty case\"";
-    fun pr_def (name, CodeThingol.Fun (eqs, (vs, ty))) =
+    fun pr_def (name, CodeThingol.Fun ((vs, ty), eqs)) =
           let
             val tyvars = CodeName.intro_vars (map fst vs) init_syms;
             fun pr_eq (ts, t) =
--- a/src/Tools/code/code_thingol.ML	Tue Aug 21 13:30:36 2007 +0200
+++ b/src/Tools/code/code_thingol.ML	Tue Aug 21 13:30:38 2007 +0200
@@ -58,7 +58,7 @@
 
   datatype def =
       Bot
-    | Fun of (iterm list * iterm) list * typscheme
+    | Fun of typscheme * (iterm list * iterm) list
     | Datatype of (vname * sort) list * (string * itype list) list
     | Datatypecons of string
     | Class of (class * string) list * (vname * (string * itype) list)
@@ -231,12 +231,10 @@
 
 (** definitions, transactions **)
 
-(* type definitions *)
-
 type typscheme = (vname * sort) list * itype;
 datatype def =
     Bot
-  | Fun of (iterm list * iterm) list * typscheme
+  | Fun of typscheme * (iterm list * iterm) list
   | Datatype of (vname * sort) list * (string * itype list) list
   | Datatypecons of string
   | Class of (class * string) list * (vname * (string * itype) list)
@@ -273,7 +271,7 @@
 fun project_code delete_empty_funs hidden raw_selected code =
   let
     fun is_empty_fun name = case Graph.get_node code name
-     of Fun ([], _) => true
+     of Fun (_, []) => true
       | _ => false;
     val names = subtract (op =) hidden (Graph.keys code);
     val deleted = Graph.all_preds code (filter is_empty_fun names);
@@ -291,7 +289,7 @@
   end;
 
 fun empty_funs code =
-  Graph.fold (fn (name, (Fun ([], _), _)) => cons name
+  Graph.fold (fn (name, (Fun (_, []), _)) => cons name
                | _ => I) code [];
 
 fun is_cons code name = case Graph.get_node code name
@@ -322,8 +320,8 @@
 
 fun check_prep_def code Bot =
       Bot
-  | check_prep_def code (Fun (eqs, d)) =
-      Fun (check_funeqs eqs, d)
+  | check_prep_def code (Fun (d, eqs)) =
+      Fun (d, check_funeqs eqs)
   | check_prep_def code (d as Datatype _) =
       d
   | check_prep_def code (Datatypecons dtco) =
@@ -405,7 +403,7 @@
 
 fun add_eval_def (name, (t, ty)) code =
   code
-  |> Graph.new_node (name, Fun ([([], t)], ([], ty)))
+  |> Graph.new_node (name, Fun (([], ty), [([], t)]))
   |> fold (curry Graph.add_edge name) (Graph.keys code);
 
 end; (*struct*)
--- a/src/Tools/nbe.ML	Tue Aug 21 13:30:36 2007 +0200
+++ b/src/Tools/nbe.ML	Tue Aug 21 13:30:38 2007 +0200
@@ -15,7 +15,7 @@
 signature NBE =
 sig
   datatype Univ = 
-      Const of string * Univ list        (*named constructors*)
+      Const of string * Univ list            (*named (uninterpreted) constants*)
     | Free of string * Univ list
     | BVar of int * Univ list
     | Abs of (int * (Univ list -> Univ)) * Univ list;
@@ -63,11 +63,11 @@
 *)
 
 datatype Univ = 
-    Const of string * Univ list        (*named constructors*)
+    Const of string * Univ list        (*named (uninterpreted) constants*)
   | Free of string * Univ list         (*free variables*)
   | BVar of int * Univ list            (*bound named variables*)
   | Abs of (int * (Univ list -> Univ)) * Univ list
-                                      (*functions*);
+                                      (*abstractions as closures*);
 
 (* constructor functions *)
 
@@ -108,7 +108,7 @@
       let
         val _ = univs_ref := [];
         val s = "Nbe.univs_ref := " ^ raw_s;
-        val _ = tracing (fn () => "\n---generated code:\n" ^ s) ();
+        val _ = tracing (fn () => "\n--- generated code:\n" ^ s) ();
         val _ = tab_ref := SOME tab;
         val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
           Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n")
@@ -228,7 +228,7 @@
         val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args);
       in (cs, ml_Let (bind_funs @ [bind_locals]) result) end;
 
-fun assemble_eval thy is_fun (t, deps) =
+fun assemble_eval thy is_fun (((vs, ty), t), deps) =
   let
     val funs = CodeThingol.fold_constnames (insert (op =)) t [];
     val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t [];
@@ -238,9 +238,9 @@
     val result = ml_list [nbe_value `$` ml_list (map nbe_free frees)];
   in ([nbe_value], ml_Let (bind_funs @ [bind_value]) result) end;
 
-fun eqns_of_stmt ((_, CodeThingol.Fun ([], _)), _) =
+fun eqns_of_stmt ((_, CodeThingol.Fun (_, [])), _) =
       NONE
-  | eqns_of_stmt ((name, CodeThingol.Fun (eqns, _)), deps) =
+  | eqns_of_stmt ((name, CodeThingol.Fun (_, eqns)), deps) =
       SOME ((name, eqns), deps)
   | eqns_of_stmt ((_, CodeThingol.Datatypecons _), _) =
       NONE
@@ -312,7 +312,7 @@
 
 (* evaluation with type reconstruction *)
 
-fun eval thy code t t' deps =
+fun eval thy code t vs_ty_t deps =
   let
     val ty = type_of t;
     fun subst_Frees [] = I
@@ -328,37 +328,37 @@
       error ("Illegal schematic type variables in normalized term: "
         ^ setmp show_types true (Sign.string_of_term thy) t);
   in
-    (t', deps)
+    (vs_ty_t, deps)
     |> eval_term thy (Symtab.defined (ensure_funs thy code))
     |> term_of_univ thy
     |> tracing (fn t => "Normalized:\n" ^ setmp show_types true Display.raw_string_of_term t)
-    |> tracing (fn _ => "Term type:\n" ^ Display.raw_string_of_typ ty)
     |> anno_vars
     |> tracing (fn t => "Vars typed:\n" ^ setmp show_types true Display.raw_string_of_term t)
-    |> tracing (fn t => setmp show_types true (Sign.string_of_term thy) t)
     |> constrain
     |> tracing (fn t => "Types inferred:\n" ^ setmp show_types true Display.raw_string_of_term t)
     |> check_tvars
+    |> tracing (fn _ => "---\n")
   end;
 
 (* evaluation oracle *)
 
-exception Normalization of CodeThingol.code * term * CodeThingol.iterm * string list;
+exception Normalization of CodeThingol.code * term
+  * (CodeThingol.typscheme * CodeThingol.iterm) * string list;
 
-fun normalization_oracle (thy, Normalization (code, t, t', deps)) =
-  Logic.mk_equals (t, eval thy code t t' deps);
+fun normalization_oracle (thy, Normalization (code, t, vs_ty_t, deps)) =
+  Logic.mk_equals (t, eval thy code t vs_ty_t deps);
 
-fun normalization_invoke thy code t t' deps =
-  Thm.invoke_oracle_i thy "HOL.normalization" (thy, Normalization (code, t, t', deps));
+fun normalization_invoke thy code t vs_ty_t deps =
+  Thm.invoke_oracle_i thy "HOL.normalization" (thy, Normalization (code, t, vs_ty_t, deps));
   (*FIXME get rid of hardwired theory name*)
 
 fun normalization_conv ct =
   let
     val thy = Thm.theory_of_cterm ct;
-    fun conv code (t', ty') deps ct =
+    fun conv code vs_ty_t deps ct =
       let
         val t = Thm.term_of ct;
-      in normalization_invoke thy code t t' deps end;
+      in normalization_invoke thy code t vs_ty_t deps end;
   in CodePackage.eval_conv thy conv ct end;
 
 (* evaluation command *)