first steps towards a new mode datastructure; new syntax for mode annotations and new output of modes
authorbulwahn
Thu, 12 Nov 2009 09:10:16 +0100
changeset 33619 d93a3cb55068
parent 33618 d8359a16e0c5
child 33620 b6bf2dc5aed7
first steps towards a new mode datastructure; new syntax for mode annotations and new output of modes
src/HOL/Tools/Predicate_Compile/predicate_compile.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML
src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Thu Nov 12 09:10:07 2009 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile.ML	Thu Nov 12 09:10:16 2009 +0100
@@ -10,7 +10,7 @@
   val preprocess : Predicate_Compile_Aux.options -> string -> theory -> theory
 end;
 
-structure Predicate_Compile : PREDICATE_COMPILE =
+structure Predicate_Compile (*: PREDICATE_COMPILE*) =
 struct
 
 (* options *)
@@ -183,9 +183,24 @@
 
 val parse_mode' = gen_parse_mode parse_smode'
 
+(* New parser for modes *)
+
+(* grammar:
+E = T "=>" E | T
+T = F * T | F
+F = i | o | bool | ( E )
+*)
+fun new_parse_mode1 xs =
+  (Args.$$$ "i" >> K Input || Args.$$$ "o" >> K Output ||
+    Args.$$$ "bool" >> K Bool || Args.$$$ "(" |-- new_parse_mode3 --| Args.$$$ ")") xs
+and new_parse_mode2 xs =
+  (new_parse_mode1 --| Args.$$$ "*" -- new_parse_mode2 >> Pair || new_parse_mode1) xs
+and new_parse_mode3 xs =
+  (new_parse_mode2 --| Args.$$$ "=>" -- new_parse_mode3 >> Fun || new_parse_mode2) xs
+
 val opt_modes =
   Scan.optional (P.$$$ "(" |-- Args.$$$ "mode" |-- P.$$$ ":" |--
-    P.enum1 "," (parse_mode || parse_mode') --| P.$$$ ")" >> SOME) NONE
+    P.enum1 "," new_parse_mode3 --| P.$$$ ")" >> SOME) NONE
 
 (* Parser for options *)
 
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Thu Nov 12 09:10:07 2009 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Thu Nov 12 09:10:16 2009 +0100
@@ -21,6 +21,7 @@
       (fn (i, is) =>
         string_of_int i ^ (case is of NONE => ""
     | SOME is => "p" ^ enclose "[" "]" (commas (map string_of_int is)))) js)
+(* FIXME: remove! *)
 
 fun string_of_mode (iss, is) = space_implode " -> " (map
   (fn NONE => "X"
@@ -28,10 +29,102 @@
        (iss @ [SOME is]));
 
 fun string_of_tmode (Mode (predmode, termmode, param_modes)) =
-  "predmode: " ^ (string_of_mode predmode) ^ 
+  "predmode: " ^ (string_of_mode predmode) ^
   (if null param_modes then "" else
     "; " ^ "params: " ^ commas (map (the_default "NONE" o Option.map string_of_tmode) param_modes))
 
+(* new datatype for mode *)
+
+datatype mode' = Bool | Input | Output | Pair of mode' * mode' | Fun of mode' * mode'
+
+(* name: binder_modes? *)
+fun strip_fun_mode (Fun (mode, mode')) = mode :: strip_fun_mode mode'
+  | strip_fun_mode Bool = []
+  | strip_fun_mode _ = error "Bad mode for strip_fun_mode"
+
+fun dest_fun_mode (Fun (mode, mode')) = mode :: dest_fun_mode mode'
+  | dest_fun_mode mode = [mode]
+
+fun dest_tuple_mode (Pair (mode, mode')) = mode :: dest_tuple_mode mode'
+  | dest_tuple_mode _ = []
+
+fun string_of_mode' mode' =
+  let
+    fun string_of_mode1 Input = "i"
+      | string_of_mode1 Output = "o"
+      | string_of_mode1 Bool = "bool"
+      | string_of_mode1 mode = "(" ^ (string_of_mode3 mode) ^ ")"
+    and string_of_mode2 (Pair (m1, m2))  = string_of_mode3 m1 ^ " * " ^  string_of_mode2 m2
+      | string_of_mode2 mode = string_of_mode1 mode
+    and string_of_mode3 (Fun (m1, m2)) = string_of_mode2 m1 ^ " => " ^ string_of_mode3 m2
+      | string_of_mode3 mode = string_of_mode2 mode
+  in string_of_mode3 mode' end
+
+
+fun translate_mode T (iss, is) =
+  let
+    val Ts = binder_types T
+    val (Ts1, Ts2) = chop (length iss) Ts
+    fun translate_smode Ts is =
+      let
+        fun translate_arg (i, T) =
+          case AList.lookup (op =) is (i + 1) of
+            SOME NONE => Input
+          | SOME (SOME its) =>
+            let
+              fun translate_tuple (i, T) = if member (op =) its (i + 1) then Input else Output
+            in 
+              foldr1 Pair (map_index translate_tuple (HOLogic.strip_tupleT T))
+            end
+          | NONE => Output
+      in map_index translate_arg Ts end
+    fun mk_mode arg_modes = foldr1 Fun (arg_modes @ [Bool])
+    val param_modes =
+      map (fn (T, NONE) => Input | (T, SOME is) => mk_mode (translate_smode (binder_types T) is))
+        (Ts1 ~~ iss)
+  in
+    mk_mode (param_modes @ translate_smode Ts2 is)
+  end;
+
+fun string_of_mode thy constname mode =
+  string_of_mode' (translate_mode (Sign.the_const_type thy constname) mode)
+
+fun eq_mode' (Fun (m1, m2), Fun (m3, m4)) = eq_mode' (m1, m3) andalso eq_mode' (m2, m4)
+  | eq_mode' (Pair (m1, m2), Pair (m3, m4)) = eq_mode' (m1, m3) andalso eq_mode' (m2, m4)
+  | eq_mode' (Pair (m1, m2), Input) = eq_mode' (m1, Input) andalso eq_mode' (m2, Input)
+  | eq_mode' (Pair (m1, m2), Output) = eq_mode' (m1, Output) andalso eq_mode' (m2, Output)
+  | eq_mode' (Input, Pair (m1, m2)) = eq_mode' (Input, m1) andalso eq_mode' (Input, m2)
+  | eq_mode' (Output, Pair (m1, m2)) = eq_mode' (Output, m1) andalso eq_mode' (Output, m2)
+  | eq_mode' (Input, Input) = true
+  | eq_mode' (Output, Output) = true
+  | eq_mode' (Bool, Bool) = true
+  | eq_mode' _ = false
+(* FIXME: remove! *)
+fun eq_mode'_mode (mode', (iss, is)) =
+  let
+    val arg_modes = strip_fun_mode mode'
+    val (arg_modes1, arg_modes2) = chop (length iss) arg_modes
+    fun eq_arg Input NONE = true
+      | eq_arg _ NONE = false
+      | eq_arg mode (SOME is) =
+        let
+          val modes = dest_tuple_mode mode
+        in
+          forall (fn i => nth modes (i - 1) = Input) is
+            andalso forall (fn i => nth modes (i - 1) = Output)
+              (subtract (op =) is (1 upto length modes))
+        end
+    fun eq_mode'_smode mode' is =
+      forall (fn (i, t) => eq_arg (nth mode' (i - 1)) t) is
+        andalso forall (fn i => (nth mode' (i - 1) = Output))
+          (subtract (op =) (map fst is) (1 upto length mode'))
+  in
+    forall (fn (m, NONE) => m = Input | (m, SOME is) => eq_mode'_smode (strip_fun_mode m) is)
+      (arg_modes1 ~~ iss)
+    andalso eq_mode'_smode arg_modes2 is
+  end
+
+
 (* general syntactic functions *)
 
 (*Like dest_conj, but flattens conjunctions however nested*)
@@ -70,7 +163,20 @@
 fun is_predT (T as Type("fun", [_, _])) = (snd (strip_type T) = HOLogic.boolT)
   | is_predT _ = false
 
-  
+(* guessing number of parameters *)
+fun find_indexes pred xs =
+  let
+    fun find is n [] = is
+      | find is n (x :: xs) = find (if pred x then (n :: is) else is) (n + 1) xs;
+  in rev (find [] 0 xs) end;
+
+fun guess_nparams T =
+  let
+    val argTs = binder_types T
+    val nparams = fold Integer.max
+      (map (fn x => x + 1) (find_indexes is_predT argTs)) 0
+  in nparams end;
+
 (*** check if a term contains only constructor functions ***)
 fun is_constrt thy =
   let
@@ -159,7 +265,7 @@
 (* Different options for compiler *)
 
 datatype options = Options of {  
-  expected_modes : (string * mode list) option,
+  expected_modes : (string * mode' list) option,
   show_steps : bool,
   show_proof_trace : bool,
   show_intermediate_results : bool,
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Thu Nov 12 09:10:07 2009 +0100
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Thu Nov 12 09:10:16 2009 +0100
@@ -128,22 +128,15 @@
 (* destruction of intro rules *)
 
 (* FIXME: look for other place where this functionality was used before *)
-fun strip_intro_concl nparams intro = let
-  val _ $ u = Logic.strip_imp_concl intro
-  val (pred, all_args) = strip_comb u
-  val (params, args) = chop nparams all_args
-in (pred, (params, args)) end
+fun strip_intro_concl nparams intro =
+  let
+    val _ $ u = Logic.strip_imp_concl intro
+    val (pred, all_args) = strip_comb u
+    val (params, args) = chop nparams all_args
+  in (pred, (params, args)) end
 
 (** data structures **)
 
-(* new datatype for modes: *)
-(*
-datatype instantiation = Input | Output
-type arg_mode = Tuple of instantiation list | Atom of instantiation | HigherOrderMode of mode
-type mode = arg_mode list
-type tmode = Mode of mode * 
-*)
-
 fun gen_split_smode (mk_tuple, strip_tuple) smode ts =
   let
     fun split_tuple' _ _ [] = ([], [])
@@ -274,7 +267,7 @@
     (AList.lookup (op =) (snd (#functions (the_pred_data thy name))) mode)
 
 fun the_predfun_data thy name mode = case lookup_predfun_data thy name mode
-  of NONE => error ("No function defined for mode " ^ string_of_mode mode ^
+  of NONE => error ("No function defined for mode " ^ string_of_mode thy name mode ^
     " of predicate " ^ name)
    | SOME data => data;
 
@@ -291,7 +284,7 @@
   (AList.lookup (op =) (snd (#random_functions (the_pred_data thy name))) mode)
 
 fun the_random_function_data thy name mode = case lookup_random_function_data thy name mode of
-     NONE => error ("No random function defined for mode " ^ string_of_mode mode ^
+     NONE => error ("No random function defined for mode " ^ string_of_mode thy name mode ^
        " of predicate " ^ name)
    | SOME data => data
 
@@ -310,7 +303,7 @@
 
 fun the_depth_limited_function_data thy name mode =
   case lookup_depth_limited_function_data thy name mode of
-    NONE => error ("No depth-limited function defined for mode " ^ string_of_mode mode
+    NONE => error ("No depth-limited function defined for mode " ^ string_of_mode thy name mode
       ^ " of predicate " ^ name)
    | SOME data => data
 
@@ -325,7 +318,7 @@
     (AList.lookup (op =) (snd (#annotated_functions (the_pred_data thy name))) mode)
 
 fun the_annotated_function_data thy name mode = case lookup_annotated_function_data thy name mode
-  of NONE => error ("No annotated function defined for mode " ^ string_of_mode mode
+  of NONE => error ("No annotated function defined for mode " ^ string_of_mode thy name mode
     ^ " of predicate " ^ name)
    | SOME data => data
 
@@ -337,17 +330,17 @@
 
 (* diagnostic display functions *)
 
-fun print_modes options modes =
+fun print_modes options thy modes =
   if show_modes options then
     tracing ("Inferred modes:\n" ^
       cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map
-        string_of_mode ms)) modes))
+        (string_of_mode thy s) ms)) modes))
   else ()
 
 fun print_pred_mode_table string_of_entry thy pred_mode_table =
   let
-    fun print_mode pred (mode, entry) =  "mode : " ^ (string_of_mode mode)
-      ^ (string_of_entry pred mode entry)  
+    fun print_mode pred (mode, entry) =  "mode : " ^ string_of_mode thy pred mode
+      ^ string_of_entry pred mode entry
     fun print_pred (pred, modes) =
       "predicate " ^ pred ^ ": " ^ cat_lines (map (print_mode pred) modes)
     val _ = tracing (cat_lines (map print_pred pred_mode_table))
@@ -417,25 +410,29 @@
     fun print (pred, modes) u =
       let
         val _ = writeln ("predicate: " ^ pred)
-        val _ = writeln ("modes: " ^ (commas (map string_of_mode modes)))
-      in u end  
+        val _ = writeln ("modes: " ^ (commas (map (string_of_mode thy pred) modes)))
+      in u end
   in
     fold print (all_modes_of thy) ()
   end
 
 (* validity checks *)
 
-fun check_expected_modes (options : Predicate_Compile_Aux.options) modes =
-  case expected_modes options of
-    SOME (s, ms) => (case AList.lookup (op =) modes s of
-      SOME modes =>
-        if not (eq_set (op =) (ms, modes)) then
-          error ("expected modes were not inferred:\n"
-          ^ "inferred modes for " ^ s ^ ": " ^ commas (map string_of_mode modes)
-          ^ "\n expected modes for " ^ s ^ ": " ^ commas (map string_of_mode ms))
-        else ()
-      | NONE => ())
-  | NONE => ()
+fun check_expected_modes preds (options : Predicate_Compile_Aux.options) modes =
+      case expected_modes options of
+      SOME (s, ms) => (case AList.lookup (op =) modes s of
+        SOME modes =>
+          let
+            val modes' = map (translate_mode (the (AList.lookup (op =) preds s))) modes
+          in
+            if not (eq_set eq_mode' (ms, modes')) then
+              error ("expected modes were not inferred:\n"
+              ^ "  inferred modes for " ^ s ^ ": " ^ commas (map string_of_mode' modes')  ^ "\n"
+              ^ "  expected modes for " ^ s ^ ": " ^ commas (map string_of_mode' ms))
+            else ()
+          end
+        | NONE => ())
+    | NONE => ()
 
 (* importing introduction rules *)
 
@@ -653,20 +650,6 @@
   end;
 *)
 
-(* guessing number of parameters *)
-fun find_indexes pred xs =
-  let
-    fun find is n [] = is
-      | find is n (x :: xs) = find (if pred x then (n :: is) else is) (n + 1) xs;
-  in rev (find [] 0 xs) end;
-
-fun guess_nparams T =
-  let
-    val argTs = binder_types T
-    val nparams = fold Integer.max
-      (map (fn x => x + 1) (find_indexes is_predT argTs)) 0
-  in nparams end;
-
 fun add_intro thm thy = let
    val (name, T) = dest_Const (fst (strip_intro_concl 0 (prop_of thm)))
    fun cons_intro gr =
@@ -1081,7 +1064,7 @@
   if show_mode_inference options then
     let
       val _ = tracing ("Clause " ^ string_of_int (i + 1) ^ " of " ^
-      p ^ " violates mode " ^ string_of_mode m)
+      p ^ " violates mode " ^ string_of_mode thy p m)
       val _ = tracing (string_of_clause thy p (nth rs i))
     in () end
   else ()
@@ -2264,8 +2247,8 @@
     val moded_clauses =
       #infer_modes (dest_steps steps) options thy extra_modes all_modes param_vs clauses
     val modes = map (fn (p, mps) => (p, map fst mps)) moded_clauses
-    val _ = check_expected_modes options modes
-    val _ = print_modes options modes
+    val _ = check_expected_modes preds options modes
+    val _ = print_modes options thy modes
       (*val _ = print_moded_clauses thy moded_clauses*)
     val _ = print_step options "Defining executable functions..."
     val thy' = fold (#define_functions (dest_steps steps) preds) modes thy
@@ -2389,7 +2372,8 @@
   additional_arguments = K [],
   wrap_compilation =
     fn compfuns => fn s => fn T => fn mode => fn additional_arguments => fn compilation =>
-      mk_tracing ("calling predicate " ^ s ^ " with mode " ^ string_of_mode mode) compilation,
+      mk_tracing ("calling predicate " ^ s ^
+        " with mode " ^ string_of_mode' (translate_mode T mode)) compilation,
   transform_additional_arguments = K I : (indprem -> term list -> term list)
   }