haftmann@24155: (* Title: Tools/Nbe/Nbe_Eval.ML haftmann@24155: ID: $Id$ haftmann@24155: Authors: Klaus Aehlig, LMU Muenchen; Tobias Nipkow, Florian Haftmann, TU Muenchen haftmann@24155: haftmann@24155: Evaluation mechanisms for normalization by evaluation. haftmann@24155: *) haftmann@24155: haftmann@24155: (* haftmann@24155: FIXME: haftmann@24155: - get rid of BVar (?) - it is only used for terms to be evaluated, not for functions haftmann@24155: - proper purge operation - preliminary for... haftmann@24155: - really incremental code generation haftmann@24155: *) haftmann@24155: haftmann@24155: signature NBE = haftmann@24155: sig haftmann@24155: datatype Univ = haftmann@24155: Const of string * Univ list (*named constructors*) haftmann@24155: | Free of string * Univ list haftmann@24155: | BVar of int * Univ list haftmann@24155: | Abs of (int * (Univ list -> Univ)) * Univ list; haftmann@24155: val free: string -> Univ list -> Univ (*free (uninterpreted) variables*) haftmann@24155: val abs: int -> (Univ list -> Univ) -> Univ list -> Univ haftmann@24155: (*abstractions as functions*) haftmann@24155: val app: Univ -> Univ -> Univ (*explicit application*) haftmann@24155: haftmann@24155: val univs_ref: Univ list ref haftmann@24155: val lookup_fun: CodegenNames.const -> Univ haftmann@24155: haftmann@24155: val normalization_conv: cterm -> thm haftmann@24155: haftmann@24155: val trace: bool ref haftmann@24155: val setup: theory -> theory haftmann@24155: end; haftmann@24155: haftmann@24155: structure Nbe: NBE = haftmann@24155: struct haftmann@24155: haftmann@24155: (* generic non-sense *) haftmann@24155: haftmann@24155: val trace = ref false; haftmann@24155: fun tracing f x = if !trace then (Output.tracing (f x); x) else x; haftmann@24155: haftmann@24155: haftmann@24155: (** the semantical universe **) haftmann@24155: haftmann@24155: (* haftmann@24155: Functions are given by their semantical function value. To avoid haftmann@24155: trouble with the ML-type system, these functions have the most haftmann@24155: generic type, that is "Univ list -> Univ". The calling convention is haftmann@24155: that the arguments come as a list, the last argument first. In haftmann@24155: other words, a function call that usually would look like haftmann@24155: haftmann@24155: f x_1 x_2 ... x_n or f(x_1,x_2, ..., x_n) haftmann@24155: haftmann@24155: would be in our convention called as haftmann@24155: haftmann@24155: f [x_n,..,x_2,x_1] haftmann@24155: haftmann@24155: Moreover, to handle functions that are still waiting for some haftmann@24155: arguments we have additionally a list of arguments collected to far haftmann@24155: and the number of arguments we're still waiting for. haftmann@24155: haftmann@24155: (?) Finally, it might happen, that a function does not get all the haftmann@24155: arguments it needs. In this case the function must provide means to haftmann@24155: present itself as a string. As this might be a heavy-wight haftmann@24155: operation, we delay it. (?) haftmann@24155: *) haftmann@24155: haftmann@24155: datatype Univ = haftmann@24155: Const of string * Univ list (*named constructors*) haftmann@24155: | Free of string * Univ list (*free variables*) haftmann@24155: | BVar of int * Univ list (*bound named variables*) haftmann@24155: | Abs of (int * (Univ list -> Univ)) * Univ list haftmann@24155: (*functions*); haftmann@24155: haftmann@24155: (* constructor functions *) haftmann@24155: haftmann@24155: val free = curry Free; haftmann@24155: fun abs n f ts = Abs ((n, f), ts); haftmann@24155: fun app (Abs ((1, f), xs)) x = f (x :: xs) haftmann@24155: | app (Abs ((n, f), xs)) x = Abs ((n - 1, f), x :: xs) haftmann@24155: | app (Const (name, args)) x = Const (name, x :: args) haftmann@24155: | app (Free (name, args)) x = Free (name, x :: args) haftmann@24155: | app (BVar (name, args)) x = BVar (name, x :: args); haftmann@24155: haftmann@24155: (* global functions store *) haftmann@24155: haftmann@24155: structure Nbe_Functions = CodeDataFun haftmann@24155: (struct haftmann@24155: type T = Univ Symtab.table; haftmann@24155: val empty = Symtab.empty; haftmann@24155: fun merge _ = Symtab.merge (K true); haftmann@24155: fun purge _ _ _ = Symtab.empty; haftmann@24155: end); haftmann@24155: haftmann@24155: (* sandbox communication *) haftmann@24155: haftmann@24155: val univs_ref = ref [] : Univ list ref; haftmann@24155: haftmann@24155: local haftmann@24155: haftmann@24155: val tab_ref = ref NONE : Univ Symtab.table option ref; haftmann@24155: haftmann@24155: in haftmann@24155: haftmann@24155: fun lookup_fun s = case ! tab_ref haftmann@24155: of NONE => error "compile_univs" haftmann@24155: | SOME tab => (the o Symtab.lookup tab) s; haftmann@24155: haftmann@24155: fun compile_univs tab ([], _) = [] haftmann@24155: | compile_univs tab (cs, raw_s) = haftmann@24155: let haftmann@24155: val _ = univs_ref := []; haftmann@24155: val s = "Nbe.univs_ref := " ^ raw_s; haftmann@24155: val _ = tracing (fn () => "\n---generated code:\n" ^ s) (); haftmann@24155: val _ = tab_ref := SOME tab; haftmann@24155: val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n", haftmann@24155: Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n") haftmann@24155: (!trace) s; haftmann@24155: val _ = tab_ref := NONE; haftmann@24155: val univs = case !univs_ref of [] => error "compile_univs" | univs => univs; haftmann@24155: in cs ~~ univs end; haftmann@24155: haftmann@24155: end; (*local*) haftmann@24155: haftmann@24155: haftmann@24155: (** assembling and compiling ML code from terms **) haftmann@24155: haftmann@24155: (* abstract ML syntax *) haftmann@24155: haftmann@24155: infix 9 `$` `$$`; haftmann@24155: fun e1 `$` e2 = "(" ^ e1 ^ " " ^ e2 ^ ")"; haftmann@24155: fun e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")"; haftmann@24155: fun ml_abs v e = "(fn" ^ v ^ " => " ^ e ^ ")"; haftmann@24155: haftmann@24155: fun ml_Val v s = "val " ^ v ^ " = " ^ s; haftmann@24155: fun ml_cases t cs = haftmann@24155: "(case " ^ t ^ " of " ^ space_implode " | " (map (fn (p, t) => p ^ " => " ^ t) cs) ^ ")"; haftmann@24155: fun ml_Let ds e = "let\n" ^ space_implode "\n" ds ^ " in " ^ e ^ " end"; haftmann@24155: haftmann@24155: fun ml_list es = "[" ^ commas es ^ "]"; haftmann@24155: haftmann@24155: fun ml_fundefs ([(name, [([], e)])]) = haftmann@24155: "val " ^ name ^ " = " ^ e ^ "\n" haftmann@24155: | ml_fundefs (eqs :: eqss) = haftmann@24155: let haftmann@24155: fun fundef (name, eqs) = haftmann@24155: let haftmann@24155: fun eqn (es, e) = name ^ " " ^ space_implode " " es ^ " = " ^ e haftmann@24155: in space_implode "\n | " (map eqn eqs) end; haftmann@24155: in haftmann@24155: (prefix "fun " o fundef) eqs :: map (prefix "and " o fundef) eqss haftmann@24155: |> space_implode "\n" haftmann@24155: |> suffix "\n" haftmann@24155: end; haftmann@24155: haftmann@24155: (* nbe specific syntax *) haftmann@24155: haftmann@24155: local haftmann@24155: val prefix = "Nbe."; haftmann@24155: val name_const = prefix ^ "Const"; haftmann@24155: val name_free = prefix ^ "free"; haftmann@24155: val name_abs = prefix ^ "abs"; haftmann@24155: val name_app = prefix ^ "app"; haftmann@24155: val name_lookup_fun = prefix ^ "lookup_fun"; haftmann@24155: in haftmann@24155: haftmann@24155: fun nbe_const c ts = name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list ts ^ ")"); haftmann@24155: fun nbe_fun c = "c_" ^ translate_string (fn "." => "_" | c => c) c; haftmann@24155: fun nbe_free v = name_free `$$` [ML_Syntax.print_string v, ml_list []]; haftmann@24155: fun nbe_bound v = "v_" ^ v; haftmann@24155: haftmann@24155: fun nbe_apps e es = haftmann@24155: Library.foldr (fn (s, e) => name_app `$$` [e, s]) (es, e); haftmann@24155: haftmann@24155: fun nbe_abss 0 f = f `$` ml_list [] haftmann@24155: | nbe_abss n f = name_abs `$$` [string_of_int n, f, ml_list []]; haftmann@24155: haftmann@24155: fun nbe_lookup c = ml_Val (nbe_fun c) (name_lookup_fun `$` ML_Syntax.print_string c); haftmann@24155: haftmann@24155: val nbe_value = "value"; haftmann@24155: haftmann@24155: end; haftmann@24155: haftmann@24155: open BasicCodegenThingol; haftmann@24155: haftmann@24155: (* greetings to Tarski *) haftmann@24155: haftmann@24155: fun assemble_iterm thy is_fun num_args = haftmann@24155: let haftmann@24155: fun of_iterm t = haftmann@24155: let haftmann@24155: val (t', ts) = CodegenThingol.unfold_app t haftmann@24155: in of_itermapp t' (fold (cons o of_iterm) ts []) end haftmann@24155: and of_itermapp (IConst (c, (dss, _))) ts = haftmann@24155: (case num_args c haftmann@24155: of SOME n => if n <= length ts haftmann@24155: then let val (args2, args1) = chop (length ts - n) ts haftmann@24155: in nbe_apps (nbe_fun c `$` ml_list args1) args2 haftmann@24155: end else nbe_const c ts haftmann@24155: | NONE => if is_fun c then nbe_apps (nbe_fun c) ts haftmann@24155: else nbe_const c ts) haftmann@24155: | of_itermapp (IVar v) ts = nbe_apps (nbe_bound v) ts haftmann@24155: | of_itermapp ((v, _) `|-> t) ts = haftmann@24155: nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts haftmann@24155: | of_itermapp (ICase (((t, _), cs), t0)) ts = haftmann@24155: nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs haftmann@24155: @ [("_", of_iterm t0)])) ts haftmann@24155: in of_iterm end; haftmann@24155: haftmann@24155: fun assemble_fun thy is_fun num_args (c, eqns) = haftmann@24155: let haftmann@24155: val assemble_arg = assemble_iterm thy (K false) (K NONE); haftmann@24155: val assemble_rhs = assemble_iterm thy is_fun num_args; haftmann@24155: fun assemble_eqn (args, rhs) = haftmann@24155: ([ml_list (map assemble_arg (rev args))], assemble_rhs rhs); haftmann@24155: val default_params = map nbe_bound haftmann@24155: (Name.invent_list [] "a" ((the o num_args) c)); haftmann@24155: val default_eqn = ([ml_list default_params], nbe_const c default_params); haftmann@24155: in map assemble_eqn eqns @ [default_eqn] end; haftmann@24155: haftmann@24155: fun assemble_eqnss thy is_fun [] = ([], "") haftmann@24155: | assemble_eqnss thy is_fun eqnss = haftmann@24155: let haftmann@24155: val cs = map fst eqnss; haftmann@24155: val num_args = cs ~~ map (fn (_, (args, rhs) :: _) => length args) eqnss; haftmann@24155: val funs = fold (fold (CodegenThingol.fold_constnames haftmann@24155: (insert (op =))) o map snd o snd) eqnss []; haftmann@24155: val bind_funs = map nbe_lookup (filter is_fun funs); haftmann@24155: val bind_locals = ml_fundefs (map nbe_fun cs ~~ map haftmann@24155: (assemble_fun thy is_fun (AList.lookup (op =) num_args)) eqnss); haftmann@24155: val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args); haftmann@24155: in (cs, ml_Let (bind_funs @ [bind_locals]) result) end; haftmann@24155: haftmann@24155: fun assemble_eval thy is_fun t = haftmann@24155: let haftmann@24155: val funs = CodegenThingol.fold_constnames (insert (op =)) t []; haftmann@24155: val frees = CodegenThingol.fold_unbound_varnames (insert (op =)) t []; haftmann@24155: val bind_funs = map nbe_lookup (filter is_fun funs); haftmann@24155: val bind_value = ml_fundefs [(nbe_value, [([ml_list (map nbe_bound frees)], haftmann@24155: assemble_iterm thy is_fun (K NONE) t)])]; haftmann@24155: val result = ml_list [nbe_value `$` ml_list (map nbe_free frees)]; haftmann@24155: in ([nbe_value], ml_Let (bind_funs @ [bind_value]) result) end; haftmann@24155: haftmann@24155: fun eqns_of_stmt (name, CodegenThingol.Fun ([], _)) = haftmann@24155: NONE haftmann@24155: | eqns_of_stmt (name, CodegenThingol.Fun (eqns, _)) = haftmann@24155: SOME (name, eqns) haftmann@24155: | eqns_of_stmt (_, CodegenThingol.Datatypecons _) = haftmann@24155: NONE haftmann@24155: | eqns_of_stmt (_, CodegenThingol.Datatype _) = haftmann@24155: NONE haftmann@24155: | eqns_of_stmt (_, CodegenThingol.Class _) = haftmann@24155: NONE haftmann@24155: | eqns_of_stmt (_, CodegenThingol.Classrel _) = haftmann@24155: NONE haftmann@24155: | eqns_of_stmt (_, CodegenThingol.Classop _) = haftmann@24155: NONE haftmann@24155: | eqns_of_stmt (_, CodegenThingol.Classinst _) = haftmann@24155: NONE; haftmann@24155: haftmann@24155: fun compile_stmts thy is_fun = haftmann@24155: map_filter eqns_of_stmt haftmann@24155: #> assemble_eqnss thy is_fun haftmann@24155: #> compile_univs (Nbe_Functions.get thy); haftmann@24155: haftmann@24155: fun eval_term thy is_fun = haftmann@24155: assemble_eval thy is_fun haftmann@24155: #> compile_univs (Nbe_Functions.get thy) haftmann@24155: #> the_single haftmann@24155: #> snd; haftmann@24155: haftmann@24155: haftmann@24155: (** compilation and evaluation **) haftmann@24155: haftmann@24155: (* ensure global functions *) haftmann@24155: haftmann@24155: fun ensure_funs thy code = haftmann@24155: let haftmann@24155: fun compile' stmts tab = haftmann@24155: let haftmann@24155: val compiled = compile_stmts thy (Symtab.defined tab) stmts; haftmann@24155: in Nbe_Functions.change thy (fold Symtab.update compiled) end; haftmann@24155: val nbe_tab = Nbe_Functions.get thy; haftmann@24155: val stmtss = haftmann@24155: map (AList.make (Graph.get_node code)) (rev (Graph.strong_conn code)) haftmann@24155: |> (map o filter_out) (Symtab.defined nbe_tab o fst) haftmann@24155: in fold compile' stmtss nbe_tab end; haftmann@24155: haftmann@24155: (* re-conversion *) haftmann@24155: haftmann@24155: fun term_of_univ thy t = haftmann@24155: let haftmann@24155: fun of_apps bounds (t, ts) = haftmann@24155: fold_map (of_univ bounds) ts haftmann@24155: #>> (fn ts' => list_comb (t, rev ts')) haftmann@24155: and of_univ bounds (Const (name, ts)) typidx = haftmann@24155: let haftmann@24155: val SOME (const as (c, _)) = CodegenNames.const_rev thy name; haftmann@24155: val T = CodegenData.default_typ thy const; haftmann@24155: val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T; haftmann@24155: val typidx' = typidx + maxidx_of_typ T' + 1; haftmann@24155: in of_apps bounds (Term.Const (c, T'), ts) typidx' end haftmann@24155: | of_univ bounds (Free (name, ts)) typidx = haftmann@24155: of_apps bounds (Term.Free (name, dummyT), ts) typidx haftmann@24155: | of_univ bounds (BVar (name, ts)) typidx = haftmann@24155: of_apps bounds (Bound (bounds - name - 1), ts) typidx haftmann@24155: | of_univ bounds (t as Abs _) typidx = haftmann@24155: typidx haftmann@24155: |> of_univ (bounds + 1) (app t (BVar (bounds, []))) haftmann@24155: |-> (fn t' => pair (Term.Abs ("u", dummyT, t'))) haftmann@24155: in of_univ 0 t 0 |> fst end; haftmann@24155: haftmann@24155: (* evaluation with type reconstruction *) haftmann@24155: haftmann@24155: fun eval thy code t t' = haftmann@24155: let haftmann@24155: fun subst_Frees [] = I haftmann@24155: | subst_Frees inst = haftmann@24155: Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s) haftmann@24155: | t => t); haftmann@24155: val anno_vars = haftmann@24155: subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t [])) haftmann@24155: #> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t [])) haftmann@24155: fun check_tvars t = if null (Term.term_tvars t) then t else haftmann@24155: error ("Illegal schematic type variables in normalized term: " haftmann@24155: ^ setmp show_types true (Sign.string_of_term thy) t); haftmann@24155: val ty = type_of t; haftmann@24155: fun constrain t = haftmann@24155: singleton (ProofContext.infer_types_pats (ProofContext.init thy)) (TypeInfer.constrain t ty); haftmann@24155: in haftmann@24155: t' haftmann@24155: |> eval_term thy (Symtab.defined (ensure_funs thy code)) haftmann@24155: |> term_of_univ thy haftmann@24155: |> tracing (fn t => "Normalized:\n" ^ setmp show_types true Display.raw_string_of_term t) haftmann@24155: |> tracing (fn _ => "Term type:\n" ^ Display.raw_string_of_typ ty) haftmann@24155: |> anno_vars haftmann@24155: |> tracing (fn t => "Vars typed:\n" ^ setmp show_types true Display.raw_string_of_term t) haftmann@24155: |> tracing (fn t => setmp show_types true (Sign.string_of_term thy) t) haftmann@24155: |> constrain haftmann@24155: |> tracing (fn t => "Types inferred:\n" ^ setmp show_types true Display.raw_string_of_term t) haftmann@24155: |> check_tvars haftmann@24155: end; haftmann@24155: haftmann@24155: (* evaluation oracle *) haftmann@24155: haftmann@24155: exception Normalization of CodegenThingol.code * term * CodegenThingol.iterm; haftmann@24155: haftmann@24155: fun normalization_oracle (thy, Normalization (code, t, t')) = haftmann@24155: Logic.mk_equals (t, eval thy code t t'); haftmann@24155: haftmann@24155: fun normalization_invoke thy code t t' = haftmann@24155: Thm.invoke_oracle_i thy "Nbe.normalization" (thy, Normalization (code, t, t')); haftmann@24155: haftmann@24155: fun normalization_conv ct = haftmann@24155: let haftmann@24155: val thy = Thm.theory_of_cterm ct; haftmann@24155: fun conv code t' ct = haftmann@24155: let haftmann@24155: val t = Thm.term_of ct; haftmann@24155: in normalization_invoke thy code t t' end; haftmann@24155: in CodegenPackage.eval_conv thy conv ct end; haftmann@24155: haftmann@24155: (* evaluation command *) haftmann@24155: haftmann@24155: fun norm_print_term ctxt modes t = haftmann@24155: let haftmann@24155: val thy = ProofContext.theory_of ctxt; haftmann@24155: val ct = Thm.cterm_of thy t; haftmann@24155: val (_, t') = (Logic.dest_equals o Thm.prop_of o normalization_conv) ct; haftmann@24155: val ty = Term.type_of t'; haftmann@24155: val p = Library.setmp print_mode (modes @ ! print_mode) (fn () => haftmann@24155: Pretty.block [Pretty.quote (ProofContext.pretty_term ctxt t'), Pretty.fbrk, haftmann@24155: Pretty.str "::", Pretty.brk 1, Pretty.quote (ProofContext.pretty_typ ctxt ty)]) (); haftmann@24155: in Pretty.writeln p end; haftmann@24155: haftmann@24155: haftmann@24155: (** Isar setup **) haftmann@24155: haftmann@24155: fun norm_print_term_cmd (modes, raw_t) state = haftmann@24155: let val ctxt = Toplevel.context_of state haftmann@24155: in norm_print_term ctxt modes (ProofContext.read_term ctxt raw_t) end; haftmann@24155: haftmann@24155: val setup = Theory.add_oracle ("normalization", normalization_oracle) haftmann@24155: haftmann@24155: local structure P = OuterParse and K = OuterKeyword in haftmann@24155: haftmann@24155: val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) []; haftmann@24155: haftmann@24155: val nbeP = haftmann@24155: OuterSyntax.improper_command "normal_form" "normalize term by evaluation" K.diag haftmann@24155: (opt_modes -- P.typ >> (Toplevel.keep o norm_print_term_cmd)); haftmann@24155: haftmann@24155: val _ = OuterSyntax.add_parsers [nbeP]; haftmann@24155: haftmann@24155: end; haftmann@24155: haftmann@24155: end;