src/Tools/nbe.ML
changeset 39399 267235a14938
parent 39396 e9cad160aa0f
child 39436 4a7d09da2b9c
equal deleted inserted replaced
39398:2e30660a2e21 39399:267235a14938
     6 
     6 
     7 signature NBE =
     7 signature NBE =
     8 sig
     8 sig
     9   val dynamic_eval_conv: conv
     9   val dynamic_eval_conv: conv
    10   val dynamic_eval_value: theory -> term -> term
    10   val dynamic_eval_value: theory -> term -> term
       
    11   val static_eval_conv: theory -> string list -> conv
    11 
    12 
    12   datatype Univ =
    13   datatype Univ =
    13       Const of int * Univ list               (*named (uninterpreted) constants*)
    14       Const of int * Univ list               (*named (uninterpreted) constants*)
    14     | DFree of string * int                  (*free (uninterpreted) dictionary parameters*)
    15     | DFree of string * int                  (*free (uninterpreted) dictionary parameters*)
    15     | BVar of int * Univ list
    16     | BVar of int * Univ list
   226       end;
   227       end;
   227 
   228 
   228 
   229 
   229 (* nbe specific syntax and sandbox communication *)
   230 (* nbe specific syntax and sandbox communication *)
   230 
   231 
   231 structure Univs = Proof_Data(
   232 structure Univs = Proof_Data (
   232   type T = unit -> Univ list -> Univ list
   233   type T = unit -> Univ list -> Univ list
   233   fun init thy () = error "Univs"
   234   fun init _ () = error "Univs"
   234 );
   235 );
   235 val put_result = Univs.put;
   236 val put_result = Univs.put;
   236 
   237 
   237 local
   238 local
   238   val prefix =      "Nbe.";
   239   val prefix =      "Nbe.";
   430         @ map (IConst o snd o fst) classparam_instances)]))];
   431         @ map (IConst o snd o fst) classparam_instances)]))];
   431 
   432 
   432 
   433 
   433 (* compile whole programs *)
   434 (* compile whole programs *)
   434 
   435 
       
   436 fun ensure_const_idx name (nbe_program, (maxidx, idx_tab)) =
       
   437   if can (Graph.get_node nbe_program) name
       
   438   then (nbe_program, (maxidx, idx_tab))
       
   439   else (Graph.new_node (name, (NONE, maxidx)) nbe_program,
       
   440     (maxidx + 1, Inttab.update_new (maxidx, name) idx_tab));
       
   441 
   435 fun compile_stmts thy stmts_deps =
   442 fun compile_stmts thy stmts_deps =
   436   let
   443   let
   437     val names = map (fst o fst) stmts_deps;
   444     val names = map (fst o fst) stmts_deps;
   438     val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps;
   445     val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps;
   439     val eqnss = maps (eqns_of_stmt o fst) stmts_deps;
   446     val eqnss = maps (eqns_of_stmt o fst) stmts_deps;
   440     val refl_deps = names_deps
   447     val refl_deps = names_deps
   441       |> maps snd
   448       |> maps snd
   442       |> distinct (op =)
   449       |> distinct (op =)
   443       |> fold (insert (op =)) names;
   450       |> fold (insert (op =)) names;
   444     fun new_node name (nbe_program, (maxidx, idx_tab)) = if can (Graph.get_node nbe_program) name
       
   445       then (nbe_program, (maxidx, idx_tab))
       
   446       else (Graph.new_node (name, (NONE, maxidx)) nbe_program,
       
   447         (maxidx + 1, Inttab.update_new (maxidx, name) idx_tab));
       
   448     fun compile nbe_program = eqnss
   451     fun compile nbe_program = eqnss
   449       |> compile_eqnss thy nbe_program refl_deps
   452       |> compile_eqnss thy nbe_program refl_deps
   450       |> rpair nbe_program;
   453       |> rpair nbe_program;
   451   in
   454   in
   452     fold new_node refl_deps
   455     fold ensure_const_idx refl_deps
   453     #> apfst (fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps
   456     #> apfst (fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps
   454       #> compile
   457       #> compile
   455       #-> fold (fn (name, univ) => (Graph.map_node name o apfst) (K (SOME univ))))
   458       #-> fold (fn (name, univ) => (Graph.map_node name o apfst) (K (SOME univ))))
   456   end;
   459   end;
   457 
   460 
   558 (
   561 (
   559   type T = (Univ option * int) Graph.T * (int * string Inttab.table);
   562   type T = (Univ option * int) Graph.T * (int * string Inttab.table);
   560   val empty = (Graph.empty, (0, Inttab.empty));
   563   val empty = (Graph.empty, (0, Inttab.empty));
   561 );
   564 );
   562 
   565 
   563 fun compile thy program =
   566 fun compile ignore_cache thy program =
   564   let
   567   let
   565     val (nbe_program, (_, idx_tab)) =
   568     val (nbe_program, (_, idx_tab)) =
   566       Nbe_Functions.change thy (compile_program thy program);
   569       Nbe_Functions.change (if ignore_cache then NONE else SOME thy)
       
   570         (compile_program thy program);
   567   in (nbe_program, idx_tab) end;
   571   in (nbe_program, idx_tab) end;
   568 
   572 
   569 
   573 
   570 (* dynamic evaluation oracle *)
   574 (* dynamic evaluation oracle *)
   571 
   575 
   575     val eq = Thm.cterm_of thy (Term.Const ("==", ty --> ty --> propT));
   579     val eq = Thm.cterm_of thy (Term.Const ("==", ty --> ty --> propT));
   576     val rhs = Thm.cterm_of thy raw_rhs;
   580     val rhs = Thm.cterm_of thy raw_rhs;
   577   in Thm.mk_binop eq lhs rhs end;
   581   in Thm.mk_binop eq lhs rhs end;
   578 
   582 
   579 val (_, raw_oracle) = Context.>>> (Context.map_theory_result
   583 val (_, raw_oracle) = Context.>>> (Context.map_theory_result
   580   (Thm.add_oracle (Binding.name "norm", fn (thy, program, vsp_ty_t, deps, ct) =>
   584   (Thm.add_oracle (Binding.name "normalization_by_evaluation",
   581     mk_equals thy ct (eval_term thy program (compile thy program) vsp_ty_t deps))));
   585     fn (thy, program, nbe_program_idx_tab, vsp_ty_t, deps, ct) =>
   582 
   586       mk_equals thy ct (eval_term thy program nbe_program_idx_tab vsp_ty_t deps))));
   583 fun oracle thy program vsp_ty_t deps ct = raw_oracle (thy, program, vsp_ty_t, deps, ct);
   587 
       
   588 fun oracle thy program nbe_program_idx_tab vsp_ty_t deps ct =
       
   589   raw_oracle (thy, program, nbe_program_idx_tab, vsp_ty_t, deps, ct);
   584 
   590 
   585 fun no_frees_rew rew t =
   591 fun no_frees_rew rew t =
   586   let
   592   let
   587     val frees = map Free (Term.add_frees t []);
   593     val frees = map Free (Term.add_frees t []);
   588   in
   594   in
   590     |> fold_rev lambda frees
   596     |> fold_rev lambda frees
   591     |> rew
   597     |> rew
   592     |> curry (Term.betapplys o swap) frees
   598     |> curry (Term.betapplys o swap) frees
   593   end;
   599   end;
   594 
   600 
   595 val dynamic_eval_conv = Code_Simp.no_frees_conv (Conv.tap_thy
   601 val dynamic_eval_conv = Conv.tap_thy (fn thy => Code_Simp.no_frees_conv
   596   (fn thy => lift_triv_classes_conv thy (Code_Thingol.dynamic_eval_conv thy (K (oracle thy)))));
   602   (lift_triv_classes_conv thy (Code_Thingol.dynamic_eval_conv thy
       
   603     (K (fn program => oracle thy program (compile false thy program))))));
   597 
   604 
   598 fun dynamic_eval_value thy = lift_triv_classes_rew thy
   605 fun dynamic_eval_value thy = lift_triv_classes_rew thy
   599   (no_frees_rew (Code_Thingol.dynamic_eval_value thy I
   606   (no_frees_rew (Code_Thingol.dynamic_eval_value thy I
   600     (K (fn program => eval_term thy program (compile thy program)))));
   607     (K (fn program => eval_term thy program (compile false thy program)))));
       
   608 
       
   609 fun static_eval_conv thy consts = Code_Simp.no_frees_conv
       
   610   (lift_triv_classes_conv thy (Code_Thingol.static_eval_conv thy consts
       
   611     (K (fn program => oracle thy program (compile true thy program)))));
   601 
   612 
   602 
   613 
   603 (** setup **)
   614 (** setup **)
   604 
   615 
   605 val setup = Value.add_evaluator ("nbe", dynamic_eval_value o ProofContext.theory_of);
   616 val setup = Value.add_evaluator ("nbe", dynamic_eval_value o ProofContext.theory_of);