src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML
changeset 33623 4ec42d38224f
parent 33620 b6bf2dc5aed7
child 33626 42f69386943a
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Thu Nov 12 09:10:37 2009 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Thu Nov 12 09:10:42 2009 +0100
@@ -37,6 +37,19 @@
 
 datatype mode' = Bool | Input | Output | Pair of mode' * mode' | Fun of mode' * mode'
 
+(* equality of instantiatedness with respect to equivalences:
+  Pair Input Input == Input and Pair Output Output == Output *)
+fun eq_mode' (Fun (m1, m2), Fun (m3, m4)) = eq_mode' (m1, m3) andalso eq_mode' (m2, m4)
+  | eq_mode' (Pair (m1, m2), Pair (m3, m4)) = eq_mode' (m1, m3) andalso eq_mode' (m2, m4)
+  | eq_mode' (Pair (m1, m2), Input) = eq_mode' (m1, Input) andalso eq_mode' (m2, Input)
+  | eq_mode' (Pair (m1, m2), Output) = eq_mode' (m1, Output) andalso eq_mode' (m2, Output)
+  | eq_mode' (Input, Pair (m1, m2)) = eq_mode' (Input, m1) andalso eq_mode' (Input, m2)
+  | eq_mode' (Output, Pair (m1, m2)) = eq_mode' (Output, m1) andalso eq_mode' (Output, m2)
+  | eq_mode' (Input, Input) = true
+  | eq_mode' (Output, Output) = true
+  | eq_mode' (Bool, Bool) = true
+  | eq_mode' _ = false
+
 (* name: binder_modes? *)
 fun strip_fun_mode (Fun (mode, mode')) = mode :: strip_fun_mode mode'
   | strip_fun_mode Bool = []
@@ -60,7 +73,6 @@
       | string_of_mode3 mode = string_of_mode2 mode
   in string_of_mode3 mode' end
 
-
 fun translate_mode T (iss, is) =
   let
     val Ts = binder_types T
@@ -86,45 +98,31 @@
     mk_mode (param_modes @ translate_smode Ts2 is)
   end;
 
+fun translate_mode' nparams mode' =
+  let
+    fun err () = error "translate_mode': given mode cannot be translated"
+    val (m1, m2) = chop nparams (strip_fun_mode mode')
+    val translate_to_tupled_mode =
+      (map_filter I) o (map_index (fn (i, m) =>
+        if eq_mode' (m, Input) then SOME (i + 1)
+        else if eq_mode' (m, Output) then NONE
+        else err ()))
+    val translate_to_smode =
+      (map_filter I) o (map_index (fn (i, m) =>
+        if eq_mode' (m, Input) then SOME (i + 1, NONE)
+        else if eq_mode' (m, Output) then NONE
+        else SOME (i + 1, SOME (translate_to_tupled_mode (dest_tuple_mode m)))))
+    fun translate_to_param_mode m =
+      case rev (dest_fun_mode m) of
+        Bool :: _ :: _ => SOME (translate_to_smode (strip_fun_mode m))
+      | _ => if eq_mode' (m, Input) then NONE else err ()
+  in
+    (map translate_to_param_mode m1, translate_to_smode m2)
+  end
+
 fun string_of_mode thy constname mode =
   string_of_mode' (translate_mode (Sign.the_const_type thy constname) mode)
 
-fun eq_mode' (Fun (m1, m2), Fun (m3, m4)) = eq_mode' (m1, m3) andalso eq_mode' (m2, m4)
-  | eq_mode' (Pair (m1, m2), Pair (m3, m4)) = eq_mode' (m1, m3) andalso eq_mode' (m2, m4)
-  | eq_mode' (Pair (m1, m2), Input) = eq_mode' (m1, Input) andalso eq_mode' (m2, Input)
-  | eq_mode' (Pair (m1, m2), Output) = eq_mode' (m1, Output) andalso eq_mode' (m2, Output)
-  | eq_mode' (Input, Pair (m1, m2)) = eq_mode' (Input, m1) andalso eq_mode' (Input, m2)
-  | eq_mode' (Output, Pair (m1, m2)) = eq_mode' (Output, m1) andalso eq_mode' (Output, m2)
-  | eq_mode' (Input, Input) = true
-  | eq_mode' (Output, Output) = true
-  | eq_mode' (Bool, Bool) = true
-  | eq_mode' _ = false
-(* FIXME: remove! *)
-fun eq_mode'_mode (mode', (iss, is)) =
-  let
-    val arg_modes = strip_fun_mode mode'
-    val (arg_modes1, arg_modes2) = chop (length iss) arg_modes
-    fun eq_arg Input NONE = true
-      | eq_arg _ NONE = false
-      | eq_arg mode (SOME is) =
-        let
-          val modes = dest_tuple_mode mode
-        in
-          forall (fn i => nth modes (i - 1) = Input) is
-            andalso forall (fn i => nth modes (i - 1) = Output)
-              (subtract (op =) is (1 upto length modes))
-        end
-    fun eq_mode'_smode mode' is =
-      forall (fn (i, t) => eq_arg (nth mode' (i - 1)) t) is
-        andalso forall (fn i => (nth mode' (i - 1) = Output))
-          (subtract (op =) (map fst is) (1 upto length mode'))
-  in
-    forall (fn (m, NONE) => m = Input | (m, SOME is) => eq_mode'_smode (strip_fun_mode m) is)
-      (arg_modes1 ~~ iss)
-    andalso eq_mode'_smode arg_modes2 is
-  end
-
-
 (* general syntactic functions *)
 
 (*Like dest_conj, but flattens conjunctions however nested*)
@@ -133,8 +131,6 @@
 
 fun conjuncts t = conjuncts_aux t [];
 
-(* syntactic functions *)
-
 fun is_equationlike_term (Const ("==", _) $ _ $ _) = true
   | is_equationlike_term (Const ("Trueprop", _) $ (Const ("op =", _) $ _ $ _)) = true
   | is_equationlike_term _ = false
@@ -178,6 +174,8 @@
   in nparams end;
 
 (*** check if a term contains only constructor functions ***)
+(* FIXME: constructor terms are supposed to be seen in the way the code generator
+  sees constructors.*)
 fun is_constrt thy =
   let
     val cnstrs = flat (maps
@@ -266,7 +264,8 @@
 
 datatype options = Options of {  
   expected_modes : (string * mode' list) option,
-  user_proposals : ((string * mode') * string) list,
+  proposed_modes : (string * mode' list) list,
+  proposed_names : ((string * mode') * string) list,
   show_steps : bool,
   show_proof_trace : bool,
   show_intermediate_results : bool,
@@ -282,8 +281,9 @@
 };
 
 fun expected_modes (Options opt) = #expected_modes opt
-fun user_proposal (Options opt) name mode = AList.lookup (eq_pair (op =) eq_mode')
-  (#user_proposals opt) (name, mode)
+fun proposed_modes (Options opt) name = AList.lookup (op =) (#proposed_modes opt) name
+fun proposed_names (Options opt) name mode = AList.lookup (eq_pair (op =) eq_mode')
+  (#proposed_names opt) (name, mode)
 
 fun show_steps (Options opt) = #show_steps opt
 fun show_intermediate_results (Options opt) = #show_intermediate_results opt
@@ -300,7 +300,8 @@
 
 val default_options = Options {
   expected_modes = NONE,
-  user_proposals = [],
+  proposed_modes = [],
+  proposed_names = [],
   show_steps = false,
   show_intermediate_results = false,
   show_proof_trace = false,