various code refinements
authorhaftmann
Mon, 02 Oct 2006 23:01:11 +0200
changeset 20846 5fde744176d7
parent 20845 c55dcf606f65
child 20847 7e8c724339e0
various code refinements
src/Pure/Tools/nbe.ML
src/Pure/Tools/nbe_codegen.ML
src/Pure/Tools/nbe_eval.ML
--- a/src/Pure/Tools/nbe.ML	Mon Oct 02 23:01:09 2006 +0200
+++ b/src/Pure/Tools/nbe.ML	Mon Oct 02 23:01:11 2006 +0200
@@ -11,24 +11,25 @@
   val normalization_conv: cterm -> thm
   val lookup: string -> NBE_Eval.Univ
   val update: string * NBE_Eval.Univ -> unit
-  val trace_nbe: bool ref
+  val trace: bool ref
 end;
 
 structure NBE: NBE =
 struct
 
+val trace = ref false;
+fun tracing f = if !trace then Output.tracing (f ()) else ();
+
 
 (* theory data setup *)
 
-structure NBE_Data = TheoryDataFun
+structure NBE_Data = CodeDataFun
 (struct
   val name = "Pure/NBE"
   type T = NBE_Eval.Univ Symtab.table
   val empty = Symtab.empty
-  val copy = I
-  val extend = I
   fun merge _ = Symtab.merge (K true)
-  fun print _ _ = ()
+  fun purge _ _ = Symtab.empty
 end);
 
 val _ = Context.add_setup NBE_Data.init;
@@ -39,18 +40,62 @@
 val tab : NBE_Eval.Univ Symtab.table ref = ref Symtab.empty;
 fun lookup s = (the o Symtab.lookup (!tab)) s;
 fun update sx = (tab := Symtab.update sx (!tab));
-fun defined s = Symtab.defined (!tab) s;
 
 
-(* FIXME replace by Term.map_aterms *)
-fun subst_Frees [] tm = tm
-  | subst_Frees inst tm =
+(* norm by eval *)
+
+local
+
+(* FIXME better turn this into a function
+    NBE_Eval.Univ Symtab.table -> NBE_Eval.Univ Symtab.table
+    with implicit side effect *)
+fun use_code NONE = ()
+  | use_code (SOME s) =
+      (tracing (fn () => "\n---generated code:\n" ^ s);
+       use_text (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
+            Output.tracing o enclose "\n--- compiler echo (with error!):\n" "\n---\n")
+        (!trace) s);
+
+fun generate thy funs =
+  let
+    val _ = tab := NBE_Data.get thy;;
+    val _ = Library.seq (use_code o NBE_Codegen.generate thy
+      (fn s => Symtab.defined (!tab) s)) funs;
+  in NBE_Data.change thy (K (!tab)) end;
+
+fun compile_term thy t =
+  let
+    (*FIXME: proper interfaces in codegen_*)
+    val (consts, cs) = CodegenConsts.consts_of thy t;
+    val funcgr = CodegenFuncgr.mk_funcgr thy consts cs;
+    (*FIXME: proper interfaces in codegen_*)
+    fun const_typ (c, ty) =
       let
-        fun subst (t as Free(s, _)) = the_default t (AList.lookup (op =) inst s)
-          | subst (Abs (a, T, t)) = Abs (a, T, subst t)
-          | subst (t $ u) = subst t $ subst u
-          | subst t = t;
-      in subst tm end;
+        val const = CodegenConsts.norm_of_typ thy (c, ty);
+      in case CodegenFuncgr.get_funcs funcgr const
+       of (thm :: _) => CodegenData.typ_func thy thm
+        | [] => Sign.the_const_type thy c
+      end;
+    val (_, ct) = CodegenData.preprocess_cterm thy const_typ (Thm.cterm_of thy t)
+    val t' = Thm.term_of ct;
+    val (consts, cs) = CodegenConsts.consts_of thy t';
+    val funcgr = CodegenFuncgr.mk_funcgr thy consts cs;
+    val nbe_tab = NBE_Data.get thy;
+    val all_consts =
+      CodegenFuncgr.all_deps_of funcgr consts
+      |> (map o filter_out) (Symtab.defined nbe_tab o CodegenNames.const thy)
+      |> filter_out null;
+    val funs = (map o map)
+      (fn c => (CodegenNames.const thy c, CodegenFuncgr.get_funcs funcgr c)) all_consts;
+    val _ = tracing (fn () => "new definitions: " ^ (commas o maps (map fst)) funs);
+    val _ = generate thy funs;
+    val nt = NBE_Eval.nbe thy (!tab) t';
+  in nt end;
+
+fun subst_Frees [] = I
+  | subst_Frees inst =
+      Term.map_aterms (fn (t as Free(s, _)) => the_default t (AList.lookup (op =) inst s)
+                  | t => t);
 
 fun var_tab t = (Term.add_frees t [], Term.add_vars t []);
 
@@ -58,70 +103,46 @@
   subst_Vars  (map (fn (ixn, T) => (ixn, Var(ixn,T))) Vtab) o
   subst_Frees (map (fn (s, T) =>   (s,   Free(s,T)))  Ftab)
 
-
-(* debugging *)
-
-val trace_nbe = ref false;
-
-fun trace f = if !trace_nbe then tracing (f ()) else ();
+in
 
-(* FIXME better turn this into a function
-    NBE_Eval.Univ Symtab.table -> NBE_Eval.Univ Symtab.table
-    with implicit side effect *)
-fun use_code "" = ()
-  | use_code s =
-      (if !trace_nbe then tracing ("\n---generated code:\n"^ s) else ();
-       use_text(tracing o enclose "\n---compiler echo:\n" "\n---\n",
-            tracing o enclose "\n--- compiler echo (with error!):\n" "\n---\n")
-        (!trace_nbe) s);
+fun norm_term thy t =
+  let
+    val _ = tracing (fn () => "Input:\n" ^ Display.raw_string_of_term t);
+    val nt = compile_term thy t;
+    val vtab = var_tab t;
+    val ty = type_of t;
+    fun constrain ty t = Sign.infer_types (Sign.pp thy) thy (Sign.consts_of thy)
+        (K NONE) (K NONE) Name.context false ([t], ty) |> fst;
+    val _ = tracing (fn () => "Normalized:\n" ^ NBE_Eval.string_of_nterm nt);
+    val t' = NBE_Codegen.nterm_to_term thy nt;
+    val _ = tracing (fn () =>"Converted back:\n" ^ Display.raw_string_of_term t');
+    val t'' = anno_vars vtab t';
+    val _ = tracing (fn () =>"Vars typed:\n" ^ Display.raw_string_of_term t'');
+    val t''' = constrain ty t'';
+    val _ = if null (Term.term_tvars t''') then () else
+      error ("Illegal schematic type variables in normalized term: "
+        ^ setmp show_types true (Sign.string_of_term thy) t''');
+  in t''' end;
+
+fun norm_print_term ctxt modes t =
+  let
+    val thy = ProofContext.theory_of ctxt;
+    val t' = norm_term thy t;
+    val ty = Term.type_of t';
+    val p = Library.setmp print_mode (modes @ ! print_mode) (fn () =>
+      Pretty.block [Pretty.quote (ProofContext.pretty_term ctxt t'), Pretty.fbrk,
+        Pretty.str "::", Pretty.brk 1, Pretty.quote (ProofContext.pretty_typ ctxt ty)]) ();
+  in Pretty.writeln p end;
+
+fun norm_print_term_e (modes, raw_t) state =
+  let
+    val ctxt = (Proof.context_of o Toplevel.enter_forward_proof) state;
+  in norm_print_term ctxt modes (ProofContext.read_term ctxt raw_t) end;
+
+end; (*local*)
 
 
-(* norm by eval *)
-
-(* FIXME try to use isar_cmd/print_term to take care of context *)
-fun norm_print t thy =
-  let
-    val _ = trace (fn () => "Input:\n" ^ Display.raw_string_of_term t);
-    fun compile_term t thy =
-      let
-        val _ = CodegenPackage.purge_code thy;
-        val nbe_tab = NBE_Data.get thy;
-        val (eq_thm, t'') = CodegenPackage.codegen_term thy t;
-        val t' = (snd o Logic.dest_equals o Drule.plain_prop_of) eq_thm;
-        val modl_new = CodegenPackage.get_root_module thy;
-        val diff = CodegenThingol.diff_module (modl_new, CodegenThingol.empty_module);
-        val _ = trace (fn () => "new definitions: " ^ (commas o map fst) diff);
-        val _ = (tab := nbe_tab;
-             Library.seq (use_code o NBE_Codegen.generate defined) diff);
-        val thy' = NBE_Data.put (!tab) thy;
-        val nt' = NBE_Eval.nbe (!tab) t'';
-      in ((t', nt'), thy') end;
-    fun eval_term t nt thy =
-      let
-        val vtab = var_tab t;
-        val ty = type_of t;
-        fun restrain ty t = Sign.infer_types (Sign.pp thy) thy (Sign.consts_of thy)
-            (K NONE) (K NONE) Name.context false ([t], ty) |> fst;
-        val _ = trace (fn () => "Normalized:\n" ^ NBE_Eval.string_of_nterm nt);
-        val t' = NBE_Codegen.nterm_to_term thy nt;
-        val _ = trace (fn () =>"Converted back:\n" ^ Display.raw_string_of_term t');
-        val t'' = anno_vars vtab t';
-        val _ = trace (fn () =>"Vars typed:\n" ^ Display.raw_string_of_term t'');
-        val t''' = restrain ty t''
-        val s = Pretty.string_of
-            (Pretty.block [Pretty.quote (Sign.pretty_term thy t'''), Pretty.fbrk,
-                Pretty.str "::", Pretty.brk 1, Pretty.quote (Sign.pretty_typ thy ty)])
-        val _ = writeln s
-      in (t''', thy) end;
-  in
-    thy
-    |> compile_term t
-    |-> (fn (t, nt) => eval_term t nt)
-  end;
-
-fun norm_print' s thy = norm_print (Sign.read_term thy s) thy;
-
-fun norm_term thy t = fst (norm_print t (Theory.copy thy));
+(* normalization oracle *)
 
 exception Normalization of term;
 
@@ -135,13 +156,16 @@
 val _ = Context.add_setup
   (Theory.add_oracle ("normalization", normalization_oracle));
 
+
 (* Isar setup *)
 
 local structure P = OuterParse and K = OuterKeyword in
 
+val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) [];
+
 val nbeP =
-  OuterSyntax.command "normal_form" "normalization by evaluation" K.thy_decl
-    (P.term >> (fn s => Toplevel.theory (snd o norm_print' s)));
+  OuterSyntax.improper_command "normal_form" "normalize term by evaluation" K.diag
+    (opt_modes -- P.typ >> (Toplevel.keep o norm_print_term_e));
 
 end;
 
--- a/src/Pure/Tools/nbe_codegen.ML	Mon Oct 02 23:01:09 2006 +0200
+++ b/src/Pure/Tools/nbe_codegen.ML	Mon Oct 02 23:01:11 2006 +0200
@@ -16,7 +16,7 @@
 
 signature NBE_CODEGEN =
 sig
-  val generate: (string -> bool) -> string * CodegenThingol.def -> string;
+  val generate: theory -> (string -> bool) -> (string * thm list) list -> string option;
   val nterm_to_term: theory -> NBE_Eval.nterm -> term;
 end
 
@@ -58,8 +58,11 @@
 
 fun eqns name ees =
   let fun eqn (es,e) = name ^ " " ^ space_implode " " es ^ " = " ^ e
-  in "fun " ^ space_implode "\n  | " (map eqn ees) ^ ";\n" end;
+  in space_implode "\n  | " (map eqn ees) end;
 
+fun eqnss (es :: ess) = prefix "fun " es :: map (prefix "and ") ess
+  |> space_implode "\n"
+  |> suffix "\n";
 
 fun Val v s = "val " ^ v ^ " = " ^ s;
 fun Let ds e = "let\n" ^ (space_implode "\n" ds) ^ " in " ^ e ^ " end"
@@ -88,83 +91,74 @@
 		   S.abs (S.tup []) (S.app Eval_C
 	(S.quote nm))]);
 
-open BasicCodegenThingol;
-
-fun mk_rexpr defined nm ar =
+fun mk_rexpr defined names ar =
   let
-    fun mk args = CodegenThingol.map_pure (mk' args)
-    and mk' args (IConst (c, _)) =
-          if c = nm then selfcall nm ar args
-          else if defined c then S.nbe_apps (MLname c) args
-          else S.app Eval_Constr (S.tup [S.quote c, S.list args])
-      | mk' args (IVar s) = S.nbe_apps (MLvname s) args
-      | mk' args (e1 `$ e2) = mk (args @ [mk [] e2]) e1
-      | mk' args ((nm, _) `|-> e) = S.nbe_apps (mk_nbeFUN (nm, mk [] e)) args;
+    fun mk args (Const (c, _)) = 
+          if member (op =) names c then selfcall c ar args
+            else if defined c then S.nbe_apps (MLname c) args
+            else S.app Eval_Constr (S.tup [S.quote c, S.list args])
+      | mk args (Free (v, _)) = S.nbe_apps (MLvname v) args
+      | mk args (t1 $ t2) = mk (args @ [mk [] t2]) t1
+      | mk args (Abs (v, _, t)) = S.nbe_apps (mk_nbeFUN (v, mk [] t)) args;
   in mk [] end;
 
 val mk_lexpr =
   let
-    fun mk args = CodegenThingol.map_pure (mk' args)
-    and mk' args (IConst (c, _)) =
+    fun mk args (Const (c, _)) =
           S.app Eval_Constr (S.tup [S.quote c, S.list args])
-      | mk' args (IVar s) = if args = [] then MLvname s else 
+      | mk args (Free (v, _)) = if null args then MLvname v else 
           sys_error "NBE mk_lexpr illegal higher order pattern"
-      | mk' args (e1 `$ e2) = mk (args @ [mk [] e2]) e1
-      | mk' args (_ `|-> _) =
+      | mk args (t1 $ t2) = mk (args @ [mk [] t2]) t1
+      | mk args (Abs _) =
           sys_error "NBE mk_lexpr illegal pattern";
   in mk [] end;
 
-fun mk_eqn defined nm ar (lhs,e) =
-  if has_duplicates (op =) (fold CodegenThingol.add_varnames lhs []) then [] else
-  [([S.list(map mk_lexpr (rev lhs))], mk_rexpr defined nm ar e)];
-
 fun lookup nm = S.Val (MLname nm) (tab_lookup (S.quote nm));
 
-fun generate defined (nm, CodegenThingol.Fun (eqns, _)) =
+fun generate thy defined [(_, [])] = NONE
+  | generate thy defined raw_eqnss =
       let
-        val ar = (length o fst o hd) eqns;
-        val params = (S.list o rev o MLparams) ar;
-        val funs =
+        val eqnss0 = map (fn (name, thms as thm :: _) =>
+          (name, ((length o snd o strip_comb o fst o Logic.dest_equals o prop_of) thm,
+            map (apfst (snd o strip_comb) o Logic.dest_equals o Logic.unvarify
+              o prop_of) thms)))
+          raw_eqnss;
+        val eqnss = (map o apsnd o apsnd o map) (fn (args, t) =>
+          (map (NBE_Eval.prep_term thy) args, NBE_Eval.prep_term thy t)) eqnss0
+        val names = map fst eqnss;
+        val used_funs =
           []
-          |> fold (fn (_, e) => CodegenThingol.add_constnames e) eqns
-          |> remove (op =) nm;
-        val globals = map lookup (filter defined funs);
-        val default_eqn = ([params], S.app Eval_Constr (S.tup[S.quote nm,params]));
-        val code = S.eqns (MLname nm)
-                          (maps (mk_eqn defined nm ar) eqns @ [default_eqn])
-        val register = tab_update
-            (S.app Eval_mk_Fun (S.tup[S.quote nm, MLname nm, string_of_int ar]))
-      in
-        S.Let (globals @ [code]) register
-      end
-  | generate _ _ = "";
+          |> fold (fold (fold_aterms (fn Const (c, _) => insert (op =) c
+                                      | _ => I)) o map snd o snd o snd) eqnss
+          |> subtract (op =) names;
+        fun mk_def (name, (ar, eqns)) =
+          let
+            fun mk_eqn (args, t) = ([S.list (map mk_lexpr (rev args))],
+              mk_rexpr defined names ar t);
+            val default_params = (S.list o rev o MLparams) ar;
+            val default_eqn = ([default_params], S.app Eval_Constr (S.tup [S.quote name, default_params]));
+          in S.eqns (MLname name) (map mk_eqn eqns @ [default_eqn]) end;
+        val globals = map lookup (filter defined used_funs);
+        fun register (name, (ar, _)) = tab_update
+            (S.app Eval_mk_Fun (S.tup [S.quote name, MLname name, string_of_int ar]))
+      in SOME (S.Let (globals @ [S.eqnss (map mk_def eqnss)]) (space_implode "; " (map register eqnss))) end;
 
 open NBE_Eval;
 
 val tcount = ref 0;
 
-(* FIXME get rid of TFree case!!! *)
 fun varifyT ty =
   let val ty' = map_type_tvar (fn ((s,i),S) => TypeInfer.param (!tcount + i) (s,S)) ty;
       val _ = (tcount := !tcount + maxidx_of_typ ty + 1);
-      val ty'' = map_type_tfree (TypeInfer.param (!tcount)) ty'
-  in tcount := !tcount+1; ty'' end;
+  in tcount := !tcount+1; ty' end;
 
 fun nterm_to_term thy t =
   let
-    fun consts_of (C s) = insert (op =) s
-      | consts_of (V _) = I
-      | consts_of (B _) = I
-      | consts_of (A (t1, t2)) = consts_of t1 #> consts_of t2
-      | consts_of (AbsN (_, t)) = consts_of t;
-    val consts = consts_of t [];
-    val ctab = AList.make (CodegenPackage.const_of_idf thy) consts;
-    val the_const = apsnd varifyT o the o AList.lookup (op =) ctab;
-    fun to_term bounds (C s) = Const (the_const s)
-      | to_term bounds (V s) = Free (s, dummyT)
-      | to_term bounds (B i) = Bound (find_index (fn j => i = j) bounds)
-      | to_term bounds (A (t1, t2)) = to_term bounds t1 $ to_term bounds t2
-      | to_term bounds (AbsN (i, t)) =
+   fun to_term bounds (C s) = Const ((apsnd varifyT o CodegenPackage.const_of_idf thy) s)
+     | to_term bounds (V s) = Free (s, dummyT)
+     | to_term bounds (B i) = Bound (find_index (fn j => i = j) bounds)
+     | to_term bounds (A (t1, t2)) = to_term bounds t1 $ to_term bounds t2
+     | to_term bounds (AbsN (i, t)) =
           Abs("u", dummyT, to_term (i::bounds) t);
   in tcount := 0; to_term [] t end;
 
--- a/src/Pure/Tools/nbe_eval.ML	Mon Oct 02 23:01:09 2006 +0200
+++ b/src/Pure/Tools/nbe_eval.ML	Mon Oct 02 23:01:11 2006 +0200
@@ -25,8 +25,9 @@
     | Fun of (Univ list -> Univ) * (Univ list) * int * (unit -> nterm)
                                          (*functions*);
 
-  val nbe: Univ Symtab.table -> CodegenThingol.iterm -> nterm
+  val nbe: theory -> Univ Symtab.table -> term -> nterm
   val apply: Univ -> Univ -> Univ
+  val prep_term: theory -> term -> term
 
   val to_term: Univ -> nterm
 
@@ -107,16 +108,18 @@
 fun mk_Fun(name,v,0) = (name,v [])
   | mk_Fun(name,v,n) = (name,Fun(v,[],n, fn () => C name));
 
-(* ---------------- table with the meaning of constants -------------------- *)
-
-val xfun_tab: Univ Symtab.table ref = ref Symtab.empty;
-
-fun lookup s = case Symtab.lookup (!xfun_tab) s of
-    SOME x => x
-  | NONE   => Constr(s,[]);
 
 (* ------------------ evaluation with greetings to Tarski ------------------ *)
 
+fun prep_term thy (Const c) = Const (CodegenNames.const thy (CodegenConsts.norm_of_typ thy c), dummyT)
+  | prep_term thy (Free v_ty) = Free v_ty
+  | prep_term thy (s $ t) = prep_term thy s $ prep_term thy t
+  | prep_term thy (Abs (raw_v, ty, raw_t)) =
+      let
+        val (v, t) = Syntax.variant_abs (CodegenNames.purify_var raw_v, ty, raw_t);
+      in Abs (v, ty, prep_term thy t) end;
+
+
 (* generation of fresh names *)
 
 val counter = ref 0;
@@ -125,24 +128,28 @@
 
 (* greetings to Tarski *)
 
-open BasicCodegenThingol;
-
-fun eval xs =
+fun eval lookup =
   let
-    fun evl (IConst (c, _)) = lookup c
-      | evl (IVar n) =
-          AList.lookup (op =) xs n
-          |> the_default (Var (n, []))
-      | evl (s `$ t) = apply (eval xs s) (eval xs t)
-      | evl ((n, _) `|-> t) =
-          Fun (fn [x] => eval (AList.update (op =) (n, x) xs) t,
-             [], 1,
-             fn () => let val var = new_name() in
-                 AbsN (var, to_term (eval (AList.update (op =) (n, BVar (var, [])) xs) t)) end)
-  in CodegenThingol.map_pure evl end;
+    fun evl vars (Const (s, _)) = lookup s
+      | evl vars (Free (v, _)) = 
+          AList.lookup (op =) vars v
+          |> the_default (Var (v, []))
+      | evl vars (s $ t) = apply (evl vars s) (evl vars t)
+      | evl vars (Abs (v, _, t)) =
+          Fun (fn [x] => evl (AList.update (op =) (v, x) vars) t,
+            [], 1,
+            fn () => let val var = new_name() in
+              AbsN (var, to_term (evl (AList.update (op =) (v, BVar (var, [])) vars) t)) end)
+  in evl end;
+
 
 (* finally... *)
 
-fun nbe tab t = (counter :=0; xfun_tab := tab; to_term(eval [] t));
+fun nbe thy tab t =
+  let
+    fun lookup s = case Symtab.lookup tab s
+       of SOME x => x
+        | NONE   => Constr (s, []);
+  in (counter := 0; to_term (eval lookup [] (prep_term thy t))) end;
 
 end;