ML antiquotation for generated computations
authorhaftmann
Fri, 27 Jan 2017 22:27:03 +0100
changeset 64959 9ca021bd718d
parent 64958 85b87da722ab
child 64960 8be78855ee7a
ML antiquotation for generated computations
src/HOL/ROOT
src/HOL/ex/Computations.thy
src/Tools/Code/code_runtime.ML
src/Tools/Code/code_target.ML
--- a/src/HOL/ROOT	Thu Jan 26 16:25:32 2017 +0100
+++ b/src/HOL/ROOT	Fri Jan 27 22:27:03 2017 +0100
@@ -536,6 +536,7 @@
     "~~/src/HOL/Library/Refute"
     "~~/src/HOL/Library/Transitive_Closure_Table"
     Cartouche_Examples
+    Computations
   theories
     Commands
     Adhoc_Overloading_Examples
--- /dev/null	Thu Jan 01 00:00:00 1970 +0000
+++ b/src/HOL/ex/Computations.thy	Fri Jan 27 22:27:03 2017 +0100
@@ -0,0 +1,63 @@
+(*  Title:      HOL/ex/Computations.thy
+    Author:     Florian Haftmann, TU Muenchen
+*)
+
+section \<open>Simple example for computations generated by the code generator\<close>
+
+theory Computations
+  imports Nat Fun_Def Num Code_Numeral
+begin
+
+fun even :: "nat \<Rightarrow> bool"
+  where "even 0 \<longleftrightarrow> True"
+      | "even (Suc 0) \<longleftrightarrow> False"
+      | "even (Suc (Suc n)) \<longleftrightarrow> even n"
+  
+fun fib :: "nat \<Rightarrow> nat"
+  where "fib 0 = 0"
+      | "fib (Suc 0) = Suc 0"
+      | "fib (Suc (Suc n)) = fib (Suc n) + fib n"
+
+declare [[ML_source_trace]]
+
+ML \<open>
+local 
+
+fun int_of_nat @{code "0 :: nat"} = 0
+  | int_of_nat (@{code Suc} n) = int_of_nat n + 1;
+
+in
+
+val comp_nat = @{computation "0 :: nat" Suc
+  "plus :: nat \<Rightarrow>_" "times :: nat \<Rightarrow> _" fib :: nat}
+  (fn post => post o HOLogic.mk_nat o int_of_nat);
+
+val comp_numeral = @{computation "0 :: nat" "1 :: nat" "2 :: nat" "3 :: nat" :: nat}
+  (fn post => post o HOLogic.mk_nat o int_of_nat);
+
+val comp_bool = @{computation True False HOL.conj HOL.disj HOL.implies
+  HOL.iff even "less_eq :: nat \<Rightarrow> _" "less :: nat \<Rightarrow> _" "HOL.eq :: nat \<Rightarrow> _" :: bool }
+  (K I)
+
+end
+\<close>
+
+declare [[ML_source_trace = false]]
+  
+ML_val \<open>
+  comp_nat @{context} @{term "fib (Suc (Suc (Suc 0)) * Suc (Suc (Suc 0))) + Suc 0"}
+  |> Syntax.string_of_term @{context}
+  |> writeln
+\<close>
+  
+ML_val \<open>
+  comp_bool @{context} @{term "fib (Suc (Suc (Suc 0)) * Suc (Suc (Suc 0))) + Suc 0 < fib (Suc (Suc 0))"}
+\<close>
+
+ML_val \<open>
+  comp_numeral @{context} @{term "Suc 42 + 7"}
+  |> Syntax.string_of_term @{context}
+  |> writeln
+\<close>
+
+end
--- a/src/Tools/Code/code_runtime.ML	Thu Jan 26 16:25:32 2017 +0100
+++ b/src/Tools/Code/code_runtime.ML	Fri Jan 27 22:27:03 2017 +0100
@@ -28,10 +28,8 @@
     -> Proof.context -> term -> 'a Exn.result
   val dynamic_holds_conv: Proof.context -> conv
   val static_holds_conv: { ctxt: Proof.context, consts: string list } -> Proof.context -> conv
-  val experimental_computation: (term -> 'a) cookie
-    -> { ctxt: Proof.context, lift_postproc: (term -> term) -> 'a -> 'a,
-           terms: term list, T: typ }
-    -> Proof.context -> term -> 'a (*EXPERIMENTAL!*)
+  val mount_computation: Proof.context -> (string * typ) list -> typ
+    -> (term -> 'b) -> ((term -> term) -> 'b -> 'a) -> Proof.context -> term -> 'a
   val code_reflect: (string * string list option) list -> string list -> string
     -> string option -> theory -> theory
   datatype truth = Holds
@@ -49,9 +47,9 @@
 
 (* technical prerequisites *)
 
-val this = "Code_Runtime";
-val s_truth = Long_Name.append this "truth";
-val s_Holds = Long_Name.append this "Holds";
+val thisN = "Code_Runtime";
+val truthN = Long_Name.append thisN "truth";
+val HoldsN = Long_Name.append thisN "Holds";
 
 val target = "Eval";
 
@@ -60,10 +58,10 @@
 val _ = Theory.setup
   (Code_Target.add_derived_target (target, [(Code_ML.target_SML, I)])
   #> Code_Target.set_printings (Type_Constructor (@{type_name prop},
-    [(target, SOME (0, (K o K o K) (Code_Printer.str s_truth)))]))
+    [(target, SOME (0, (K o K o K) (Code_Printer.str truthN)))]))
   #> Code_Target.set_printings (Constant (@{const_name Code_Generator.holds},
-    [(target, SOME (Code_Printer.plain_const_syntax s_Holds))]))
-  #> Code_Target.add_reserved target this
+    [(target, SOME (Code_Printer.plain_const_syntax HoldsN))]))
+  #> Code_Target.add_reserved target thisN
   #> fold (Code_Target.add_reserved target) ["oo", "ooo", "oooo", "upto", "downto", "orf", "andf"]);
        (*avoid further pervasive infix names*)
 
@@ -153,7 +151,7 @@
   fun init _ = empty;
 );
 val put_truth = Truth_Result.put;
-val truth_cookie = (Truth_Result.get, put_truth, Long_Name.append this "put_truth");
+val truth_cookie = (Truth_Result.get, put_truth, Long_Name.append thisN "put_truth");
 
 local
 
@@ -231,6 +229,8 @@
 
 in
 
+val covered_constsN = "covered_consts";
+
 fun of_term_for_typ Ts =
   let
     val names = Ts
@@ -277,6 +277,8 @@
 
 (* code generation for of the universal morphism *)
 
+val print_const = ML_Syntax.print_pair ML_Syntax.print_string ML_Syntax.print_typ;
+
 fun print_of_term_funs { typ_sign_for, eval_for_const, of_term_for_typ } Ts =
   let
     val var_names = map_range (fn n => "t" ^ string_of_int (n + 1));
@@ -305,42 +307,47 @@
     |> prefix "fun "
   end;
 
-fun print_computation_code ctxt compiled_value requested_Ts cTs =
-  let
-    val proper_cTs = map_filter I cTs;
-    val typ_sign_for = typ_signatures proper_cTs;
-    fun add_typ T Ts =
-      if member (op =) Ts T
-      then Ts
-      else Ts
-        |> cons T
-        |> fold (fold add_typ o snd) (typ_sign_for T);
-    val required_Ts = fold add_typ requested_Ts [];
-    val of_term_for_typ' = of_term_for_typ required_Ts;
-    val eval_for_const' = eval_for_const ctxt proper_cTs;
-    val eval_for_const'' = the_default "_" o Option.map eval_for_const';
-    val eval_tuple = enclose "(" ")" (commas (map eval_for_const' proper_cTs));
-    fun mk_abs s = "fn " ^ s ^ " => ";
-    val eval_abs = space_implode ""
-      (map (mk_abs o eval_for_const'') cTs);
-    val of_term_code = print_of_term_funs {
-      typ_sign_for = typ_sign_for,
-      eval_for_const = eval_for_const',
-      of_term_for_typ = of_term_for_typ' } required_Ts;
-  in
-    (cat_lines [
-      "structure " ^ generated_computationN ^ " =",
-      "struct",
-      "",
-      "val " ^ eval_tuple ^ " = " ^ compiled_value ^ " ()",
-      "  (" ^ eval_abs,
-      "    " ^ eval_tuple ^ ");",
-      "",
-      of_term_code,
-      "",
-      "end"
-    ], map (prefix (generated_computationN ^ ".") o of_term_for_typ') requested_Ts)
-  end;
+fun print_computation_code ctxt compiled_value [] requested_Ts = ("", [])
+  | print_computation_code ctxt compiled_value cTs requested_Ts =
+      let
+        val proper_cTs = map_filter I cTs;
+        val typ_sign_for = typ_signatures proper_cTs;
+        fun add_typ T Ts =
+          if member (op =) Ts T
+          then Ts
+          else Ts
+            |> cons T
+            |> fold (fold add_typ o snd) (typ_sign_for T);
+        val required_Ts = fold add_typ requested_Ts [];
+        val of_term_for_typ' = of_term_for_typ required_Ts;
+        val eval_for_const' = eval_for_const ctxt proper_cTs;
+        val eval_for_const'' = the_default "_" o Option.map eval_for_const';
+        val eval_tuple = enclose "(" ")" (commas (map eval_for_const' proper_cTs));
+        fun mk_abs s = "fn " ^ s ^ " => ";
+        val eval_abs = space_implode ""
+          (map (mk_abs o eval_for_const'') cTs);
+        val of_term_code = print_of_term_funs {
+          typ_sign_for = typ_sign_for,
+          eval_for_const = eval_for_const',
+          of_term_for_typ = of_term_for_typ' } required_Ts;
+        val of_term_names = map (Long_Name.append generated_computationN
+          o of_term_for_typ') requested_Ts;
+      in
+        cat_lines [
+          "structure " ^ generated_computationN ^ " =",
+          "struct",
+          "",
+          "val " ^ covered_constsN ^ " = " ^ ML_Syntax.print_list print_const proper_cTs ^ ";",
+          "",
+          "val " ^ eval_tuple ^ " = " ^ compiled_value ^ " ()",
+          "  (" ^ eval_abs,
+          "    " ^ eval_tuple ^ ");",
+          "",
+          of_term_code,
+          "",
+          "end"
+        ] |> rpair of_term_names
+      end;
 
 fun mount_computation ctxt cTs T raw_computation lift_postproc =
   Code_Preproc.static_value { ctxt = ctxt, lift_postproc = lift_postproc, consts = [] }
@@ -350,48 +357,12 @@
       #> check_computation_input ctxt' cTs
       #> raw_computation);
 
-fun compile_computation cookie ctxt T program evals vs_ty_evals deps =
-  let
-    fun the_const (Const cT) = cT
-      | the_const t = error ("No constant after preprocessing: " ^ Syntax.string_of_term ctxt t)
-    val raw_cTs = case evals of
-        Abs (_, _, t) => (map the_const o snd o strip_comb) t
-      | _ => error ("Bad term after preprocessing: " ^ Syntax.string_of_term ctxt evals);
-    val cTs = fold_rev (fn cT => fn cTs =>
-      (if member (op =) cTs (SOME cT) then NONE else SOME cT) :: cTs) raw_cTs [];
-    fun comp' vs_ty_evals =
-      let
-        val (generated_code, compiled_value) =
-          build_compilation_text ctxt NONE program deps vs_ty_evals;
-        val (of_term_code, [of_term]) = print_computation_code ctxt compiled_value [T] cTs;
-      in
-        (generated_code ^ "\n" ^ of_term_code,
-          enclose "(" ")" ("fn () => " ^ of_term))
-      end;
-    val compiled_computation =
-      Exn.release (run_compilation_text cookie ctxt comp' vs_ty_evals []);
-  in (map_filter I cTs, compiled_computation) end;
+
+(** variants of universal runtime code generation **)
 
-fun experimental_computation cookie { ctxt, lift_postproc, terms = ts, T } =
-  let
-    val _ = if not (monomorphic T)
-      then error ("Polymorphic type: " ^ Syntax.string_of_typ ctxt T)
-      else ();
-    val cTs = (fold o fold_aterms)
-      (fn (t as Const (cT as (_, T))) =>
-        if not (monomorphic T) then error ("Polymorphic constant:" ^ Syntax.string_of_term ctxt t)
-        else insert (op =) cT | _ => I) ts [];
-    val evals = Abs ("eval", map snd cTs ---> TFree (Name.aT, []), list_comb (Bound 0, map Const cTs));
-    val (cTs, raw_computation) = Code_Thingol.dynamic_value ctxt
-      (K I) (compile_computation cookie ctxt T) evals;
-  in
-    mount_computation ctxt cTs T raw_computation lift_postproc
-  end;
+(*FIXME consolidate variants*)
 
-
-(** code antiquotation **)
-
-fun runtime_code ctxt module_name program tycos consts =
+fun runtime_code'' ctxt module_name program tycos consts =
   let
     val thy = Proof_Context.theory_of ctxt;
     val (ml_modules, target_names) =
@@ -411,63 +382,159 @@
         | SOME tyco' => (tyco, tyco')) tycos tycos';
   in (ml_code, (tycos_map, consts_map)) end;
 
+fun runtime_code' ctxt some_module_name named_tycos named_consts computation_Ts program evals vs_ty_evals deps =
+  let
+    val thy = Proof_Context.theory_of ctxt;
+    fun the_const (Const cT) = cT
+      | the_const t = error ("No constant after preprocessing: " ^ Syntax.string_of_term ctxt t)
+    val raw_computation_cTs = case evals of
+        Abs (_, _, t) => (map the_const o snd o strip_comb) t
+      | _ => error ("Bad term after preprocessing: " ^ Syntax.string_of_term ctxt evals);
+    val computation_cTs = fold_rev (fn cT => fn cTs =>
+      (if member (op =) cTs (SOME cT) then NONE else SOME cT) :: cTs) raw_computation_cTs [];
+    val consts' = fold (fn NONE => I | SOME (c, _) => insert (op =) c)
+      computation_cTs named_consts;
+    val program' = Code_Thingol.consts_program ctxt consts';
+      (*FIXME insufficient interfaces require double invocation of code generator*)
+    val ((ml_modules, compiled_value), deresolve) =
+      Code_Target.compilation_text' ctxt target some_module_name program'
+        (map Code_Symbol.Type_Constructor named_tycos @ map Code_Symbol.Constant consts' @ deps) true vs_ty_evals;
+        (*FIXME constrain signature*)
+    fun deresolve_const c = case (deresolve o Code_Symbol.Constant) c of
+          NONE => error ("Constant " ^ (quote o Code.string_of_const thy) c ^
+            "\nhas a user-defined serialization")
+        | SOME c' => c';
+    fun deresolve_tyco tyco = case (deresolve o Code_Symbol.Type_Constructor) tyco of
+          NONE => error ("Type " ^ quote (Proof_Context.markup_type ctxt tyco) ^
+            "\nhas a user-defined serialization")
+        | SOME c' => c';
+    val tyco_names =  map deresolve_const named_tycos;
+    val const_names = map deresolve_const named_consts;
+    val generated_code = space_implode "\n\n" (map snd ml_modules);
+    val (of_term_code, of_term_names) =
+      print_computation_code ctxt compiled_value computation_cTs computation_Ts;
+    val compiled_computation = generated_code ^ "\n" ^ of_term_code;
+  in
+    compiled_computation
+    |> rpair { tyco_map = named_tycos ~~ tyco_names,
+      const_map = named_consts ~~ const_names,
+      of_term_map = computation_Ts ~~ of_term_names }
+  end;
+
+fun funs_of_maps { tyco_map, const_map, of_term_map } =
+  { name_for_tyco = the o AList.lookup (op =) tyco_map,
+    name_for_const = the o AList.lookup (op =) const_map,
+    of_term_for_typ = the o AList.lookup (op =) of_term_map };
+
+fun runtime_code ctxt some_module_name named_tycos named_consts computation_Ts program evals vs_ty_evals deps =
+  runtime_code' ctxt some_module_name named_tycos named_consts computation_Ts program evals vs_ty_evals deps
+  ||> funs_of_maps;
+
+
+(** code and computation antiquotations **)
+
+val mount_computationN = Long_Name.append thisN "mount_computation";
+
 local
 
 structure Code_Antiq_Data = Proof_Data
 (
   type T = { named_consts: string list,
-    first_occurrence: bool,
-    generated_code: {
-      code: string,
-      name_for_const: string -> string
-    } lazy
+    computation_Ts: typ list, computation_cTs: (string * typ) list,
+    position_index: int,
+    generated_code: (string * {
+      name_for_tyco: string -> string,
+      name_for_const: string -> string,
+      of_term_for_typ: typ -> string
+    }) lazy
   };
   val empty: T = { named_consts = [],
-    first_occurrence = true,
-    generated_code = Lazy.value {
-      code = "",
-      name_for_const = I
-    }
+    computation_Ts = [], computation_cTs = [],
+    position_index = 0,
+    generated_code = Lazy.lazy (fn () => raise Fail "empty")
   };
   fun init _ = empty;
 );
 
-val is_first_occurrence = #first_occurrence o Code_Antiq_Data.get;
+val current_position_index = #position_index o Code_Antiq_Data.get;
 
-fun lazy_code ctxt consts = Lazy.lazy (fn () =>
+fun register { named_consts, computation_Ts, computation_cTs } ctxt =
   let
-    val program = Code_Thingol.consts_program ctxt consts;
-    val (code, (_, consts_map)) =
-      runtime_code ctxt Code_Target.generatedN program [] consts
-  in { code = code, name_for_const = the o AList.lookup (op =) consts_map } end);
-
-fun register_const const ctxt =
-  let
-    val consts = insert (op =) const ((#named_consts o Code_Antiq_Data.get) ctxt);
+    val data = Code_Antiq_Data.get ctxt;
+    val named_consts' = union (op =) named_consts (#named_consts data);
+    val computation_Ts' = union (op =) computation_Ts (#computation_Ts data);
+    val computation_cTs' = union (op =) computation_cTs (#computation_cTs data);
+    val position_index' = #position_index data + 1;
+    fun generated_code' () =
+      let
+        val evals = Abs ("eval", map snd computation_cTs' --->
+          TFree (Name.aT, []), list_comb (Bound 0, map Const computation_cTs'));
+      in Code_Thingol.dynamic_value ctxt
+        (K I) (runtime_code ctxt NONE [] named_consts' computation_Ts') evals
+      end;
   in
     ctxt
     |> Code_Antiq_Data.put { 
-        named_consts = consts,
-        first_occurrence = false,
-        generated_code = lazy_code ctxt consts
+        named_consts = named_consts',
+        computation_Ts = computation_Ts',
+        computation_cTs = computation_cTs',
+        position_index = position_index',
+        generated_code = Lazy.lazy generated_code'
       }
   end;
 
-fun print_code is_first_occ const ctxt =
+fun register_const const =
+  register { named_consts = [const],
+    computation_Ts = [],
+    computation_cTs = [] };
+
+fun register_computation cTs T =
+  register { named_consts = [],
+    computation_Ts = [T],
+    computation_cTs = cTs };
+
+fun print body_code_for ctxt ctxt' =
   let
-    val { code, name_for_const } = (Lazy.force o #generated_code o Code_Antiq_Data.get) ctxt;
-    val context_code = if is_first_occ then code else "";
-    val body_code = ML_Context.struct_name ctxt ^ "." ^ name_for_const const;
+    val position_index = current_position_index ctxt;
+    val (code, name_ofs) = (Lazy.force o #generated_code o Code_Antiq_Data.get) ctxt';
+    val context_code = if position_index = 0 then code else "";
+    val body_code = body_code_for name_ofs (ML_Context.struct_name ctxt') position_index;
   in (context_code, body_code) end;
 
+fun print_code ctxt const =
+  print (fn { name_for_const, ... } => fn prfx => fn _ =>
+    Long_Name.append prfx (name_for_const const)) ctxt;
+
+fun print_computation ctxt T =
+  print (fn { of_term_for_typ, ... } => fn prfx => fn _ =>
+    space_implode " " [
+      mount_computationN,
+      "(Context.proof_of (Context.the_generic_context ()))",
+      Long_Name.implode [prfx, generated_computationN, covered_constsN],
+      (ML_Syntax.atomic o ML_Syntax.print_typ) T,
+      Long_Name.append prfx (of_term_for_typ T)
+    ]) ctxt;
+
 in
 
 fun ml_code_antiq raw_const ctxt =
   let
     val thy = Proof_Context.theory_of ctxt;
     val const = Code.check_const thy raw_const;
-    val is_first = is_first_occurrence ctxt;
-  in (print_code is_first const, register_const const ctxt) end;
+  in (print_code ctxt const, register_const const ctxt) end;
+
+fun ml_computation_antiq (raw_ts, raw_T) ctxt =
+  let
+    val ts = map (Syntax.check_term ctxt) raw_ts;
+    val T = Syntax.check_typ ctxt raw_T;
+    val _ = if not (monomorphic T)
+      then error ("Polymorphic type: " ^ Syntax.string_of_typ ctxt T)
+      else ();
+    val cTs = (fold o fold_aterms)
+      (fn (t as Const (cT as (_, T))) =>
+        if not (monomorphic T) then error ("Polymorphic constant: " ^ Syntax.string_of_term ctxt t)
+        else insert (op =) cT | _ => I) ts [];
+  in (print_computation ctxt T, register_computation cTs T ctxt) end;
 
 end; (*local*)
 
@@ -548,7 +615,7 @@
     val functions = map (prep_const thy) raw_functions;
     val consts = constrs @ functions;
     val program = Code_Thingol.consts_program ctxt consts;
-    val result = runtime_code ctxt module_name program tycos consts
+    val result = runtime_code'' ctxt module_name program tycos consts
       |> (apsnd o apsnd) (chop (length constrs));
   in
     thy
@@ -562,7 +629,12 @@
 (** Isar setup **)
 
 val _ =
-  Theory.setup (ML_Antiquotation.declaration @{binding code} Args.term (fn _ => ml_code_antiq));
+  Theory.setup (ML_Antiquotation.declaration @{binding code}
+    Args.term (fn _ => ml_code_antiq));
+
+val _ =
+  Theory.setup (ML_Antiquotation.declaration @{binding computation}
+    (Scan.repeat Args.term --| Scan.lift (Args.$$$ "::") -- Args.typ) (fn _ => ml_computation_antiq));
 
 local
 
--- a/src/Tools/Code/code_target.ML	Thu Jan 26 16:25:32 2017 +0100
+++ b/src/Tools/Code/code_target.ML	Fri Jan 27 22:27:03 2017 +0100
@@ -31,6 +31,9 @@
   val compilation_text: Proof.context -> string -> Code_Thingol.program
     -> Code_Symbol.T list -> bool -> ((string * class list) list * Code_Thingol.itype) * Code_Thingol.iterm
     -> (string * string) list * string
+  val compilation_text': Proof.context -> string -> string option -> Code_Thingol.program
+    -> Code_Symbol.T list -> bool -> ((string * class list) list * Code_Thingol.itype) * Code_Thingol.iterm
+    -> ((string * string) list * string) * (Code_Symbol.T -> string option)
 
   type serializer
   type literals = Code_Printer.literals
@@ -414,17 +417,20 @@
     val (program_code, deresolve) =
       produce (mounted_serializer program (if all_public then [] else [Code_Symbol.value]));
     val value_name = the (deresolve Code_Symbol.value);
-  in (program_code, value_name) end;
+  in ((program_code, value_name), deresolve) end;
 
-fun compilation_text ctxt target_name program syms =
+fun compilation_text' ctxt target_name some_module_name program syms =
   let
     val (mounted_serializer, (_, prepared_program)) =
-      mount_serializer ctxt target_name NONE generatedN [] program syms;
+      mount_serializer ctxt target_name NONE (the_default generatedN some_module_name) [] program syms;
   in
     Code_Preproc.timed_exec "serializing"
     (fn () => dynamic_compilation_text mounted_serializer prepared_program syms) ctxt
   end;
 
+fun compilation_text ctxt target_name program syms =
+  fst oo compilation_text' ctxt target_name NONE program syms
+  
 end; (* local *)