generalized alternative functions to alternative compilation to handle arithmetic functions better
authorbulwahn
Mon, 29 Mar 2010 17:30:54 +0200
changeset 36038 385f706eff24
parent 36037 b1b21a8f6362
child 36039 affb6e1041e1
generalized alternative functions to alternative compilation to handle arithmetic functions better
src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Mon Mar 29 17:30:54 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Mon Mar 29 17:30:54 2010 +0200
@@ -427,12 +427,11 @@
   | _ => error "equals_conv"  
 *)
 
-(* Different options for compiler *)
+(* Different compilations *)
 
 datatype compilation = Pred | Depth_Limited | Random | Depth_Limited_Random | DSeq | Annotated
   | Pos_Random_DSeq | Neg_Random_DSeq | New_Pos_Random_DSeq | New_Neg_Random_DSeq
 
-
 fun negative_compilation_of Pos_Random_DSeq = Neg_Random_DSeq
   | negative_compilation_of Neg_Random_DSeq = Pos_Random_DSeq
   | negative_compilation_of New_Pos_Random_DSeq = New_Neg_Random_DSeq
@@ -455,7 +454,7 @@
   | Neg_Random_DSeq => "neg_random_dseq"
   | New_Pos_Random_DSeq => "new_pos_random dseq"
   | New_Neg_Random_DSeq => "new_neg_random_dseq"
-  
+
 val compilation_names = [("pred", Pred),
   ("random", Random),
   ("depth_limited", Depth_Limited),
@@ -463,7 +462,15 @@
   (*("annotated", Annotated),*)
   ("dseq", DSeq), ("random_dseq", Pos_Random_DSeq),
   ("new_random_dseq", New_Pos_Random_DSeq)]
-  
+
+val non_random_compilations = [Pred, Depth_Limited, DSeq, Annotated]
+
+
+val random_compilations = [Random, Depth_Limited_Random,
+  Pos_Random_DSeq, Neg_Random_DSeq, New_Pos_Random_DSeq, New_Neg_Random_DSeq]
+
+(* Different options for compiler *)
+
 (*datatype compilation_options =
   Pred | Random of int | Depth_Limited of int | DSeq of int | Annotated*)
 
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Mon Mar 29 17:30:54 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Mon Mar 29 17:30:54 2010 +0200
@@ -28,13 +28,29 @@
   val intros_of : theory -> string -> thm list
   val add_intro : thm -> theory -> theory
   val set_elim : thm -> theory -> theory
+  datatype compilation_funs = CompilationFuns of {
+    mk_predT : typ -> typ,
+    dest_predT : typ -> typ,
+    mk_bot : typ -> term,
+    mk_single : term -> term,
+    mk_bind : term * term -> term,
+    mk_sup : term * term -> term,
+    mk_if : term -> term,
+    mk_not : term -> term,
+    mk_map : typ -> typ -> term -> term -> term
+  };
   val register_alternative_function : string -> Predicate_Compile_Aux.mode -> string -> theory -> theory
-  val alternative_function_of : theory -> string -> Predicate_Compile_Aux.mode -> string option
+  val alternative_compilation_of : theory -> string -> Predicate_Compile_Aux.mode ->
+    (compilation_funs -> typ -> term) option
+  val functional_compilation : string -> Predicate_Compile_Aux.mode -> compilation_funs -> typ -> term
+  val force_modes_and_functions : string ->
+    (Predicate_Compile_Aux.mode * (string * bool)) list -> theory -> theory
+  val force_modes_and_compilations : string ->
+    (Predicate_Compile_Aux.mode * ((compilation_funs -> typ -> term) * bool)) list -> theory -> theory
   val preprocess_intro : theory -> thm -> thm
   val print_stored_rules : theory -> unit
   val print_all_modes : Predicate_Compile_Aux.compilation -> theory -> unit
   val mk_casesrule : Proof.context -> term -> thm list -> term
-  
   val eval_ref : (unit -> term Predicate.pred) option Unsynchronized.ref
   val random_eval_ref : (unit -> int * int -> term Predicate.pred * (int * int))
     option Unsynchronized.ref
@@ -48,22 +64,8 @@
     (unit -> int -> int -> int * int -> int -> (term * int) Lazy_Sequence.lazy_sequence)
       option Unsynchronized.ref
   val code_pred_intro_attrib : attribute
-  
   (* used by Quickcheck_Generator *) 
   (* temporary for testing of the compilation *)
-  
-  datatype compilation_funs = CompilationFuns of {
-    mk_predT : typ -> typ,
-    dest_predT : typ -> typ,
-    mk_bot : typ -> term,
-    mk_single : term -> term,
-    mk_bind : term * term -> term,
-    mk_sup : term * term -> term,
-    mk_if : term -> term,
-    mk_not : term -> term,
-    mk_map : typ -> typ -> term -> term -> term
-  };
-  
   val pred_compfuns : compilation_funs
   val randompred_compfuns : compilation_funs
   val new_randompred_compfuns : compilation_funs
@@ -722,22 +724,6 @@
     PredData.map (Graph.map_node name (map_pred_data set))
   end
 
-(* registration of alternative function names *)
-
-structure Alt_Names_Data = Theory_Data
-(
-  type T = (mode * string) list Symtab.table;
-  val empty = Symtab.empty;
-  val extend = I;
-  val merge = Symtab.merge (op =);
-);
-
-fun register_alternative_function pred_name mode fun_name =
-  Alt_Names_Data.map (Symtab.insert_list (op =) (pred_name, (mode, fun_name)))
-
-fun alternative_function_of thy pred_name mode =
-  AList.lookup eq_mode (Symtab.lookup_list (Alt_Names_Data.get thy) pred_name) mode
-
 (* datastructures and setup for generic compilation *)
 
 datatype compilation_funs = CompilationFuns of {
@@ -762,6 +748,56 @@
 fun mk_not (CompilationFuns funs) = #mk_not funs
 fun mk_map (CompilationFuns funs) = #mk_map funs
 
+(* registration of alternative function names *)
+
+structure Alt_Compilations_Data = Theory_Data
+(
+  type T = (mode * (compilation_funs -> typ -> term)) list Symtab.table;
+  val empty = Symtab.empty;
+  val extend = I;
+  val merge = Symtab.merge (K true);
+);
+
+fun alternative_compilation_of thy pred_name mode =
+  AList.lookup eq_mode (Symtab.lookup_list (Alt_Compilations_Data.get thy) pred_name) mode
+
+fun force_modes_and_compilations pred_name compilations =
+  let
+    (* thm refl is a dummy thm *)
+    val modes = map fst compilations
+    val (needs_random, non_random_modes) = pairself (map fst)
+      (List.partition (fn (m, (fun_name, random)) => random) compilations)
+    val non_random_dummys = map (rpair "dummy") non_random_modes
+    val all_dummys = map (rpair "dummy") modes
+    val dummy_function_names = map (rpair all_dummys) Predicate_Compile_Aux.random_compilations
+      @ map (rpair non_random_dummys) Predicate_Compile_Aux.non_random_compilations
+    val alt_compilations = map (apsnd fst) compilations
+  in
+    PredData.map (Graph.new_node
+      (pred_name, mk_pred_data (([], SOME @{thm refl}), (dummy_function_names, ([], needs_random)))))
+    #> Alt_Compilations_Data.map (Symtab.insert (K false) (pred_name, alt_compilations))
+  end
+
+fun functional_compilation fun_name mode compfuns T =
+  let
+    val (inpTs, outpTs) = split_map_modeT (fn _ => fn T => (SOME T, NONE))
+      mode (binder_types T)
+    val bs = map (pair "x") inpTs
+    val bounds = map Bound (rev (0 upto (length bs) - 1))
+    val f = Const (fun_name, inpTs ---> HOLogic.mk_tupleT outpTs)
+  in list_abs (bs, mk_single compfuns (list_comb (f, bounds))) end
+
+fun register_alternative_function pred_name mode fun_name =
+  Alt_Compilations_Data.map (Symtab.insert_list (eq_pair eq_mode (K false))
+    (pred_name, (mode, functional_compilation fun_name mode)))
+
+fun force_modes_and_functions pred_name fun_names =
+  force_modes_and_compilations pred_name
+    (map (fn (mode, (fun_name, random)) => (mode, (functional_compilation fun_name mode, random)))
+    fun_names)
+
+(* structures for different compilations *)
+
 structure PredicateCompFuns =
 struct
 
@@ -1906,7 +1942,7 @@
     val name' = Name.variant (name :: names) "y";
     val T = HOLogic.mk_tupleT (map fastype_of out_ts);
     val U = fastype_of success_t;
-    val U' = dest_predT compfuns U;
+    val U' = dest_predT compfuns U;        
     val v = Free (name, T);
     val v' = Free (name', T);
   in
@@ -1937,15 +1973,8 @@
         (t, Term Input) => SOME t
       | (t, Term Output) => NONE
       | (Const (name, T), Context mode) =>
-        (case alternative_function_of (ProofContext.theory_of ctxt) name mode of
-          SOME alt_function_name =>
-            let
-              val (inpTs, outpTs) = split_map_modeT (fn _ => fn T => (SOME T, NONE))
-                mode (binder_types T)
-              val bs = map (pair "x") inpTs
-              val bounds = map Bound (rev (0 upto (length bs) - 1))
-              val f = Const (alt_function_name, inpTs ---> HOLogic.mk_tupleT outpTs)
-            in SOME (list_abs (bs, mk_single compfuns (list_comb (f, bounds)))) end
+        (case alternative_compilation_of (ProofContext.theory_of ctxt) name mode of
+          SOME alt_comp => SOME (alt_comp compfuns T)
         | NONE =>
           SOME (Const (function_name_of (Comp_Mod.compilation compilation_modifiers)
             (ProofContext.theory_of ctxt) name mode,