constructing proof
authorhaftmann
Tue, 31 Oct 2006 09:29:18 +0100
changeset 21124 8648b5dd6a87
parent 21123 9f7c430cf9ac
child 21125 9b7d35ca1eef
constructing proof
src/Pure/Tools/nbe.ML
--- a/src/Pure/Tools/nbe.ML	Tue Oct 31 09:29:17 2006 +0100
+++ b/src/Pure/Tools/nbe.ML	Tue Oct 31 09:29:18 2006 +0100
@@ -8,7 +8,6 @@
 signature NBE =
 sig
   (*preconditions: no Vars/TVars in term*)
-  val norm_term: theory -> term -> term
   val normalization_conv: cterm -> thm
   val lookup: string -> NBE_Eval.Univ
   val update: string * NBE_Eval.Univ -> unit
@@ -19,8 +18,7 @@
 struct
 
 val trace = ref false;
-fun tracing f = if !trace then Output.tracing (f ()) else ();
-
+fun tracing f x = if !trace then (Output.tracing (f x); x) else x;
 
 
 (** data setup **)
@@ -97,7 +95,7 @@
 val _ = Context.add_setup NBE_Data.init;
 
 
-(** interface **)
+(** norm by eval **)
 
 (* sandbox communication *)
 
@@ -105,97 +103,106 @@
 fun lookup s = (the o Symtab.lookup (!tab)) s;
 fun update sx = (tab := Symtab.update sx (!tab));
 
-
-(* norm by eval *)
-
 local
 
-(* FIXME better turn this into a function
-    NBE_Eval.Univ Symtab.table -> NBE_Eval.Univ Symtab.table
-    with implicit side effect *)
-fun use_code NONE = ()
-  | use_code (SOME s) =
-      (tracing (fn () => "\n---generated code:\n" ^ s);
-       use_text (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
-            Output.tracing o enclose "\n--- compiler echo (with error!):\n" "\n---\n")
-        (!trace) s);
+(* function generation *)
 
 fun generate thy funs =
   let
+    (* FIXME better turn this into a function
+        NBE_Eval.Univ Symtab.table -> NBE_Eval.Univ Symtab.table
+        with implicit side effect *)
+    fun use_code NONE = ()
+      | use_code (SOME s) =
+          (tracing (fn () => "\n---generated code:\n" ^ s);
+           use_text (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
+                Output.tracing o enclose "\n--- compiler echo (with error!):\n" "\n---\n")
+            (!trace) s);
+
+    val _ = tracing (fn () => "new definitions: " ^ (commas o maps (map fst)) funs);
     val _ = tab := NBE_Data.get thy;;
     val _ = Library.seq (use_code o NBE_Codegen.generate thy
       (fn s => Symtab.defined (!tab) s)) funs;
   in NBE_Data.change thy (K (!tab)) end;
 
-fun compile_term thy t =
+fun ensure_funs thy t =
   let
-    (*FIXME: proper interfaces in codegen_*)
-    val (consts, cs) = CodegenConsts.consts_of thy t;
-    val funcgr = CodegenFuncgr.mk_funcgr thy consts cs;
-    (*FIXME: proper interfaces in codegen_*)
-    fun const_typ (c, ty) =
-      let
-        val const = CodegenConsts.norm_of_typ thy (c, ty);
-      in case CodegenFuncgr.get_funcs funcgr const
-       of (thm :: _) => CodegenData.typ_func thy thm
-        | [] => Sign.the_const_type thy c
-      end;
-    val (_, ct) = CodegenData.preprocess_cterm thy const_typ (Thm.cterm_of thy t)
-    val t' = Thm.term_of ct;
-    val (consts, cs) = CodegenConsts.consts_of thy t';
+    val consts = CodegenConsts.consts_of thy t;
     val pre_consts = consts_of_pres thy;
     val consts' = pre_consts @ consts;
-    val funcgr = CodegenFuncgr.mk_funcgr thy consts' cs;
+    val funcgr = CodegenFuncgr.make thy consts';
     val nbe_tab = NBE_Data.get thy;
     val all_consts =
-      (pre_consts :: CodegenFuncgr.all_deps_of funcgr consts')
+      (pre_consts :: CodegenFuncgr.deps funcgr consts')
       |> (map o filter_out) (Symtab.defined nbe_tab o CodegenNames.const thy)
       |> filter_out null;
     val funs = (map o map)
-      (fn c => (CodegenNames.const thy c, apply_pres thy (CodegenFuncgr.get_funcs funcgr c))) all_consts;
-    val _ = tracing (fn () => "new definitions: " ^ (commas o maps (map fst)) funs);
-    val _ = generate thy funs;
-    val nt = NBE_Eval.eval thy (!tab) t';
-  in nt end;
+      (fn c => (CodegenNames.const thy c, apply_pres thy (CodegenFuncgr.funcs funcgr c))) all_consts;
+  in generate thy funs end;
+
+(* term evaluation *)
 
-fun subst_Frees [] = I
-  | subst_Frees inst =
-      Term.map_aterms (fn (t as Free(s, _)) => the_default t (AList.lookup (op =) inst s)
-                  | t => t);
+fun eval_term thy t =
+  let
+    fun subst_Frees [] = I
+      | subst_Frees inst =
+          Term.map_aterms (fn (t as Free(s, _)) => the_default t (AList.lookup (op =) inst s)
+                      | t => t);
+    val anno_vars =
+      subst_Frees (map (fn (s, T) => (s, Free (s, T))) (Term.add_frees t []))
+      #> subst_Vars  (map (fn (ixn, T) => (ixn, Var(ixn,T))) (Term.add_vars t []))
+    fun check_tvars t = if null (Term.term_tvars t) then t else
+      error ("Illegal schematic type variables in normalized term: "
+        ^ setmp show_types true (Sign.string_of_term thy) t);
+    val ty = type_of t;
+    fun constrain t = Sign.infer_types (Sign.pp thy) thy (Sign.consts_of thy)
+      (K NONE) (K NONE) Name.context false ([t], ty) |> fst;
+    val _ = ensure_funs thy t;
+  in
+    t
+    |> tracing (fn t => "Input:\n" ^ Display.raw_string_of_term t)
+    |> NBE_Eval.eval thy (!tab)
+    |> tracing (fn nt => "Normalized:\n" ^ NBE_Eval.string_of_nterm nt)
+    |> NBE_Codegen.nterm_to_term thy
+    |> tracing (fn t =>"Converted back:\n" ^ Display.raw_string_of_term t)
+    |> anno_vars
+    |> tracing (fn t =>"Vars typed:\n" ^ Display.raw_string_of_term t)
+    |> constrain
+    |> check_tvars
+  end;
 
-fun var_tab t = (Term.add_frees t [], Term.add_vars t []);
+(* evaluation oracle *)
+
+exception Normalization of term;
 
-fun anno_vars (Ftab, Vtab) =
-  subst_Vars  (map (fn (ixn, T) => (ixn, Var(ixn,T))) Vtab) o
-  subst_Frees (map (fn (s, T) =>   (s,   Free(s,T)))  Ftab)
+fun normalization_oracle (thy, Normalization t) =
+  Logic.mk_equals (t, eval_term thy t);
+
+fun normalization_invoke thy t =
+  Thm.invoke_oracle_i thy "Pure.normalization" (thy, Normalization t);
 
 in
 
-fun norm_term thy t =
+(* interface *)
+
+fun normalization_conv ct =
   let
-    val _ = tracing (fn () => "Input:\n" ^ Display.raw_string_of_term t);
-    val nt = compile_term thy t;
-    val vtab = var_tab t;
-    val ty = type_of t;
-    fun constrain ty t = Sign.infer_types (Sign.pp thy) thy (Sign.consts_of thy)
-        (K NONE) (K NONE) Name.context false ([t], ty) |> fst;
-    val _ = tracing (fn () => "Normalized:\n" ^ NBE_Eval.string_of_nterm nt);
-    val t1 = NBE_Codegen.nterm_to_term thy nt;
-    val _ = tracing (fn () =>"Converted back:\n" ^ Display.raw_string_of_term t1);
-    val t2 = anno_vars vtab t1;
-    val _ = tracing (fn () =>"Vars typed:\n" ^ Display.raw_string_of_term t2);
-    val t3 = constrain ty t2;
-    val _ = if null (Term.term_tvars t3) then () else
-      error ("Illegal schematic type variables in normalized term: "
-        ^ setmp show_types true (Sign.string_of_term thy) t3);
-    val eq = apply_posts thy (Thm.cterm_of thy t3);
-    val t4 = snd(Logic.dest_equals(prop_of eq))
-  in t4 end;
+    val thy = Thm.theory_of_cterm ct;
+    val ((ct', (thm1, drop_classes)), _) = CodegenFuncgr.make_term thy ct;
+    val t = Thm.term_of ct';
+    val thm2 = normalization_invoke thy t;
+    val thm3 = apply_posts thy ((snd o Drule.dest_equals o Thm.cprop_of) thm2);
+  in
+    Thm.transitive thm1 (drop_classes (Thm.transitive thm2 thm3)) handle
+      THM _ => error ("normalization_conv - could not construct proof:\n"
+        ^ (cat_lines o map string_of_thm) [thm1, thm2, thm3])
+  end;
 
 fun norm_print_term ctxt modes t =
   let
     val thy = ProofContext.theory_of ctxt;
-    val t' = norm_term thy t;
+    val ct = Thm.cterm_of thy t;
+    val (_, t') = (Logic.dest_equals o Thm.prop_of o normalization_conv) ct;
     val ty = Term.type_of t';
     val p = Library.setmp print_mode (modes @ ! print_mode) (fn () =>
       Pretty.block [Pretty.quote (ProofContext.pretty_term ctxt t'), Pretty.fbrk,
@@ -207,23 +214,11 @@
     val ctxt = Context.proof_of (Toplevel.context_of state);
   in norm_print_term ctxt modes (ProofContext.read_term ctxt raw_t) end;
 
-end; (*local*)
-
-
-(* oracle *)
-
-exception Normalization of term;
-
-fun normalization_oracle (thy, Normalization t) =
-  Logic.mk_equals (t, norm_term thy t);
-
-fun normalization_conv ct =
-  let val {sign, t, ...} = rep_cterm ct
-  in Thm.invoke_oracle_i sign "Pure.normalization" (sign, Normalization t) end;
-
 val _ = Context.add_setup
   (Theory.add_oracle ("normalization", normalization_oracle));
 
+end; (*local*)
+
 
 (* Isar setup *)