src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML
changeset 33619 d93a3cb55068
parent 33473 3b275a0bf18c
child 33620 b6bf2dc5aed7
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Thu Nov 12 09:10:07 2009 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Thu Nov 12 09:10:16 2009 +0100
@@ -21,6 +21,7 @@
       (fn (i, is) =>
         string_of_int i ^ (case is of NONE => ""
     | SOME is => "p" ^ enclose "[" "]" (commas (map string_of_int is)))) js)
+(* FIXME: remove! *)
 
 fun string_of_mode (iss, is) = space_implode " -> " (map
   (fn NONE => "X"
@@ -28,10 +29,102 @@
        (iss @ [SOME is]));
 
 fun string_of_tmode (Mode (predmode, termmode, param_modes)) =
-  "predmode: " ^ (string_of_mode predmode) ^ 
+  "predmode: " ^ (string_of_mode predmode) ^
   (if null param_modes then "" else
     "; " ^ "params: " ^ commas (map (the_default "NONE" o Option.map string_of_tmode) param_modes))
 
+(* new datatype for mode *)
+
+datatype mode' = Bool | Input | Output | Pair of mode' * mode' | Fun of mode' * mode'
+
+(* name: binder_modes? *)
+fun strip_fun_mode (Fun (mode, mode')) = mode :: strip_fun_mode mode'
+  | strip_fun_mode Bool = []
+  | strip_fun_mode _ = error "Bad mode for strip_fun_mode"
+
+fun dest_fun_mode (Fun (mode, mode')) = mode :: dest_fun_mode mode'
+  | dest_fun_mode mode = [mode]
+
+fun dest_tuple_mode (Pair (mode, mode')) = mode :: dest_tuple_mode mode'
+  | dest_tuple_mode _ = []
+
+fun string_of_mode' mode' =
+  let
+    fun string_of_mode1 Input = "i"
+      | string_of_mode1 Output = "o"
+      | string_of_mode1 Bool = "bool"
+      | string_of_mode1 mode = "(" ^ (string_of_mode3 mode) ^ ")"
+    and string_of_mode2 (Pair (m1, m2))  = string_of_mode3 m1 ^ " * " ^  string_of_mode2 m2
+      | string_of_mode2 mode = string_of_mode1 mode
+    and string_of_mode3 (Fun (m1, m2)) = string_of_mode2 m1 ^ " => " ^ string_of_mode3 m2
+      | string_of_mode3 mode = string_of_mode2 mode
+  in string_of_mode3 mode' end
+
+
+fun translate_mode T (iss, is) =
+  let
+    val Ts = binder_types T
+    val (Ts1, Ts2) = chop (length iss) Ts
+    fun translate_smode Ts is =
+      let
+        fun translate_arg (i, T) =
+          case AList.lookup (op =) is (i + 1) of
+            SOME NONE => Input
+          | SOME (SOME its) =>
+            let
+              fun translate_tuple (i, T) = if member (op =) its (i + 1) then Input else Output
+            in 
+              foldr1 Pair (map_index translate_tuple (HOLogic.strip_tupleT T))
+            end
+          | NONE => Output
+      in map_index translate_arg Ts end
+    fun mk_mode arg_modes = foldr1 Fun (arg_modes @ [Bool])
+    val param_modes =
+      map (fn (T, NONE) => Input | (T, SOME is) => mk_mode (translate_smode (binder_types T) is))
+        (Ts1 ~~ iss)
+  in
+    mk_mode (param_modes @ translate_smode Ts2 is)
+  end;
+
+fun string_of_mode thy constname mode =
+  string_of_mode' (translate_mode (Sign.the_const_type thy constname) mode)
+
+fun eq_mode' (Fun (m1, m2), Fun (m3, m4)) = eq_mode' (m1, m3) andalso eq_mode' (m2, m4)
+  | eq_mode' (Pair (m1, m2), Pair (m3, m4)) = eq_mode' (m1, m3) andalso eq_mode' (m2, m4)
+  | eq_mode' (Pair (m1, m2), Input) = eq_mode' (m1, Input) andalso eq_mode' (m2, Input)
+  | eq_mode' (Pair (m1, m2), Output) = eq_mode' (m1, Output) andalso eq_mode' (m2, Output)
+  | eq_mode' (Input, Pair (m1, m2)) = eq_mode' (Input, m1) andalso eq_mode' (Input, m2)
+  | eq_mode' (Output, Pair (m1, m2)) = eq_mode' (Output, m1) andalso eq_mode' (Output, m2)
+  | eq_mode' (Input, Input) = true
+  | eq_mode' (Output, Output) = true
+  | eq_mode' (Bool, Bool) = true
+  | eq_mode' _ = false
+(* FIXME: remove! *)
+fun eq_mode'_mode (mode', (iss, is)) =
+  let
+    val arg_modes = strip_fun_mode mode'
+    val (arg_modes1, arg_modes2) = chop (length iss) arg_modes
+    fun eq_arg Input NONE = true
+      | eq_arg _ NONE = false
+      | eq_arg mode (SOME is) =
+        let
+          val modes = dest_tuple_mode mode
+        in
+          forall (fn i => nth modes (i - 1) = Input) is
+            andalso forall (fn i => nth modes (i - 1) = Output)
+              (subtract (op =) is (1 upto length modes))
+        end
+    fun eq_mode'_smode mode' is =
+      forall (fn (i, t) => eq_arg (nth mode' (i - 1)) t) is
+        andalso forall (fn i => (nth mode' (i - 1) = Output))
+          (subtract (op =) (map fst is) (1 upto length mode'))
+  in
+    forall (fn (m, NONE) => m = Input | (m, SOME is) => eq_mode'_smode (strip_fun_mode m) is)
+      (arg_modes1 ~~ iss)
+    andalso eq_mode'_smode arg_modes2 is
+  end
+
+
 (* general syntactic functions *)
 
 (*Like dest_conj, but flattens conjunctions however nested*)
@@ -70,7 +163,20 @@
 fun is_predT (T as Type("fun", [_, _])) = (snd (strip_type T) = HOLogic.boolT)
   | is_predT _ = false
 
-  
+(* guessing number of parameters *)
+fun find_indexes pred xs =
+  let
+    fun find is n [] = is
+      | find is n (x :: xs) = find (if pred x then (n :: is) else is) (n + 1) xs;
+  in rev (find [] 0 xs) end;
+
+fun guess_nparams T =
+  let
+    val argTs = binder_types T
+    val nparams = fold Integer.max
+      (map (fn x => x + 1) (find_indexes is_predT argTs)) 0
+  in nparams end;
+
 (*** check if a term contains only constructor functions ***)
 fun is_constrt thy =
   let
@@ -159,7 +265,7 @@
 (* Different options for compiler *)
 
 datatype options = Options of {  
-  expected_modes : (string * mode list) option,
+  expected_modes : (string * mode' list) option,
   show_steps : bool,
   show_proof_trace : bool,
   show_intermediate_results : bool,