src/Tools/nbe.ML
author haftmann
Sat Sep 15 19:27:48 2007 +0200 (2007-09-15)
changeset 24590 733120d04233
parent 24508 c8b82fec6447
child 24612 d1b315bdb8d7
permissions -rw-r--r--
delayed evaluation
     1 (*  Title:      Tools/nbe.ML
     2     ID:         $Id$
     3     Authors:    Klaus Aehlig, LMU Muenchen; Tobias Nipkow, Florian Haftmann, TU Muenchen
     4 
     5 Evaluation mechanisms for normalization by evaluation.
     6 *)
     7 
     8 (*
     9 FIXME:
    10 - get rid of BVar (?) - it is only used for terms to be evaluated, not for functions
    11 - proper purge operation - preliminary for...
    12 - really incremental code generation
    13 *)
    14 
    15 signature NBE =
    16 sig
    17   datatype Univ = 
    18       Const of string * Univ list            (*named (uninterpreted) constants*)
    19     | Free of string * Univ list
    20     | BVar of int * Univ list
    21     | Abs of (int * (Univ list -> Univ)) * Univ list;
    22   val free: string -> Univ list -> Univ       (*free (uninterpreted) variables*)
    23   val abs: int -> (Univ list -> Univ) -> Univ list -> Univ
    24                                             (*abstractions as functions*)
    25   val app: Univ -> Univ -> Univ              (*explicit application*)
    26 
    27   val univs_ref: (unit -> Univ list) ref 
    28   val lookup_fun: string -> Univ
    29 
    30   val normalization_conv: cterm -> thm
    31 
    32   val trace: bool ref
    33   val setup: theory -> theory
    34 end;
    35 
    36 structure Nbe: NBE =
    37 struct
    38 
    39 (* generic non-sense *)
    40 
    41 val trace = ref false;
    42 fun tracing f x = if !trace then (Output.tracing (f x); x) else x;
    43 
    44 
    45 (** the semantical universe **)
    46 
    47 (*
    48    Functions are given by their semantical function value. To avoid
    49    trouble with the ML-type system, these functions have the most
    50    generic type, that is "Univ list -> Univ". The calling convention is
    51    that the arguments come as a list, the last argument first. In
    52    other words, a function call that usually would look like
    53 
    54    f x_1 x_2 ... x_n   or   f(x_1,x_2, ..., x_n)
    55 
    56    would be in our convention called as
    57 
    58               f [x_n,..,x_2,x_1]
    59 
    60    Moreover, to handle functions that are still waiting for some
    61    arguments we have additionally a list of arguments collected to far
    62    and the number of arguments we're still waiting for.
    63 *)
    64 
    65 datatype Univ = 
    66     Const of string * Univ list        (*named (uninterpreted) constants*)
    67   | Free of string * Univ list         (*free variables*)
    68   | BVar of int * Univ list            (*bound named variables*)
    69   | Abs of (int * (Univ list -> Univ)) * Univ list
    70                                       (*abstractions as closures*);
    71 
    72 (* constructor functions *)
    73 
    74 val free = curry Free;
    75 fun abs n f ts = Abs ((n, f), ts);
    76 fun app (Abs ((1, f), xs)) x = f (x :: xs)
    77   | app (Abs ((n, f), xs)) x = Abs ((n - 1, f), x :: xs)
    78   | app (Const (name, args)) x = Const (name, x :: args)
    79   | app (Free (name, args)) x = Free (name, x :: args)
    80   | app (BVar (name, args)) x = BVar (name, x :: args);
    81 
    82 (* global functions store *)
    83 
    84 structure Nbe_Functions = CodeDataFun
    85 (struct
    86   type T = Univ Symtab.table;
    87   val empty = Symtab.empty;
    88   fun merge _ = Symtab.merge (K true);
    89   fun purge _ _ _ = Symtab.empty;
    90 end);
    91 
    92 (* sandbox communication *)
    93 
    94 val univs_ref = ref (fn () => [] : Univ list);
    95 
    96 local
    97 
    98 val tab_ref = ref NONE : Univ Symtab.table option ref;
    99 
   100 in
   101 
   102 fun lookup_fun s = case ! tab_ref
   103  of NONE => error "compile_univs"
   104   | SOME tab => (the o Symtab.lookup tab) s;
   105 
   106 fun compile_univs tab ([], _) = []
   107   | compile_univs tab (cs, raw_s) =
   108       let
   109         val _ = univs_ref := (fn () => []);
   110         val s = "Nbe.univs_ref := " ^ raw_s;
   111         val _ = tracing (fn () => "\n--- generated code:\n" ^ s) ();
   112         val _ = tab_ref := SOME tab;
   113         val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
   114           Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n")
   115           (!trace) s;
   116         val _ = tab_ref := NONE;
   117         val univs = case !univs_ref () of [] => error "compile_univs" | univs => univs;
   118       in cs ~~ univs end;
   119 
   120 end; (*local*)
   121 
   122 
   123 (** assembling and compiling ML code from terms **)
   124 
   125 (* abstract ML syntax *)
   126 
   127 infix 9 `$` `$$`;
   128 fun e1 `$` e2 = "(" ^ e1 ^ " " ^ e2 ^ ")";
   129 fun e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")";
   130 fun ml_abs v e = "(fn " ^ v ^ " => " ^ e ^ ")";
   131 
   132 fun ml_Val v s = "val " ^ v ^ " = " ^ s;
   133 fun ml_cases t cs =
   134   "(case " ^ t ^ " of " ^ space_implode " | " (map (fn (p, t) => p ^ " => " ^ t) cs) ^ ")";
   135 fun ml_Let ds e = "let\n" ^ space_implode "\n" ds ^ " in " ^ e ^ " end";
   136 
   137 fun ml_list es = "[" ^ commas es ^ "]";
   138 
   139 val ml_delay = ml_abs "()"
   140 
   141 fun ml_fundefs ([(name, [([], e)])]) =
   142       "val " ^ name ^ " = " ^ e ^ "\n"
   143   | ml_fundefs (eqs :: eqss) =
   144       let
   145         fun fundef (name, eqs) =
   146           let
   147             fun eqn (es, e) = name ^ " " ^ space_implode " " es ^ " = " ^ e
   148           in space_implode "\n  | " (map eqn eqs) end;
   149       in
   150         (prefix "fun " o fundef) eqs :: map (prefix "and " o fundef) eqss
   151         |> space_implode "\n"
   152         |> suffix "\n"
   153       end;
   154 
   155 (* nbe specific syntax *)
   156 
   157 local
   158   val prefix =          "Nbe.";
   159   val name_const =      prefix ^ "Const";
   160   val name_free =       prefix ^ "free";
   161   val name_abs =        prefix ^ "abs";
   162   val name_app =        prefix ^ "app";
   163   val name_lookup_fun = prefix ^ "lookup_fun";
   164 in
   165 
   166 fun nbe_const c ts = name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list ts ^ ")");
   167 fun nbe_fun c = "c_" ^ translate_string (fn "." => "_" | c => c) c;
   168 fun nbe_free v = name_free `$$` [ML_Syntax.print_string v, ml_list []];
   169 fun nbe_bound v = "v_" ^ v;
   170 
   171 fun nbe_apps e es =
   172   Library.foldr (fn (s, e) => name_app `$$` [e, s]) (es, e);
   173 
   174 fun nbe_abss 0 f = f `$` ml_list []
   175   | nbe_abss n f = name_abs `$$` [string_of_int n, f, ml_list []];
   176 
   177 fun nbe_lookup c = ml_Val (nbe_fun c) (name_lookup_fun `$` ML_Syntax.print_string c);
   178 
   179 val nbe_value = "value";
   180 
   181 end;
   182 
   183 open BasicCodeThingol;
   184 
   185 (* greetings to Tarski *)
   186 
   187 fun assemble_iterm thy is_fun num_args =
   188   let
   189     fun of_iterm t =
   190       let
   191         val (t', ts) = CodeThingol.unfold_app t
   192       in of_iapp t' (fold (cons o of_iterm) ts []) end
   193     and of_iconst c ts = case num_args c
   194      of SOME n => if n <= length ts
   195           then let val (args2, args1) = chop (length ts - n) ts
   196           in nbe_apps (nbe_fun c `$` ml_list args1) args2
   197           end else nbe_const c ts
   198       | NONE => if is_fun c then nbe_apps (nbe_fun c) ts
   199           else nbe_const c ts
   200     and of_iapp (IConst (c, (dss, _))) ts = of_iconst c ts
   201       | of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts
   202       | of_iapp ((v, _) `|-> t) ts =
   203           nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts
   204       | of_iapp (ICase (((t, _), cs), t0)) ts =
   205           nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs
   206             @ [("_", of_iterm t0)])) ts
   207   in of_iterm end;
   208 
   209 fun assemble_fun thy is_fun num_args (c, eqns) =
   210   let
   211     val assemble_arg = assemble_iterm thy (K false) (K NONE);
   212     val assemble_rhs = assemble_iterm thy is_fun num_args;
   213     fun assemble_eqn (args, rhs) =
   214       ([ml_list (map assemble_arg (rev args))], assemble_rhs rhs);
   215     val default_params = map nbe_bound
   216       (Name.invent_list [] "a" ((the o num_args) c));
   217     val default_eqn = ([ml_list default_params], nbe_const c default_params);
   218   in map assemble_eqn eqns @ [default_eqn] end;
   219 
   220 fun assemble_eqnss thy is_fun ([], deps) = ([], "")
   221   | assemble_eqnss thy is_fun (eqnss, deps) =
   222       let
   223         val cs = map fst eqnss;
   224         val num_args = cs ~~ map (fn (_, (args, rhs) :: _) => length args) eqnss;
   225         val funs = fold (fold (CodeThingol.fold_constnames
   226           (insert (op =))) o map snd o snd) eqnss [];
   227         val bind_funs = map nbe_lookup (filter is_fun funs);
   228         val bind_locals = ml_fundefs (map nbe_fun cs ~~ map
   229           (assemble_fun thy is_fun (AList.lookup (op =) num_args)) eqnss);
   230         val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args)
   231           |> ml_delay;
   232       in (cs, ml_Let (bind_funs @ [bind_locals]) result) end;
   233 
   234 fun assemble_eval thy is_fun (((vs, ty), t), deps) =
   235   let
   236     val funs = CodeThingol.fold_constnames (insert (op =)) t [];
   237     val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t [];
   238     val bind_funs = map nbe_lookup (filter is_fun funs);
   239     val bind_value = ml_fundefs [(nbe_value, [([ml_list (map nbe_bound frees)],
   240       assemble_iterm thy is_fun (K NONE) t)])];
   241     val result = ml_list [nbe_value `$` ml_list (map nbe_free frees)]
   242       |> ml_delay;
   243   in ([nbe_value], ml_Let (bind_funs @ [bind_value]) result) end;
   244 
   245 fun eqns_of_stmt ((_, CodeThingol.Fun (_, [])), _) =
   246       NONE
   247   | eqns_of_stmt ((name, CodeThingol.Fun (_, eqns)), deps) =
   248       SOME ((name, map fst eqns), deps)
   249   | eqns_of_stmt ((_, CodeThingol.Datatypecons _), _) =
   250       NONE
   251   | eqns_of_stmt ((_, CodeThingol.Datatype _), _) =
   252       NONE
   253   | eqns_of_stmt ((_, CodeThingol.Class _), _) =
   254       NONE
   255   | eqns_of_stmt ((_, CodeThingol.Classrel _), _) =
   256       NONE
   257   | eqns_of_stmt ((_, CodeThingol.Classop _), _) =
   258       NONE
   259   | eqns_of_stmt ((_, CodeThingol.Classinst _), _) =
   260       NONE;
   261 
   262 fun compile_stmts thy is_fun =
   263   map_filter eqns_of_stmt
   264   #> split_list
   265   #> assemble_eqnss thy is_fun
   266   #> compile_univs (Nbe_Functions.get thy);
   267 
   268 fun eval_term thy is_fun =
   269   assemble_eval thy is_fun
   270   #> compile_univs (Nbe_Functions.get thy)
   271   #> the_single
   272   #> snd;
   273 
   274 
   275 (** compilation and evaluation **)
   276 
   277 (* ensure global functions *)
   278 
   279 fun ensure_funs thy code =
   280   let
   281     fun compile' stmts tab =
   282       let
   283         val compiled = compile_stmts thy (Symtab.defined tab) stmts;
   284       in Nbe_Functions.change thy (fold Symtab.update compiled) end;
   285     val nbe_tab = Nbe_Functions.get thy;
   286     val stmtss = rev (Graph.strong_conn code)
   287       |> (map o map_filter) (fn name => if Symtab.defined nbe_tab name
   288            then NONE
   289            else SOME ((name, Graph.get_node code name), Graph.imm_succs code name))
   290       |> filter_out null
   291   in fold compile' stmtss nbe_tab end;
   292 
   293 (* re-conversion *)
   294 
   295 fun term_of_univ thy t =
   296   let
   297     fun of_apps bounds (t, ts) =
   298       fold_map (of_univ bounds) ts
   299       #>> (fn ts' => list_comb (t, rev ts'))
   300     and of_univ bounds (Const (name, ts)) typidx =
   301           let
   302             val SOME c = CodeName.const_rev thy name;
   303             val T = Code.default_typ thy c;
   304             val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T;
   305             val typidx' = typidx + maxidx_of_typ T' + 1;
   306           in of_apps bounds (Term.Const (c, T'), ts) typidx' end
   307       | of_univ bounds (Free (name, ts)) typidx =
   308           of_apps bounds (Term.Free (name, dummyT), ts) typidx
   309       | of_univ bounds (BVar (name, ts)) typidx =
   310           of_apps bounds (Bound (bounds - name - 1), ts) typidx
   311       | of_univ bounds (t as Abs _) typidx =
   312           typidx
   313           |> of_univ (bounds + 1) (app t (BVar (bounds, [])))
   314           |-> (fn t' => pair (Term.Abs ("u", dummyT, t')))
   315   in of_univ 0 t 0 |> fst end;
   316 
   317 (* evaluation with type reconstruction *)
   318 
   319 fun eval thy code t vs_ty_t deps =
   320   let
   321     val ty = type_of t;
   322     fun subst_Frees [] = I
   323       | subst_Frees inst =
   324           Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s)
   325                             | t => t);
   326     val anno_vars =
   327       subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []))
   328       #> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t []))
   329     fun constrain t =
   330       singleton (Syntax.check_terms (ProofContext.init thy)) (TypeInfer.constrain t ty);
   331     fun check_tvars t = if null (Term.term_tvars t) then t else
   332       error ("Illegal schematic type variables in normalized term: "
   333         ^ setmp show_types true (Sign.string_of_term thy) t);
   334   in
   335     (vs_ty_t, deps)
   336     |> eval_term thy (Symtab.defined (ensure_funs thy code))
   337     |> term_of_univ thy
   338     |> tracing (fn t => "Normalized:\n" ^ setmp show_types true Display.raw_string_of_term t)
   339     |> anno_vars
   340     |> tracing (fn t => "Vars typed:\n" ^ setmp show_types true Display.raw_string_of_term t)
   341     |> constrain
   342     |> tracing (fn t => "Types inferred:\n" ^ setmp show_types true Display.raw_string_of_term t)
   343     |> check_tvars
   344     |> tracing (fn _ => "---\n")
   345   end;
   346 
   347 (* evaluation oracle *)
   348 
   349 exception Normalization of CodeThingol.code * term
   350   * (CodeThingol.typscheme * CodeThingol.iterm) * string list;
   351 
   352 fun normalization_oracle (thy, Normalization (code, t, vs_ty_t, deps)) =
   353   Logic.mk_equals (t, eval thy code t vs_ty_t deps);
   354 
   355 fun normalization_invoke thy code t vs_ty_t deps =
   356   Thm.invoke_oracle_i thy "HOL.normalization" (thy, Normalization (code, t, vs_ty_t, deps));
   357   (*FIXME get rid of hardwired theory name*)
   358 
   359 fun normalization_conv ct =
   360   let
   361     val thy = Thm.theory_of_cterm ct;
   362     fun conv code vs_ty_t deps ct =
   363       let
   364         val t = Thm.term_of ct;
   365       in normalization_invoke thy code t vs_ty_t deps end;
   366   in CodePackage.eval_conv thy conv ct end;
   367 
   368 (* evaluation command *)
   369 
   370 fun norm_print_term ctxt modes t =
   371   let
   372     val thy = ProofContext.theory_of ctxt;
   373     val ct = Thm.cterm_of thy t;
   374     val (_, t') = (Logic.dest_equals o Thm.prop_of o normalization_conv) ct;
   375     val ty = Term.type_of t';
   376     val p = Library.setmp print_mode (modes @ ! print_mode) (fn () =>
   377       Pretty.block [Pretty.quote (ProofContext.pretty_term ctxt t'), Pretty.fbrk,
   378         Pretty.str "::", Pretty.brk 1, Pretty.quote (ProofContext.pretty_typ ctxt ty)]) ();
   379   in Pretty.writeln p end;
   380 
   381 
   382 (** Isar setup **)
   383 
   384 fun norm_print_term_cmd (modes, s) state =
   385   let val ctxt = Toplevel.context_of state
   386   in norm_print_term ctxt modes (Syntax.read_term ctxt s) end;
   387 
   388 val setup = Theory.add_oracle ("normalization", normalization_oracle)
   389 
   390 local structure P = OuterParse and K = OuterKeyword in
   391 
   392 val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) [];
   393 
   394 val nbeP =
   395   OuterSyntax.improper_command "normal_form" "normalize term by evaluation" K.diag
   396     (opt_modes -- P.typ >> (Toplevel.keep o norm_print_term_cmd));
   397 
   398 val _ = OuterSyntax.add_parsers [nbeP];
   399 
   400 end;
   401 
   402 end;