merged
authorbulwahn
Tue, 28 Sep 2010 13:44:06 +0200
changeset 39768 1c46d4f8afd2
parent 39767 327e463531e4 (diff)
parent 39759 b4bd83468600 (current diff)
child 39769 5bcf4253d579
child 39773 38852e989efa
merged
src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
--- a/src/HOL/Predicate_Compile_Examples/Predicate_Compile_Tests.thy	Tue Sep 28 12:48:05 2010 +0200
+++ b/src/HOL/Predicate_Compile_Examples/Predicate_Compile_Tests.thy	Tue Sep 28 13:44:06 2010 +0200
@@ -1483,6 +1483,34 @@
 
 thm detect_switches9.equation
 
+text {* The higher-order predicate r is in an output term *}
+
+datatype result = Result bool
+
+inductive fixed_relation :: "'a => bool"
+
+inductive test_relation_in_output_terms :: "('a => bool) => 'a => result => bool"
+where
+  "test_relation_in_output_terms r x (Result (r x))"
+| "test_relation_in_output_terms r x (Result (fixed_relation x))"
+
+code_pred test_relation_in_output_terms .
+
+thm test_relation_in_output_terms.equation
+
+
+text {*
+  We want that the argument r is not treated as a higher-order relation, but simply as input.
+*}
+
+inductive test_uninterpreted_relation :: "('a => bool) => 'a list => bool"
+where
+  "list_all r xs ==> test_uninterpreted_relation r xs"
+
+code_pred (modes: i => i => bool) test_uninterpreted_relation .
+
+thm test_uninterpreted_relation.equation
+
 
 
 end
--- a/src/HOL/Tools/Predicate_Compile/code_prolog.ML	Tue Sep 28 12:48:05 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/code_prolog.ML	Tue Sep 28 13:44:06 2010 +0200
@@ -369,7 +369,7 @@
   let
     val options = Predicate_Compile_Aux.default_options
     val mode_analysis_options =
-      {use_random = true, reorder_premises = true, infer_pos_and_neg_modes = true}
+      {use_generators = true, reorder_premises = true, infer_pos_and_neg_modes = true}
     fun infer prednames (gr, (pos_modes, neg_modes, random)) =
       let
         val (lookup_modes, lookup_neg_modes, needs_random) =
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Tue Sep 28 12:48:05 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Tue Sep 28 13:44:06 2010 +0200
@@ -68,7 +68,10 @@
   val prepare_intrs : options -> Proof.context -> string list -> thm list ->
     ((string * typ) list * string list * string list * (string * mode list) list *
       (string *  (Term.term list * Predicate_Compile_Aux.indprem list) list) list)
-  type mode_analysis_options = {use_random : bool, reorder_premises : bool, infer_pos_and_neg_modes : bool}  
+  type mode_analysis_options =
+   {use_generators : bool,
+    reorder_premises : bool,
+    infer_pos_and_neg_modes : bool}  
   datatype mode_derivation = Mode_App of mode_derivation * mode_derivation | Context of mode
     | Mode_Pair of mode_derivation * mode_derivation | Term of mode
   val head_mode_of : mode_derivation -> mode
@@ -431,19 +434,33 @@
 fun check_matches_type ctxt predname T ms =
   let
     fun check (m as Fun (m1, m2)) (Type("fun", [T1,T2])) = check m1 T1 andalso check m2 T2
-      | check m (T as Type("fun", _)) =
-        if body_type T = @{typ bool} then false else (m = Input orelse m = Output)
+      | check m (T as Type("fun", _)) = (m = Input orelse m = Output)
       | check (Pair (m1, m2)) (Type (@{type_name Product_Type.prod}, [T1, T2])) =
           check m1 T1 andalso check m2 T2 
       | check Input T = true
       | check Output T = true
       | check Bool @{typ bool} = true
       | check _ _ = false
+    fun check_consistent_modes ms =
+      if forall (fn Fun (m1', m2') => true | _ => false) ms then
+        pairself check_consistent_modes (split_list (map (fn Fun (m1, m2) => (m1, m2)) ms))
+        |> (fn (res1, res2) => res1 andalso res2) 
+      else if forall (fn Input => true | Output => true | Pair _ => true | _ => false) ms then
+        true
+      else if forall (fn Bool => true | _ => false) ms then
+        true
+      else
+        false
     val _ = map
       (fn mode =>
-        if (forall (uncurry check) (strip_fun_mode mode ~~ binder_types T)) then ()
+        if length (strip_fun_mode mode) = length (binder_types T)
+          andalso (forall (uncurry check) (strip_fun_mode mode ~~ binder_types T)) then ()
         else error (string_of_mode mode ^ " is not a valid mode for " ^ Syntax.string_of_typ ctxt T
         ^ " at predicate " ^ predname)) ms
+    val _ =
+     if check_consistent_modes ms then ()
+     else error (commas (map string_of_mode ms) ^
+       " are inconsistent modes for predicate " ^ predname)
   in
     ms
   end
@@ -1052,7 +1069,10 @@
 
 (** mode analysis **)
 
-type mode_analysis_options = {use_random : bool, reorder_premises : bool, infer_pos_and_neg_modes : bool}
+type mode_analysis_options =
+  {use_generators : bool,
+  reorder_premises : bool,
+  infer_pos_and_neg_modes : bool}
 
 fun is_constrt thy =
   let
@@ -1340,7 +1360,7 @@
     (* prefer non-random modes *)
     fun random_mode_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) =
       int_ord (if random_mode_in_deriv modes t1 deriv1 then 1 else 0,
-        if random_mode_in_deriv modes t1 deriv1 then 1 else 0)
+               if random_mode_in_deriv modes t2 deriv2 then 1 else 0)
     (* prefer modes with more input and less output *)
     fun output_mode_ord ((t1, deriv1, mvars1), (t2, deriv2, mvars2)) =
       int_ord (number_of_output_positions (head_mode_of deriv1),
@@ -1411,7 +1431,7 @@
           SOME (p, SOME (deriv, [])) => check_mode_prems ((p, deriv) :: acc_ps) rnd
             (known_vs_after p vs) (filter_out (equal p) ps)
         | SOME (p, SOME (deriv, missing_vars)) =>
-          if #use_random mode_analysis_options andalso pol then
+          if #use_generators mode_analysis_options andalso pol then
             check_mode_prems ((p, deriv) :: (map
               (fn v => (Generator (v, the (AList.lookup (op =) vTs v)), Term Output))
                 (distinct (op =) missing_vars))
@@ -1426,7 +1446,7 @@
       if forall (is_constructable vs) (in_ts @ out_ts) then
         SOME (ts, rev acc_ps, rnd)
       else
-        if #use_random mode_analysis_options andalso pol then
+        if #use_generators mode_analysis_options andalso pol then
           let
              val generators = map
               (fn v => (Generator (v, the (AList.lookup (op =) vTs v)), Term Output))
@@ -1711,9 +1731,11 @@
               fold_map (collect_non_invertible_subterms ctxt) out_ts' (names, []);
             val (out_ts''', (names'', constr_vs)) = fold_map distinct_v
               out_ts'' (names', map (rpair []) vs);
+            val processed_out_ts = map (compile_arg compilation_modifiers additional_arguments
+              ctxt param_modes) out_ts
           in
             compile_match constr_vs (eqs @ eqs') out_ts'''
-              (mk_single compfuns (HOLogic.mk_tuple out_ts))
+              (mk_single compfuns (HOLogic.mk_tuple processed_out_ts))
           end
       | compile_prems out_ts vs names ((p, deriv) :: ps) =
           let
@@ -2262,12 +2284,14 @@
     if valid_expr expr then
       ((*tracing "expression is valid";*) Seq.single st)
     else
-      ((*tracing "expression is not valid";*) Seq.empty) (*error "check_format: wrong format"*)
+      ((*tracing "expression is not valid";*) Seq.empty) (* error "check_format: wrong format" *)
   end
 
-fun prove_match options ctxt out_ts =
+fun prove_match options ctxt nargs out_ts =
   let
     val thy = ProofContext.theory_of ctxt
+    val eval_if_P =
+      @{lemma "P ==> Predicate.eval x z ==> Predicate.eval (if P then x else y) z" by simp} 
     fun get_case_rewrite t =
       if (is_constructor thy t) then let
         val case_rewrites = (#case_rewrites (Datatype.the_info thy
@@ -2281,12 +2305,22 @@
      (* make this simpset better! *)
     asm_full_simp_tac (HOL_basic_ss' addsimps simprules) 1
     THEN print_tac options "after prove_match:"
-    THEN (DETERM (TRY (EqSubst.eqsubst_tac ctxt [0] [@{thm HOL.if_P}] 1
-           THEN (REPEAT_DETERM (rtac @{thm conjI} 1 THEN (SOLVED (asm_simp_tac HOL_basic_ss' 1))))
-           THEN print_tac options "if condition to be solved:"
-           THEN (SOLVED (asm_simp_tac HOL_basic_ss' 1 THEN print_tac options "after if simp; in SOLVED:"))
-           THEN check_format thy
-           THEN print_tac options "after if simplification - a TRY block")))
+    THEN (DETERM (TRY 
+           (rtac eval_if_P 1
+           THEN (SUBPROOF (fn {context = ctxt, params, prems, asms, concl, schematics} =>
+             (REPEAT_DETERM (rtac @{thm conjI} 1
+             THEN (SOLVED (asm_simp_tac HOL_basic_ss' 1))))
+             THEN print_tac options "if condition to be solved:"
+             THEN asm_simp_tac HOL_basic_ss' 1
+             THEN TRY (
+                let
+                  val prems' = maps dest_conjunct_prem (take nargs prems)
+                in
+                  MetaSimplifier.rewrite_goal_tac
+                    (map (fn th => th RS @{thm sym} RS @{thm eq_reflection}) prems') 1
+                end
+             THEN REPEAT_DETERM (rtac @{thm refl} 1))
+             THEN print_tac options "after if simp; in SUBPROOF") ctxt 1))))
     THEN print_tac options "after if simplification"
   end;
 
@@ -2314,10 +2348,18 @@
   let
     val (in_ts, clause_out_ts) = split_mode mode ts;
     fun prove_prems out_ts [] =
-      (prove_match options ctxt out_ts)
+      (prove_match options ctxt nargs out_ts)
       THEN print_tac options "before simplifying assumptions"
       THEN asm_full_simp_tac HOL_basic_ss' 1
       THEN print_tac options "before single intro rule"
+      THEN Subgoal.FOCUS_PREMS
+             (fn {context = ctxt, params = params, prems, asms, concl, schematics} =>
+              let
+                val prems' = maps dest_conjunct_prem (take nargs prems)
+              in
+                MetaSimplifier.rewrite_goal_tac
+                  (map (fn th => th RS @{thm sym} RS @{thm eq_reflection}) prems') 1
+              end) ctxt 1
       THEN (rtac (if null clause_out_ts then @{thm singleI_unit} else @{thm singleI}) 1)
     | prove_prems out_ts ((p, deriv) :: ps) =
       let
@@ -2356,9 +2398,10 @@
                  @{thm snd_conv}, @{thm pair_collapse}, @{thm Product_Type.split_conv}]) 1
               THEN (if (is_some name) then
                   print_tac options "before applying not introduction rule"
-                  THEN rotate_tac premposition 1
-                  THEN etac (the neg_intro_rule) 1
-                  THEN rotate_tac (~premposition) 1
+                  THEN Subgoal.FOCUS_PREMS
+                    (fn {context = ctxt, params = params, prems, asms, concl, schematics} =>
+                      rtac (the neg_intro_rule) 1
+                      THEN rtac (nth prems premposition) 1) ctxt 1
                   THEN print_tac options "after applying not introduction rule"
                   THEN (EVERY (map2 (prove_param options ctxt nargs) params param_derivations))
                   THEN (REPEAT_DETERM (atac 1))
@@ -2376,7 +2419,7 @@
            THEN prove_sidecond ctxt t
            THEN print_tac options "after sidecond:"
            THEN prove_prems [] ps)
-      in (prove_match options ctxt out_ts)
+      in (prove_match options ctxt nargs out_ts)
           THEN rest_tac
       end;
     val prems_tac = prove_prems in_ts moded_ps
@@ -2675,8 +2718,9 @@
         [] =>
           let
             val T = snd (hd preds)
+            val one_mode = hd (the (AList.lookup (op =) all_modes (fst (hd preds))))
             val paramTs =
-              ho_argsT_of_typ (binder_types T)
+              ho_argsT_of one_mode (binder_types T)
             val param_names = Name.variant_list [] (map (fn i => "p" ^ string_of_int i)
               (1 upto length paramTs))
           in
@@ -2684,9 +2728,10 @@
           end
       | (intr :: _) =>
         let
-          val (p, args) = strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl intr)) 
+          val (p, args) = strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl intr))
+          val one_mode = hd (the (AList.lookup (op =) all_modes (fst (dest_Const p))))
         in
-          ho_args_of_typ (snd (dest_Const p)) args
+          ho_args_of one_mode args
         end
     val param_vs = map (fst o dest_Free) params
     fun add_clause intr clauses =
@@ -2790,7 +2835,7 @@
   add_code_equations : Proof.context -> (string * typ) list
     -> (string * thm list) list -> (string * thm list) list,
   comp_modifiers : Comp_Mod.comp_modifiers,
-  use_random : bool,
+  use_generators : bool,
   qname : bstring
   }
 
@@ -2901,10 +2946,10 @@
       (fn preds => fn thy =>
         if not (forall (defined (ProofContext.init_global thy)) preds) then
           let
-            val mode_analysis_options = {use_random = #use_random (dest_steps steps),
+            val mode_analysis_options = {use_generators = #use_generators (dest_steps steps),
               reorder_premises =
                 not (no_topmost_reordering options andalso not (null (inter (op =) preds names))),
-              infer_pos_and_neg_modes = #use_random (dest_steps steps)}
+              infer_pos_and_neg_modes = #use_generators (dest_steps steps)}
           in
             add_equations_of steps mode_analysis_options options preds thy
           end
@@ -2921,7 +2966,7 @@
   prove = prove,
   add_code_equations = add_code_equations,
   comp_modifiers = predicate_comp_modifiers,
-  use_random = false,
+  use_generators = false,
   qname = "equation"})
 
 val add_depth_limited_equations = gen_add_equations
@@ -2933,7 +2978,7 @@
   prove = prove_by_skip,
   add_code_equations = K (K I),
   comp_modifiers = depth_limited_comp_modifiers,
-  use_random = false,
+  use_generators = false,
   qname = "depth_limited_equation"})
 
 val add_annotated_equations = gen_add_equations
@@ -2945,7 +2990,7 @@
   prove = prove_by_skip,
   add_code_equations = K (K I),
   comp_modifiers = annotated_comp_modifiers,
-  use_random = false,
+  use_generators = false,
   qname = "annotated_equation"})
 
 val add_random_equations = gen_add_equations
@@ -2957,7 +3002,7 @@
   comp_modifiers = random_comp_modifiers,
   prove = prove_by_skip,
   add_code_equations = K (K I),
-  use_random = true,
+  use_generators = true,
   qname = "random_equation"})
 
 val add_depth_limited_random_equations = gen_add_equations
@@ -2969,7 +3014,7 @@
   comp_modifiers = depth_limited_random_comp_modifiers,
   prove = prove_by_skip,
   add_code_equations = K (K I),
-  use_random = true,
+  use_generators = true,
   qname = "depth_limited_random_equation"})
 
 val add_dseq_equations = gen_add_equations
@@ -2981,7 +3026,7 @@
   prove = prove_by_skip,
   add_code_equations = K (K I),
   comp_modifiers = dseq_comp_modifiers,
-  use_random = false,
+  use_generators = false,
   qname = "dseq_equation"})
 
 val add_random_dseq_equations = gen_add_equations
@@ -2999,7 +3044,7 @@
   prove = prove_by_skip,
   add_code_equations = K (K I),
   comp_modifiers = pos_random_dseq_comp_modifiers,
-  use_random = true,
+  use_generators = true,
   qname = "random_dseq_equation"})
 
 val add_new_random_dseq_equations = gen_add_equations
@@ -3017,7 +3062,7 @@
   prove = prove_by_skip,
   add_code_equations = K (K I),
   comp_modifiers = new_pos_random_dseq_comp_modifiers,
-  use_random = true,
+  use_generators = true,
   qname = "new_random_dseq_equation"})
 
 (** user interface **)