src/Tools/nbe.ML
changeset 63164 72aaf69328fc
parent 63162 93e75d2f0d01
child 67149 e61557884799
equal deleted inserted replaced
63163:b561284a4214 63164:72aaf69328fc
     6 
     6 
     7 signature NBE =
     7 signature NBE =
     8 sig
     8 sig
     9   val dynamic_conv: Proof.context -> conv
     9   val dynamic_conv: Proof.context -> conv
    10   val dynamic_value: Proof.context -> term -> term
    10   val dynamic_value: Proof.context -> term -> term
    11   val static_conv: { ctxt: Proof.context, consts: string list } -> Proof.context -> conv
    11   val static_conv: { ctxt: Proof.context, consts: string list }
    12   val static_value: { ctxt: Proof.context, consts: string list } -> Proof.context -> term -> term
    12     -> Proof.context -> conv
       
    13   val static_value: { ctxt: Proof.context, consts: string list }
       
    14     -> Proof.context -> term -> term
    13 
    15 
    14   datatype Univ =
    16   datatype Univ =
    15       Const of int * Univ list               (*named (uninterpreted) constants*)
    17       Const of int * Univ list               (*named (uninterpreted) constants*)
    16     | DFree of string * int                  (*free (uninterpreted) dictionary parameters*)
    18     | DFree of string * int                  (*free (uninterpreted) dictionary parameters*)
    17     | BVar of int * Univ list
    19     | BVar of int * Univ list
    19   val apps: Univ -> Univ list -> Univ        (*explicit applications*)
    21   val apps: Univ -> Univ list -> Univ        (*explicit applications*)
    20   val abss: int -> (Univ list -> Univ) -> Univ
    22   val abss: int -> (Univ list -> Univ) -> Univ
    21                                              (*abstractions as closures*)
    23                                              (*abstractions as closures*)
    22   val same: Univ * Univ -> bool
    24   val same: Univ * Univ -> bool
    23 
    25 
    24   val put_result: (unit -> Univ list -> Univ list) -> Proof.context -> Proof.context
    26   val put_result: (unit -> Univ list -> Univ list)
       
    27     -> Proof.context -> Proof.context
    25   val trace: bool Config.T
    28   val trace: bool Config.T
    26 
    29 
    27   val add_const_alias: thm -> theory -> theory
    30   val add_const_alias: thm -> theory -> theory
    28 end;
    31 end;
    29 
    32 
   475     #> apfst (fold (fn (name, deps) => fold (curry Code_Symbol.Graph.add_edge name) deps) names_deps
   478     #> apfst (fold (fn (name, deps) => fold (curry Code_Symbol.Graph.add_edge name) deps) names_deps
   476       #> compile
   479       #> compile
   477       #-> fold (fn (name, univ) => (Code_Symbol.Graph.map_node name o apfst) (K (SOME univ))))
   480       #-> fold (fn (name, univ) => (Code_Symbol.Graph.map_node name o apfst) (K (SOME univ))))
   478   end;
   481   end;
   479 
   482 
   480 fun compile_program ctxt program =
   483 fun compile_program { ctxt, program } =
   481   let
   484   let
   482     fun add_stmts names (nbe_program, (maxidx, idx_tab)) = if exists ((can o Code_Symbol.Graph.get_node) nbe_program) names
   485     fun add_stmts names (nbe_program, (maxidx, idx_tab)) = if exists ((can o Code_Symbol.Graph.get_node) nbe_program) names
   483       then (nbe_program, (maxidx, idx_tab))
   486       then (nbe_program, (maxidx, idx_tab))
   484       else (nbe_program, (maxidx, idx_tab))
   487       else (nbe_program, (maxidx, idx_tab))
   485         |> compile_stmts ctxt (map (fn name => ((name, Code_Symbol.Graph.get_node program name),
   488         |> compile_stmts ctxt (map (fn name => ((name, Code_Symbol.Graph.get_node program name),
   552         (Type.constraint (fastype_of t_original) t');
   555         (Type.constraint (fastype_of t_original) t');
   553     fun check_tvars t' =
   556     fun check_tvars t' =
   554       if null (Term.add_tvars t' []) then t'
   557       if null (Term.add_tvars t' []) then t'
   555       else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t');
   558       else error ("Illegal schematic type variables in normalized term: " ^ string_of_term t');
   556   in
   559   in
   557     compile_and_reconstruct_term
   560     Code_Preproc.timed "computing NBE expression" #ctxt compile_and_reconstruct_term
   558       { ctxt = ctxt, nbe_program = nbe_program, idx_tab = idx_tab, deps = deps, term = (vs, t) }
   561       { ctxt = ctxt, nbe_program = nbe_program, idx_tab = idx_tab, deps = deps, term = (vs, t) }
   559     |> traced ctxt (fn t => "Normalized:\n" ^ string_of_term t)
   562     |> traced ctxt (fn t => "Normalized:\n" ^ string_of_term t)
   560     |> type_infer
   563     |> type_infer
   561     |> traced ctxt (fn t => "Types inferred:\n" ^ string_of_term t)
   564     |> traced ctxt (fn t => "Types inferred:\n" ^ string_of_term t)
   562     |> check_tvars
   565     |> check_tvars
   574 
   577 
   575 fun compile ignore_cache ctxt program =
   578 fun compile ignore_cache ctxt program =
   576   let
   579   let
   577     val (nbe_program, (_, idx_tab)) =
   580     val (nbe_program, (_, idx_tab)) =
   578       Nbe_Functions.change (if ignore_cache then NONE else SOME (Proof_Context.theory_of ctxt))
   581       Nbe_Functions.change (if ignore_cache then NONE else SOME (Proof_Context.theory_of ctxt))
   579         (compile_program ctxt program);
   582         (Code_Preproc.timed "compiling NBE program" #ctxt
       
   583           compile_program { ctxt = ctxt, program = program });
   580   in (nbe_program, idx_tab) end;
   584   in (nbe_program, idx_tab) end;
   581 
   585 
   582 
   586 
   583 (* evaluation oracle *)
   587 (* evaluation oracle *)
   584 
   588 
   606     normalize_term (compile false ctxt program) ctxt));
   610     normalize_term (compile false ctxt program) ctxt));
   607 
   611 
   608 fun static_conv (ctxt_consts as { ctxt, ... }) =
   612 fun static_conv (ctxt_consts as { ctxt, ... }) =
   609   let
   613   let
   610     val conv = Code_Thingol.static_conv_thingol ctxt_consts
   614     val conv = Code_Thingol.static_conv_thingol ctxt_consts
   611       (fn { program, ... } => oracle (compile true ctxt program));
   615       (fn { program, deps = _ } => oracle (compile true ctxt program));
   612   in fn ctxt' => lift_triv_classes_conv ctxt' conv end;
   616   in fn ctxt' => lift_triv_classes_conv ctxt' conv end;
   613 
   617 
   614 fun static_value { ctxt, consts } =
   618 fun static_value { ctxt, consts } =
   615   let
   619   let
   616     val comp = Code_Thingol.static_value { ctxt = ctxt, lift_postproc = I, consts = consts }
   620     val comp = Code_Thingol.static_value { ctxt = ctxt, lift_postproc = I, consts = consts }
   617       (fn { program, ... } => normalize_term (compile false ctxt program));
   621       (fn { program, deps = _ } => normalize_term (compile false ctxt program));
   618   in fn ctxt' => lift_triv_classes_rew ctxt' (comp ctxt') end;
   622   in fn ctxt' => lift_triv_classes_rew ctxt' (comp ctxt') end;
   619 
   623 
   620 end;
   624 end;