src/HOL/Decision_Procs/approximation_generator.ML
changeset 58988 6ebf918128b9
child 59582 0fbed69ff081
equal deleted inserted replaced
58987:119680ebf37c 58988:6ebf918128b9
       
     1 (*  Title:      HOL/Decision_Procs/approximation_generator.ML
       
     2     Author:     Fabian Immler, TU Muenchen
       
     3 *)
       
     4 
       
     5 signature APPROXIMATION_GENERATOR =
       
     6 sig
       
     7   val custom_seed: int Config.T
       
     8   val precision: int Config.T
       
     9   val epsilon: real Config.T
       
    10   val approximation_generator:
       
    11     Proof.context ->
       
    12     (term * term list) list ->
       
    13     bool -> int list -> (bool * term list) option * Quickcheck.report option
       
    14   val setup: theory -> theory
       
    15 end;
       
    16 
       
    17 structure Approximation_Generator : APPROXIMATION_GENERATOR =
       
    18 struct
       
    19 
       
    20 val custom_seed = Attrib.setup_config_int @{binding quickcheck_approximation_custom_seed} (K ~1)
       
    21 
       
    22 val precision = Attrib.setup_config_int @{binding quickcheck_approximation_precision} (K 30)
       
    23 
       
    24 val epsilon = Attrib.setup_config_real @{binding quickcheck_approximation_epsilon} (K 0.0)
       
    25 
       
    26 val random_float = @{code "random_class.random::_ \<Rightarrow> _ \<Rightarrow> (float \<times> (unit \<Rightarrow> term)) \<times> _"}
       
    27 
       
    28 fun nat_of_term t =
       
    29   (HOLogic.dest_nat t handle TERM _ => snd (HOLogic.dest_number t)
       
    30     handle TERM _ => raise TERM ("nat_of_term", [t]));
       
    31 
       
    32 fun int_of_term t = snd (HOLogic.dest_number t) handle TERM _ => raise TERM ("int_of_term", [t]);
       
    33 
       
    34 fun real_of_man_exp m e = Real.fromManExp {man = Real.fromInt m, exp = e}
       
    35 
       
    36 fun mapprox_float (@{term Float} $ m $ e) = real_of_man_exp (int_of_term m) (int_of_term e)
       
    37   | mapprox_float t = Real.fromInt (snd (HOLogic.dest_number t))
       
    38       handle TERM _ => raise TERM ("mapprox_float", [t]);
       
    39 
       
    40 (* TODO: define using compiled terms? *)
       
    41 fun mapprox_floatarith (@{term Add} $ a $ b) xs = mapprox_floatarith a xs + mapprox_floatarith b xs
       
    42   | mapprox_floatarith (@{term Minus} $ a) xs = ~ (mapprox_floatarith a xs)
       
    43   | mapprox_floatarith (@{term Mult} $ a $ b) xs = mapprox_floatarith a xs * mapprox_floatarith b xs
       
    44   | mapprox_floatarith (@{term Inverse} $ a) xs = 1.0 / mapprox_floatarith a xs
       
    45   | mapprox_floatarith (@{term Cos} $ a) xs = Math.cos (mapprox_floatarith a xs)
       
    46   | mapprox_floatarith (@{term Arctan} $ a) xs = Math.atan (mapprox_floatarith a xs)
       
    47   | mapprox_floatarith (@{term Abs} $ a) xs = abs (mapprox_floatarith a xs)
       
    48   | mapprox_floatarith (@{term Max} $ a $ b) xs =
       
    49       Real.max (mapprox_floatarith a xs, mapprox_floatarith b xs)
       
    50   | mapprox_floatarith (@{term Min} $ a $ b) xs =
       
    51       Real.min (mapprox_floatarith a xs, mapprox_floatarith b xs)
       
    52   | mapprox_floatarith @{term Pi} _ = Math.pi
       
    53   | mapprox_floatarith (@{term Sqrt} $ a) xs = Math.sqrt (mapprox_floatarith a xs)
       
    54   | mapprox_floatarith (@{term Exp} $ a) xs = Math.exp (mapprox_floatarith a xs)
       
    55   | mapprox_floatarith (@{term Ln} $ a) xs = Math.ln (mapprox_floatarith a xs)
       
    56   | mapprox_floatarith (@{term Power} $ a $ n) xs =
       
    57       Math.pow (mapprox_floatarith a xs, Real.fromInt (nat_of_term n))
       
    58   | mapprox_floatarith (@{term Var} $ n) xs = nth xs (nat_of_term n)
       
    59   | mapprox_floatarith (@{term Num} $ m) _ = mapprox_float m
       
    60   | mapprox_floatarith t _ = raise TERM ("mapprox_floatarith", [t])
       
    61 
       
    62 fun mapprox_atLeastAtMost eps x a b xs =
       
    63     let
       
    64       val x' = mapprox_floatarith x xs
       
    65     in
       
    66       mapprox_floatarith a xs + eps <= x' andalso x' + eps <= mapprox_floatarith b xs
       
    67     end
       
    68 
       
    69 fun mapprox_form eps (@{term Bound} $ x $ a $ b $ f) xs =
       
    70     (not (mapprox_atLeastAtMost eps x a b xs)) orelse mapprox_form eps f xs
       
    71 | mapprox_form eps (@{term Assign} $ x $ a $ f) xs =
       
    72     (Real.!= (mapprox_floatarith x xs, mapprox_floatarith a xs)) orelse mapprox_form eps f xs
       
    73 | mapprox_form eps (@{term Less} $ a $ b) xs = mapprox_floatarith a xs + eps < mapprox_floatarith b xs
       
    74 | mapprox_form eps (@{term LessEqual} $ a $ b) xs = mapprox_floatarith a xs + eps <= mapprox_floatarith b xs
       
    75 | mapprox_form eps (@{term AtLeastAtMost} $ x $ a $ b) xs = mapprox_atLeastAtMost eps x a b xs
       
    76 | mapprox_form eps (@{term Conj} $ f $ g) xs = mapprox_form eps f xs andalso mapprox_form eps g xs
       
    77 | mapprox_form eps (@{term Disj} $ f $ g) xs = mapprox_form eps f xs orelse mapprox_form eps g xs
       
    78 | mapprox_form _ t _ = raise TERM ("mapprox_form", [t])
       
    79 
       
    80 fun dest_interpret_form (@{const "interpret_form"} $ b $ xs) = (b, xs)
       
    81   | dest_interpret_form t = raise TERM ("dest_interpret_form", [t])
       
    82 
       
    83 fun optionT t = Type (@{type_name "option"}, [t])
       
    84 fun mk_Some t = Const (@{const_name "Some"}, t --> optionT t)
       
    85 
       
    86 fun random_float_list size xs seed =
       
    87   fold (K (apsnd (random_float size) #-> (fn c => apfst (fn b => b::c)))) xs ([],seed)
       
    88 
       
    89 fun real_of_Float (@{code Float} (m, e)) =
       
    90     real_of_man_exp (@{code integer_of_int} m) (@{code integer_of_int} e)
       
    91 
       
    92 fun is_True @{term True} = true
       
    93   | is_True _ = false
       
    94 
       
    95 val postproc_form_eqs =
       
    96   @{lemma
       
    97     "real (Float 0 a) = 0"
       
    98     "real (Float (numeral m) 0) = numeral m"
       
    99     "real (Float 1 0) = 1"
       
   100     "real (Float (- 1) 0) = - 1"
       
   101     "real (Float 1 (numeral e)) = 2 ^ numeral e"
       
   102     "real (Float 1 (- numeral e)) = 1 / 2 ^ numeral e"
       
   103     "real (Float a 1) = a * 2"
       
   104     "real (Float a (-1)) = a / 2"
       
   105     "real (Float (- a) b) = - real (Float a b)"
       
   106     "real (Float (numeral m) (numeral e)) = numeral m * 2 ^ (numeral e)"
       
   107     "real (Float (numeral m) (- numeral e)) = numeral m / 2 ^ (numeral e)"
       
   108     "- (c * d::real) = -c * d"
       
   109     "- (c / d::real) = -c / d"
       
   110     "- (0::real) = 0"
       
   111     "int_of_integer (numeral k) = numeral k"
       
   112     "int_of_integer (- numeral k) = - numeral k"
       
   113     "int_of_integer 0 = 0"
       
   114     "int_of_integer 1 = 1"
       
   115     "int_of_integer (- 1) = - 1"
       
   116     by auto
       
   117   }
       
   118 
       
   119 fun rewrite_with ctxt thms = Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps thms)
       
   120 fun conv_term thy conv r = cterm_of thy r |> conv |> Thm.prop_of |> Logic.dest_equals |> snd
       
   121 
       
   122 fun approx_random ctxt prec eps frees e xs genuine_only size seed =
       
   123   let
       
   124     val (rs, seed') = random_float_list size xs seed
       
   125     fun mk_approx_form e ts =
       
   126       @{const "approx_form"} $
       
   127         HOLogic.mk_number @{typ nat} prec $
       
   128         e $
       
   129         (HOLogic.mk_list @{typ "(float * float) option"}
       
   130           (map (fn t => mk_Some @{typ "float * float"} $ HOLogic.mk_prod (t, t)) ts)) $
       
   131         @{term "[] :: nat list"}
       
   132   in
       
   133     (if mapprox_form eps e (map (real_of_Float o fst) rs)
       
   134     then
       
   135       let
       
   136         val ts = map (fn x => snd x ()) rs
       
   137         val ts' = map
       
   138           (AList.lookup op = (map dest_Free xs ~~ ts)
       
   139             #> the_default Term.dummy
       
   140             #> curry op $ @{term "real::float\<Rightarrow>_"}
       
   141             #> conv_term (Proof_Context.theory_of ctxt) (rewrite_with ctxt postproc_form_eqs))
       
   142           frees
       
   143       in
       
   144         if approximate ctxt (mk_approx_form e ts) |> is_True
       
   145         then SOME (true, ts')
       
   146         else (if genuine_only then NONE else SOME (false, ts'))
       
   147       end
       
   148     else NONE, seed')
       
   149   end
       
   150 
       
   151 val preproc_form_eqs =
       
   152   @{lemma
       
   153     "(a::real) \<in> {b .. c} \<longleftrightarrow> b \<le> a \<and> a \<le> c"
       
   154     "a = b \<longleftrightarrow> a \<le> b \<and> b \<le> a"
       
   155     "(p \<longrightarrow> q) \<longleftrightarrow> \<not>p \<or> q"
       
   156     "(p \<longleftrightarrow> q) \<longleftrightarrow> (p \<longrightarrow> q) \<and> (q \<longrightarrow> p)"
       
   157     "\<not> (a < b) \<longleftrightarrow> b \<le> a"
       
   158     "\<not> (a \<le> b) \<longleftrightarrow> b < a"
       
   159     "\<not> (p \<and> q) \<longleftrightarrow> \<not> p \<or> \<not> q"
       
   160     "\<not> (p \<or> q) \<longleftrightarrow> \<not> p \<and> \<not> q"
       
   161     "\<not> \<not> q \<longleftrightarrow> q"
       
   162     by auto
       
   163   }
       
   164 
       
   165 fun reify_goal ctxt t =
       
   166   HOLogic.mk_not t
       
   167     |> conv_term (Proof_Context.theory_of ctxt) (rewrite_with ctxt preproc_form_eqs)
       
   168     |> conv_term (Proof_Context.theory_of ctxt) (Reification.conv ctxt form_equations)
       
   169     |> dest_interpret_form
       
   170     ||> HOLogic.dest_list
       
   171 
       
   172 fun approximation_generator_raw ctxt t =
       
   173   let
       
   174     val iterations = Config.get ctxt Quickcheck.iterations
       
   175     val prec = Config.get ctxt precision
       
   176     val eps = Config.get ctxt epsilon
       
   177     val cs = Config.get ctxt custom_seed
       
   178     val seed = (Code_Numeral.natural_of_integer (cs + 1), Code_Numeral.natural_of_integer 1)
       
   179     val run = if cs < 0
       
   180       then (fn f => fn seed => (Random_Engine.run f, seed))
       
   181       else (fn f => fn seed => f seed)
       
   182     val frees = Term.add_frees t []
       
   183     val (e, xs) = reify_goal ctxt t
       
   184     fun single_tester b s =
       
   185       approx_random ctxt prec eps frees e xs b s |> run
       
   186     fun iterate _ _ 0 _ = NONE
       
   187       | iterate genuine_only size j seed =
       
   188         case single_tester genuine_only size seed of
       
   189           (NONE, seed') => iterate genuine_only size (j - 1) seed'
       
   190         | (SOME q, _) => SOME q
       
   191   in
       
   192     fn genuine_only => fn size => (iterate genuine_only size iterations seed, NONE)
       
   193   end
       
   194 
       
   195 fun approximation_generator ctxt [(t, _)] =
       
   196   (fn genuine_only =>
       
   197     fn [_, size] =>
       
   198       approximation_generator_raw ctxt t genuine_only
       
   199         (Code_Numeral.natural_of_integer size))
       
   200   | approximation_generator _ _ =
       
   201       error "Quickcheck-approximation does not support type variables (or finite instantiations)"
       
   202 
       
   203 val test_goals =
       
   204   Quickcheck_Common.generator_test_goal_terms
       
   205     ("approximation", (fn _ => fn _ => false, approximation_generator))
       
   206 
       
   207 val active = Attrib.setup_config_bool @{binding quickcheck_approximation_active} (K false)
       
   208 
       
   209 val setup = Context.theory_map (Quickcheck.add_tester ("approximation", (active, test_goals)))
       
   210 
       
   211 end