explicit a-priori detection of unsuitable terms for computations
authorhaftmann
Tue, 24 Jan 2017 22:29:36 +0100
changeset 64943 c5618df67c2a
parent 64942 bae35a568b1b
child 64944 27b1ba3ef778
explicit a-priori detection of unsuitable terms for computations
src/Tools/Code/code_runtime.ML
--- a/src/Tools/Code/code_runtime.ML	Tue Jan 24 09:39:21 2017 +0100
+++ b/src/Tools/Code/code_runtime.ML	Tue Jan 24 22:29:36 2017 +0100
@@ -28,7 +28,7 @@
     -> Proof.context -> term -> 'a Exn.result
   val dynamic_holds_conv: Proof.context -> conv
   val static_holds_conv: { ctxt: Proof.context, consts: string list } -> Proof.context -> conv
-  val experimental_computation: (Proof.context -> term -> 'a) cookie
+  val experimental_computation: (term -> 'a) cookie
     -> { ctxt: Proof.context, lift_postproc: (term -> term) -> 'a -> 'a,
            terms: term list, T: typ }
     -> Proof.context -> term -> 'a (*EXPERIMENTAL!*)
@@ -196,6 +196,8 @@
 
 (** computations -- experimental! **)
 
+fun monomorphic T = fold_atyps ((K o K) false) T true;
+
 fun typ_signatures_for T =
   let
     val (Ts, T') = strip_type T;
@@ -217,23 +219,20 @@
     val var_names = map_range (fn n => "t" ^ string_of_int (n + 1));
     fun print_lhs c xs = "Const (" ^ quote c ^ ", _)"
       |> fold (fn x => fn s => s ^ " $ " ^ x) xs
-      |> enclose "(" ")"
-      |> prefix "ctxt ";
+      |> enclose "(" ")";
     fun print_rhs c Ts T xs = eval_for_const (c, Ts ---> T)
       |> fold2 (fn T' => fn x => fn s =>
-         s ^ (" (" ^ of_term_for_typ T' ^ " ctxt " ^ x ^ ")")) Ts xs
+         s ^ (" (" ^ of_term_for_typ T' ^ " " ^ x ^ ")")) Ts xs
     fun print_eq T (c, Ts) =
       let
         val xs = var_names (length Ts);
       in print_lhs c xs ^ " = " ^ print_rhs c Ts T xs end;
-    val err_eq =
-      "ctxt t = error (" ^ quote "Bad term: " ^ " ^ Syntax.string_of_term ctxt t)";
     fun print_eqs T =
       let
         val typ_signs = typ_sign_for T;
         val name = of_term_for_typ T;
       in
-        (map (print_eq T) typ_signs @ [err_eq])
+        map (print_eq T) typ_signs
         |> map (prefix (name ^ " "))
         |> space_implode "\n  | "
       end;
@@ -306,6 +305,27 @@
     ], map (prefix (generated_computationN ^ ".")) of_terms)
   end;
 
+fun check_typ ctxt T t =
+  Syntax.check_term ctxt (Type.constraint T t);
+
+fun check_computation_input ctxt cTs t =
+  let
+    fun check t = check_comb (strip_comb t)
+    and check_comb (t as Abs _, _) =
+          error ("Bad term, contains abstraction: " ^ Syntax.string_of_term ctxt t)
+      | check_comb (t as Const (cT as (c, T)), ts) =
+          let
+            val _ = if not (member (op =) cTs cT)
+              then error ("Bad term, computation cannot proceed on constant " ^ Syntax.string_of_term ctxt t)
+              else ();
+            val _ = if not (monomorphic T)
+              then error ("Bad term, contains polymorphic constant " ^ Syntax.string_of_term ctxt t)
+              else ();
+            val _ = map check ts;
+          in () end;
+    val _ = check t;
+  in t end;
+
 fun compile_computation cookie ctxt T program evals vs_ty_evals deps =
   let
     val raw_cTs = case evals of
@@ -325,12 +345,18 @@
     val compiled_computation =
       Exn.release (run_computation_text cookie ctxt comp' vs_ty_evals []);
   in fn ctxt' =>
-    compiled_computation ctxt' o reject_vars ctxt' o Syntax.check_term ctxt' o Type.constraint T
+    check_typ ctxt' T
+    #> reject_vars ctxt'
+    #> check_computation_input ctxt cTs
+    #> compiled_computation
   end;
 
 fun experimental_computation cookie { ctxt, lift_postproc, terms = ts, T } =
   let
-    val cTs = (fold o fold_aterms) (fn Const cT => insert (op =) cT | _ => I) ts [];
+    val cTs = (fold o fold_aterms)
+      (fn (t as Const (cT as (_, T))) =>
+        if not (monomorphic T) then error ("Polymorphic constant:" ^ Syntax.string_of_term ctxt t)
+        else insert (op =) cT | _ => I) ts [];
     val vT = TFree (singleton (Name.variant_list
       (fold (fn (_, T) => fold_atyps (fn TFree (v, _) => insert (op =) v | _ => I)
         T) cTs [])) Name.aT, []);