src/Tools/Nbe/nbe_eval.ML
changeset 23930 6d81e2ef69f7
child 23998 694fbb0871eb
equal deleted inserted replaced
23929:6a98d0826daf 23930:6d81e2ef69f7
       
     1 (*  Title:      Tools/Nbe/Nbe_Eval.ML
       
     2     ID:         $Id$
       
     3     Authors:    Klaus Aehlig, LMU Muenchen; Tobias Nipkow, Florian Haftmann, TU Muenchen
       
     4 
       
     5 Evaluation mechanisms for normalization by evaluation.
       
     6 *)
       
     7 
       
     8 (*
       
     9 FIXME:
       
    10 - implement purge operation proper
       
    11 - get rid of BVar (?) - it is only used tor terms to be evaluated, not for functions
       
    12 *)
       
    13 
       
    14 signature NBE_EVAL =
       
    15 sig
       
    16   datatype Univ = 
       
    17       Constr of string * Univ list       (*named constructors*)
       
    18     | Var of string * Univ list          (*free variables*)
       
    19     | BVar of int * Univ list            (*bound named variables*)
       
    20     | Fun of (Univ list -> Univ) * Univ list * int
       
    21                                         (*functions*)
       
    22   val apply: Univ -> Univ -> Univ
       
    23 
       
    24   val univs_ref: (CodegenNames.const * Univ) list ref 
       
    25   val compile_univs: string -> (CodegenNames.const * Univ) list
       
    26   val lookup_fun: CodegenNames.const -> Univ
       
    27 
       
    28   (*preconditions: no Vars/TVars in term*)
       
    29   val eval: theory -> CodegenFuncgr.T -> term -> term
       
    30 
       
    31   val trace: bool ref
       
    32 end;
       
    33 
       
    34 structure Nbe_Eval: NBE_EVAL =
       
    35 struct
       
    36 
       
    37 
       
    38 (* generic non-sense *)
       
    39 
       
    40 val trace = ref false;
       
    41 fun tracing f x = if !trace then (Output.tracing (f x); x) else x;
       
    42 
       
    43 (** the semantical universe **)
       
    44 
       
    45 (*
       
    46    Functions are given by their semantical function value. To avoid
       
    47    trouble with the ML-type system, these functions have the most
       
    48    generic type, that is "Univ list -> Univ". The calling convention is
       
    49    that the arguments come as a list, the last argument first. In
       
    50    other words, a function call that usually would look like
       
    51 
       
    52    f x_1 x_2 ... x_n   or   f(x_1,x_2, ..., x_n)
       
    53 
       
    54    would be in our convention called as
       
    55 
       
    56               f [x_n,..,x_2,x_1]
       
    57 
       
    58    Moreover, to handle functions that are still waiting for some
       
    59    arguments we have additionally a list of arguments collected to far
       
    60    and the number of arguments we're still waiting for.
       
    61 
       
    62    (?) Finally, it might happen, that a function does not get all the
       
    63    arguments it needs. In this case the function must provide means to
       
    64    present itself as a string. As this might be a heavy-wight
       
    65    operation, we delay it. (?) 
       
    66 *)
       
    67 
       
    68 datatype Univ = 
       
    69     Constr of string * Univ list       (*named constructors*)
       
    70   | Var of string * Univ list          (*free variables*)
       
    71   | BVar of int * Univ list            (*bound named variables*)
       
    72   | Fun of (Univ list -> Univ) * Univ list * int
       
    73                                       (*functions*);
       
    74 
       
    75 fun apply (Fun (f, xs, 1)) x = f (x :: xs)
       
    76   | apply (Fun (f, xs, n)) x = Fun (f, x :: xs, n - 1)
       
    77   | apply (Constr (name, args)) x = Constr (name, x :: args)
       
    78   | apply (Var (name, args)) x = Var (name, x :: args)
       
    79   | apply (BVar (name, args)) x = BVar (name, x :: args);
       
    80 
       
    81 
       
    82 (** global functions **)
       
    83 
       
    84 structure Nbe_Data = CodeDataFun
       
    85 (struct
       
    86   type T = Univ Symtab.table;
       
    87   val empty = Symtab.empty;
       
    88   fun merge _ = Symtab.merge (K true);
       
    89   fun purge _ _ _ = Symtab.empty;
       
    90 end);
       
    91 
       
    92 
       
    93 (** sandbox communication **)
       
    94 
       
    95 val univs_ref = ref [] : (string * Univ) list ref;
       
    96 
       
    97 fun compile_univs "" = []
       
    98   | compile_univs raw_s =
       
    99       let
       
   100         val _ = univs_ref := [];
       
   101         val s = "Nbe_Eval.univs_ref := " ^ raw_s;
       
   102         val _ = tracing (fn () => "\n---generated code:\n" ^ s) ();
       
   103         val _ = use_text "" (Output.tracing o enclose "\n---compiler echo:\n" "\n---\n",
       
   104           Output.tracing o enclose "\n--- compiler echo (with error!):\n" "\n---\n")
       
   105           (!trace) s;
       
   106         val univs = case !univs_ref of [] => error "compile_univs" | univs => univs;
       
   107       in univs end;
       
   108 
       
   109 val tab_ref = ref Symtab.empty : Univ Symtab.table ref;
       
   110 fun lookup_fun s = (the o Symtab.lookup (! tab_ref)) s;
       
   111 
       
   112 
       
   113 (** printing ML syntax **)
       
   114 
       
   115 structure S =
       
   116 struct
       
   117 
       
   118 (* generic basics *)
       
   119 
       
   120 fun app e1 e2 = "(" ^ e1 ^ " " ^ e2 ^ ")";
       
   121 fun apps s ss = Library.foldl (uncurry app) (s, ss);
       
   122 fun abs v e = "(fn" ^ v ^ " => " ^ e ^ ")";
       
   123 
       
   124 fun Val v s = "val " ^ v ^ " = " ^ s;
       
   125 fun Let ds e = "let\n" ^ space_implode "\n" ds ^ " in " ^ e ^ " end";
       
   126 
       
   127 val string = ML_Syntax.print_string;
       
   128 fun tup es = "(" ^ commas es ^ ")";
       
   129 fun list es = "[" ^ commas es ^ "]";
       
   130 
       
   131 fun fundefs (eqs :: eqss) =
       
   132   let
       
   133     fun fundef (name, eqs) =
       
   134       let
       
   135         fun eqn (es, e) = name ^ " " ^ space_implode " " es ^ " = " ^ e
       
   136       in space_implode "\n  | " (map eqn eqs) end;
       
   137   in
       
   138     (prefix "fun " o fundef) eqs :: map (prefix "and " o fundef) eqss
       
   139     |> space_implode "\n"
       
   140     |> suffix "\n"
       
   141   end;
       
   142 
       
   143 
       
   144 (* runtime names *)
       
   145 
       
   146 local
       
   147 
       
   148 val Eval =              "Nbe_Eval.";
       
   149 val Eval_Constr =       Eval ^ "Constr";
       
   150 val Eval_apply =        Eval ^ "apply";
       
   151 val Eval_Fun =          Eval ^ "Fun";
       
   152 val Eval_lookup_fun =   Eval ^ "lookup_fun";
       
   153 
       
   154 in
       
   155 
       
   156 (* nbe specific syntax *)
       
   157 
       
   158 fun nbe_constr c args = app Eval_Constr (tup [string c, list args]);
       
   159 
       
   160 fun nbe_const c = "c_" ^ translate_string (fn "." => "_" | c => c) c;
       
   161 
       
   162 fun nbe_free v = "v_" ^ v;
       
   163 
       
   164 fun nbe_apps e es =
       
   165   Library.foldr (fn (s, e) => app (app Eval_apply e) s) (es, e);
       
   166 
       
   167 fun nbe_abs (v, e) =
       
   168   app Eval_Fun (tup [abs (list [nbe_free v]) e, list [], "1"]);
       
   169 
       
   170 fun nbe_fun (c, 0) = tup [string c, app (nbe_const c) (list [])]
       
   171   | nbe_fun (c, n) = tup [string c,
       
   172       app Eval_Fun (tup [nbe_const c, list [], string_of_int n])];
       
   173 
       
   174 fun nbe_lookup c = Val (nbe_const c) (app Eval_lookup_fun (string c));
       
   175 
       
   176 end;
       
   177 
       
   178 end;
       
   179 
       
   180 
       
   181 (** assembling and compiling ML representation of terms **)
       
   182 
       
   183 fun assemble_term thy is_global local_arity =
       
   184   let
       
   185     fun of_term t =
       
   186       let
       
   187         val (t', ts) = strip_comb t
       
   188       in of_termapp t' (fold (cons o of_term) ts []) end
       
   189     and of_termapp (Const cexpr) ts =
       
   190           let
       
   191             val c = (CodegenNames.const thy o CodegenConsts.const_of_cexpr thy) cexpr;
       
   192           in case local_arity c
       
   193            of SOME n => if n <= length ts
       
   194                 then let val (args2, args1) = chop (length ts - n) ts
       
   195                 in S.nbe_apps (S.app (S.nbe_const c) (S.list args1)) args2
       
   196                 end else S.nbe_constr c ts
       
   197             | NONE => if is_global c then S.nbe_apps (S.nbe_const c) ts
       
   198                 else S.nbe_constr c ts
       
   199           end
       
   200       | of_termapp (Free (v, _)) ts = S.nbe_apps (S.nbe_free v) ts
       
   201       | of_termapp (Abs abs) ts =
       
   202           let
       
   203             val (v', t') = Syntax.variant_abs abs;
       
   204           in S.nbe_apps (S.nbe_abs (v', of_term t')) ts end;
       
   205   in of_term end;
       
   206 
       
   207 fun assemble_fun thy is_global local_arity (c, eqns) =
       
   208   let
       
   209     val assemble_arg = assemble_term thy (K false) (K NONE);
       
   210     val assemble_rhs = assemble_term thy is_global local_arity;
       
   211     fun assemble_eqn (args, rhs) =
       
   212       ([S.list (map assemble_arg (rev args))], assemble_rhs rhs);
       
   213     val default_params = map S.nbe_free
       
   214       (Name.invent_list [] "a" ((the o local_arity) c));
       
   215     val default_eqn = ([S.list default_params], S.nbe_constr c default_params);
       
   216   in map assemble_eqn eqns @ [default_eqn] end;
       
   217 
       
   218 fun compile _ _ [] = []
       
   219   | compile _ _ [(_, [])] = []
       
   220   | compile thy is_global fundefs =
       
   221       let
       
   222         val eqnss = (map o apsnd o map) (apfst (snd o strip_comb) o Logic.dest_equals
       
   223           o Logic.unvarify o prop_of) fundefs;
       
   224         val cs = map fst eqnss;
       
   225         val arities = cs ~~ map (fn (_, (args, rhs) :: _) => length args) eqnss;
       
   226         val used_cs = fold (fold (fold_aterms (fn Const cexpr =>
       
   227           insert (op =) ((CodegenNames.const thy o CodegenConsts.const_of_cexpr thy) cexpr)
       
   228             | _ => I)) o map snd o snd) eqnss [];
       
   229         val bind_globals = map S.nbe_lookup (filter is_global used_cs);
       
   230         val bind_locals = S.fundefs (map S.nbe_const cs ~~ map
       
   231           (assemble_fun thy is_global (AList.lookup (op =) arities)) eqnss);
       
   232         val result = S.list (map S.nbe_fun arities);
       
   233       in compile_univs (S.Let (bind_globals @ [bind_locals]) result) end;
       
   234 
       
   235 
       
   236 (** evaluation with greetings to Tarski **)
       
   237 
       
   238 (* conversion and evaluation *)
       
   239 
       
   240 fun univ_of_term thy lookup_fun =
       
   241   let
       
   242     fun of_term vars t =
       
   243       let
       
   244         val (t', ts) = strip_comb t
       
   245       in
       
   246         Library.foldl (uncurry apply)
       
   247           (of_termapp vars  t', map (of_term vars) ts)
       
   248       end
       
   249     and of_termapp vars (Const cexpr) =
       
   250           let
       
   251             val s = (CodegenNames.const thy o CodegenConsts.const_of_cexpr thy) cexpr;
       
   252           in the_default (Constr (s, [])) (lookup_fun s) end
       
   253       | of_termapp vars (Free (v, _)) =
       
   254           the_default (Var (v, [])) (AList.lookup (op =) vars v)
       
   255       | of_termapp vars (Abs abs) =
       
   256           let
       
   257             val (v', t') = Syntax.variant_abs abs;
       
   258           in Fun (fn [x] => of_term (AList.update (op =) (v', x) vars) t', [], 1) end;
       
   259   in of_term [] end;
       
   260 
       
   261 
       
   262 (* ensure global functions *)
       
   263 
       
   264 fun ensure_funs thy funcgr t =
       
   265   let
       
   266     fun consts_of thy t =
       
   267       fold_aterms (fn Const c => cons (CodegenConsts.const_of_cexpr thy c) | _ => I) t []
       
   268     val consts = consts_of thy t;
       
   269     fun compile' eqs tab =
       
   270       let
       
   271         val _ = tab_ref := tab;
       
   272         val compiled = compile thy (Symtab.defined tab) eqs;
       
   273       in Nbe_Data.change thy (fold Symtab.update compiled) end;
       
   274     val nbe_tab = Nbe_Data.get thy;
       
   275   in
       
   276     CodegenFuncgr.deps funcgr consts
       
   277     |> (map o filter_out) (Symtab.defined nbe_tab o CodegenNames.const thy)
       
   278     |> filter_out null
       
   279     |> (map o map) (fn c => (CodegenNames.const thy c, CodegenFuncgr.funcs funcgr c))
       
   280     |> tracing (fn funs => "new definitions: " ^ (commas o maps (map fst)) funs)
       
   281     |> (fn funs => fold compile' funs nbe_tab)
       
   282   end;
       
   283 
       
   284 
       
   285 (* re-conversion *)
       
   286 
       
   287 fun term_of_univ thy t =
       
   288   let
       
   289     fun of_apps bounds (t, ts) =
       
   290       fold_map (of_univ bounds) ts
       
   291       #>> (fn ts' => list_comb (t, rev ts'))
       
   292     and of_univ bounds (Constr (name, ts)) typidx =
       
   293           let
       
   294             val SOME (const as (c, _)) = CodegenNames.const_rev thy name;
       
   295             val T = CodegenData.default_typ thy const;
       
   296             val T' = map_type_tvar (fn ((v, i), S) => TypeInfer.param (typidx + i) (v, S)) T;
       
   297             val typidx' = typidx + maxidx_of_typ T' + 1;
       
   298           in of_apps bounds (Const (c, T'), ts) typidx' end
       
   299       | of_univ bounds (Var (name, ts)) typidx =
       
   300           of_apps bounds (Free (name, dummyT), ts) typidx
       
   301       | of_univ bounds (BVar (name, ts)) typidx =
       
   302           of_apps bounds (Bound (bounds - name - 1), ts) typidx
       
   303       | of_univ bounds (F as Fun _) typidx =
       
   304           typidx
       
   305           |> of_univ (bounds + 1) (apply F (BVar (bounds, [])))
       
   306           |-> (fn t' => pair (Abs ("u", dummyT, t')))
       
   307   in of_univ 0 t 0 |> fst end;
       
   308 
       
   309 
       
   310 (* interface *)
       
   311 
       
   312 fun eval thy funcgr t =
       
   313   let
       
   314     val tab = ensure_funs thy funcgr t;
       
   315     val u = univ_of_term thy (Symtab.lookup tab) t;
       
   316   in term_of_univ thy u end;;
       
   317 
       
   318 end;