merged
authorhaftmann
Wed, 15 Sep 2010 12:16:35 +0200
changeset 39393 7565c649e7dd
parent 39389 20db6db55a6b (current diff)
parent 39392 7a0fcee7a2a3 (diff)
child 39394 955ce6038aa5
merged
--- a/src/Tools/nbe.ML	Wed Sep 15 12:16:08 2010 +0200
+++ b/src/Tools/nbe.ML	Wed Sep 15 12:16:35 2010 +0200
@@ -377,15 +377,16 @@
   in ml_abs deps_vars (ml_Let (ml_fundefs (flat fun_vars)) (ml_list fun_vals)) end;
 
 
-(* code compilation *)
+(* compile equations *)
 
-fun compile_eqnss ctxt gr raw_deps [] = []
-  | compile_eqnss ctxt gr raw_deps eqnss =
+fun compile_eqnss thy nbe_program raw_deps [] = []
+  | compile_eqnss thy nbe_program raw_deps eqnss =
       let
+        val ctxt = ProofContext.init_global thy;
         val (deps, deps_vals) = split_list (map_filter
-          (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node gr dep)))) raw_deps);
+          (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node nbe_program dep)))) raw_deps);
         val idx_of = raw_deps
-          |> map (fn dep => (dep, snd (Graph.get_node gr dep)))
+          |> map (fn dep => (dep, snd (Graph.get_node nbe_program dep)))
           |> AList.lookup (op =)
           |> (fn f => the o f);
         val s = assemble_eqnss idx_of deps eqnss;
@@ -400,7 +401,7 @@
       end;
 
 
-(* preparing function equations *)
+(* extract equations from statements *)
 
 fun eqns_of_stmt (_, Code_Thingol.Fun (_, ((_, []), _))) =
       []
@@ -428,7 +429,10 @@
         map (fn (_, (_, (inst, dss))) => IConst (inst, (([], dss), []))) super_instances
         @ map (IConst o snd o fst) classparam_instances)]))];
 
-fun compile_stmts ctxt stmts_deps =
+
+(* compile whole programs *)
+
+fun compile_stmts thy stmts_deps =
   let
     val names = map (fst o fst) stmts_deps;
     val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps;
@@ -437,13 +441,13 @@
       |> maps snd
       |> distinct (op =)
       |> fold (insert (op =)) names;
-    fun new_node name (gr, (maxidx, idx_tab)) = if can (Graph.get_node gr) name
-      then (gr, (maxidx, idx_tab))
-      else (Graph.new_node (name, (NONE, maxidx)) gr,
+    fun new_node name (nbe_program, (maxidx, idx_tab)) = if can (Graph.get_node nbe_program) name
+      then (nbe_program, (maxidx, idx_tab))
+      else (Graph.new_node (name, (NONE, maxidx)) nbe_program,
         (maxidx + 1, Inttab.update_new (maxidx, name) idx_tab));
-    fun compile gr = eqnss
-      |> compile_eqnss ctxt gr refl_deps
-      |> rpair gr;
+    fun compile nbe_program = eqnss
+      |> compile_eqnss thy nbe_program refl_deps
+      |> rpair nbe_program;
   in
     fold new_node refl_deps
     #> apfst (fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps
@@ -451,12 +455,12 @@
       #-> fold (fn (name, univ) => (Graph.map_node name o apfst) (K (SOME univ))))
   end;
 
-fun ensure_stmts ctxt program =
+fun compile_program thy program =
   let
-    fun add_stmts names (gr, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) gr) names
-      then (gr, (maxidx, idx_tab))
-      else (gr, (maxidx, idx_tab))
-        |> compile_stmts ctxt (map (fn name => ((name, Graph.get_node program name),
+    fun add_stmts names (nbe_program, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) nbe_program) names
+      then (nbe_program, (maxidx, idx_tab))
+      else (nbe_program, (maxidx, idx_tab))
+        |> compile_stmts thy (map (fn name => ((name, Graph.get_node program name),
           Graph.imm_succs program name)) names);
   in
     fold_rev add_stmts (Graph.strong_conn program)
@@ -465,20 +469,20 @@
 
 (** evaluation **)
 
-(* term evaluation *)
+(* term evaluation by compilation *)
 
-fun eval_term ctxt gr deps (vs : (string * sort) list, t) =
+fun compile_term thy nbe_program deps (vs : (string * sort) list, t) =
   let 
     val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs;
   in
     ("", (vs, [([], t)]))
-    |> singleton (compile_eqnss ctxt gr deps)
+    |> singleton (compile_eqnss thy nbe_program deps)
     |> snd
     |> (fn t => apps t (rev dict_frees))
   end;
 
 
-(* reification *)
+(* reconstruction *)
 
 fun typ_of_itype program vs (ityco `%% itys) =
       let
@@ -525,6 +529,29 @@
   in of_univ 0 t 0 |> fst end;
 
 
+(* evaluation with type reconstruction *)
+
+fun eval_term thy program (nbe_program, idx_tab) ((vs0, (vs, ty)), t) deps =
+  let
+    val ctxt = Syntax.init_pretty_global thy;
+    val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt);
+    val ty' = typ_of_itype program vs0 ty;
+    fun type_infer t = singleton
+      (Type_Infer.infer_types ctxt (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE))
+      (Type.constraint ty' t);
+    fun check_tvars t =
+      if null (Term.add_tvars t []) then t
+      else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t);
+  in
+    compile_term thy nbe_program deps (vs, t)
+    |> term_of_univ thy program idx_tab
+    |> traced (fn t => "Normalized:\n" ^ string_of_term t)
+    |> type_infer
+    |> traced (fn t => "Types inferred:\n" ^ string_of_term t)
+    |> check_tvars
+    |> traced (fn _ => "---\n")
+  end;
+
 (* function store *)
 
 structure Nbe_Functions = Code_Data
@@ -533,43 +560,11 @@
   val empty = (Graph.empty, (0, Inttab.empty));
 );
 
-
-(* compilation, evaluation and reification *)
-
-fun compile_eval thy program =
-  let
-    val ctxt = ProofContext.init_global thy;
-    val (gr, (_, idx_tab)) =
-      Nbe_Functions.change thy (ensure_stmts ctxt program);
-  in fn vs_t => fn deps =>
-    vs_t
-    |> eval_term ctxt gr deps
-    |> term_of_univ thy program idx_tab
-  end;
-
-
-(* evaluation with type reconstruction *)
-
-fun normalize thy program ((vs0, (vs, ty)), t) deps =
+fun compile thy program =
   let
-    val ctxt = Syntax.init_pretty_global thy;
-    val ty' = typ_of_itype program vs0 ty;
-    fun type_infer t =
-      singleton
-        (Type_Infer.infer_types ctxt (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE))
-        (Type.constraint ty' t);
-    val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt);
-    fun check_tvars t =
-      if null (Term.add_tvars t []) then t
-      else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t);
-  in
-    compile_eval thy program (vs, t) deps
-    |> traced (fn t => "Normalized:\n" ^ string_of_term t)
-    |> type_infer
-    |> traced (fn t => "Types inferred:\n" ^ string_of_term t)
-    |> check_tvars
-    |> traced (fn _ => "---\n")
-  end;
+    val (nbe_program, (_, idx_tab)) =
+      Nbe_Functions.change thy (compile_program thy program);
+  in (nbe_program, idx_tab) end;
 
 
 (* evaluation oracle *)
@@ -583,7 +578,7 @@
 
 val (_, raw_oracle) = Context.>>> (Context.map_theory_result
   (Thm.add_oracle (Binding.name "norm", fn (thy, program, vsp_ty_t, deps, ct) =>
-    mk_equals thy ct (normalize thy program vsp_ty_t deps))));
+    mk_equals thy ct (eval_term thy program (compile thy program) vsp_ty_t deps))));
 
 fun oracle thy program vsp_ty_t deps ct = raw_oracle (thy, program, vsp_ty_t, deps, ct);
 
@@ -601,7 +596,8 @@
   (fn thy => lift_triv_classes_conv thy (Code_Thingol.dynamic_eval_conv thy (K (oracle thy)))));
 
 fun dynamic_eval_value thy = lift_triv_classes_rew thy
-  (no_frees_rew (Code_Thingol.dynamic_eval_value thy I (K (normalize thy))));
+  (no_frees_rew (Code_Thingol.dynamic_eval_value thy I
+    (K (fn program => eval_term thy program (compile thy program)))));
 
 
 (* evaluation command *)