375 val (fun_vars, fun_vals) = map_split assemble_eqns eqnss'; |
375 val (fun_vars, fun_vals) = map_split assemble_eqns eqnss'; |
376 val deps_vars = ml_list (map (nbe_fun 0) deps); |
376 val deps_vars = ml_list (map (nbe_fun 0) deps); |
377 in ml_abs deps_vars (ml_Let (ml_fundefs (flat fun_vars)) (ml_list fun_vals)) end; |
377 in ml_abs deps_vars (ml_Let (ml_fundefs (flat fun_vars)) (ml_list fun_vals)) end; |
378 |
378 |
379 |
379 |
380 (* code compilation *) |
380 (* compile equations *) |
381 |
381 |
382 fun compile_eqnss ctxt gr raw_deps [] = [] |
382 fun compile_eqnss thy nbe_program raw_deps [] = [] |
383 | compile_eqnss ctxt gr raw_deps eqnss = |
383 | compile_eqnss thy nbe_program raw_deps eqnss = |
384 let |
384 let |
|
385 val ctxt = ProofContext.init_global thy; |
385 val (deps, deps_vals) = split_list (map_filter |
386 val (deps, deps_vals) = split_list (map_filter |
386 (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node gr dep)))) raw_deps); |
387 (fn dep => Option.map (fn univ => (dep, univ)) (fst ((Graph.get_node nbe_program dep)))) raw_deps); |
387 val idx_of = raw_deps |
388 val idx_of = raw_deps |
388 |> map (fn dep => (dep, snd (Graph.get_node gr dep))) |
389 |> map (fn dep => (dep, snd (Graph.get_node nbe_program dep))) |
389 |> AList.lookup (op =) |
390 |> AList.lookup (op =) |
390 |> (fn f => the o f); |
391 |> (fn f => the o f); |
391 val s = assemble_eqnss idx_of deps eqnss; |
392 val s = assemble_eqnss idx_of deps eqnss; |
392 val cs = map fst eqnss; |
393 val cs = map fst eqnss; |
393 in |
394 in |
426 | eqns_of_stmt (inst, Code_Thingol.Classinst ((class, (_, arity_args)), (super_instances, (classparam_instances, _)))) = |
427 | eqns_of_stmt (inst, Code_Thingol.Classinst ((class, (_, arity_args)), (super_instances, (classparam_instances, _)))) = |
427 [(inst, (arity_args, [([], IConst (class, (([], []), [])) `$$ |
428 [(inst, (arity_args, [([], IConst (class, (([], []), [])) `$$ |
428 map (fn (_, (_, (inst, dss))) => IConst (inst, (([], dss), []))) super_instances |
429 map (fn (_, (_, (inst, dss))) => IConst (inst, (([], dss), []))) super_instances |
429 @ map (IConst o snd o fst) classparam_instances)]))]; |
430 @ map (IConst o snd o fst) classparam_instances)]))]; |
430 |
431 |
431 fun compile_stmts ctxt stmts_deps = |
432 |
|
433 (* compile whole programs *) |
|
434 |
|
435 fun compile_stmts thy stmts_deps = |
432 let |
436 let |
433 val names = map (fst o fst) stmts_deps; |
437 val names = map (fst o fst) stmts_deps; |
434 val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps; |
438 val names_deps = map (fn ((name, _), deps) => (name, deps)) stmts_deps; |
435 val eqnss = maps (eqns_of_stmt o fst) stmts_deps; |
439 val eqnss = maps (eqns_of_stmt o fst) stmts_deps; |
436 val refl_deps = names_deps |
440 val refl_deps = names_deps |
437 |> maps snd |
441 |> maps snd |
438 |> distinct (op =) |
442 |> distinct (op =) |
439 |> fold (insert (op =)) names; |
443 |> fold (insert (op =)) names; |
440 fun new_node name (gr, (maxidx, idx_tab)) = if can (Graph.get_node gr) name |
444 fun new_node name (nbe_program, (maxidx, idx_tab)) = if can (Graph.get_node nbe_program) name |
441 then (gr, (maxidx, idx_tab)) |
445 then (nbe_program, (maxidx, idx_tab)) |
442 else (Graph.new_node (name, (NONE, maxidx)) gr, |
446 else (Graph.new_node (name, (NONE, maxidx)) nbe_program, |
443 (maxidx + 1, Inttab.update_new (maxidx, name) idx_tab)); |
447 (maxidx + 1, Inttab.update_new (maxidx, name) idx_tab)); |
444 fun compile gr = eqnss |
448 fun compile nbe_program = eqnss |
445 |> compile_eqnss ctxt gr refl_deps |
449 |> compile_eqnss thy nbe_program refl_deps |
446 |> rpair gr; |
450 |> rpair nbe_program; |
447 in |
451 in |
448 fold new_node refl_deps |
452 fold new_node refl_deps |
449 #> apfst (fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps |
453 #> apfst (fold (fn (name, deps) => fold (curry Graph.add_edge name) deps) names_deps |
450 #> compile |
454 #> compile |
451 #-> fold (fn (name, univ) => (Graph.map_node name o apfst) (K (SOME univ)))) |
455 #-> fold (fn (name, univ) => (Graph.map_node name o apfst) (K (SOME univ)))) |
452 end; |
456 end; |
453 |
457 |
454 fun ensure_stmts ctxt program = |
458 fun compile_program thy program = |
455 let |
459 let |
456 fun add_stmts names (gr, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) gr) names |
460 fun add_stmts names (nbe_program, (maxidx, idx_tab)) = if exists ((can o Graph.get_node) nbe_program) names |
457 then (gr, (maxidx, idx_tab)) |
461 then (nbe_program, (maxidx, idx_tab)) |
458 else (gr, (maxidx, idx_tab)) |
462 else (nbe_program, (maxidx, idx_tab)) |
459 |> compile_stmts ctxt (map (fn name => ((name, Graph.get_node program name), |
463 |> compile_stmts thy (map (fn name => ((name, Graph.get_node program name), |
460 Graph.imm_succs program name)) names); |
464 Graph.imm_succs program name)) names); |
461 in |
465 in |
462 fold_rev add_stmts (Graph.strong_conn program) |
466 fold_rev add_stmts (Graph.strong_conn program) |
463 end; |
467 end; |
464 |
468 |
465 |
469 |
466 (** evaluation **) |
470 (** evaluation **) |
467 |
471 |
468 (* term evaluation *) |
472 (* term evaluation by compilation *) |
469 |
473 |
470 fun eval_term ctxt gr deps (vs : (string * sort) list, t) = |
474 fun compile_term thy nbe_program deps (vs : (string * sort) list, t) = |
471 let |
475 let |
472 val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs; |
476 val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs; |
473 in |
477 in |
474 ("", (vs, [([], t)])) |
478 ("", (vs, [([], t)])) |
475 |> singleton (compile_eqnss ctxt gr deps) |
479 |> singleton (compile_eqnss thy nbe_program deps) |
476 |> snd |
480 |> snd |
477 |> (fn t => apps t (rev dict_frees)) |
481 |> (fn t => apps t (rev dict_frees)) |
478 end; |
482 end; |
479 |
483 |
480 |
484 |
481 (* reification *) |
485 (* reconstruction *) |
482 |
486 |
483 fun typ_of_itype program vs (ityco `%% itys) = |
487 fun typ_of_itype program vs (ityco `%% itys) = |
484 let |
488 let |
485 val Code_Thingol.Datatype (tyco, _) = Graph.get_node program ityco; |
489 val Code_Thingol.Datatype (tyco, _) = Graph.get_node program ityco; |
486 in Type (tyco, map (typ_of_itype program vs) itys) end |
490 in Type (tyco, map (typ_of_itype program vs) itys) end |
523 |> of_univ (bounds + 1) (apps t [BVar (bounds, [])]) |
527 |> of_univ (bounds + 1) (apps t [BVar (bounds, [])]) |
524 |-> (fn t' => pair (Term.Abs ("u", dummyT, t'))) |
528 |-> (fn t' => pair (Term.Abs ("u", dummyT, t'))) |
525 in of_univ 0 t 0 |> fst end; |
529 in of_univ 0 t 0 |> fst end; |
526 |
530 |
527 |
531 |
528 (* function store *) |
|
529 |
|
530 structure Nbe_Functions = Code_Data |
|
531 ( |
|
532 type T = (Univ option * int) Graph.T * (int * string Inttab.table); |
|
533 val empty = (Graph.empty, (0, Inttab.empty)); |
|
534 ); |
|
535 |
|
536 |
|
537 (* compilation, evaluation and reification *) |
|
538 |
|
539 fun compile_eval thy program = |
|
540 let |
|
541 val ctxt = ProofContext.init_global thy; |
|
542 val (gr, (_, idx_tab)) = |
|
543 Nbe_Functions.change thy (ensure_stmts ctxt program); |
|
544 in fn vs_t => fn deps => |
|
545 vs_t |
|
546 |> eval_term ctxt gr deps |
|
547 |> term_of_univ thy program idx_tab |
|
548 end; |
|
549 |
|
550 |
|
551 (* evaluation with type reconstruction *) |
532 (* evaluation with type reconstruction *) |
552 |
533 |
553 fun normalize thy program ((vs0, (vs, ty)), t) deps = |
534 fun eval_term thy program (nbe_program, idx_tab) ((vs0, (vs, ty)), t) deps = |
554 let |
535 let |
555 val ctxt = Syntax.init_pretty_global thy; |
536 val ctxt = Syntax.init_pretty_global thy; |
|
537 val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt); |
556 val ty' = typ_of_itype program vs0 ty; |
538 val ty' = typ_of_itype program vs0 ty; |
557 fun type_infer t = |
539 fun type_infer t = singleton |
558 singleton |
540 (Type_Infer.infer_types ctxt (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE)) |
559 (Type_Infer.infer_types ctxt (try (Type.strip_sorts o Sign.the_const_type thy)) (K NONE)) |
541 (Type.constraint ty' t); |
560 (Type.constraint ty' t); |
|
561 val string_of_term = Syntax.string_of_term (Config.put show_types true ctxt); |
|
562 fun check_tvars t = |
542 fun check_tvars t = |
563 if null (Term.add_tvars t []) then t |
543 if null (Term.add_tvars t []) then t |
564 else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t); |
544 else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t); |
565 in |
545 in |
566 compile_eval thy program (vs, t) deps |
546 compile_term thy nbe_program deps (vs, t) |
|
547 |> term_of_univ thy program idx_tab |
567 |> traced (fn t => "Normalized:\n" ^ string_of_term t) |
548 |> traced (fn t => "Normalized:\n" ^ string_of_term t) |
568 |> type_infer |
549 |> type_infer |
569 |> traced (fn t => "Types inferred:\n" ^ string_of_term t) |
550 |> traced (fn t => "Types inferred:\n" ^ string_of_term t) |
570 |> check_tvars |
551 |> check_tvars |
571 |> traced (fn _ => "---\n") |
552 |> traced (fn _ => "---\n") |
572 end; |
553 end; |
573 |
554 |
|
555 (* function store *) |
|
556 |
|
557 structure Nbe_Functions = Code_Data |
|
558 ( |
|
559 type T = (Univ option * int) Graph.T * (int * string Inttab.table); |
|
560 val empty = (Graph.empty, (0, Inttab.empty)); |
|
561 ); |
|
562 |
|
563 fun compile thy program = |
|
564 let |
|
565 val (nbe_program, (_, idx_tab)) = |
|
566 Nbe_Functions.change thy (compile_program thy program); |
|
567 in (nbe_program, idx_tab) end; |
|
568 |
574 |
569 |
575 (* evaluation oracle *) |
570 (* evaluation oracle *) |
576 |
571 |
577 fun mk_equals thy lhs raw_rhs = |
572 fun mk_equals thy lhs raw_rhs = |
578 let |
573 let |
581 val rhs = Thm.cterm_of thy raw_rhs; |
576 val rhs = Thm.cterm_of thy raw_rhs; |
582 in Thm.mk_binop eq lhs rhs end; |
577 in Thm.mk_binop eq lhs rhs end; |
583 |
578 |
584 val (_, raw_oracle) = Context.>>> (Context.map_theory_result |
579 val (_, raw_oracle) = Context.>>> (Context.map_theory_result |
585 (Thm.add_oracle (Binding.name "norm", fn (thy, program, vsp_ty_t, deps, ct) => |
580 (Thm.add_oracle (Binding.name "norm", fn (thy, program, vsp_ty_t, deps, ct) => |
586 mk_equals thy ct (normalize thy program vsp_ty_t deps)))); |
581 mk_equals thy ct (eval_term thy program (compile thy program) vsp_ty_t deps)))); |
587 |
582 |
588 fun oracle thy program vsp_ty_t deps ct = raw_oracle (thy, program, vsp_ty_t, deps, ct); |
583 fun oracle thy program vsp_ty_t deps ct = raw_oracle (thy, program, vsp_ty_t, deps, ct); |
589 |
584 |
590 fun no_frees_rew rew t = |
585 fun no_frees_rew rew t = |
591 let |
586 let |