# HG changeset patch # User haftmann # Date 1284545795 -7200 # Node ID 7565c649e7dd25f0d85cdb8d12950d301a86e043 # Parent 20db6db55a6b06e7826093d8a45be05830d6b774# Parent 7a0fcee7a2a393035e270bbe439783f819aba6c2 merged diff -r 20db6db55a6b -r 7565c649e7dd src/Tools/nbe.ML --- a/src/Tools/nbe.ML Wed Sep 15 12:16:08 2010 +0200 +++ b/src/Tools/nbe.ML Wed Sep 15 12:16:35 2010 +0200 @@ -377,15 +377,16 @@ in ml_abs deps_vars (ml_Let (ml_fundefs (flat fun_vars)) (ml_list fun_vals)) end; -(* code compilation *) +(* compile equations *) -fun compile_eqnss ctxt gr raw_deps [] = [] - | compile_eqnss ctxt gr raw_deps eqnss = +fun compile_eqnss thy nbe_program raw_deps [] = [] + | compile_eqnss thy nbe_program raw_deps eqnss = let + val ctxt = ProofContext.init_global thy; val (deps, deps_vals) = split_list (map_filter - (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node gr dep)))) raw_deps); + (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node nbe_program dep)))) raw_deps); val idx_of = raw_deps - |> map (fn dep => (dep, snd (Graph.get_node gr dep))) + |> map (fn dep => (dep, snd (Graph.get_node nbe_program dep))) |> AList.lookup (op =) |> (fn f => the o f); val s = assemble_eqnss idx_of deps eqnss; @@ -400,7 +401,7 @@ end; -(* preparing function equations *) +(* extract equations from statements *) fun eqns_of_stmt (_, Code_Thingol.Fun (_, ((_, []), _))) = [] @@ -428,7 +429,10 @@ map (fn (_, (_, (inst, dss))) => IConst (inst, (([], dss), []))) super_instances @ map (IConst o snd o fst) classparam_instances)]))]; -fun compile_stmts ctxt stmts_deps = + +(* compile whole programs *) + +fun compile_stmts thy stmts_deps = let val names = map (fst o fst) stmts_deps; val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps; @@ -437,13 +441,13 @@ |> maps snd |> distinct (op =) |> fold (insert (op =)) names; - fun new_node name (gr, (maxidx, idx_tab)) = if can (Graph.get_node gr) name - then (gr, (maxidx, idx_tab)) - else (Graph.new_node (name, (NONE, maxidx)) gr, + fun new_node name (nbe_program, (maxidx, idx_tab)) = if can (Graph.get_node nbe_program) name + then (nbe_program, (maxidx, idx_tab)) + else (Graph.new_node (name, (NONE, maxidx)) nbe_program, (maxidx + 1, Inttab.update_new (maxidx, name) idx_tab)); - fun compile gr = eqnss - |> compile_eqnss ctxt gr refl_deps - |> rpair gr; + fun compile nbe_program = eqnss + |> compile_eqnss thy nbe_program refl_deps + |> rpair nbe_program; in fold new_node refl_deps #> apfst (fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps @@ -451,12 +455,12 @@ #-> fold (fn (name, univ) => (Graph.map_node name o apfst) (K (SOME univ)))) end; -fun ensure_stmts ctxt program = +fun compile_program thy program = let - fun add_stmts names (gr, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) gr) names - then (gr, (maxidx, idx_tab)) - else (gr, (maxidx, idx_tab)) - |> compile_stmts ctxt (map (fn name => ((name, Graph.get_node program name), + fun add_stmts names (nbe_program, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) nbe_program) names + then (nbe_program, (maxidx, idx_tab)) + else (nbe_program, (maxidx, idx_tab)) + |> compile_stmts thy (map (fn name => ((name, Graph.get_node program name), Graph.imm_succs program name)) names); in fold_rev add_stmts (Graph.strong_conn program) @@ -465,20 +469,20 @@ (** evaluation **) -(* term evaluation *) +(* term evaluation by compilation *) -fun eval_term ctxt gr deps (vs : (string * sort) list, t) = +fun compile_term thy nbe_program deps (vs : (string * sort) list, t) = let val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs; in ("", (vs, [([], t)])) - |> singleton (compile_eqnss ctxt gr deps) + |> singleton (compile_eqnss thy nbe_program deps) |> snd |> (fn t => apps t (rev dict_frees)) end; -(* reification *) +(* reconstruction *) fun typ_of_itype program vs (ityco `%% itys) = let @@ -525,6 +529,29 @@ in of_univ 0 t 0 |> fst end; +(* evaluation with type reconstruction *) + +fun eval_term thy program (nbe_program, idx_tab) ((vs0, (vs, ty)), t) deps = + let + val ctxt = Syntax.init_pretty_global thy; + val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt); + val ty' = typ_of_itype program vs0 ty; + fun type_infer t = singleton + (Type_Infer.infer_types ctxt (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE)) + (Type.constraint ty' t); + fun check_tvars t = + if null (Term.add_tvars t []) then t + else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t); + in + compile_term thy nbe_program deps (vs, t) + |> term_of_univ thy program idx_tab + |> traced (fn t => "Normalized:\n" ^ string_of_term t) + |> type_infer + |> traced (fn t => "Types inferred:\n" ^ string_of_term t) + |> check_tvars + |> traced (fn _ => "---\n") + end; + (* function store *) structure Nbe_Functions = Code_Data @@ -533,43 +560,11 @@ val empty = (Graph.empty, (0, Inttab.empty)); ); - -(* compilation, evaluation and reification *) - -fun compile_eval thy program = - let - val ctxt = ProofContext.init_global thy; - val (gr, (_, idx_tab)) = - Nbe_Functions.change thy (ensure_stmts ctxt program); - in fn vs_t => fn deps => - vs_t - |> eval_term ctxt gr deps - |> term_of_univ thy program idx_tab - end; - - -(* evaluation with type reconstruction *) - -fun normalize thy program ((vs0, (vs, ty)), t) deps = +fun compile thy program = let - val ctxt = Syntax.init_pretty_global thy; - val ty' = typ_of_itype program vs0 ty; - fun type_infer t = - singleton - (Type_Infer.infer_types ctxt (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE)) - (Type.constraint ty' t); - val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt); - fun check_tvars t = - if null (Term.add_tvars t []) then t - else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t); - in - compile_eval thy program (vs, t) deps - |> traced (fn t => "Normalized:\n" ^ string_of_term t) - |> type_infer - |> traced (fn t => "Types inferred:\n" ^ string_of_term t) - |> check_tvars - |> traced (fn _ => "---\n") - end; + val (nbe_program, (_, idx_tab)) = + Nbe_Functions.change thy (compile_program thy program); + in (nbe_program, idx_tab) end; (* evaluation oracle *) @@ -583,7 +578,7 @@ val (_, raw_oracle) = Context.>>> (Context.map_theory_result (Thm.add_oracle (Binding.name "norm", fn (thy, program, vsp_ty_t, deps, ct) => - mk_equals thy ct (normalize thy program vsp_ty_t deps)))); + mk_equals thy ct (eval_term thy program (compile thy program) vsp_ty_t deps)))); fun oracle thy program vsp_ty_t deps ct = raw_oracle (thy, program, vsp_ty_t, deps, ct); @@ -601,7 +596,8 @@ (fn thy => lift_triv_classes_conv thy (Code_Thingol.dynamic_eval_conv thy (K (oracle thy))))); fun dynamic_eval_value thy = lift_triv_classes_rew thy - (no_frees_rew (Code_Thingol.dynamic_eval_value thy I (K (normalize thy)))); + (no_frees_rew (Code_Thingol.dynamic_eval_value thy I + (K (fn program => eval_term thy program (compile thy program))))); (* evaluation command *)