src/Tools/nbe.ML
changeset 39392 7a0fcee7a2a3
parent 39388 fdbb2c55ffc2
child 39396 e9cad160aa0f
equal deleted inserted replaced
39388:fdbb2c55ffc2 39392:7a0fcee7a2a3
   375     val (fun_vars, fun_vals) = map_split assemble_eqns eqnss';
   375     val (fun_vars, fun_vals) = map_split assemble_eqns eqnss';
   376     val deps_vars = ml_list (map (nbe_fun 0) deps);
   376     val deps_vars = ml_list (map (nbe_fun 0) deps);
   377   in ml_abs deps_vars (ml_Let (ml_fundefs (flat fun_vars)) (ml_list fun_vals)) end;
   377   in ml_abs deps_vars (ml_Let (ml_fundefs (flat fun_vars)) (ml_list fun_vals)) end;
   378 
   378 
   379 
   379 
   380 (* code compilation *)
   380 (* compile equations *)
   381 
   381 
   382 fun compile_eqnss ctxt gr raw_deps [] = []
   382 fun compile_eqnss thy nbe_program raw_deps [] = []
   383   | compile_eqnss ctxt gr raw_deps eqnss =
   383   | compile_eqnss thy nbe_program raw_deps eqnss =
   384       let
   384       let
       
   385         val ctxt = ProofContext.init_global thy;
   385         val (deps, deps_vals) = split_list (map_filter
   386         val (deps, deps_vals) = split_list (map_filter
   386           (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node gr dep)))) raw_deps);
   387           (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node nbe_program dep)))) raw_deps);
   387         val idx_of = raw_deps
   388         val idx_of = raw_deps
   388           |> map (fn dep => (dep, snd (Graph.get_node gr dep)))
   389           |> map (fn dep => (dep, snd (Graph.get_node nbe_program dep)))
   389           |> AList.lookup (op =)
   390           |> AList.lookup (op =)
   390           |> (fn f => the o f);
   391           |> (fn f => the o f);
   391         val s = assemble_eqnss idx_of deps eqnss;
   392         val s = assemble_eqnss idx_of deps eqnss;
   392         val cs = map fst eqnss;
   393         val cs = map fst eqnss;
   393       in
   394       in
   398         |> (fn f => f deps_vals)
   399         |> (fn f => f deps_vals)
   399         |> (fn univs => cs ~~ univs)
   400         |> (fn univs => cs ~~ univs)
   400       end;
   401       end;
   401 
   402 
   402 
   403 
   403 (* preparing function equations *)
   404 (* extract equations from statements *)
   404 
   405 
   405 fun eqns_of_stmt (_, Code_Thingol.Fun (_, ((_, []), _))) =
   406 fun eqns_of_stmt (_, Code_Thingol.Fun (_, ((_, []), _))) =
   406       []
   407       []
   407   | eqns_of_stmt (const, Code_Thingol.Fun (_, (((vs, _), eqns), _))) =
   408   | eqns_of_stmt (const, Code_Thingol.Fun (_, (((vs, _), eqns), _))) =
   408       [(const, (vs, map fst eqns))]
   409       [(const, (vs, map fst eqns))]
   426   | eqns_of_stmt (inst, Code_Thingol.Classinst ((class, (_, arity_args)), (super_instances, (classparam_instances, _)))) =
   427   | eqns_of_stmt (inst, Code_Thingol.Classinst ((class, (_, arity_args)), (super_instances, (classparam_instances, _)))) =
   427       [(inst, (arity_args, [([], IConst (class, (([], []), [])) `$$
   428       [(inst, (arity_args, [([], IConst (class, (([], []), [])) `$$
   428         map (fn (_, (_, (inst, dss))) => IConst (inst, (([], dss), []))) super_instances
   429         map (fn (_, (_, (inst, dss))) => IConst (inst, (([], dss), []))) super_instances
   429         @ map (IConst o snd o fst) classparam_instances)]))];
   430         @ map (IConst o snd o fst) classparam_instances)]))];
   430 
   431 
   431 fun compile_stmts ctxt stmts_deps =
   432 
       
   433 (* compile whole programs *)
       
   434 
       
   435 fun compile_stmts thy stmts_deps =
   432   let
   436   let
   433     val names = map (fst o fst) stmts_deps;
   437     val names = map (fst o fst) stmts_deps;
   434     val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps;
   438     val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps;
   435     val eqnss = maps (eqns_of_stmt o fst) stmts_deps;
   439     val eqnss = maps (eqns_of_stmt o fst) stmts_deps;
   436     val refl_deps = names_deps
   440     val refl_deps = names_deps
   437       |> maps snd
   441       |> maps snd
   438       |> distinct (op =)
   442       |> distinct (op =)
   439       |> fold (insert (op =)) names;
   443       |> fold (insert (op =)) names;
   440     fun new_node name (gr, (maxidx, idx_tab)) = if can (Graph.get_node gr) name
   444     fun new_node name (nbe_program, (maxidx, idx_tab)) = if can (Graph.get_node nbe_program) name
   441       then (gr, (maxidx, idx_tab))
   445       then (nbe_program, (maxidx, idx_tab))
   442       else (Graph.new_node (name, (NONE, maxidx)) gr,
   446       else (Graph.new_node (name, (NONE, maxidx)) nbe_program,
   443         (maxidx + 1, Inttab.update_new (maxidx, name) idx_tab));
   447         (maxidx + 1, Inttab.update_new (maxidx, name) idx_tab));
   444     fun compile gr = eqnss
   448     fun compile nbe_program = eqnss
   445       |> compile_eqnss ctxt gr refl_deps
   449       |> compile_eqnss thy nbe_program refl_deps
   446       |> rpair gr;
   450       |> rpair nbe_program;
   447   in
   451   in
   448     fold new_node refl_deps
   452     fold new_node refl_deps
   449     #> apfst (fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps
   453     #> apfst (fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps
   450       #> compile
   454       #> compile
   451       #-> fold (fn (name, univ) => (Graph.map_node name o apfst) (K (SOME univ))))
   455       #-> fold (fn (name, univ) => (Graph.map_node name o apfst) (K (SOME univ))))
   452   end;
   456   end;
   453 
   457 
   454 fun ensure_stmts ctxt program =
   458 fun compile_program thy program =
   455   let
   459   let
   456     fun add_stmts names (gr, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) gr) names
   460     fun add_stmts names (nbe_program, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) nbe_program) names
   457       then (gr, (maxidx, idx_tab))
   461       then (nbe_program, (maxidx, idx_tab))
   458       else (gr, (maxidx, idx_tab))
   462       else (nbe_program, (maxidx, idx_tab))
   459         |> compile_stmts ctxt (map (fn name => ((name, Graph.get_node program name),
   463         |> compile_stmts thy (map (fn name => ((name, Graph.get_node program name),
   460           Graph.imm_succs program name)) names);
   464           Graph.imm_succs program name)) names);
   461   in
   465   in
   462     fold_rev add_stmts (Graph.strong_conn program)
   466     fold_rev add_stmts (Graph.strong_conn program)
   463   end;
   467   end;
   464 
   468 
   465 
   469 
   466 (** evaluation **)
   470 (** evaluation **)
   467 
   471 
   468 (* term evaluation *)
   472 (* term evaluation by compilation *)
   469 
   473 
   470 fun eval_term ctxt gr deps (vs : (string * sort) list, t) =
   474 fun compile_term thy nbe_program deps (vs : (string * sort) list, t) =
   471   let 
   475   let 
   472     val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs;
   476     val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs;
   473   in
   477   in
   474     ("", (vs, [([], t)]))
   478     ("", (vs, [([], t)]))
   475     |> singleton (compile_eqnss ctxt gr deps)
   479     |> singleton (compile_eqnss thy nbe_program deps)
   476     |> snd
   480     |> snd
   477     |> (fn t => apps t (rev dict_frees))
   481     |> (fn t => apps t (rev dict_frees))
   478   end;
   482   end;
   479 
   483 
   480 
   484 
   481 (* reification *)
   485 (* reconstruction *)
   482 
   486 
   483 fun typ_of_itype program vs (ityco `%% itys) =
   487 fun typ_of_itype program vs (ityco `%% itys) =
   484       let
   488       let
   485         val Code_Thingol.Datatype (tyco, _) = Graph.get_node program ityco;
   489         val Code_Thingol.Datatype (tyco, _) = Graph.get_node program ityco;
   486       in Type (tyco, map (typ_of_itype program vs) itys) end
   490       in Type (tyco, map (typ_of_itype program vs) itys) end
   523           |> of_univ (bounds + 1) (apps t [BVar (bounds, [])])
   527           |> of_univ (bounds + 1) (apps t [BVar (bounds, [])])
   524           |-> (fn t' => pair (Term.Abs ("u", dummyT, t')))
   528           |-> (fn t' => pair (Term.Abs ("u", dummyT, t')))
   525   in of_univ 0 t 0 |> fst end;
   529   in of_univ 0 t 0 |> fst end;
   526 
   530 
   527 
   531 
   528 (* function store *)
       
   529 
       
   530 structure Nbe_Functions = Code_Data
       
   531 (
       
   532   type T = (Univ option * int) Graph.T * (int * string Inttab.table);
       
   533   val empty = (Graph.empty, (0, Inttab.empty));
       
   534 );
       
   535 
       
   536 
       
   537 (* compilation, evaluation and reification *)
       
   538 
       
   539 fun compile_eval thy program =
       
   540   let
       
   541     val ctxt = ProofContext.init_global thy;
       
   542     val (gr, (_, idx_tab)) =
       
   543       Nbe_Functions.change thy (ensure_stmts ctxt program);
       
   544   in fn vs_t => fn deps =>
       
   545     vs_t
       
   546     |> eval_term ctxt gr deps
       
   547     |> term_of_univ thy program idx_tab
       
   548   end;
       
   549 
       
   550 
       
   551 (* evaluation with type reconstruction *)
   532 (* evaluation with type reconstruction *)
   552 
   533 
   553 fun normalize thy program ((vs0, (vs, ty)), t) deps =
   534 fun eval_term thy program (nbe_program, idx_tab) ((vs0, (vs, ty)), t) deps =
   554   let
   535   let
   555     val ctxt = Syntax.init_pretty_global thy;
   536     val ctxt = Syntax.init_pretty_global thy;
       
   537     val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt);
   556     val ty' = typ_of_itype program vs0 ty;
   538     val ty' = typ_of_itype program vs0 ty;
   557     fun type_infer t =
   539     fun type_infer t = singleton
   558       singleton
   540       (Type_Infer.infer_types ctxt (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE))
   559         (Type_Infer.infer_types ctxt (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE))
   541       (Type.constraint ty' t);
   560         (Type.constraint ty' t);
       
   561     val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt);
       
   562     fun check_tvars t =
   542     fun check_tvars t =
   563       if null (Term.add_tvars t []) then t
   543       if null (Term.add_tvars t []) then t
   564       else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t);
   544       else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t);
   565   in
   545   in
   566     compile_eval thy program (vs, t) deps
   546     compile_term thy nbe_program deps (vs, t)
       
   547     |> term_of_univ thy program idx_tab
   567     |> traced (fn t => "Normalized:\n" ^ string_of_term t)
   548     |> traced (fn t => "Normalized:\n" ^ string_of_term t)
   568     |> type_infer
   549     |> type_infer
   569     |> traced (fn t => "Types inferred:\n" ^ string_of_term t)
   550     |> traced (fn t => "Types inferred:\n" ^ string_of_term t)
   570     |> check_tvars
   551     |> check_tvars
   571     |> traced (fn _ => "---\n")
   552     |> traced (fn _ => "---\n")
   572   end;
   553   end;
   573 
   554 
       
   555 (* function store *)
       
   556 
       
   557 structure Nbe_Functions = Code_Data
       
   558 (
       
   559   type T = (Univ option * int) Graph.T * (int * string Inttab.table);
       
   560   val empty = (Graph.empty, (0, Inttab.empty));
       
   561 );
       
   562 
       
   563 fun compile thy program =
       
   564   let
       
   565     val (nbe_program, (_, idx_tab)) =
       
   566       Nbe_Functions.change thy (compile_program thy program);
       
   567   in (nbe_program, idx_tab) end;
       
   568 
   574 
   569 
   575 (* evaluation oracle *)
   570 (* evaluation oracle *)
   576 
   571 
   577 fun mk_equals thy lhs raw_rhs =
   572 fun mk_equals thy lhs raw_rhs =
   578   let
   573   let
   581     val rhs = Thm.cterm_of thy raw_rhs;
   576     val rhs = Thm.cterm_of thy raw_rhs;
   582   in Thm.mk_binop eq lhs rhs end;
   577   in Thm.mk_binop eq lhs rhs end;
   583 
   578 
   584 val (_, raw_oracle) = Context.>>> (Context.map_theory_result
   579 val (_, raw_oracle) = Context.>>> (Context.map_theory_result
   585   (Thm.add_oracle (Binding.name "norm", fn (thy, program, vsp_ty_t, deps, ct) =>
   580   (Thm.add_oracle (Binding.name "norm", fn (thy, program, vsp_ty_t, deps, ct) =>
   586     mk_equals thy ct (normalize thy program vsp_ty_t deps))));
   581     mk_equals thy ct (eval_term thy program (compile thy program) vsp_ty_t deps))));
   587 
   582 
   588 fun oracle thy program vsp_ty_t deps ct = raw_oracle (thy, program, vsp_ty_t, deps, ct);
   583 fun oracle thy program vsp_ty_t deps ct = raw_oracle (thy, program, vsp_ty_t, deps, ct);
   589 
   584 
   590 fun no_frees_rew rew t =
   585 fun no_frees_rew rew t =
   591   let
   586   let
   599 
   594 
   600 val dynamic_eval_conv = Code_Simp.no_frees_conv (Conv.tap_thy
   595 val dynamic_eval_conv = Code_Simp.no_frees_conv (Conv.tap_thy
   601   (fn thy => lift_triv_classes_conv thy (Code_Thingol.dynamic_eval_conv thy (K (oracle thy)))));
   596   (fn thy => lift_triv_classes_conv thy (Code_Thingol.dynamic_eval_conv thy (K (oracle thy)))));
   602 
   597 
   603 fun dynamic_eval_value thy = lift_triv_classes_rew thy
   598 fun dynamic_eval_value thy = lift_triv_classes_rew thy
   604   (no_frees_rew (Code_Thingol.dynamic_eval_value thy I (K (normalize thy))));
   599   (no_frees_rew (Code_Thingol.dynamic_eval_value thy I
       
   600     (K (fn program => eval_term thy program (compile thy program)))));
   605 
   601 
   606 
   602 
   607 (* evaluation command *)
   603 (* evaluation command *)
   608 
   604 
   609 fun norm_print_term ctxt modes t =
   605 fun norm_print_term ctxt modes t =