experimental computations: use arbitrary generated code for RHSs, not just constants
authorhaftmann
Sun, 22 Jan 2017 21:39:16 +0100
changeset 64940 19ca3644ec46
parent 64939 c8626f7fae06
child 64941 730bc1bcf27c
experimental computations: use arbitrary generated code for RHSs, not just constants
src/Tools/Code/code_runtime.ML
--- a/src/Tools/Code/code_runtime.ML	Mon Jan 23 17:35:37 2017 +0100
+++ b/src/Tools/Code/code_runtime.ML	Sun Jan 22 21:39:16 2017 +0100
@@ -28,18 +28,10 @@
     -> Proof.context -> term -> 'a Exn.result
   val dynamic_holds_conv: Proof.context -> conv
   val static_holds_conv: { ctxt: Proof.context, consts: string list } -> Proof.context -> conv
-  val fully_static_value: (Proof.context -> term -> 'a) cookie
-    -> { ctxt: Proof.context, lift_postproc: (term -> term) -> 'a -> 'a,
-           consts: (string * typ) list, T: typ }
-    -> Proof.context -> term -> 'a option (*EXPERIMENTAL!*)
-  val fully_static_value_strict: (Proof.context -> term -> 'a) cookie
+  val experimental_computation: (Proof.context -> term -> 'a) cookie
     -> { ctxt: Proof.context, lift_postproc: (term -> term) -> 'a -> 'a,
-           consts: (string * typ) list, T: typ }
+           terms: term list, T: typ }
     -> Proof.context -> term -> 'a (*EXPERIMENTAL!*)
-  val fully_static_value_exn: (Proof.context -> term -> 'a) cookie
-    -> { ctxt: Proof.context, lift_postproc: (term -> term) -> 'a -> 'a,
-           consts: (string * typ) list, T: typ }
-    -> Proof.context -> term -> 'a Exn.result (*EXPERIMENTAL!*)
   val code_reflect: (string * string list option) list -> string list -> string
     -> string option -> theory -> theory
   datatype truth = Holds
@@ -62,7 +54,6 @@
 val s_Holds = Long_Name.append this "Holds";
 
 val target = "Eval";
-val structure_generated = "Generated_Code";
 
 datatype truth = Holds;
 
@@ -203,7 +194,156 @@
 end; (*local*)
 
 
-(** fully static evaluation -- still with limited coverage! **)
+(** computations -- experimental! **)
+
+fun typ_signatures_for T =
+  let
+    val (Ts, T') = strip_type T;
+  in map_range (fn n => (drop n Ts ---> T', take n Ts)) (length Ts + 1) end;
+
+fun typ_signatures cTs =
+  let
+    fun add (c, T) =
+      fold (fn (T, Ts) => Typtab.map_default (T, []) (cons (c, Ts)))
+        (typ_signatures_for T);
+  in
+    Typtab.empty
+    |> fold add cTs
+    |> Typtab.lookup_list
+  end;
+
+fun print_of_term_funs { typ_sign_for, eval_for_const, of_term_for_typ } Ts =
+  let
+    val var_names = map_range (fn n => "t" ^ string_of_int (n + 1));
+    fun print_lhs c xs = "Const (" ^ quote c ^ ", _)"
+      |> fold (fn x => fn s => s ^ " $ " ^ x) xs
+      |> enclose "(" ")"
+      |> prefix "ctxt ";
+    fun print_rhs c Ts T xs = eval_for_const (c, Ts ---> T)
+      |> fold2 (fn T' => fn x => fn s =>
+         s ^ (" (" ^ of_term_for_typ T' ^ " ctxt " ^ x ^ ")")) Ts xs
+    fun print_eq T (c, Ts) =
+      let
+        val xs = var_names (length Ts);
+      in print_lhs c xs ^ " = " ^ print_rhs c Ts T xs end;
+    val err_eq =
+      "ctxt t = error (" ^ quote "Bad term: " ^ " ^ Syntax.string_of_term ctxt t)";
+    fun print_eqs T =
+      let
+        val typ_signs = typ_sign_for T;
+        val name = of_term_for_typ T;
+      in
+        (map (print_eq T) typ_signs @ [err_eq])
+        |> map (prefix (name ^ " "))
+        |> space_implode "\n  | "
+      end;
+  in
+    map print_eqs Ts
+    |> space_implode "\nand "
+    |> prefix "fun "
+  end;
+
+local
+
+fun tycos_of (Type (tyco, Ts)) = maps tycos_of Ts @ [tyco]
+  | tycos_of _ = [];
+
+val ml_name_of = Name.desymbolize NONE o Long_Name.base_name;
+
+in
+
+fun of_term_for_typ Ts =
+  let
+    val names = Ts
+      |> map (suffix "_of_term" o space_implode "_" o map ml_name_of o tycos_of)
+      |> Name.variant_list [];
+  in the o AList.lookup (op =) (Ts ~~ names) end;
+
+fun eval_for_const ctxt cTs =
+  let
+    fun symbol_list (c, T) = c :: maps tycos_of (Sign.const_typargs (Proof_Context.theory_of ctxt) (c, T))
+    val names = cTs
+      |> map (prefix "eval_" o space_implode "_" o map ml_name_of o symbol_list)
+      |> Name.variant_list [];
+  in the o AList.lookup (op =) (cTs ~~ names) end;
+
+end;
+
+val generated_computationN = "Generated_Computation";
+
+fun print_computation_code ctxt compiled_value T cTs =
+  let
+    val typ_sign_for = typ_signatures cTs;
+    fun add_typ T Ts =
+      if member (op =) Ts T
+      then Ts
+      else Ts
+        |> cons T
+        |> fold (fold add_typ o snd) (typ_sign_for T);
+    val Ts = add_typ T [];
+    val of_term_for_typ' = of_term_for_typ Ts;
+    val of_terms = map of_term_for_typ' Ts;
+    val eval_for_const' = eval_for_const ctxt cTs;
+    val evals = map eval_for_const' cTs;
+    val eval_tuple = enclose "(" ")" (commas evals);
+    val eval_abs = space_implode "" (map (fn s => "fn " ^ s ^ " => ") evals);
+    val of_term_code = print_of_term_funs {
+      typ_sign_for = typ_sign_for,
+      eval_for_const = eval_for_const',
+      of_term_for_typ = of_term_for_typ' } Ts;
+  in
+    (cat_lines [
+      "structure " ^ generated_computationN ^ " =",
+      "struct",
+      "",
+      "val " ^ eval_tuple ^ " = " ^ compiled_value ^ " ()",
+      "  (" ^ eval_abs,
+      "    " ^ eval_tuple ^ ");",
+      "",
+      of_term_code,
+      "",
+      "end"
+    ], map (prefix (generated_computationN ^ ".")) of_terms)
+  end;
+
+fun compile_computation cookie ctxt T program evals vs_ty_evals deps =
+  let
+    val raw_cTs = case evals of
+        Abs (_, _, t) => (snd o strip_comb) t
+      | _ => error ("Bad term after preprocessing: " ^ Syntax.string_of_term ctxt evals);
+    val cTs = map (fn Const cT => cT
+      | t => error ("No constant after preprocessing: " ^ Syntax.string_of_term ctxt t)) raw_cTs;
+    fun comp' vs_ty_evals =
+      let
+        val (generated_code, compiled_value) =
+          build_computation_text ctxt NONE program deps vs_ty_evals;
+        val (of_term_code, of_terms) = print_computation_code ctxt compiled_value T cTs;
+      in
+        (generated_code ^ "\n" ^ of_term_code,
+          enclose "(" ")" ("fn () => " ^ List.last of_terms))
+      end;
+    val compiled_computation =
+      Exn.release (run_computation_text cookie ctxt comp' vs_ty_evals []);
+  in fn ctxt' =>
+    compiled_computation ctxt' o reject_vars ctxt' o Syntax.check_term ctxt' o Type.constraint T
+  end;
+
+fun experimental_computation cookie { ctxt, lift_postproc, terms = ts, T } =
+  let
+    val cTs = (fold o fold_aterms) (fn Const cT => insert (op =) cT | _ => I) ts [];
+    val vT = TFree (singleton (Name.variant_list
+      (fold (fn (_, T) => fold_atyps (fn TFree (v, _) => insert (op =) v | _ => I)
+        T) cTs [])) Name.aT, []);
+    val evals = Abs ("eval", map snd cTs ---> vT, list_comb (Bound 0, map Const cTs));
+    val computation = Code_Thingol.dynamic_value ctxt
+      (K I) (compile_computation cookie ctxt T) evals;
+  in
+    Code_Preproc.static_value { ctxt = ctxt, lift_postproc = lift_postproc, consts = [] }
+      (K computation)
+  end;
+
+
+(** code antiquotation **)
 
 fun evaluation_code ctxt module_name program tycos consts =
   let
@@ -225,107 +365,6 @@
         | SOME tyco' => (tyco, tyco')) tycos tycos';
   in (ml_code, (tycos_map, consts_map)) end;
 
-fun typ_signatures_for T =
-  let
-    val (Ts, T') = strip_type T;
-  in map_range (fn n => (drop n Ts ---> T', take n Ts)) (length Ts + 1) end;
-
-fun typ_signatures cTs =
-  let
-    fun add (c, T) =
-      fold (fn (T, Ts) => Typtab.map_default (T, []) (cons (c, Ts)))
-        (typ_signatures_for T);
-  in
-    Typtab.empty
-    |> fold add cTs
-    |> Typtab.lookup_list
-  end;
-
-fun print_of_term_funs { typ_sign_for, ml_name_for_const, ml_name_for_typ } Ts =
-  let
-    val var_names = map_range (fn n => "t" ^ string_of_int (n + 1));
-    fun print_lhs c xs = "Const (" ^ quote c ^ ", _)"
-      |> fold (fn x => fn s => s ^ " $ " ^ x) xs
-      |> enclose "(" ")"
-      |> prefix "ctxt ";
-    fun print_rhs c Ts xs = ml_name_for_const c
-      |> fold2 (fn T => fn x => fn s =>
-         s ^ (" (" ^ ml_name_for_typ T ^ " ctxt " ^ x ^ ")")) Ts xs
-    fun print_eq (c, Ts) =
-      let
-        val xs = var_names (length Ts);
-      in print_lhs c xs ^ " = " ^ print_rhs c Ts xs end;
-    val err_eq =
-      "ctxt t = error (" ^ quote "Bad term: " ^ " ^ Syntax.string_of_term ctxt t)";
-    fun print_eqs T =
-      let
-        val typ_signs = typ_sign_for T;
-        val name = ml_name_for_typ T;
-      in
-        (map print_eq typ_signs @ [err_eq])
-        |> map (prefix (name ^ " "))
-        |> space_implode "\n  | "
-      end;
-  in
-    map print_eqs Ts
-    |> space_implode "\nand "
-    |> prefix "fun "
-    |> pair (map ml_name_for_typ Ts)
-  end;
-
-fun print_of_term ctxt ml_name_for_const T cTs =
-  let
-    val typ_sign_for = typ_signatures cTs;
-    fun add_typ T Ts =
-      if member (op =) Ts T
-      then Ts
-      else Ts
-        |> cons T
-        |> fold (fold add_typ o snd) (typ_sign_for T);
-    val Ts = add_typ T [];
-    fun tycos_of (Type (tyco, Ts)) = maps tycos_of Ts @ [tyco]
-      | tycos_of _ = [];
-    val ml_name_of = Name.desymbolize NONE o Long_Name.base_name;
-    val ml_names = map (suffix "_of_term" o space_implode "_" o map ml_name_of o tycos_of) Ts
-      |> Name.variant_list [];
-    val ml_name_for_typ = the o AList.lookup (op =) (Ts ~~ ml_names);
-  in
-    print_of_term_funs { typ_sign_for = typ_sign_for,
-      ml_name_for_const = ml_name_for_const,
-      ml_name_for_typ = ml_name_for_typ } Ts
-  end;
-
-fun compile_computation cookie ctxt cs_code cTs T { program, deps } =
-  let
-    val (context_code, (_, const_map)) =
-      evaluation_code ctxt structure_generated program [] cs_code;
-    val ml_name_for_const = the o AList.lookup (op =) const_map;
-    val (ml_names, of_term_code) = print_of_term ctxt ml_name_for_const T cTs
-    val of_term = value ctxt cookie (context_code ^ "\n" ^ of_term_code, List.last ml_names);
-  in
-    Code_Preproc.timed_value "computing" 
-      (fn ctxt' => fn t => fn _ => fn _ => Exn.interruptible_capture (of_term ctxt') t)
-  end;
-
-fun fully_static_value_exn cookie { ctxt, lift_postproc, consts, T } =
-  let
-    val thy = Proof_Context.theory_of ctxt;
-    val cs_code = map (Axclass.unoverload_const thy) consts;
-    val cTs = map2 (fn (_, T) => fn c => (c, T)) consts cs_code;
-    val computation = Code_Thingol.static_value { ctxt = ctxt,
-      lift_postproc = Exn.map_res o lift_postproc, consts = cs_code }
-      (compile_computation cookie ctxt cs_code cTs T);
-  in fn ctxt' =>
-    computation ctxt' o reject_vars ctxt' o Syntax.check_term ctxt' o Type.constraint T
-  end;
-
-fun fully_static_value_strict cookie x = Exn.release oo fully_static_value_exn cookie x;
-
-fun fully_static_value cookie x = partiality_as_none oo fully_static_value_exn cookie x;
-
-
-(** code antiquotation **)
-
 local
 
 structure Code_Antiq_Data = Proof_Data
@@ -345,7 +384,7 @@
     val consts' = fold (insert (op =)) new_consts consts;
     val program = Code_Thingol.consts_program ctxt consts';
     val acc_code = Lazy.lazy (fn () =>
-      evaluation_code ctxt structure_generated program tycos' consts'
+      evaluation_code ctxt Code_Target.generatedN program tycos' consts'
       |> apsnd snd);
   in Code_Antiq_Data.put ((tycos', consts'), (false, acc_code)) ctxt end;
 
@@ -553,7 +592,7 @@
   |-> (fn ([Const (const, _)], _) =>
     Code_Target.set_printings (Constant (const,
       [(target, SOME (Code_Printer.simple_const_syntax (0, (K o K o K o Code_Printer.str) ml_name)))]))
-  #> tap (fn thy => Code_Target.produce_code (Proof_Context.init_global thy) false [const] target NONE structure_generated []));
+  #> tap (fn thy => Code_Target.produce_code (Proof_Context.init_global thy) false [const] target NONE Code_Target.generatedN []));
 
 fun process_file filepath (definienda, thy) =
   let