src/Tools/nbe.ML
changeset 25924 f974a1c64348
parent 25865 a141d6bfd398
child 25935 ce3cd5f0c4ee
equal deleted inserted replaced
25923:5fe4b543512e 25924:f974a1c64348
    10   val norm_conv: cterm -> thm
    10   val norm_conv: cterm -> thm
    11   val norm_term: theory -> term -> term
    11   val norm_term: theory -> term -> term
    12 
    12 
    13   datatype Univ =
    13   datatype Univ =
    14       Const of string * Univ list            (*named (uninterpreted) constants*)
    14       Const of string * Univ list            (*named (uninterpreted) constants*)
    15     | Free of string * Univ list
    15     | Free of string * Univ list             (*free (uninterpreted) variables*)
    16     | DFree of string                        (*free (uninterpreted) dictionary parameters*)
    16     | DFree of string * int                  (*free (uninterpreted) dictionary parameters*)
    17     | BVar of int * Univ list
    17     | BVar of int * Univ list
    18     | Abs of (int * (Univ list -> Univ)) * Univ list;
    18     | Abs of (int * (Univ list -> Univ)) * Univ list;
    19   val free: string -> Univ                   (*free (uninterpreted) variables*)
    19   val apps: Univ -> Univ list -> Univ        (*explicit applications*)
    20   val app: Univ -> Univ -> Univ              (*explicit application*)
       
    21   val abs: int -> (Univ list -> Univ) -> Univ
    20   val abs: int -> (Univ list -> Univ) -> Univ
    22                                              (*abstractions as closures*)
    21                                             (*abstractions as closures*)
    23 
    22 
    24   val univs_ref: (unit -> Univ list -> Univ list) option ref
    23   val univs_ref: (unit -> Univ list -> Univ list) option ref
       
    24   val norm_invoke: theory -> CodeThingol.code -> term
       
    25     -> CodeThingol.typscheme * CodeThingol.iterm -> string list -> thm
    25   val trace: bool ref
    26   val trace: bool ref
       
    27 
    26   val setup: theory -> theory
    28   val setup: theory -> theory
    27 end;
    29 end;
    28 
    30 
    29 structure Nbe: NBE =
    31 structure Nbe: NBE =
    30 struct
    32 struct
    56 *)
    58 *)
    57 
    59 
    58 datatype Univ =
    60 datatype Univ =
    59     Const of string * Univ list        (*named (uninterpreted) constants*)
    61     Const of string * Univ list        (*named (uninterpreted) constants*)
    60   | Free of string * Univ list         (*free variables*)
    62   | Free of string * Univ list         (*free variables*)
    61   | DFree of string                    (*free (uninterpreted) dictionary parameters*)
    63   | DFree of string * int              (*free (uninterpreted) dictionary parameters*)
    62   | BVar of int * Univ list            (*bound named variables*)
    64   | BVar of int * Univ list            (*bound named variables*)
    63   | Abs of (int * (Univ list -> Univ)) * Univ list
    65   | Abs of (int * (Univ list -> Univ)) * Univ list
    64                                       (*abstractions as closures*);
    66                                       (*abstractions as closures*);
    65 
    67 
    66 (* constructor functions *)
    68 (* constructor functions *)
    67 
    69 
    68 fun free v = Free (v, []);
       
    69 fun abs n f = Abs ((n, f), []);
    70 fun abs n f = Abs ((n, f), []);
    70 fun app (Abs ((1, f), xs)) x = f (x :: xs)
    71 fun apps (Abs ((n, f), xs)) ys = let val k = n - length ys in
    71   | app (Abs ((n, f), xs)) x = Abs ((n - 1, f), x :: xs)
    72       if k = 0 then f (ys @ xs)
    72   | app (Const (name, args)) x = Const (name, x :: args)
    73       else if k < 0 then
    73   | app (Free (name, args)) x = Free (name, x :: args)
    74         let val (zs, ws) = chop (~ k) ys
    74   | app (BVar (name, args)) x = BVar (name, x :: args);
    75         in apps (f (ws @ xs)) zs end
       
    76       else Abs ((k, f), ys @ xs) end (*note: reverse convention also for apps!*)
       
    77   | apps (Const (name, xs)) ys = Const (name, ys @ xs)
       
    78   | apps (Free (name, xs)) ys = Free (name, ys @ xs)
       
    79   | apps (BVar (name, xs)) ys = BVar (name, ys @ xs);
    75 
    80 
    76 (* universe graph *)
    81 (* universe graph *)
    77 
    82 
    78 type univ_gr = Univ option Graph.T;
    83 type univ_gr = Univ option Graph.T;
    79 val compiled : univ_gr -> string -> bool = can o Graph.get_node;
    84 val compiled : univ_gr -> string -> bool = can o Graph.get_node;
   112 (* nbe specific syntax *)
   117 (* nbe specific syntax *)
   113 
   118 
   114 local
   119 local
   115   val prefix =          "Nbe.";
   120   val prefix =          "Nbe.";
   116   val name_const =      prefix ^ "Const";
   121   val name_const =      prefix ^ "Const";
   117   val name_free =       prefix ^ "free";
       
   118   val name_dfree =      prefix ^ "DFree";
       
   119   val name_abs =        prefix ^ "abs";
   122   val name_abs =        prefix ^ "abs";
   120   val name_app =        prefix ^ "app";
   123   val name_apps =       prefix ^ "apps";
   121   val name_lookup_fun = prefix ^ "lookup_fun";
       
   122 in
   124 in
   123 
   125 
   124 fun nbe_const c ts =
   126 fun nbe_fun' c = "c_" ^ translate_string (fn "." => "_" | c => c) c;
   125   name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list ts ^ ")");
   127 val nbe_fun = nbe_fun'; (*FIXME!*)
   126 fun nbe_fun c = "c_" ^ translate_string (fn "." => "_" | c => c) c;
       
   127 fun nbe_free v = name_free `$` ML_Syntax.print_string v;
       
   128 fun nbe_dfree v = name_dfree `$` ML_Syntax.print_string v;
       
   129 fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n;
   128 fun nbe_dict v n = "d_" ^ v ^ "_" ^ string_of_int n;
   130 fun nbe_bound v = "v_" ^ v;
   129 fun nbe_bound v = "v_" ^ v;
   131 
   130 val nbe_value = "";
   132 fun nbe_apps e es =
   131 
   133   Library.foldr (fn (s, e) => name_app `$$` [e, s]) (es, e);
   132 (*note: these three are the "turning spots" where proper argument order is established!*)
       
   133 fun nbe_apps t [] = t
       
   134   | nbe_apps t ts = name_apps `$$` [t, ml_list (rev ts)];
       
   135 fun nbe_apps_local c ts = nbe_fun c `$` ml_list (rev ts);
       
   136 fun nbe_apps_constr c ts =
       
   137   name_const `$` ("(" ^ ML_Syntax.print_string c ^ ", " ^ ml_list (rev ts) ^ ")");
       
   138 
   134 
   139 
   135 fun nbe_abss 0 f = f `$` ml_list []
   140 fun nbe_abss 0 f = f `$` ml_list []
   136   | nbe_abss n f = name_abs `$$` [string_of_int n, f];
   141   | nbe_abss n f = name_abs `$$` [string_of_int n, f];
   137 
       
   138 val nbe_value = "value";
       
   139 
   142 
   140 end;
   143 end;
   141 
   144 
   142 open BasicCodeThingol;
   145 open BasicCodeThingol;
   143 
   146 
   152       Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n")
   155       Output.tracing o enclose "\n--- compiler echo (with error):\n" "\n---\n")
   153       (!trace) ("Nbe.univs_ref", univs_ref);
   156       (!trace) ("Nbe.univs_ref", univs_ref);
   154 
   157 
   155 (* code generation *)
   158 (* code generation *)
   156 
   159 
   157 fun assemble_idict (DictConst (inst, dss)) =
   160 datatype const_kind = Local of int | Global | Constr;
   158       nbe_apps (nbe_fun inst) ((maps o map) assemble_idict dss)
   161 
   159   | assemble_idict (DictVar (supers, (v, (n, _)))) =
   162 fun assemble_constapp kind c dss ts = 
   160       fold_rev (fn super => nbe_apps (nbe_fun super) o single) supers (nbe_dict v n);
   163       let
   161 
   164         val ts' = (maps o map) (assemble_idict kind) dss @ ts;
   162 fun assemble_iterm is_fun num_args =
   165       in case kind c
       
   166        of Local n => if n <= length ts'
       
   167             then let val (ts1, ts2) = chop n ts'
       
   168             in nbe_apps (nbe_apps_local c ts1) ts2
       
   169             end else nbe_apps (nbe_abss n (nbe_fun c)) ts'
       
   170         | Global => nbe_apps (nbe_fun c) ts'
       
   171         | Constr => nbe_apps_constr c ts'
       
   172       end
       
   173 and assemble_idict kind (DictConst (inst, dss)) =
       
   174       assemble_constapp kind inst dss []
       
   175   | assemble_idict kind (DictVar (supers, (v, (n, _)))) =
       
   176       fold_rev (fn super => assemble_constapp kind super [] o single) supers (nbe_dict v n);
       
   177 
       
   178 fun assemble_iterm kind =
   163   let
   179   let
   164     fun of_iterm t =
   180     fun of_iterm t =
   165       let
   181       let
   166         val (t', ts) = CodeThingol.unfold_app t
   182         val (t', ts) = CodeThingol.unfold_app t
   167       in of_iapp t' (fold (cons o of_iterm) ts []) end
   183       in of_iapp t' (fold_rev (cons o of_iterm) ts []) end
   168     and of_iconst c ts = case num_args c
   184     and of_iapp (IConst (c, (dss, _))) ts = assemble_constapp kind c dss ts
   169      of SOME n => if n <= length ts
       
   170           then let val (args2, args1) = chop (length ts - n) ts
       
   171           in nbe_apps (nbe_fun c `$` ml_list args1) args2
       
   172           end else nbe_const c ts
       
   173       | NONE => if is_fun c then nbe_apps (nbe_fun c) ts
       
   174           else nbe_const c ts
       
   175     and of_iapp (IConst (c, (dss, _))) ts = of_iconst c
       
   176           (ts @ rev ((maps o map) assemble_idict dss))
       
   177       | of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts
   185       | of_iapp (IVar v) ts = nbe_apps (nbe_bound v) ts
   178       | of_iapp ((v, _) `|-> t) ts =
   186       | of_iapp ((v, _) `|-> t) ts =
   179           nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts
   187           nbe_apps (nbe_abss 1 (ml_abs (ml_list [nbe_bound v]) (of_iterm t))) ts
   180       | of_iapp (ICase (((t, _), cs), t0)) ts =
   188       | of_iapp (ICase (((t, _), cs), t0)) ts =
   181           nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs
   189           nbe_apps (ml_cases (of_iterm t) (map (pairself of_iterm) cs
   182             @ [("_", of_iterm t0)])) ts
   190             @ [("_", of_iterm t0)])) ts
   183   in of_iterm end;
   191   in of_iterm end;
   184 
   192 
   185 fun assemble_fun gr num_args (c, (vs, eqns)) =
   193 fun assemble_eqns kind (c, (vs, eqns)) =
   186   let
   194   let
   187     val assemble_arg = assemble_iterm (K false) (K NONE);
   195     val dict_args = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs;
   188     val assemble_rhs = assemble_iterm (is_some o Graph.get_node gr) num_args;
   196     val assemble_arg = assemble_iterm (K Constr);
   189     val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs
   197     val assemble_rhs = assemble_iterm kind;
   190       |> rev;
       
   191     fun assemble_eqn (args, rhs) =
   198     fun assemble_eqn (args, rhs) =
   192       ([ml_list (map assemble_arg (rev args) @ dict_params)], assemble_rhs rhs);
   199       ([ml_list (rev (dict_args @ map assemble_arg args))], assemble_rhs rhs);
   193     val default_params = map nbe_bound (Name.invent_list [] "a" ((the o num_args) c));
   200     val default_args = dict_args @ map nbe_bound (Name.invent_list [] "a" ((length o fst o hd) eqns));
   194     val default_eqn = ([ml_list default_params], nbe_const c default_params);
   201     val default_eqn = ([ml_list (rev default_args)], nbe_apps_constr c default_args);
   195   in map assemble_eqn eqns @ [default_eqn] end;
   202   in (nbe_fun' c, map assemble_eqn eqns @ [default_eqn]) end;
   196 
   203 
   197 fun assemble_eqnss gr deps [] = ([], ("", []))
   204 fun assemble_eqnss gr deps [] = ([], ("", []))
   198   | assemble_eqnss gr deps eqnss =
   205   | assemble_eqnss gr deps eqnss =
   199       let
   206       let
   200         val cs = map fst eqnss;
   207         val cs = map fst eqnss;
   201         val num_args = cs ~~ map (fn (_, (vs, (args, rhs) :: _)) =>
   208         val num_args = cs ~~ map (fn (_, (vs, (args, rhs) :: _)) =>
   202           length (maps snd vs) + length args) eqnss;
   209           length (maps snd vs) + length args) eqnss;
       
   210         fun kind c = case AList.lookup (op =) num_args c
       
   211          of SOME n => Local n
       
   212           | NONE => if (is_some o Option.join o try (Graph.get_node gr)) c
       
   213               then Global else Constr;
   203         val deps' = filter (is_some o Option.join o try (Graph.get_node gr)) deps;
   214         val deps' = filter (is_some o Option.join o try (Graph.get_node gr)) deps;
   204         val bind_deps = ml_list (map nbe_fun deps');
   215         val bind_deps = ml_list (map nbe_fun' deps');
   205         val bind_locals = ml_fundefs (map nbe_fun cs ~~ map
   216         val bind_locals = ml_fundefs (map (assemble_eqns kind) eqnss);
   206           (assemble_fun gr (AList.lookup (op =) num_args)) eqnss);
       
   207         val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args);
   217         val result = ml_list (map (fn (c, n) => nbe_abss n (nbe_fun c)) num_args);
   208         val arg_deps = map (the o Graph.get_node gr) deps';
   218         val arg_deps = map (the o Graph.get_node gr) deps';
   209       in (cs, (ml_abs bind_deps (ml_Let [bind_locals] result), arg_deps)) end;
   219       in (cs, (ml_abs bind_deps (ml_Let [bind_locals] result), arg_deps)) end;
   210 
   220 
   211 fun compile_eqnss gr deps eqnss = case assemble_eqnss gr deps eqnss
   221 fun compile_eqnss gr deps eqnss = case assemble_eqnss gr deps eqnss
   261     fun add_stmts names gr = if exists (compiled gr) names then gr else gr
   271     fun add_stmts names gr = if exists (compiled gr) names then gr else gr
   262       |> compile_stmts (map (fn name => ((name, Graph.get_node code name),
   272       |> compile_stmts (map (fn name => ((name, Graph.get_node code name),
   263           Graph.imm_succs code name)) names);
   273           Graph.imm_succs code name)) names);
   264   in fold_rev add_stmts (Graph.strong_conn code) end;
   274   in fold_rev add_stmts (Graph.strong_conn code) end;
   265 
   275 
   266 fun assemble_eval gr deps ((vs, ty), t) =
   276 fun eval_term gr deps ((vs, ty), t) =
   267   let
   277   let 
   268     val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t [];
   278     val frees = CodeThingol.fold_unbound_varnames (insert (op =)) t []
   269     val deps' = filter (is_some o Option.join o try (Graph.get_node gr)) deps;
   279     val frees' = map (fn v => Free (v, [])) frees;
   270     val bind_deps = ml_list (map nbe_fun deps');
   280     val dict_frees = maps (fn (v, sort) => map_index (curry DFree v o fst) sort) vs;
   271     val dict_params = maps (fn (v, sort) => map_index (nbe_dict v o fst) sort) vs
   281   in
   272       |> rev;
   282     (nbe_value, (vs, [(map IVar frees, t)]))
   273     val bind_value = ml_fundefs [(nbe_value,
   283     |> singleton (compile_eqnss gr deps)
   274       [([ml_list (map nbe_bound frees @ dict_params)],
   284     |> snd
   275         assemble_iterm (is_some o Graph.get_node gr) (K NONE) t)])];
   285     |> (fn t => apps t (rev (dict_frees @ frees')))
   276     val result = ml_list [nbe_value `$` ml_list
   286   end;
   277       (map nbe_free frees @ map nbe_dfree dict_params)];
       
   278     val arg_deps = map (the o Graph.get_node gr) deps';
       
   279   in (ml_abs bind_deps (ml_Let [bind_value] result), arg_deps) end;
       
   280 
       
   281 fun eval_term gr deps t' =
       
   282   let
       
   283     val (s, args) = assemble_eval gr deps t';
       
   284   in the_single (compile s args) end;
       
   285 
   287 
   286 
   288 
   287 (** evaluation **)
   289 (** evaluation **)
   288 
   290 
   289 (* reification *)
   291 (* reification *)
   313           of_apps bounds (Term.Free (name, dummyT), ts) typidx
   315           of_apps bounds (Term.Free (name, dummyT), ts) typidx
   314       | of_univ bounds (BVar (name, ts)) typidx =
   316       | of_univ bounds (BVar (name, ts)) typidx =
   315           of_apps bounds (Bound (bounds - name - 1), ts) typidx
   317           of_apps bounds (Bound (bounds - name - 1), ts) typidx
   316       | of_univ bounds (t as Abs _) typidx =
   318       | of_univ bounds (t as Abs _) typidx =
   317           typidx
   319           typidx
   318           |> of_univ (bounds + 1) (app t (BVar (bounds, [])))
   320           |> of_univ (bounds + 1) (apps t [BVar (bounds, [])])
   319           |-> (fn t' => pair (Term.Abs ("u", dummyT, t')))
   321           |-> (fn t' => pair (Term.Abs ("u", dummyT, t')))
   320   in of_univ 0 t 0 |> fst end;
   322   in of_univ 0 t 0 |> fst end;
   321 
   323 
   322 (* function store *)
   324 (* function store *)
   323 
   325