--- a/src/Tools/nbe.ML Wed Sep 15 12:16:08 2010 +0200
+++ b/src/Tools/nbe.ML Wed Sep 15 12:16:35 2010 +0200
@@ -377,15 +377,16 @@
in ml_abs deps_vars (ml_Let (ml_fundefs (flat fun_vars)) (ml_list fun_vals)) end;
-(* code compilation *)
+(* compile equations *)
-fun compile_eqnss ctxt gr raw_deps [] = []
- | compile_eqnss ctxt gr raw_deps eqnss =
+fun compile_eqnss thy nbe_program raw_deps [] = []
+ | compile_eqnss thy nbe_program raw_deps eqnss =
let
+ val ctxt = ProofContext.init_global thy;
val (deps, deps_vals) = split_list (map_filter
- (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node gr dep)))) raw_deps);
+ (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node nbe_program dep)))) raw_deps);
val idx_of = raw_deps
- |> map (fn dep => (dep, snd (Graph.get_node gr dep)))
+ |> map (fn dep => (dep, snd (Graph.get_node nbe_program dep)))
|> AList.lookup (op =)
|> (fn f => the o f);
val s = assemble_eqnss idx_of deps eqnss;
@@ -400,7 +401,7 @@
end;
-(* preparing function equations *)
+(* extract equations from statements *)
fun eqns_of_stmt (_, Code_Thingol.Fun (_, ((_, []), _))) =
[]
@@ -428,7 +429,10 @@
map (fn (_, (_, (inst, dss))) => IConst (inst, (([], dss), []))) super_instances
@ map (IConst o snd o fst) classparam_instances)]))];
-fun compile_stmts ctxt stmts_deps =
+
+(* compile whole programs *)
+
+fun compile_stmts thy stmts_deps =
let
val names = map (fst o fst) stmts_deps;
val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps;
@@ -437,13 +441,13 @@
|> maps snd
|> distinct (op =)
|> fold (insert (op =)) names;
- fun new_node name (gr, (maxidx, idx_tab)) = if can (Graph.get_node gr) name
- then (gr, (maxidx, idx_tab))
- else (Graph.new_node (name, (NONE, maxidx)) gr,
+ fun new_node name (nbe_program, (maxidx, idx_tab)) = if can (Graph.get_node nbe_program) name
+ then (nbe_program, (maxidx, idx_tab))
+ else (Graph.new_node (name, (NONE, maxidx)) nbe_program,
(maxidx + 1, Inttab.update_new (maxidx, name) idx_tab));
- fun compile gr = eqnss
- |> compile_eqnss ctxt gr refl_deps
- |> rpair gr;
+ fun compile nbe_program = eqnss
+ |> compile_eqnss thy nbe_program refl_deps
+ |> rpair nbe_program;
in
fold new_node refl_deps
#> apfst (fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps
@@ -451,12 +455,12 @@
#-> fold (fn (name, univ) => (Graph.map_node name o apfst) (K (SOME univ))))
end;
-fun ensure_stmts ctxt program =
+fun compile_program thy program =
let
- fun add_stmts names (gr, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) gr) names
- then (gr, (maxidx, idx_tab))
- else (gr, (maxidx, idx_tab))
- |> compile_stmts ctxt (map (fn name => ((name, Graph.get_node program name),
+ fun add_stmts names (nbe_program, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) nbe_program) names
+ then (nbe_program, (maxidx, idx_tab))
+ else (nbe_program, (maxidx, idx_tab))
+ |> compile_stmts thy (map (fn name => ((name, Graph.get_node program name),
Graph.imm_succs program name)) names);
in
fold_rev add_stmts (Graph.strong_conn program)
@@ -465,20 +469,20 @@
(** evaluation **)
-(* term evaluation *)
+(* term evaluation by compilation *)
-fun eval_term ctxt gr deps (vs : (string * sort) list, t) =
+fun compile_term thy nbe_program deps (vs : (string * sort) list, t) =
let
val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs;
in
("", (vs, [([], t)]))
- |> singleton (compile_eqnss ctxt gr deps)
+ |> singleton (compile_eqnss thy nbe_program deps)
|> snd
|> (fn t => apps t (rev dict_frees))
end;
-(* reification *)
+(* reconstruction *)
fun typ_of_itype program vs (ityco `%% itys) =
let
@@ -525,6 +529,29 @@
in of_univ 0 t 0 |> fst end;
+(* evaluation with type reconstruction *)
+
+fun eval_term thy program (nbe_program, idx_tab) ((vs0, (vs, ty)), t) deps =
+ let
+ val ctxt = Syntax.init_pretty_global thy;
+ val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt);
+ val ty' = typ_of_itype program vs0 ty;
+ fun type_infer t = singleton
+ (Type_Infer.infer_types ctxt (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE))
+ (Type.constraint ty' t);
+ fun check_tvars t =
+ if null (Term.add_tvars t []) then t
+ else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t);
+ in
+ compile_term thy nbe_program deps (vs, t)
+ |> term_of_univ thy program idx_tab
+ |> traced (fn t => "Normalized:\n" ^ string_of_term t)
+ |> type_infer
+ |> traced (fn t => "Types inferred:\n" ^ string_of_term t)
+ |> check_tvars
+ |> traced (fn _ => "---\n")
+ end;
+
(* function store *)
structure Nbe_Functions = Code_Data
@@ -533,43 +560,11 @@
val empty = (Graph.empty, (0, Inttab.empty));
);
-
-(* compilation, evaluation and reification *)
-
-fun compile_eval thy program =
- let
- val ctxt = ProofContext.init_global thy;
- val (gr, (_, idx_tab)) =
- Nbe_Functions.change thy (ensure_stmts ctxt program);
- in fn vs_t => fn deps =>
- vs_t
- |> eval_term ctxt gr deps
- |> term_of_univ thy program idx_tab
- end;
-
-
-(* evaluation with type reconstruction *)
-
-fun normalize thy program ((vs0, (vs, ty)), t) deps =
+fun compile thy program =
let
- val ctxt = Syntax.init_pretty_global thy;
- val ty' = typ_of_itype program vs0 ty;
- fun type_infer t =
- singleton
- (Type_Infer.infer_types ctxt (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE))
- (Type.constraint ty' t);
- val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt);
- fun check_tvars t =
- if null (Term.add_tvars t []) then t
- else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t);
- in
- compile_eval thy program (vs, t) deps
- |> traced (fn t => "Normalized:\n" ^ string_of_term t)
- |> type_infer
- |> traced (fn t => "Types inferred:\n" ^ string_of_term t)
- |> check_tvars
- |> traced (fn _ => "---\n")
- end;
+ val (nbe_program, (_, idx_tab)) =
+ Nbe_Functions.change thy (compile_program thy program);
+ in (nbe_program, idx_tab) end;
(* evaluation oracle *)
@@ -583,7 +578,7 @@
val (_, raw_oracle) = Context.>>> (Context.map_theory_result
(Thm.add_oracle (Binding.name "norm", fn (thy, program, vsp_ty_t, deps, ct) =>
- mk_equals thy ct (normalize thy program vsp_ty_t deps))));
+ mk_equals thy ct (eval_term thy program (compile thy program) vsp_ty_t deps))));
fun oracle thy program vsp_ty_t deps ct = raw_oracle (thy, program, vsp_ty_t, deps, ct);
@@ -601,7 +596,8 @@
(fn thy => lift_triv_classes_conv thy (Code_Thingol.dynamic_eval_conv thy (K (oracle thy)))));
fun dynamic_eval_value thy = lift_triv_classes_rew thy
- (no_frees_rew (Code_Thingol.dynamic_eval_value thy I (K (normalize thy))));
+ (no_frees_rew (Code_Thingol.dynamic_eval_value thy I
+ (K (fn program => eval_term thy program (compile thy program)))));
(* evaluation command *)