src/Tools/Code/code_runtime.ML
changeset 64959 9ca021bd718d
parent 64958 85b87da722ab
child 64987 1985502518ce
     1.1 --- a/src/Tools/Code/code_runtime.ML	Thu Jan 26 16:25:32 2017 +0100
     1.2 +++ b/src/Tools/Code/code_runtime.ML	Fri Jan 27 22:27:03 2017 +0100
     1.3 @@ -28,10 +28,8 @@
     1.4      -> Proof.context -> term -> 'a Exn.result
     1.5    val dynamic_holds_conv: Proof.context -> conv
     1.6    val static_holds_conv: { ctxt: Proof.context, consts: string list } -> Proof.context -> conv
     1.7 -  val experimental_computation: (term -> 'a) cookie
     1.8 -    -> { ctxt: Proof.context, lift_postproc: (term -> term) -> 'a -> 'a,
     1.9 -           terms: term list, T: typ }
    1.10 -    -> Proof.context -> term -> 'a (*EXPERIMENTAL!*)
    1.11 +  val mount_computation: Proof.context -> (string * typ) list -> typ
    1.12 +    -> (term -> 'b) -> ((term -> term) -> 'b -> 'a) -> Proof.context -> term -> 'a
    1.13    val code_reflect: (string * string list option) list -> string list -> string
    1.14      -> string option -> theory -> theory
    1.15    datatype truth = Holds
    1.16 @@ -49,9 +47,9 @@
    1.17  
    1.18  (* technical prerequisites *)
    1.19  
    1.20 -val this = "Code_Runtime";
    1.21 -val s_truth = Long_Name.append this "truth";
    1.22 -val s_Holds = Long_Name.append this "Holds";
    1.23 +val thisN = "Code_Runtime";
    1.24 +val truthN = Long_Name.append thisN "truth";
    1.25 +val HoldsN = Long_Name.append thisN "Holds";
    1.26  
    1.27  val target = "Eval";
    1.28  
    1.29 @@ -60,10 +58,10 @@
    1.30  val _ = Theory.setup
    1.31    (Code_Target.add_derived_target (target, [(Code_ML.target_SML, I)])
    1.32    #> Code_Target.set_printings (Type_Constructor (@{type_name prop},
    1.33 -    [(target, SOME (0, (K o K o K) (Code_Printer.str s_truth)))]))
    1.34 +    [(target, SOME (0, (K o K o K) (Code_Printer.str truthN)))]))
    1.35    #> Code_Target.set_printings (Constant (@{const_name Code_Generator.holds},
    1.36 -    [(target, SOME (Code_Printer.plain_const_syntax s_Holds))]))
    1.37 -  #> Code_Target.add_reserved target this
    1.38 +    [(target, SOME (Code_Printer.plain_const_syntax HoldsN))]))
    1.39 +  #> Code_Target.add_reserved target thisN
    1.40    #> fold (Code_Target.add_reserved target) ["oo", "ooo", "oooo", "upto", "downto", "orf", "andf"]);
    1.41         (*avoid further pervasive infix names*)
    1.42  
    1.43 @@ -153,7 +151,7 @@
    1.44    fun init _ = empty;
    1.45  );
    1.46  val put_truth = Truth_Result.put;
    1.47 -val truth_cookie = (Truth_Result.get, put_truth, Long_Name.append this "put_truth");
    1.48 +val truth_cookie = (Truth_Result.get, put_truth, Long_Name.append thisN "put_truth");
    1.49  
    1.50  local
    1.51  
    1.52 @@ -231,6 +229,8 @@
    1.53  
    1.54  in
    1.55  
    1.56 +val covered_constsN = "covered_consts";
    1.57 +
    1.58  fun of_term_for_typ Ts =
    1.59    let
    1.60      val names = Ts
    1.61 @@ -277,6 +277,8 @@
    1.62  
    1.63  (* code generation for of the universal morphism *)
    1.64  
    1.65 +val print_const = ML_Syntax.print_pair ML_Syntax.print_string ML_Syntax.print_typ;
    1.66 +
    1.67  fun print_of_term_funs { typ_sign_for, eval_for_const, of_term_for_typ } Ts =
    1.68    let
    1.69      val var_names = map_range (fn n => "t" ^ string_of_int (n + 1));
    1.70 @@ -305,42 +307,47 @@
    1.71      |> prefix "fun "
    1.72    end;
    1.73  
    1.74 -fun print_computation_code ctxt compiled_value requested_Ts cTs =
    1.75 -  let
    1.76 -    val proper_cTs = map_filter I cTs;
    1.77 -    val typ_sign_for = typ_signatures proper_cTs;
    1.78 -    fun add_typ T Ts =
    1.79 -      if member (op =) Ts T
    1.80 -      then Ts
    1.81 -      else Ts
    1.82 -        |> cons T
    1.83 -        |> fold (fold add_typ o snd) (typ_sign_for T);
    1.84 -    val required_Ts = fold add_typ requested_Ts [];
    1.85 -    val of_term_for_typ' = of_term_for_typ required_Ts;
    1.86 -    val eval_for_const' = eval_for_const ctxt proper_cTs;
    1.87 -    val eval_for_const'' = the_default "_" o Option.map eval_for_const';
    1.88 -    val eval_tuple = enclose "(" ")" (commas (map eval_for_const' proper_cTs));
    1.89 -    fun mk_abs s = "fn " ^ s ^ " => ";
    1.90 -    val eval_abs = space_implode ""
    1.91 -      (map (mk_abs o eval_for_const'') cTs);
    1.92 -    val of_term_code = print_of_term_funs {
    1.93 -      typ_sign_for = typ_sign_for,
    1.94 -      eval_for_const = eval_for_const',
    1.95 -      of_term_for_typ = of_term_for_typ' } required_Ts;
    1.96 -  in
    1.97 -    (cat_lines [
    1.98 -      "structure " ^ generated_computationN ^ " =",
    1.99 -      "struct",
   1.100 -      "",
   1.101 -      "val " ^ eval_tuple ^ " = " ^ compiled_value ^ " ()",
   1.102 -      "  (" ^ eval_abs,
   1.103 -      "    " ^ eval_tuple ^ ");",
   1.104 -      "",
   1.105 -      of_term_code,
   1.106 -      "",
   1.107 -      "end"
   1.108 -    ], map (prefix (generated_computationN ^ ".") o of_term_for_typ') requested_Ts)
   1.109 -  end;
   1.110 +fun print_computation_code ctxt compiled_value [] requested_Ts = ("", [])
   1.111 +  | print_computation_code ctxt compiled_value cTs requested_Ts =
   1.112 +      let
   1.113 +        val proper_cTs = map_filter I cTs;
   1.114 +        val typ_sign_for = typ_signatures proper_cTs;
   1.115 +        fun add_typ T Ts =
   1.116 +          if member (op =) Ts T
   1.117 +          then Ts
   1.118 +          else Ts
   1.119 +            |> cons T
   1.120 +            |> fold (fold add_typ o snd) (typ_sign_for T);
   1.121 +        val required_Ts = fold add_typ requested_Ts [];
   1.122 +        val of_term_for_typ' = of_term_for_typ required_Ts;
   1.123 +        val eval_for_const' = eval_for_const ctxt proper_cTs;
   1.124 +        val eval_for_const'' = the_default "_" o Option.map eval_for_const';
   1.125 +        val eval_tuple = enclose "(" ")" (commas (map eval_for_const' proper_cTs));
   1.126 +        fun mk_abs s = "fn " ^ s ^ " => ";
   1.127 +        val eval_abs = space_implode ""
   1.128 +          (map (mk_abs o eval_for_const'') cTs);
   1.129 +        val of_term_code = print_of_term_funs {
   1.130 +          typ_sign_for = typ_sign_for,
   1.131 +          eval_for_const = eval_for_const',
   1.132 +          of_term_for_typ = of_term_for_typ' } required_Ts;
   1.133 +        val of_term_names = map (Long_Name.append generated_computationN
   1.134 +          o of_term_for_typ') requested_Ts;
   1.135 +      in
   1.136 +        cat_lines [
   1.137 +          "structure " ^ generated_computationN ^ " =",
   1.138 +          "struct",
   1.139 +          "",
   1.140 +          "val " ^ covered_constsN ^ " = " ^ ML_Syntax.print_list print_const proper_cTs ^ ";",
   1.141 +          "",
   1.142 +          "val " ^ eval_tuple ^ " = " ^ compiled_value ^ " ()",
   1.143 +          "  (" ^ eval_abs,
   1.144 +          "    " ^ eval_tuple ^ ");",
   1.145 +          "",
   1.146 +          of_term_code,
   1.147 +          "",
   1.148 +          "end"
   1.149 +        ] |> rpair of_term_names
   1.150 +      end;
   1.151  
   1.152  fun mount_computation ctxt cTs T raw_computation lift_postproc =
   1.153    Code_Preproc.static_value { ctxt = ctxt, lift_postproc = lift_postproc, consts = [] }
   1.154 @@ -350,48 +357,12 @@
   1.155        #> check_computation_input ctxt' cTs
   1.156        #> raw_computation);
   1.157  
   1.158 -fun compile_computation cookie ctxt T program evals vs_ty_evals deps =
   1.159 -  let
   1.160 -    fun the_const (Const cT) = cT
   1.161 -      | the_const t = error ("No constant after preprocessing: " ^ Syntax.string_of_term ctxt t)
   1.162 -    val raw_cTs = case evals of
   1.163 -        Abs (_, _, t) => (map the_const o snd o strip_comb) t
   1.164 -      | _ => error ("Bad term after preprocessing: " ^ Syntax.string_of_term ctxt evals);
   1.165 -    val cTs = fold_rev (fn cT => fn cTs =>
   1.166 -      (if member (op =) cTs (SOME cT) then NONE else SOME cT) :: cTs) raw_cTs [];
   1.167 -    fun comp' vs_ty_evals =
   1.168 -      let
   1.169 -        val (generated_code, compiled_value) =
   1.170 -          build_compilation_text ctxt NONE program deps vs_ty_evals;
   1.171 -        val (of_term_code, [of_term]) = print_computation_code ctxt compiled_value [T] cTs;
   1.172 -      in
   1.173 -        (generated_code ^ "\n" ^ of_term_code,
   1.174 -          enclose "(" ")" ("fn () => " ^ of_term))
   1.175 -      end;
   1.176 -    val compiled_computation =
   1.177 -      Exn.release (run_compilation_text cookie ctxt comp' vs_ty_evals []);
   1.178 -  in (map_filter I cTs, compiled_computation) end;
   1.179 +
   1.180 +(** variants of universal runtime code generation **)
   1.181  
   1.182 -fun experimental_computation cookie { ctxt, lift_postproc, terms = ts, T } =
   1.183 -  let
   1.184 -    val _ = if not (monomorphic T)
   1.185 -      then error ("Polymorphic type: " ^ Syntax.string_of_typ ctxt T)
   1.186 -      else ();
   1.187 -    val cTs = (fold o fold_aterms)
   1.188 -      (fn (t as Const (cT as (_, T))) =>
   1.189 -        if not (monomorphic T) then error ("Polymorphic constant:" ^ Syntax.string_of_term ctxt t)
   1.190 -        else insert (op =) cT | _ => I) ts [];
   1.191 -    val evals = Abs ("eval", map snd cTs ---> TFree (Name.aT, []), list_comb (Bound 0, map Const cTs));
   1.192 -    val (cTs, raw_computation) = Code_Thingol.dynamic_value ctxt
   1.193 -      (K I) (compile_computation cookie ctxt T) evals;
   1.194 -  in
   1.195 -    mount_computation ctxt cTs T raw_computation lift_postproc
   1.196 -  end;
   1.197 +(*FIXME consolidate variants*)
   1.198  
   1.199 -
   1.200 -(** code antiquotation **)
   1.201 -
   1.202 -fun runtime_code ctxt module_name program tycos consts =
   1.203 +fun runtime_code'' ctxt module_name program tycos consts =
   1.204    let
   1.205      val thy = Proof_Context.theory_of ctxt;
   1.206      val (ml_modules, target_names) =
   1.207 @@ -411,63 +382,159 @@
   1.208          | SOME tyco' => (tyco, tyco')) tycos tycos';
   1.209    in (ml_code, (tycos_map, consts_map)) end;
   1.210  
   1.211 +fun runtime_code' ctxt some_module_name named_tycos named_consts computation_Ts program evals vs_ty_evals deps =
   1.212 +  let
   1.213 +    val thy = Proof_Context.theory_of ctxt;
   1.214 +    fun the_const (Const cT) = cT
   1.215 +      | the_const t = error ("No constant after preprocessing: " ^ Syntax.string_of_term ctxt t)
   1.216 +    val raw_computation_cTs = case evals of
   1.217 +        Abs (_, _, t) => (map the_const o snd o strip_comb) t
   1.218 +      | _ => error ("Bad term after preprocessing: " ^ Syntax.string_of_term ctxt evals);
   1.219 +    val computation_cTs = fold_rev (fn cT => fn cTs =>
   1.220 +      (if member (op =) cTs (SOME cT) then NONE else SOME cT) :: cTs) raw_computation_cTs [];
   1.221 +    val consts' = fold (fn NONE => I | SOME (c, _) => insert (op =) c)
   1.222 +      computation_cTs named_consts;
   1.223 +    val program' = Code_Thingol.consts_program ctxt consts';
   1.224 +      (*FIXME insufficient interfaces require double invocation of code generator*)
   1.225 +    val ((ml_modules, compiled_value), deresolve) =
   1.226 +      Code_Target.compilation_text' ctxt target some_module_name program'
   1.227 +        (map Code_Symbol.Type_Constructor named_tycos @ map Code_Symbol.Constant consts' @ deps) true vs_ty_evals;
   1.228 +        (*FIXME constrain signature*)
   1.229 +    fun deresolve_const c = case (deresolve o Code_Symbol.Constant) c of
   1.230 +          NONE => error ("Constant " ^ (quote o Code.string_of_const thy) c ^
   1.231 +            "\nhas a user-defined serialization")
   1.232 +        | SOME c' => c';
   1.233 +    fun deresolve_tyco tyco = case (deresolve o Code_Symbol.Type_Constructor) tyco of
   1.234 +          NONE => error ("Type " ^ quote (Proof_Context.markup_type ctxt tyco) ^
   1.235 +            "\nhas a user-defined serialization")
   1.236 +        | SOME c' => c';
   1.237 +    val tyco_names =  map deresolve_const named_tycos;
   1.238 +    val const_names = map deresolve_const named_consts;
   1.239 +    val generated_code = space_implode "\n\n" (map snd ml_modules);
   1.240 +    val (of_term_code, of_term_names) =
   1.241 +      print_computation_code ctxt compiled_value computation_cTs computation_Ts;
   1.242 +    val compiled_computation = generated_code ^ "\n" ^ of_term_code;
   1.243 +  in
   1.244 +    compiled_computation
   1.245 +    |> rpair { tyco_map = named_tycos ~~ tyco_names,
   1.246 +      const_map = named_consts ~~ const_names,
   1.247 +      of_term_map = computation_Ts ~~ of_term_names }
   1.248 +  end;
   1.249 +
   1.250 +fun funs_of_maps { tyco_map, const_map, of_term_map } =
   1.251 +  { name_for_tyco = the o AList.lookup (op =) tyco_map,
   1.252 +    name_for_const = the o AList.lookup (op =) const_map,
   1.253 +    of_term_for_typ = the o AList.lookup (op =) of_term_map };
   1.254 +
   1.255 +fun runtime_code ctxt some_module_name named_tycos named_consts computation_Ts program evals vs_ty_evals deps =
   1.256 +  runtime_code' ctxt some_module_name named_tycos named_consts computation_Ts program evals vs_ty_evals deps
   1.257 +  ||> funs_of_maps;
   1.258 +
   1.259 +
   1.260 +(** code and computation antiquotations **)
   1.261 +
   1.262 +val mount_computationN = Long_Name.append thisN "mount_computation";
   1.263 +
   1.264  local
   1.265  
   1.266  structure Code_Antiq_Data = Proof_Data
   1.267  (
   1.268    type T = { named_consts: string list,
   1.269 -    first_occurrence: bool,
   1.270 -    generated_code: {
   1.271 -      code: string,
   1.272 -      name_for_const: string -> string
   1.273 -    } lazy
   1.274 +    computation_Ts: typ list, computation_cTs: (string * typ) list,
   1.275 +    position_index: int,
   1.276 +    generated_code: (string * {
   1.277 +      name_for_tyco: string -> string,
   1.278 +      name_for_const: string -> string,
   1.279 +      of_term_for_typ: typ -> string
   1.280 +    }) lazy
   1.281    };
   1.282    val empty: T = { named_consts = [],
   1.283 -    first_occurrence = true,
   1.284 -    generated_code = Lazy.value {
   1.285 -      code = "",
   1.286 -      name_for_const = I
   1.287 -    }
   1.288 +    computation_Ts = [], computation_cTs = [],
   1.289 +    position_index = 0,
   1.290 +    generated_code = Lazy.lazy (fn () => raise Fail "empty")
   1.291    };
   1.292    fun init _ = empty;
   1.293  );
   1.294  
   1.295 -val is_first_occurrence = #first_occurrence o Code_Antiq_Data.get;
   1.296 +val current_position_index = #position_index o Code_Antiq_Data.get;
   1.297  
   1.298 -fun lazy_code ctxt consts = Lazy.lazy (fn () =>
   1.299 +fun register { named_consts, computation_Ts, computation_cTs } ctxt =
   1.300    let
   1.301 -    val program = Code_Thingol.consts_program ctxt consts;
   1.302 -    val (code, (_, consts_map)) =
   1.303 -      runtime_code ctxt Code_Target.generatedN program [] consts
   1.304 -  in { code = code, name_for_const = the o AList.lookup (op =) consts_map } end);
   1.305 -
   1.306 -fun register_const const ctxt =
   1.307 -  let
   1.308 -    val consts = insert (op =) const ((#named_consts o Code_Antiq_Data.get) ctxt);
   1.309 +    val data = Code_Antiq_Data.get ctxt;
   1.310 +    val named_consts' = union (op =) named_consts (#named_consts data);
   1.311 +    val computation_Ts' = union (op =) computation_Ts (#computation_Ts data);
   1.312 +    val computation_cTs' = union (op =) computation_cTs (#computation_cTs data);
   1.313 +    val position_index' = #position_index data + 1;
   1.314 +    fun generated_code' () =
   1.315 +      let
   1.316 +        val evals = Abs ("eval", map snd computation_cTs' --->
   1.317 +          TFree (Name.aT, []), list_comb (Bound 0, map Const computation_cTs'));
   1.318 +      in Code_Thingol.dynamic_value ctxt
   1.319 +        (K I) (runtime_code ctxt NONE [] named_consts' computation_Ts') evals
   1.320 +      end;
   1.321    in
   1.322      ctxt
   1.323      |> Code_Antiq_Data.put { 
   1.324 -        named_consts = consts,
   1.325 -        first_occurrence = false,
   1.326 -        generated_code = lazy_code ctxt consts
   1.327 +        named_consts = named_consts',
   1.328 +        computation_Ts = computation_Ts',
   1.329 +        computation_cTs = computation_cTs',
   1.330 +        position_index = position_index',
   1.331 +        generated_code = Lazy.lazy generated_code'
   1.332        }
   1.333    end;
   1.334  
   1.335 -fun print_code is_first_occ const ctxt =
   1.336 +fun register_const const =
   1.337 +  register { named_consts = [const],
   1.338 +    computation_Ts = [],
   1.339 +    computation_cTs = [] };
   1.340 +
   1.341 +fun register_computation cTs T =
   1.342 +  register { named_consts = [],
   1.343 +    computation_Ts = [T],
   1.344 +    computation_cTs = cTs };
   1.345 +
   1.346 +fun print body_code_for ctxt ctxt' =
   1.347    let
   1.348 -    val { code, name_for_const } = (Lazy.force o #generated_code o Code_Antiq_Data.get) ctxt;
   1.349 -    val context_code = if is_first_occ then code else "";
   1.350 -    val body_code = ML_Context.struct_name ctxt ^ "." ^ name_for_const const;
   1.351 +    val position_index = current_position_index ctxt;
   1.352 +    val (code, name_ofs) = (Lazy.force o #generated_code o Code_Antiq_Data.get) ctxt';
   1.353 +    val context_code = if position_index = 0 then code else "";
   1.354 +    val body_code = body_code_for name_ofs (ML_Context.struct_name ctxt') position_index;
   1.355    in (context_code, body_code) end;
   1.356  
   1.357 +fun print_code ctxt const =
   1.358 +  print (fn { name_for_const, ... } => fn prfx => fn _ =>
   1.359 +    Long_Name.append prfx (name_for_const const)) ctxt;
   1.360 +
   1.361 +fun print_computation ctxt T =
   1.362 +  print (fn { of_term_for_typ, ... } => fn prfx => fn _ =>
   1.363 +    space_implode " " [
   1.364 +      mount_computationN,
   1.365 +      "(Context.proof_of (Context.the_generic_context ()))",
   1.366 +      Long_Name.implode [prfx, generated_computationN, covered_constsN],
   1.367 +      (ML_Syntax.atomic o ML_Syntax.print_typ) T,
   1.368 +      Long_Name.append prfx (of_term_for_typ T)
   1.369 +    ]) ctxt;
   1.370 +
   1.371  in
   1.372  
   1.373  fun ml_code_antiq raw_const ctxt =
   1.374    let
   1.375      val thy = Proof_Context.theory_of ctxt;
   1.376      val const = Code.check_const thy raw_const;
   1.377 -    val is_first = is_first_occurrence ctxt;
   1.378 -  in (print_code is_first const, register_const const ctxt) end;
   1.379 +  in (print_code ctxt const, register_const const ctxt) end;
   1.380 +
   1.381 +fun ml_computation_antiq (raw_ts, raw_T) ctxt =
   1.382 +  let
   1.383 +    val ts = map (Syntax.check_term ctxt) raw_ts;
   1.384 +    val T = Syntax.check_typ ctxt raw_T;
   1.385 +    val _ = if not (monomorphic T)
   1.386 +      then error ("Polymorphic type: " ^ Syntax.string_of_typ ctxt T)
   1.387 +      else ();
   1.388 +    val cTs = (fold o fold_aterms)
   1.389 +      (fn (t as Const (cT as (_, T))) =>
   1.390 +        if not (monomorphic T) then error ("Polymorphic constant: " ^ Syntax.string_of_term ctxt t)
   1.391 +        else insert (op =) cT | _ => I) ts [];
   1.392 +  in (print_computation ctxt T, register_computation cTs T ctxt) end;
   1.393  
   1.394  end; (*local*)
   1.395  
   1.396 @@ -548,7 +615,7 @@
   1.397      val functions = map (prep_const thy) raw_functions;
   1.398      val consts = constrs @ functions;
   1.399      val program = Code_Thingol.consts_program ctxt consts;
   1.400 -    val result = runtime_code ctxt module_name program tycos consts
   1.401 +    val result = runtime_code'' ctxt module_name program tycos consts
   1.402        |> (apsnd o apsnd) (chop (length constrs));
   1.403    in
   1.404      thy
   1.405 @@ -562,7 +629,12 @@
   1.406  (** Isar setup **)
   1.407  
   1.408  val _ =
   1.409 -  Theory.setup (ML_Antiquotation.declaration @{binding code} Args.term (fn _ => ml_code_antiq));
   1.410 +  Theory.setup (ML_Antiquotation.declaration @{binding code}
   1.411 +    Args.term (fn _ => ml_code_antiq));
   1.412 +
   1.413 +val _ =
   1.414 +  Theory.setup (ML_Antiquotation.declaration @{binding computation}
   1.415 +    (Scan.repeat Args.term --| Scan.lift (Args.$$$ "::") -- Args.typ) (fn _ => ml_computation_antiq));
   1.416  
   1.417  local
   1.418