src/Tools/nbe.ML
author haftmann
Wed Aug 15 08:57:42 2007 +0200 (2007-08-15)
changeset 24283 8ca96f4e49cd
parent 24219 e558fe311376
child 24292 26ac9fe0e80e
permissions -rw-r--r--
tuned
     1 (*  Title:      Tools/Nbe/Nbe_Eval.ML
     2     ID:         $Id$
     3     Authors:    Klaus Aehlig, LMU Muenchen; Tobias Nipkow, Florian Haftmann, TU Muenchen
     4 
     5 Evaluation mechanisms for normalization by evaluation.
     6 *)
     7 
     8 (*
     9 FIXME:
    10 - get rid of BVar (?) - it is only used for terms to be evaluated, not for functions
    11 - proper purge operation - preliminary for...
    12 - really incremental code generation
    13 *)
    14 
    15 signature NBE =
    16 sig
    17   datatype Univ = 
    18       Const of string * Univ list        (*named constructors*)
    19     | Free of string * Univ list
    20     | BVar of int * Univ list
    21     | Abs of (int * (Univ list -> Univ)) * Univ list;
    22   val free: string -> Univ list -> Univ       (*free (uninterpreted) variables*)
    23   val abs: int -> (Univ list -> Univ) -> Univ list -> Univ
    24                                             (*abstractions as functions*)
    25   val app: Univ -> Univ -> Univ              (*explicit application*)
    26 
    27   val univs_ref: Univ list ref 
    28   val lookup_fun: CodeName.const -> Univ
    29 
    30   val normalization_conv: cterm -> thm
    31 
    32   val trace: bool ref
    33   val setup: theory -> theory
    34 end;
    35 
    36 structure Nbe: NBE =
    37 struct
    38 
    39 (* generic non-sense *)
    40 
    41 val trace = ref false;
    42 fun tracing f x = if !trace then (Output.tracing (f x); x) else x;
    43 
    44 
    45 (** the semantical universe **)
    46 
    47 (*
    48    Functions are given by their semantical function value. To avoid
    49    trouble with the ML-type system, these functions have the most
    50    generic type, that is "Univ list -> Univ". The calling convention is
    51    that the arguments come as a list, the last argument first. In
    52    other words, a function call that usually would look like
    53 
    54    f x_1 x_2 ... x_n   or   f(x_1,x_2, ..., x_n)
    55 
    56    would be in our convention called as
    57 
    58               f [x_n,..,x_2,x_1]
    59 
    60    Moreover, to handle functions that are still waiting for some
    61    arguments we have additionally a list of arguments collected to far
    62    and the number of arguments we're still waiting for.
    63 *)
    64 
    65 datatype Univ = 
    66     Const of string * Univ list        (*named constructors*)
    67   | Free of string * Univ list         (*free variables*)
    68   | BVar of int * Univ list            (*bound named variables*)
    69   | Abs of (int * (Univ list -> Univ)) * Univ list
    70                                       (*functions*);
    71 
    72 (* constructor functions *)
    73 
    74 val free = curry Free;
    75 fun abs n f ts = Abs ((n, f), ts);
    76 fun app (Abs ((1, f), xs)) x = f (x :: xs)
    77   | app (Abs ((n, f), xs)) x = Abs ((n - 1, f), x :: xs)
    78   | app (Const (name, args)) x = Const (name, x :: args)
    79   | app (Free (name, args)) x = Free (name, x :: args)
    80   | app (BVar (name, args)) x = BVar (name, x :: args);
    81 
    82 (* global functions store *)
    83 
    84 structure Nbe_Functions = CodeDataFun
    85 (struct
    86   type T = Univ Symtab.table;
    87   val empty = Symtab.empty;
    88   fun merge _ = Symtab.merge (K true);
    89   fun purge _ _ _ = Symtab.empty;
    90 end);
    91 
    92 (* sandbox communication *)
    93 
    94 val univs_ref = ref [] : Univ list ref;
    95 
    96 local
    97 
    98 val tab_ref = ref NONE : Univ Symtab.table option ref;
    99 
   100 in
   101 
   102 fun lookup_fun s = case ! tab_ref
   103  of NONE => error "compile_univs"
   104   | SOME tab => (the o Symtab.lookup tab) s;
   105 
   106 fun compile_univs tab ([], _) = []
   107   | compile_univs tab (cs, raw_s) =
   108       let
   109         val _ = univs_ref := [];
   110         val s = "Nbe.univs_ref := " ^ raw_s;
   111         val _ = tracing (fn () => "\n---generated code:\n" ^ s) ();
   112         val _ = tab_ref := SOME tab;
   113         val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
   114           Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n")
   115           (!trace) s;
   116         val _ = tab_ref := NONE;
   117         val univs = case !univs_ref of [] => error "compile_univs" | univs => univs;
   118       in cs ~~ univs end;
   119 
   120 end; (*local*)
   121 
   122 
   123 (** assembling and compiling ML code from terms **)
   124 
   125 (* abstract ML syntax *)
   126 
   127 infix 9 `$` `$$`;
   128 fun e1 `$` e2 = "(" ^ e1 ^ " " ^ e2 ^ ")";
   129 fun e `$$` es = "(" ^ e ^ " " ^ space_implode " " es ^ ")";
   130 fun ml_abs v e = "(fn" ^ v ^ " => " ^ e ^ ")";
   131 
   132 fun ml_Val v s = "val " ^ v ^ " = " ^ s;
   133 fun ml_cases t cs =
   134   "(case " ^ t ^ " of " ^ space_implode " | " (map (fn (p, t) => p ^ " => " ^ t) cs) ^ ")";
   135 fun ml_Let ds e = "let\n" ^ space_implode "\n" ds ^ " in " ^ e ^ " end";
   136 
   137 fun ml_list es = "[" ^ commas es ^ "]";
   138 
   139 fun ml_fundefs ([(name, [([], e)])]) =
   140       "val " ^ name ^ " = " ^ e ^ "\n"
   141   | ml_fundefs (eqs :: eqss) =
   142       let
   143         fun fundef (name, eqs) =
   144           let
   145             fun eqn (es, e) = name ^ " " ^ space_implode " " es ^ " = " ^ e
   146           in space_implode "\n  | " (map eqn eqs) end;
   147       in
   148         (prefix "fun " o fundef) eqs :: map (prefix "and " o fundef) eqss
   149         |> space_implode "\n"
   150         |> suffix "\n"
   151       end;
   152 
   153 (* nbe specific syntax *)
   154 
   155 local
   156   val prefix =          "Nbe.";
   157   val name_const =      prefix ^ "Const";
   158   val name_free =       prefix ^ "free";
   159   val name_abs =        prefix ^ "abs";
   160   val name_app =        prefix ^ "app";
   161   val name_lookup_fun = prefix ^ "lookup_fun";
   162 in
   163 
   164 fun nbe_const c ts = name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list ts ^ ")");
   165 fun nbe_fun c = "c_" ^ translate_string (fn "." => "_" | c => c) c;
   166 fun nbe_free v = name_free `$$` [ML_Syntax.print_string v, ml_list []];
   167 fun nbe_bound v = "v_" ^ v;
   168 
   169 fun nbe_apps e es =
   170   Library.foldr (fn (s, e) => name_app `$$` [e, s]) (es, e);
   171 
   172 fun nbe_abss 0 f = f `$` ml_list []
   173   | nbe_abss n f = name_abs `$$` [string_of_int n, f, ml_list []];
   174 
   175 fun nbe_lookup c = ml_Val (nbe_fun c) (name_lookup_fun `$` ML_Syntax.print_string c);
   176 
   177 val nbe_value = "value";
   178 
   179 end;
   180 
   181 open BasicCodeThingol;
   182 
   183 (* greetings to Tarski *)
   184 
   185 fun assemble_iterm thy is_fun num_args =
   186   let
   187     fun of_iterm t =
   188       let
   189         val (t', ts) = CodeThingol.unfold_app t
   190       in of_itermapp t' (fold (cons o of_iterm) ts []) end
   191     and of_itermapp (IConst (c, (dss, _))) ts =
   192           (case num_args c
   193            of SOME n => if n <= length ts
   194                 then let val (args2, args1) = chop (length ts - n) ts
   195                 in nbe_apps (nbe_fun c `$` ml_list args1) args2
   196                 end else nbe_const c ts
   197             | NONE => if is_fun c then nbe_apps (nbe_fun c) ts
   198                 else nbe_const c ts)
   199       | of_itermapp (IVar v) ts = nbe_apps (nbe_bound v) ts
   200       | of_itermapp ((v, _) `|-> t) ts =
   201           nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts
   202       | of_itermapp (ICase (((t, _), cs), t0)) ts =
   203           nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs
   204             @ [("_", of_iterm t0)])) ts
   205   in of_iterm end;
   206 
   207 fun assemble_fun thy is_fun num_args (c, eqns) =
   208   let
   209     val assemble_arg = assemble_iterm thy (K false) (K NONE);
   210     val assemble_rhs = assemble_iterm thy is_fun num_args;
   211     fun assemble_eqn (args, rhs) =
   212       ([ml_list (map assemble_arg (rev args))], assemble_rhs rhs);
   213     val default_params = map nbe_bound
   214       (Name.invent_list [] "a" ((the o num_args) c));
   215     val default_eqn = ([ml_list default_params], nbe_const c default_params);
   216   in map assemble_eqn eqns @ [default_eqn] end;
   217 
   218 fun assemble_eqnss thy is_fun [] = ([], "")
   219   | assemble_eqnss thy is_fun eqnss =
   220       let
   221         val cs = map fst eqnss;
   222         val num_args = cs ~~ map (fn (_, (args, rhs) :: _) => length args) eqnss;
   223         val funs = fold (fold (CodeThingol.fold_constnames
   224           (insert (op =))) o map snd o snd) eqnss [];
   225         val bind_funs = map nbe_lookup (filter is_fun funs);
   226         val bind_locals = ml_fundefs (map nbe_fun cs ~~ map
   227           (assemble_fun thy is_fun (AList.lookup (op =) num_args)) eqnss);
   228         val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args);
   229       in (cs, ml_Let (bind_funs @ [bind_locals]) result) end;
   230 
   231 fun assemble_eval thy is_fun (t, deps) =
   232   let
   233     val funs = CodeThingol.fold_constnames (insert (op =)) t [];
   234     val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t [];
   235     val bind_funs = map nbe_lookup (filter is_fun funs);
   236     val bind_value = ml_fundefs [(nbe_value, [([ml_list (map nbe_bound frees)],
   237       assemble_iterm thy is_fun (K NONE) t)])];
   238     val result = ml_list [nbe_value `$` ml_list (map nbe_free frees)];
   239   in ([nbe_value], ml_Let (bind_funs @ [bind_value]) result) end;
   240 
   241 fun eqns_of_stmt (name, CodeThingol.Fun ([], _)) =
   242       NONE
   243   | eqns_of_stmt (name, CodeThingol.Fun (eqns, _)) =
   244       SOME (name, eqns)
   245   | eqns_of_stmt (_, CodeThingol.Datatypecons _) =
   246       NONE
   247   | eqns_of_stmt (_, CodeThingol.Datatype _) =
   248       NONE
   249   | eqns_of_stmt (_, CodeThingol.Class _) =
   250       NONE
   251   | eqns_of_stmt (_, CodeThingol.Classrel _) =
   252       NONE
   253   | eqns_of_stmt (_, CodeThingol.Classop _) =
   254       NONE
   255   | eqns_of_stmt (_, CodeThingol.Classinst _) =
   256       NONE;
   257 
   258 fun compile_stmts thy is_fun =
   259   map_filter eqns_of_stmt
   260   #> assemble_eqnss thy is_fun
   261   #> compile_univs (Nbe_Functions.get thy);
   262 
   263 fun eval_term thy is_fun =
   264   assemble_eval thy is_fun
   265   #> compile_univs (Nbe_Functions.get thy)
   266   #> the_single
   267   #> snd;
   268 
   269 
   270 (** compilation and evaluation **)
   271 
   272 (* ensure global functions *)
   273 
   274 fun ensure_funs thy code =
   275   let
   276     fun compile' stmts tab =
   277       let
   278         val compiled = compile_stmts thy (Symtab.defined tab) stmts;
   279       in Nbe_Functions.change thy (fold Symtab.update compiled) end;
   280     val nbe_tab = Nbe_Functions.get thy;
   281     val stmtss =
   282       map (AList.make (Graph.get_node code)) (rev (Graph.strong_conn code))
   283       |> (map o filter_out) (Symtab.defined nbe_tab o fst)
   284   in fold compile' stmtss nbe_tab end;
   285 
   286 (* re-conversion *)
   287 
   288 fun term_of_univ thy t =
   289   let
   290     fun of_apps bounds (t, ts) =
   291       fold_map (of_univ bounds) ts
   292       #>> (fn ts' => list_comb (t, rev ts'))
   293     and of_univ bounds (Const (name, ts)) typidx =
   294           let
   295             val SOME (const as (c, _)) = CodeName.const_rev thy name;
   296             val T = Code.default_typ thy const;
   297             val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T;
   298             val typidx' = typidx + maxidx_of_typ T' + 1;
   299           in of_apps bounds (Term.Const (c, T'), ts) typidx' end
   300       | of_univ bounds (Free (name, ts)) typidx =
   301           of_apps bounds (Term.Free (name, dummyT), ts) typidx
   302       | of_univ bounds (BVar (name, ts)) typidx =
   303           of_apps bounds (Bound (bounds - name - 1), ts) typidx
   304       | of_univ bounds (t as Abs _) typidx =
   305           typidx
   306           |> of_univ (bounds + 1) (app t (BVar (bounds, [])))
   307           |-> (fn t' => pair (Term.Abs ("u", dummyT, t')))
   308   in of_univ 0 t 0 |> fst end;
   309 
   310 (* evaluation with type reconstruction *)
   311 
   312 fun eval thy code t t' deps =
   313   let
   314     fun subst_Frees [] = I
   315       | subst_Frees inst =
   316           Term.map_aterms (fn (t as Term.Free (s, _)) => the_default t (AList.lookup (op =) inst s)
   317                             | t => t);
   318     val anno_vars =
   319       subst_Frees (map (fn (s, T) => (s, Term.Free (s, T))) (Term.add_frees t []))
   320       #> subst_Vars (map (fn (ixn, T) => (ixn, Var (ixn, T))) (Term.add_vars t []))
   321     fun check_tvars t = if null (Term.term_tvars t) then t else
   322       error ("Illegal schematic type variables in normalized term: "
   323         ^ setmp show_types true (Sign.string_of_term thy) t);
   324     val ty = type_of t;
   325     fun constrain t =
   326       singleton (ProofContext.infer_types_pats (ProofContext.init thy)) (TypeInfer.constrain t ty);
   327   in
   328     (t', deps)
   329     |> eval_term thy (Symtab.defined (ensure_funs thy code))
   330     |> term_of_univ thy
   331     |> tracing (fn t => "Normalized:\n" ^ setmp show_types true Display.raw_string_of_term t)
   332     |> tracing (fn _ => "Term type:\n" ^ Display.raw_string_of_typ ty)
   333     |> anno_vars
   334     |> tracing (fn t => "Vars typed:\n" ^ setmp show_types true Display.raw_string_of_term t)
   335     |> tracing (fn t => setmp show_types true (Sign.string_of_term thy) t)
   336     |> constrain
   337     |> tracing (fn t => "Types inferred:\n" ^ setmp show_types true Display.raw_string_of_term t)
   338     |> check_tvars
   339   end;
   340 
   341 (* evaluation oracle *)
   342 
   343 exception Normalization of CodeThingol.code * term * CodeThingol.iterm * string list;
   344 
   345 fun normalization_oracle (thy, Normalization (code, t, t', deps)) =
   346   Logic.mk_equals (t, eval thy code t t' deps);
   347 
   348 fun normalization_invoke thy code t t' deps =
   349   Thm.invoke_oracle_i thy "Code_Setup.normalization" (thy, Normalization (code, t, t', deps));
   350   (*FIXME get rid of hardwired theory name*)
   351 
   352 fun normalization_conv ct =
   353   let
   354     val thy = Thm.theory_of_cterm ct;
   355     fun conv code (t', ty') deps ct =
   356       let
   357         val t = Thm.term_of ct;
   358       in normalization_invoke thy code t t' deps end;
   359   in CodePackage.eval_conv thy conv ct end;
   360 
   361 (* evaluation command *)
   362 
   363 fun norm_print_term ctxt modes t =
   364   let
   365     val thy = ProofContext.theory_of ctxt;
   366     val ct = Thm.cterm_of thy t;
   367     val (_, t') = (Logic.dest_equals o Thm.prop_of o normalization_conv) ct;
   368     val ty = Term.type_of t';
   369     val p = Library.setmp print_mode (modes @ ! print_mode) (fn () =>
   370       Pretty.block [Pretty.quote (ProofContext.pretty_term ctxt t'), Pretty.fbrk,
   371         Pretty.str "::", Pretty.brk 1, Pretty.quote (ProofContext.pretty_typ ctxt ty)]) ();
   372   in Pretty.writeln p end;
   373 
   374 
   375 (** Isar setup **)
   376 
   377 fun norm_print_term_cmd (modes, raw_t) state =
   378   let val ctxt = Toplevel.context_of state
   379   in norm_print_term ctxt modes (ProofContext.read_term ctxt raw_t) end;
   380 
   381 val setup = Theory.add_oracle ("normalization", normalization_oracle)
   382 
   383 local structure P = OuterParse and K = OuterKeyword in
   384 
   385 val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) [];
   386 
   387 val nbeP =
   388   OuterSyntax.improper_command "normal_form" "normalize term by evaluation" K.diag
   389     (opt_modes -- P.typ >> (Toplevel.keep o norm_print_term_cmd));
   390 
   391 val _ = OuterSyntax.add_parsers [nbeP];
   392 
   393 end;
   394 
   395 end;