more computation antiquotations
authorhaftmann
Mon, 06 Feb 2017 20:56:33 +0100
changeset 64989 40c36a4aee1f
parent 64988 93aaff2b0ae0
child 64990 c6a7de505796
more computation antiquotations
src/HOL/ex/Computations.thy
src/Tools/Code/code_runtime.ML
--- a/src/HOL/ex/Computations.thy	Mon Feb 06 20:56:32 2017 +0100
+++ b/src/HOL/ex/Computations.thy	Mon Feb 06 20:56:33 2017 +0100
@@ -36,9 +36,11 @@
   (fn post => post o HOLogic.mk_nat o int_of_nat o the);
 
 val comp_bool = @{computation True False HOL.conj HOL.disj HOL.implies
-  HOL.iff even "less_eq :: nat \<Rightarrow> _" "less :: nat \<Rightarrow> _" "HOL.eq :: nat \<Rightarrow> _" :: bool }
+  HOL.iff even "less_eq :: nat \<Rightarrow> _" "less :: nat \<Rightarrow> _" "HOL.eq :: nat \<Rightarrow> _" :: bool}
   (K the);
 
+val comp_check = @{computation_check Trueprop};
+
 end
 \<close>
 
@@ -55,6 +57,10 @@
 \<close>
 
 ML_val \<open>
+  comp_check @{context} @{cprop "fib (Suc (Suc (Suc 0)) * Suc (Suc (Suc 0))) + Suc 0 > fib (Suc (Suc 0))"}
+\<close>
+  
+ML_val \<open>
   comp_numeral @{context} @{term "Suc 42 + 7"}
   |> Syntax.string_of_term @{context}
   |> writeln
--- a/src/Tools/Code/code_runtime.ML	Mon Feb 06 20:56:32 2017 +0100
+++ b/src/Tools/Code/code_runtime.ML	Mon Feb 06 20:56:33 2017 +0100
@@ -34,6 +34,10 @@
   val put_truth: (unit -> truth) -> Proof.context -> Proof.context
   val mount_computation: Proof.context -> (string * typ) list -> typ
     -> (term -> 'ml) -> ((term -> term) -> 'ml option -> 'a) -> Proof.context -> term -> 'a
+  val mount_computation_conv: Proof.context -> (string * typ) list -> typ
+    -> (term -> 'ml) -> ('ml -> conv) -> Proof.context -> conv
+  val mount_computation_check: Proof.context -> (string * typ) list
+    -> (term -> truth) -> Proof.context -> conv
   val polyml_as_definition: (binding * typ) list -> Path.T list -> theory -> theory
   val trace: bool Config.T
 end;
@@ -361,6 +365,29 @@
   Code_Preproc.static_value { ctxt = ctxt, lift_postproc = lift_postproc, consts = [] }
     (K (checked_computation cTs T raw_computation));
 
+fun mount_computation_conv ctxt cTs T raw_computation conv =
+  Code_Preproc.static_conv { ctxt = ctxt, consts = [] }
+    (K (fn ctxt' => fn t =>
+      case checked_computation cTs T raw_computation ctxt' t of
+        SOME x => conv x
+      | NONE => Conv.all_conv));
+
+local
+
+fun holds ct = Thm.mk_binop @{cterm "Pure.eq :: prop \<Rightarrow> prop \<Rightarrow> prop"}
+  ct @{cprop "PROP Code_Generator.holds"};
+
+val (_, holds_oracle) = Context.>>> (Context.map_theory_result
+  (Thm.add_oracle (@{binding holds}, holds)));
+
+in
+
+fun mount_computation_check ctxt cTs raw_computation =
+  mount_computation_conv ctxt cTs @{typ prop} raw_computation
+    (K holds_oracle);
+
+end;
+
 
 (** variants of universal runtime code generation **)
 
@@ -437,9 +464,11 @@
 
 (** code and computation antiquotations **)
 
-val mount_computationN = prefix_this "mount_computation";
+local
 
-local
+val mount_computationN = prefix_this "mount_computation";
+val mount_computation_convN = prefix_this "mount_computation_conv";
+val mount_computation_checkN = prefix_this "mount_computation_check";
 
 structure Code_Antiq_Data = Proof_Data
 (
@@ -502,23 +531,42 @@
     val position_index = current_position_index ctxt;
     val (code, name_ofs) = (Lazy.force o #generated_code o Code_Antiq_Data.get) ctxt';
     val context_code = if position_index = 0 then code else "";
-    val body_code = body_code_for name_ofs (ML_Context.struct_name ctxt') position_index;
+    val body_code = body_code_for name_ofs (ML_Context.struct_name ctxt');
   in (context_code, body_code) end;
 
 fun print_code ctxt const =
-  print (fn { name_for_const, ... } => fn prfx => fn _ =>
+  print (fn { name_for_const, ... } => fn prfx =>
     Long_Name.append prfx (name_for_const const)) ctxt;
 
-fun print_computation ctxt T =
-  print (fn { of_term_for_typ, ... } => fn prfx => fn _ =>
+fun print_computation kind ctxt T =
+  print (fn { of_term_for_typ, ... } => fn prfx =>
     space_implode " " [
-      mount_computationN,
+      kind,
       "(Context.proof_of (Context.the_generic_context ()))",
       Long_Name.implode [prfx, generated_computationN, covered_constsN],
       (ML_Syntax.atomic o ML_Syntax.print_typ) T,
       Long_Name.append prfx (of_term_for_typ T)
     ]) ctxt;
 
+fun print_computation_check ctxt =
+  print (fn { of_term_for_typ, ... } => fn prfx =>
+    space_implode " " [
+      mount_computation_checkN,
+      "(Context.proof_of (Context.the_generic_context ()))",
+      Long_Name.implode [prfx, generated_computationN, covered_constsN],
+      Long_Name.append prfx (of_term_for_typ @{typ prop})
+    ]) ctxt;
+
+fun prep_terms ctxt raw_ts =
+  let
+    val ts = map (Syntax.check_term ctxt) raw_ts;
+  in
+    (fold o fold_aterms)
+      (fn (t as Const (cT as (_, T))) =>
+        if not (monomorphic T) then error ("Polymorphic constant: " ^ Syntax.string_of_term ctxt t)
+        else insert (op =) cT | _ => I) ts []
+  end;
+
 in
 
 fun ml_code_antiq raw_const ctxt =
@@ -527,18 +575,23 @@
     val const = Code.check_const thy raw_const;
   in (print_code ctxt const, register_const const ctxt) end;
 
-fun ml_computation_antiq (raw_ts, raw_T) ctxt =
+fun gen_ml_computation_antiq kind (raw_ts, raw_T) ctxt =
   let
-    val ts = map (Syntax.check_term ctxt) raw_ts;
-    val cTs = (fold o fold_aterms)
-      (fn (t as Const (cT as (_, T))) =>
-        if not (monomorphic T) then error ("Polymorphic constant: " ^ Syntax.string_of_term ctxt t)
-        else insert (op =) cT | _ => I) ts [];
+    val cTs = prep_terms ctxt raw_ts;
     val T = Syntax.check_typ ctxt raw_T;
     val _ = if not (monomorphic T)
       then error ("Polymorphic type: " ^ Syntax.string_of_typ ctxt T)
       else ();
-  in (print_computation ctxt T, register_computation cTs T ctxt) end;
+  in (print_computation kind ctxt T, register_computation cTs T ctxt) end;
+
+val ml_computation_antiq = gen_ml_computation_antiq mount_computationN;
+
+val ml_computation_conv_antiq = gen_ml_computation_antiq mount_computation_convN;
+
+fun ml_computation_check_antiq raw_ts ctxt =
+  let
+    val cTs = insert (op =) (dest_Const @{const holds}) (prep_terms ctxt raw_ts);
+  in (print_computation_check ctxt, register_computation cTs @{typ prop} ctxt) end;
 
 end; (*local*)
 
@@ -641,6 +694,16 @@
     (Scan.repeat Args.term --| Scan.lift (Args.$$$ "::") -- Args.typ)
        (fn _ => ml_computation_antiq));
 
+val _ =
+  Theory.setup (ML_Antiquotation.declaration @{binding computation_conv}
+    (Scan.repeat Args.term --| Scan.lift (Args.$$$ "::") -- Args.typ)
+       (fn _ => ml_computation_conv_antiq));
+
+val _ =
+  Theory.setup (ML_Antiquotation.declaration @{binding computation_check}
+    (Scan.repeat Args.term) 
+       (fn _ => ml_computation_check_antiq));
+
 local
 
 val parse_datatype =