changed preprocessing due to problems with LightweightJava; added transfer of thereoms; changed the type of mode to support tuples in the predicate compiler
authorbulwahn
Wed, 23 Sep 2009 16:20:12 +0200
changeset 32663 c2f63118b251
parent 32662 2faf1148c062
child 32664 5d4f32b02450
changed preprocessing due to problems with LightweightJava; added transfer of thereoms; changed the type of mode to support tuples in the predicate compiler
src/HOL/ex/predicate_compile.ML
--- a/src/HOL/ex/predicate_compile.ML	Wed Sep 23 16:20:12 2009 +0200
+++ b/src/HOL/ex/predicate_compile.ML	Wed Sep 23 16:20:12 2009 +0200
@@ -6,7 +6,9 @@
 
 signature PREDICATE_COMPILE =
 sig
-  type mode = int list option list * int list
+  type smode = (int * int list option) list
+  type mode = smode option list * smode
+  datatype tmode = Mode of mode * smode * tmode option list;
   (*val add_equations_of: bool -> string list -> theory -> theory *)
   val register_predicate : (thm list * thm * int) -> theory -> theory
   val is_registered : theory -> string -> bool
@@ -57,10 +59,9 @@
     mk_map : typ -> typ -> term -> term -> term,
     lift_pred : term -> term
   };  
-  datatype tmode = Mode of mode * int list * tmode option list;
   type moded_clause = term list * (indprem * tmode) list
   type 'a pred_mode_table = (string * (mode * 'a) list) list
-  val infer_modes : bool -> theory -> (string * (int list option list * int list) list) list
+  val infer_modes : bool -> theory -> (string * mode list) list
     -> (string * (int option list * int)) list -> string list
     -> (string * (term list * indprem list) list) list
     -> (moded_clause list) pred_mode_table
@@ -183,25 +184,51 @@
 
 (** data structures **)
 
-type smode = int list;
+type smode = (int * int list option) list;
 type mode = smode option list * smode;
-datatype tmode = Mode of mode * int list * tmode option list;
+datatype tmode = Mode of mode * smode * tmode option list;
 
-fun split_smode is ts =
+fun gen_split_smode (mk_tuple, strip_tuple) smode ts =
   let
+    fun split_tuple' _ _ [] = ([], [])
+    | split_tuple' is i (t::ts) =
+      (if i mem is then apfst else apsnd) (cons t)
+        (split_tuple' is (i+1) ts)
+    fun split_tuple is t = split_tuple' is 1 (strip_tuple t)
     fun split_smode' _ _ [] = ([], [])
-      | split_smode' is i (t::ts) = (if i mem is then apfst else apsnd) (cons t)
-          (split_smode' is (i+1) ts)
-  in split_smode' is 1 ts end
+      | split_smode' smode i (t::ts) =
+        (if i mem (map fst smode) then
+          case (the (AList.lookup (op =) smode i)) of
+            NONE => apfst (cons t)
+            | SOME is =>
+              let
+                val (ts1, ts2) = split_tuple is t
+                fun cons_tuple ts = if null ts then I else cons (mk_tuple ts)
+                in (apfst (cons_tuple ts1)) o (apsnd (cons_tuple ts2)) end
+          else apsnd (cons t))
+        (split_smode' smode (i+1) ts)
+  in split_smode' smode 1 ts end
 
-fun split_mode (iss, is) ts =
+val split_smode = gen_split_smode (HOLogic.mk_tuple, HOLogic.strip_tuple)   
+val split_smodeT = gen_split_smode (HOLogic.mk_tupleT, HOLogic.strip_tupleT)
+
+fun gen_split_mode split_smode (iss, is) ts =
   let
     val (t1, t2) = chop (length iss) ts 
   in (t1, split_smode is t2) end
 
+val split_mode = gen_split_mode split_smode
+val split_modeT = gen_split_mode split_smodeT
+
+fun string_of_smode js =
+    commas (map
+      (fn (i, is) =>
+        string_of_int i ^ (case is of NONE => ""
+    | SOME is => "p" ^ enclose "[" "]" (commas (map string_of_int is)))) js)
+
 fun string_of_mode (iss, is) = space_implode " -> " (map
   (fn NONE => "X"
-    | SOME js => enclose "[" "]" (commas (map string_of_int js)))
+    | SOME js => enclose "[" "]" (string_of_smode js))
        (iss @ [SOME is]));
 
 fun string_of_tmode (Mode (predmode, termmode, param_modes)) =
@@ -282,11 +309,11 @@
 
 val all_preds_of = Graph.keys o PredData.get
 
-val intros_of = #intros oo the_pred_data
+fun intros_of thy = map (Thm.transfer thy) o #intros o the_pred_data thy
 
 fun the_elim_of thy name = case #elim (the_pred_data thy name)
  of NONE => error ("No elimination rule for predicate " ^ quote name)
-  | SOME thm => thm 
+  | SOME thm => Thm.transfer thy thm 
   
 val has_elim = is_some o #elim oo the_pred_data;
 
@@ -367,10 +394,10 @@
     "Generator for " ^ v ^ " of Type " ^ (Syntax.string_of_typ_global thy T)
   | string_of_moded_prem thy (Negprem (ts, p), Mode (_, is, _)) =
     (Syntax.string_of_term_global thy (list_comb (p, ts))) ^
-    "(negative mode: " ^ (space_implode ", " (map string_of_int is)) ^ ")"
+    "(negative mode: " ^ string_of_smode is ^ ")"
   | string_of_moded_prem thy (Sidecond t, Mode (_, is, _)) =
     (Syntax.string_of_term_global thy t) ^
-    "(sidecond mode: " ^ (space_implode ", " (map string_of_int is)) ^ ")"    
+    "(sidecond mode: " ^ string_of_smode is ^ ")"    
   | string_of_moded_prem _ _ = error "string_of_moded_prem: unimplemented"
      
 fun print_moded_clauses thy =
@@ -435,6 +462,8 @@
 
 fun preprocess_elim thy nparams elimrule =
   let
+    val _ = Output.tracing ("Preprocessing elimination rule "
+      ^ (Display.string_of_thm_global thy elimrule))
     fun replace_eqs (Const ("Trueprop", _) $ (Const ("op =", T) $ lhs $ rhs)) =
        HOLogic.mk_Trueprop (Const (@{const_name Predicate.eq}, T) $ lhs $ rhs)
      | replace_eqs t = t
@@ -450,11 +479,21 @@
      end 
     val cases' = map preprocess_case (tl prems)
     val elimrule' = Logic.list_implies ((hd prems) :: cases', Thm.concl_of elimrule)
+    (*
+    (*val _ =  Output.tracing ("elimrule': "^ (Syntax.string_of_term_global thy elimrule'))*)
+    val bigeq = (Thm.symmetric (Conv.implies_concl_conv (MetaSimplifier.rewrite true [@{thm Predicate.eq_is_eq}])
+         (cterm_of thy elimrule')))
+    val _ = Output.tracing ("bigeq:" ^ (Display.string_of_thm_global thy bigeq))   
+    val res = 
+    Thm.equal_elim bigeq
+      
+      elimrule
+    *)
+    val t = (fn {...} => mycheat_tac thy 1)
+    val eq = Goal.prove (ProofContext.init thy) [] [] (Logic.mk_equals ((Thm.prop_of elimrule), elimrule')) t
+    val _ = Output.tracing "Preprocessed elimination rule"
   in
-    Thm.equal_elim
-      (Thm.symmetric (Conv.implies_concl_conv (MetaSimplifier.rewrite true [@{thm eq_is_eq}])
-         (cterm_of thy elimrule')))
-      elimrule
+    Thm.equal_elim eq elimrule
   end;
 
 (* special case: predicate with no introduction rule *)
@@ -629,7 +668,7 @@
 fun funT_of compfuns (iss, is) T =
   let
     val Ts = binder_types T
-    val (paramTs, (inargTs, outargTs)) = split_mode (iss, is) Ts
+    val (paramTs, (inargTs, outargTs)) = split_modeT (iss, is) Ts
     val paramTs' = map2 (fn NONE => I | SOME is => funT_of compfuns ([], is)) iss paramTs 
   in
     (paramTs' @ inargTs) ---> (mk_predT compfuns (mk_tupleT outargTs))
@@ -638,7 +677,7 @@
 fun sizelim_funT_of compfuns (iss, is) T =
   let
     val Ts = binder_types T
-    val (paramTs, (inargTs, outargTs)) = split_mode (iss, is) Ts
+    val (paramTs, (inargTs, outargTs)) = split_modeT (iss, is) Ts
     val paramTs' = map2 (fn SOME is => sizelim_funT_of compfuns ([], is) | NONE => I) iss paramTs 
   in
     (paramTs' @ inargTs @ [@{typ "code_numeral"}]) ---> (mk_predT compfuns (mk_tupleT outargTs))
@@ -868,7 +907,7 @@
 *)
 fun modes_of_term modes t =
   let
-    val ks = 1 upto length (binder_types (fastype_of t));
+    val ks = map_index (fn (i, T) => (i, NONE)) (binder_types (fastype_of t));
     val default = [Mode (([], ks), ks, [])];
     fun mk_modes name args = Option.map (maps (fn (m as (iss, is)) =>
         let
@@ -877,10 +916,10 @@
               error ("Too few arguments for inductive predicate " ^ name)
             else chop (length iss) args;
           val k = length args2;
-          val prfx = 1 upto k
+          val prfx = map (rpair NONE) (1 upto k)
         in
           if not (is_prefix op = prfx is) then [] else
-          let val is' = map (fn i => i - k) (List.drop (is, k))
+          let val is' = List.drop (is, k)
           in map (fn x => Mode (m, is', x)) (cprods (map
             (fn (NONE, _) => [NONE]
               | (SOME js, arg) => map SOME (filter
@@ -1003,7 +1042,8 @@
     (is_none o check_mode_clause with_generator thy param_vs modes gen_modes m) rs of
       ~1 => true
     | i => (Output.tracing ("Clause " ^ string_of_int (i + 1) ^ " of " ^
-      p ^ " violates mode " ^ string_of_mode m); false)) ms)
+      p ^ " violates mode " ^ string_of_mode m);
+        Output.tracing (commas (map (Syntax.string_of_term_global thy) (fst (nth rs i)))); false)) ms)
   end;
 
 fun get_modes_pred with_generator thy param_vs preds modes gen_modes (p, ms) =
@@ -1021,8 +1061,8 @@
 fun modes_of_arities arities =
   (map (fn (s, (ks, k)) => (s, cprod (cprods (map
             (fn NONE => [NONE]
-              | SOME k' => map SOME (subsets 1 k')) ks),
-            subsets 1 k))) arities)
+              | SOME k' => map SOME (map (map (rpair NONE)) (subsets 1 k'))) ks),
+    map (map (rpair NONE)) (subsets 1 k)))) arities)
   
 fun infer_modes with_generator thy extra_modes arities param_vs preds =
   let
@@ -1294,11 +1334,11 @@
 
 fun compile_pred compfuns mk_fun_of use_size thy all_vs param_vs s T mode moded_cls =
   let
-    val (Ts1, (Us1, Us2)) = split_mode mode (binder_types T)
+    val (Ts1, (Us1, Us2)) = split_modeT mode (binder_types T)
     val funT_of = if use_size then sizelim_funT_of else funT_of 
     val Ts1' = map2 (fn NONE => I | SOME is => funT_of compfuns ([], is)) (fst mode) Ts1
     val xnames = Name.variant_list (all_vs @ param_vs)
-      (map (fn i => "x" ^ string_of_int i) (snd mode));
+      (map (fn (i, NONE) => "x" ^ string_of_int i | (i, SOME s) => error "pair mode") (snd mode));
     val size_name = Name.variant (all_vs @ param_vs @ xnames) "size"
     (* termify code: val xs = map2 (fn s => fn T => Free (s, termifyT T)) xnames Us1; *)
     val xs = map2 (fn s => fn T => Free (s, T)) xnames Us1;
@@ -1390,7 +1430,7 @@
 fun create_constname_of_mode thy prefix name mode = 
   let
     fun string_of_mode mode = if null mode then "0"
-      else space_implode "_" (map string_of_int mode)
+      else space_implode "_" (map (fn (i, NONE) => string_of_int i | (i, SOME _) => error "pair mode") mode)
     val HOmode = space_implode "_and_"
       (fold (fn NONE => I | SOME mode => cons (string_of_mode mode)) (fst mode) [])
   in
@@ -1407,14 +1447,14 @@
       val mode_cbasename = Long_Name.base_name mode_cname
       val Ts = binder_types T
       val (Ts1, Ts2) = chop (length iss) Ts
-      val (Us1, Us2) =  split_smode is Ts2
+      val (Us1, Us2) =  split_smodeT is Ts2
       val Ts1' = map2 (fn NONE => I | SOME is => funT_of compfuns ([], is)) iss Ts1
       val funT = (Ts1' @ Us1) ---> (mk_predT compfuns (mk_tupleT Us2))
       val names = Name.variant_list []
         (map (fn i => "x" ^ string_of_int i) (1 upto (length Ts)));
-      val xs = map Free (names ~~ (Ts1' @ Ts2));                   
-      val (xparams, xargs) = chop (length iss) xs;
-      val (xins, xouts) = split_smode is xargs 
+      val xs = map Free (names ~~ (Ts1' @ Ts2))
+      val (xparams, xargs) = chop (length iss) xs
+      val (xins, xouts) = split_smode is xargs
       val (xparams', names') = fold_map mk_Eval_of ((xparams ~~ Ts1) ~~ iss) names
       fun mk_split_lambda [] t = lambda (Free (Name.variant names' "x", HOLogic.unitT)) t
         | mk_split_lambda [x] t = lambda x t
@@ -1572,7 +1612,8 @@
       | _ => nameTs
     val preds = preds_of t []
     val defs = map
-      (fn (pred, T) => predfun_definition_of thy pred ([], (1 upto (length (binder_types T)))))
+      (fn (pred, T) => predfun_definition_of thy pred
+        ([], map (rpair NONE) (1 upto (length (binder_types T)))))
         preds
   in 
     (* remove not_False_eq_True when simpset in prove_match is better *)
@@ -1734,7 +1775,8 @@
     | _ => nameTs
   val preds = preds_of t []
   val defs = map
-    (fn (pred, T) => predfun_definition_of thy pred ([], (1 upto (length (binder_types T)))))
+    (fn (pred, T) => predfun_definition_of thy pred 
+      ([], map (rpair NONE) (1 upto (length (binder_types T)))))
       preds
   in
    (* only simplify the one assumption *)
@@ -1991,7 +2033,7 @@
   are_not_defined = (fn thy => fn preds => true), (* TODO *)
   qname = "sizelim_equation"
   }
-  
+
 val add_quickcheck_equations = gen_add_equations
   {infer_modes = infer_modes_with_generator,
   create_definitions = rpred_create_definitions,
@@ -2036,7 +2078,6 @@
 local
 
 (* TODO: make TheoryDataFun to GenericDataFun & remove duplication of local theory and theory *)
-(* TODO: must create state to prove multiple cases *)
 fun generic_code_pred prep_const raw_const lthy =
   let
     val thy = ProofContext.theory_of lthy
@@ -2114,7 +2155,8 @@
     val user_mode = map_filter I (map_index
       (fn (i, t) => case t of Bound j => if j < length Ts then NONE
         else SOME (i+1) | _ => SOME (i+1)) args); (*FIXME dangling bounds should not occur*)
-    val modes = filter (fn Mode (_, is, _) => is = user_mode)
+    val user_mode' = map (rpair NONE) user_mode
+    val modes = filter (fn Mode (_, is, _) => is = user_mode')
       (modes_of_term (all_modes_of thy) (list_comb (pred, params)));
     val m = case modes
      of [] => error ("No mode possible for comprehension "
@@ -2122,7 +2164,7 @@
       | [m] => m
       | m :: _ :: _ => (warning ("Multiple modes possible for comprehension "
                 ^ Syntax.string_of_term_global thy t_compr); m);
-    val (inargs, outargs) = split_smode user_mode args;
+    val (inargs, outargs) = split_smode user_mode' args;
     val t_pred = list_comb (compile_expr NONE thy (m, list_comb (pred, params)), inargs);
     val t_eval = if null outargs then t_pred else let
         val outargs_bounds = map (fn Bound i => i) outargs;