--- a/src/Pure/Tools/nbe.ML Tue Oct 31 09:29:17 2006 +0100
+++ b/src/Pure/Tools/nbe.ML Tue Oct 31 09:29:18 2006 +0100
@@ -8,7 +8,6 @@
signature NBE =
sig
(*preconditions: no Vars/TVars in term*)
- val norm_term: theory -> term -> term
val normalization_conv: cterm -> thm
val lookup: string -> NBE_Eval.Univ
val update: string * NBE_Eval.Univ -> unit
@@ -19,8 +18,7 @@
struct
val trace = ref false;
-fun tracing f = if !trace then Output.tracing (f ()) else ();
-
+fun tracing f x = if !trace then (Output.tracing (f x); x) else x;
(** data setup **)
@@ -97,7 +95,7 @@
val _ = Context.add_setup NBE_Data.init;
-(** interface **)
+(** norm by eval **)
(* sandbox communication *)
@@ -105,97 +103,106 @@
fun lookup s = (the o Symtab.lookup (!tab)) s;
fun update sx = (tab := Symtab.update sx (!tab));
-
-(* norm by eval *)
-
local
-(* FIXME better turn this into a function
- NBE_Eval.Univ Symtab.table -> NBE_Eval.Univ Symtab.table
- with implicit side effect *)
-fun use_code NONE = ()
- | use_code (SOME s) =
- (tracing (fn () => "\n---generated code:\n" ^ s);
- use_text (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
- Output.tracing o enclose "\n--- compiler echo (with error!):\n" "\n---\n")
- (!trace) s);
+(* function generation *)
fun generate thy funs =
let
+ (* FIXME better turn this into a function
+ NBE_Eval.Univ Symtab.table -> NBE_Eval.Univ Symtab.table
+ with implicit side effect *)
+ fun use_code NONE = ()
+ | use_code (SOME s) =
+ (tracing (fn () => "\n---generated code:\n" ^ s);
+ use_text (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
+ Output.tracing o enclose "\n--- compiler echo (with error!):\n" "\n---\n")
+ (!trace) s);
+
+ val _ = tracing (fn () => "new definitions: " ^ (commas o maps (map fst)) funs);
val _ = tab := NBE_Data.get thy;;
val _ = Library.seq (use_code o NBE_Codegen.generate thy
(fn s => Symtab.defined (!tab) s)) funs;
in NBE_Data.change thy (K (!tab)) end;
-fun compile_term thy t =
+fun ensure_funs thy t =
let
- (*FIXME: proper interfaces in codegen_*)
- val (consts, cs) = CodegenConsts.consts_of thy t;
- val funcgr = CodegenFuncgr.mk_funcgr thy consts cs;
- (*FIXME: proper interfaces in codegen_*)
- fun const_typ (c, ty) =
- let
- val const = CodegenConsts.norm_of_typ thy (c, ty);
- in case CodegenFuncgr.get_funcs funcgr const
- of (thm :: _) => CodegenData.typ_func thy thm
- | [] => Sign.the_const_type thy c
- end;
- val (_, ct) = CodegenData.preprocess_cterm thy const_typ (Thm.cterm_of thy t)
- val t' = Thm.term_of ct;
- val (consts, cs) = CodegenConsts.consts_of thy t';
+ val consts = CodegenConsts.consts_of thy t;
val pre_consts = consts_of_pres thy;
val consts' = pre_consts @ consts;
- val funcgr = CodegenFuncgr.mk_funcgr thy consts' cs;
+ val funcgr = CodegenFuncgr.make thy consts';
val nbe_tab = NBE_Data.get thy;
val all_consts =
- (pre_consts :: CodegenFuncgr.all_deps_of funcgr consts')
+ (pre_consts :: CodegenFuncgr.deps funcgr consts')
|> (map o filter_out) (Symtab.defined nbe_tab o CodegenNames.const thy)
|> filter_out null;
val funs = (map o map)
- (fn c => (CodegenNames.const thy c, apply_pres thy (CodegenFuncgr.get_funcs funcgr c))) all_consts;
- val _ = tracing (fn () => "new definitions: " ^ (commas o maps (map fst)) funs);
- val _ = generate thy funs;
- val nt = NBE_Eval.eval thy (!tab) t';
- in nt end;
+ (fn c => (CodegenNames.const thy c, apply_pres thy (CodegenFuncgr.funcs funcgr c))) all_consts;
+ in generate thy funs end;
+
+(* term evaluation *)
-fun subst_Frees [] = I
- | subst_Frees inst =
- Term.map_aterms (fn (t as Free(s, _)) => the_default t (AList.lookup (op =) inst s)
- | t => t);
+fun eval_term thy t =
+ let
+ fun subst_Frees [] = I
+ | subst_Frees inst =
+ Term.map_aterms (fn (t as Free(s, _)) => the_default t (AList.lookup (op =) inst s)
+ | t => t);
+ val anno_vars =
+ subst_Frees (map (fn (s, T) => (s, Free (s, T))) (Term.add_frees t []))
+ #> subst_Vars (map (fn (ixn, T) => (ixn, Var(ixn,T))) (Term.add_vars t []))
+ fun check_tvars t = if null (Term.term_tvars t) then t else
+ error ("Illegal schematic type variables in normalized term: "
+ ^ setmp show_types true (Sign.string_of_term thy) t);
+ val ty = type_of t;
+ fun constrain t = Sign.infer_types (Sign.pp thy) thy (Sign.consts_of thy)
+ (K NONE) (K NONE) Name.context false ([t], ty) |> fst;
+ val _ = ensure_funs thy t;
+ in
+ t
+ |> tracing (fn t => "Input:\n" ^ Display.raw_string_of_term t)
+ |> NBE_Eval.eval thy (!tab)
+ |> tracing (fn nt => "Normalized:\n" ^ NBE_Eval.string_of_nterm nt)
+ |> NBE_Codegen.nterm_to_term thy
+ |> tracing (fn t =>"Converted back:\n" ^ Display.raw_string_of_term t)
+ |> anno_vars
+ |> tracing (fn t =>"Vars typed:\n" ^ Display.raw_string_of_term t)
+ |> constrain
+ |> check_tvars
+ end;
-fun var_tab t = (Term.add_frees t [], Term.add_vars t []);
+(* evaluation oracle *)
+
+exception Normalization of term;
-fun anno_vars (Ftab, Vtab) =
- subst_Vars (map (fn (ixn, T) => (ixn, Var(ixn,T))) Vtab) o
- subst_Frees (map (fn (s, T) => (s, Free(s,T))) Ftab)
+fun normalization_oracle (thy, Normalization t) =
+ Logic.mk_equals (t, eval_term thy t);
+
+fun normalization_invoke thy t =
+ Thm.invoke_oracle_i thy "Pure.normalization" (thy, Normalization t);
in
-fun norm_term thy t =
+(* interface *)
+
+fun normalization_conv ct =
let
- val _ = tracing (fn () => "Input:\n" ^ Display.raw_string_of_term t);
- val nt = compile_term thy t;
- val vtab = var_tab t;
- val ty = type_of t;
- fun constrain ty t = Sign.infer_types (Sign.pp thy) thy (Sign.consts_of thy)
- (K NONE) (K NONE) Name.context false ([t], ty) |> fst;
- val _ = tracing (fn () => "Normalized:\n" ^ NBE_Eval.string_of_nterm nt);
- val t1 = NBE_Codegen.nterm_to_term thy nt;
- val _ = tracing (fn () =>"Converted back:\n" ^ Display.raw_string_of_term t1);
- val t2 = anno_vars vtab t1;
- val _ = tracing (fn () =>"Vars typed:\n" ^ Display.raw_string_of_term t2);
- val t3 = constrain ty t2;
- val _ = if null (Term.term_tvars t3) then () else
- error ("Illegal schematic type variables in normalized term: "
- ^ setmp show_types true (Sign.string_of_term thy) t3);
- val eq = apply_posts thy (Thm.cterm_of thy t3);
- val t4 = snd(Logic.dest_equals(prop_of eq))
- in t4 end;
+ val thy = Thm.theory_of_cterm ct;
+ val ((ct', (thm1, drop_classes)), _) = CodegenFuncgr.make_term thy ct;
+ val t = Thm.term_of ct';
+ val thm2 = normalization_invoke thy t;
+ val thm3 = apply_posts thy ((snd o Drule.dest_equals o Thm.cprop_of) thm2);
+ in
+ Thm.transitive thm1 (drop_classes (Thm.transitive thm2 thm3)) handle
+ THM _ => error ("normalization_conv - could not construct proof:\n"
+ ^ (cat_lines o map string_of_thm) [thm1, thm2, thm3])
+ end;
fun norm_print_term ctxt modes t =
let
val thy = ProofContext.theory_of ctxt;
- val t' = norm_term thy t;
+ val ct = Thm.cterm_of thy t;
+ val (_, t') = (Logic.dest_equals o Thm.prop_of o normalization_conv) ct;
val ty = Term.type_of t';
val p = Library.setmp print_mode (modes @ ! print_mode) (fn () =>
Pretty.block [Pretty.quote (ProofContext.pretty_term ctxt t'), Pretty.fbrk,
@@ -207,23 +214,11 @@
val ctxt = Context.proof_of (Toplevel.context_of state);
in norm_print_term ctxt modes (ProofContext.read_term ctxt raw_t) end;
-end; (*local*)
-
-
-(* oracle *)
-
-exception Normalization of term;
-
-fun normalization_oracle (thy, Normalization t) =
- Logic.mk_equals (t, norm_term thy t);
-
-fun normalization_conv ct =
- let val {sign, t, ...} = rep_cterm ct
- in Thm.invoke_oracle_i sign "Pure.normalization" (sign, Normalization t) end;
-
val _ = Context.add_setup
(Theory.add_oracle ("normalization", normalization_oracle));
+end; (*local*)
+
(* Isar setup *)