extended syntax allows to include datatype constructors directly in computations
authorhaftmann
Mon, 06 Feb 2017 20:56:36 +0100
changeset 64992 41e2c3617582
parent 64991 d2c79b16e133
child 64993 4fb84597ec5a
extended syntax allows to include datatype constructors directly in computations
src/HOL/ex/Computations.thy
src/Tools/Code/code_runtime.ML
--- a/src/HOL/ex/Computations.thy	Mon Feb 06 20:56:35 2017 +0100
+++ b/src/HOL/ex/Computations.thy	Mon Feb 06 20:56:36 2017 +0100
@@ -28,18 +28,25 @@
 
 in
 
-val comp_nat = @{computation "0 :: nat" Suc
-  "plus :: nat \<Rightarrow>_" "times :: nat \<Rightarrow> _" fib :: nat}
+val comp_nat = @{computation nat
+  terms: "plus :: nat \<Rightarrow>_" "times :: nat \<Rightarrow> _" fib
+  datatypes: nat}
+  (fn post => post o HOLogic.mk_nat o int_of_nat o the);
+
+val comp_numeral = @{computation nat
+  terms: "0 :: nat" "1 :: nat" "2 :: nat" "3 :: nat"}
   (fn post => post o HOLogic.mk_nat o int_of_nat o the);
 
-val comp_numeral = @{computation "0 :: nat" "1 :: nat" "2 :: nat" "3 :: nat" :: nat}
-  (fn post => post o HOLogic.mk_nat o int_of_nat o the);
-
-val comp_bool = @{computation True False HOL.conj HOL.disj HOL.implies
-  HOL.iff even "less_eq :: nat \<Rightarrow> _" "less :: nat \<Rightarrow> _" "HOL.eq :: nat \<Rightarrow> _" :: bool}
+val comp_bool = @{computation bool
+  terms: HOL.conj HOL.disj HOL.implies
+    HOL.iff even "less_eq :: nat \<Rightarrow> _" "less :: nat \<Rightarrow> _" "HOL.eq :: nat \<Rightarrow> _"
+  datatypes: bool}
   (K the);
 
-val comp_check = @{computation_check Trueprop};
+val comp_check = @{computation_check terms: Trueprop};
+
+val comp_dummy = @{computation "(nat \<times> unit) option"
+  datatypes: "(nat \<times> unit) option"}
 
 end
 \<close>
--- a/src/Tools/Code/code_runtime.ML	Mon Feb 06 20:56:35 2017 +0100
+++ b/src/Tools/Code/code_runtime.ML	Mon Feb 06 20:56:36 2017 +0100
@@ -573,14 +573,32 @@
       Long_Name.append prfx (of_term_for_typ @{typ prop})
     ]) ctxt;
 
-fun prep_terms ctxt raw_ts =
+
+fun add_all_constrs ctxt (dT as Type (tyco, Ts)) =
+  let
+    val ((vs, constrs), _) = Code.get_type (Proof_Context.theory_of ctxt) tyco;
+    val subst_TFree = the o AList.lookup (op =) (map fst vs ~~ Ts);
+    val cs = map (fn (c, (_, Ts')) =>
+      (c, (map o map_atyps) (fn TFree (v, _) => subst_TFree v) Ts'
+        ---> dT)) constrs;
+  in
+    union (op =) cs
+    #> fold (add_all_constrs ctxt) Ts
+  end;
+
+fun prep_spec ctxt (raw_ts, raw_dTs) =
   let
     val ts = map (Syntax.check_term ctxt) raw_ts;
+    val dTs = map (Syntax.check_typ ctxt) raw_dTs;
   in
-    (fold o fold_aterms)
+    []
+    |> (fold o fold_aterms)
       (fn (t as Const (cT as (_, T))) =>
         if not (monomorphic T) then error ("Polymorphic constant: " ^ Syntax.string_of_term ctxt t)
-        else insert (op =) cT | _ => I) ts []
+        else insert (op =) cT | _ => I) ts
+    |> fold (fn dT =>
+        if not (monomorphic dT) then error ("Polymorphic datatype: " ^ Syntax.string_of_typ ctxt dT)
+        else add_all_constrs ctxt dT) dTs
   end;
 
 in
@@ -591,9 +609,9 @@
     val const = Code.check_const thy raw_const;
   in (print_code ctxt const, register_const const ctxt) end;
 
-fun gen_ml_computation_antiq kind (raw_ts, raw_T) ctxt =
+fun gen_ml_computation_antiq kind (raw_T, raw_spec) ctxt =
   let
-    val cTs = prep_terms ctxt raw_ts;
+    val cTs = prep_spec ctxt raw_spec;
     val T = Syntax.check_typ ctxt raw_T;
     val _ = if not (monomorphic T)
       then error ("Polymorphic type: " ^ Syntax.string_of_typ ctxt T)
@@ -604,9 +622,9 @@
 
 val ml_computation_conv_antiq = gen_ml_computation_antiq mount_computation_convN;
 
-fun ml_computation_check_antiq raw_ts ctxt =
+fun ml_computation_check_antiq raw_spec ctxt =
   let
-    val cTs = insert (op =) (dest_Const @{const holds}) (prep_terms ctxt raw_ts);
+    val cTs = insert (op =) (dest_Const @{const holds}) (prep_spec ctxt raw_spec);
   in (print_computation_check ctxt, register_computation cTs @{typ prop} ctxt) end;
 
 end; (*local*)
@@ -701,25 +719,35 @@
 
 (** Isar setup **)
 
+local
+
+val parse_consts_spec =
+  Scan.optional (Scan.lift (Args.$$$ "terms" -- Args.colon) |-- Scan.repeat1 Args.term) []
+  -- Scan.optional (Scan.lift (Args.$$$ "datatypes"  -- Args.colon) |-- Scan.repeat1 Args.typ) [];
+
+in
+
 val _ =
   Theory.setup (ML_Antiquotation.declaration @{binding code}
     Args.term (fn _ => ml_code_antiq));
 
 val _ =
   Theory.setup (ML_Antiquotation.declaration @{binding computation}
-    (Scan.repeat Args.term --| Scan.lift (Args.$$$ "::") -- Args.typ)
+    (Args.typ -- parse_consts_spec)
        (fn _ => ml_computation_antiq));
 
 val _ =
   Theory.setup (ML_Antiquotation.declaration @{binding computation_conv}
-    (Scan.repeat Args.term --| Scan.lift (Args.$$$ "::") -- Args.typ)
+    (Args.typ -- parse_consts_spec)
        (fn _ => ml_computation_conv_antiq));
 
 val _ =
   Theory.setup (ML_Antiquotation.declaration @{binding computation_check}
-    (Scan.repeat Args.term) 
+    parse_consts_spec 
        (fn _ => ml_computation_check_antiq));
 
+end;
+
 local
 
 val parse_datatype =