src/HOL/Tools/exhaustive_generators.ML
changeset 41920 d4fb7a418152
parent 41919 e180c2a9873b
child 41921 ee84fc7a61f1
equal deleted inserted replaced
41919:e180c2a9873b 41920:d4fb7a418152
     1 (*  Title:      HOL/Tools/exhaustive_generators.ML
       
     2     Author:     Lukas Bulwahn, TU Muenchen
       
     3 
       
     4 Exhaustive generators for various types.
       
     5 *)
       
     6 
       
     7 signature EXHAUSTIVE_GENERATORS =
       
     8 sig
       
     9   val compile_generator_expr:
       
    10     Proof.context -> term -> int -> term list option * Quickcheck.report option
       
    11   val compile_generator_exprs:
       
    12     Proof.context -> term list -> (int -> term list option) list
       
    13   val put_counterexample: (unit -> int -> term list option)
       
    14     -> Proof.context -> Proof.context
       
    15   val put_counterexample_batch: (unit -> (int -> term list option) list)
       
    16     -> Proof.context -> Proof.context
       
    17   val smart_quantifier : bool Config.T;
       
    18   val quickcheck_pretty : bool Config.T;
       
    19   val setup: theory -> theory
       
    20 end;
       
    21 
       
    22 structure Exhaustive_Generators : EXHAUSTIVE_GENERATORS =
       
    23 struct
       
    24 
       
    25 (* static options *)
       
    26 
       
    27 val define_foundationally = false
       
    28 
       
    29 (* dynamic options *)
       
    30 
       
    31 val (smart_quantifier, setup_smart_quantifier) =
       
    32   Attrib.config_bool "quickcheck_smart_quantifier" (K true)
       
    33 
       
    34 val (quickcheck_pretty, setup_quickcheck_pretty) =
       
    35   Attrib.config_bool "quickcheck_pretty" (K true)
       
    36  
       
    37 (** general term functions **)
       
    38 
       
    39 fun mk_measure f =
       
    40   let
       
    41     val Type ("fun", [T, @{typ nat}]) = fastype_of f 
       
    42   in
       
    43     Const (@{const_name Wellfounded.measure},
       
    44       (T --> @{typ nat}) --> HOLogic.mk_prodT (T, T) --> @{typ bool})
       
    45     $ f
       
    46   end
       
    47 
       
    48 fun mk_sumcases rT f (Type (@{type_name Sum_Type.sum}, [TL, TR])) =
       
    49   let
       
    50     val lt = mk_sumcases rT f TL
       
    51     val rt = mk_sumcases rT f TR
       
    52   in
       
    53     SumTree.mk_sumcase TL TR rT lt rt
       
    54   end
       
    55   | mk_sumcases _ f T = f T
       
    56 
       
    57 fun mk_undefined T = Const(@{const_name undefined}, T)
       
    58   
       
    59 
       
    60 (** abstract syntax **)
       
    61 
       
    62 fun termifyT T = HOLogic.mk_prodT (T, @{typ "unit => Code_Evaluation.term"});
       
    63 
       
    64 val size = @{term "i :: code_numeral"}
       
    65 val size_pred = @{term "(i :: code_numeral) - 1"}
       
    66 val size_ge_zero = @{term "(i :: code_numeral) > 0"}
       
    67 fun test_function T = Free ("f", termifyT T --> @{typ "term list option"})
       
    68 
       
    69 fun mk_none_continuation (x, y) =
       
    70   let
       
    71     val (T as Type(@{type_name "option"}, [T'])) = fastype_of x
       
    72   in
       
    73     Const (@{const_name "Quickcheck_Exhaustive.orelse"}, T --> T --> T)
       
    74       $ x $ y
       
    75   end
       
    76 
       
    77 (** datatypes **)
       
    78 
       
    79 (* constructing exhaustive generator instances on datatypes *)
       
    80 
       
    81 exception FUNCTION_TYPE;
       
    82 val exhaustiveN = "exhaustive";
       
    83 
       
    84 fun exhaustiveT T = (termifyT T --> @{typ "Code_Evaluation.term list option"})
       
    85   --> @{typ code_numeral} --> @{typ "Code_Evaluation.term list option"}
       
    86 
       
    87 fun check_allT T = (termifyT T --> @{typ "Code_Evaluation.term list option"})
       
    88   --> @{typ "Code_Evaluation.term list option"}
       
    89 
       
    90 fun mk_equations thy descr vs tycos exhaustives (Ts, Us) =
       
    91   let
       
    92     fun mk_call T =
       
    93       let
       
    94         val exhaustive = Const (@{const_name "Quickcheck_Exhaustive.exhaustive_class.exhaustive"}, exhaustiveT T)        
       
    95       in
       
    96         (T, (fn t => exhaustive $
       
    97           (HOLogic.split_const (T, @{typ "unit => Code_Evaluation.term"}, @{typ "Code_Evaluation.term list option"})
       
    98           $ absdummy (T, absdummy (@{typ "unit => Code_Evaluation.term"}, t))) $ size_pred))
       
    99       end
       
   100     fun mk_aux_call fTs (k, _) (tyco, Ts) =
       
   101       let
       
   102         val T = Type (tyco, Ts)
       
   103         val _ = if not (null fTs) then raise FUNCTION_TYPE else ()
       
   104       in
       
   105        (T, (fn t => nth exhaustives k $
       
   106           (HOLogic.split_const (T, @{typ "unit => Code_Evaluation.term"}, @{typ "Code_Evaluation.term list option"})
       
   107             $ absdummy (T, absdummy (@{typ "unit => Code_Evaluation.term"}, t))) $ size_pred))
       
   108       end
       
   109     fun mk_consexpr simpleT (c, xs) =
       
   110       let
       
   111         val (Ts, fns) = split_list xs
       
   112         val constr = Const (c, Ts ---> simpleT)
       
   113         val bounds = map (fn x => Bound (2 * x + 1)) (((length xs) - 1) downto 0)
       
   114         val term_bounds = map (fn x => Bound (2 * x)) (((length xs) - 1) downto 0)
       
   115         val Eval_App = Const ("Code_Evaluation.App", HOLogic.termT --> HOLogic.termT --> HOLogic.termT)
       
   116         val Eval_Const = Const ("Code_Evaluation.Const", HOLogic.literalT --> @{typ typerep} --> HOLogic.termT)
       
   117         val term = fold (fn u => fn t => Eval_App $ t $ (u $ @{term "()"}))
       
   118           bounds (Eval_Const $ HOLogic.mk_literal c $ HOLogic.mk_typerep (Ts ---> simpleT))
       
   119         val start_term = test_function simpleT $ 
       
   120         (HOLogic.pair_const simpleT @{typ "unit => Code_Evaluation.term"}
       
   121           $ (list_comb (constr, bounds)) $ absdummy (@{typ unit}, term))
       
   122       in fold_rev (fn f => fn t => f t) fns start_term end
       
   123     fun mk_rhs exprs =
       
   124         @{term "If :: bool => term list option => term list option => term list option"}
       
   125             $ size_ge_zero $ (foldr1 mk_none_continuation exprs) $ @{term "None :: term list option"}
       
   126     val rhss =
       
   127       Datatype_Aux.interpret_construction descr vs
       
   128         { atyp = mk_call, dtyp = mk_aux_call }
       
   129       |> (map o apfst) Type
       
   130       |> map (fn (T, cs) => map (mk_consexpr T) cs)
       
   131       |> map mk_rhs
       
   132     val lhss = map2 (fn t => fn T => t $ test_function T $ size) exhaustives (Ts @ Us);
       
   133     val eqs = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhss ~~ rhss)
       
   134   in
       
   135     eqs
       
   136   end
       
   137 
       
   138 (* foundational definition with the function package *)
       
   139 
       
   140 val less_int_pred = @{lemma "i > 0 ==> Code_Numeral.nat_of ((i :: code_numeral) - 1) < Code_Numeral.nat_of i" by auto}
       
   141 
       
   142 fun mk_single_measure T = HOLogic.mk_comp (@{term "Code_Numeral.nat_of"},
       
   143     Const (@{const_name "Product_Type.snd"}, T --> @{typ "code_numeral"}))
       
   144 
       
   145 fun mk_termination_measure T =
       
   146   let
       
   147     val T' = fst (HOLogic.dest_prodT (HOLogic.dest_setT T))
       
   148   in
       
   149     mk_measure (mk_sumcases @{typ nat} mk_single_measure T')
       
   150   end
       
   151 
       
   152 fun termination_tac ctxt = 
       
   153   Function_Relation.relation_tac ctxt mk_termination_measure 1
       
   154   THEN rtac @{thm wf_measure} 1
       
   155   THEN (REPEAT_DETERM (Simplifier.asm_full_simp_tac 
       
   156     (HOL_basic_ss addsimps [@{thm in_measure}, @{thm o_def}, @{thm snd_conv},
       
   157      @{thm nat_mono_iff}, less_int_pred] @ @{thms sum.cases}) 1))
       
   158 
       
   159 fun pat_completeness_auto ctxt =
       
   160   Pat_Completeness.pat_completeness_tac ctxt 1
       
   161   THEN auto_tac (clasimpset_of ctxt)    
       
   162 
       
   163 
       
   164 (* creating the instances *)
       
   165 
       
   166 fun instantiate_exhaustive_datatype config descr vs tycos prfx (names, auxnames) (Ts, Us) thy =
       
   167   let
       
   168     val _ = Datatype_Aux.message config "Creating exhaustive generators ...";
       
   169     val exhaustivesN = map (prefix (exhaustiveN ^ "_")) (names @ auxnames);
       
   170   in
       
   171     thy
       
   172     |> Class.instantiation (tycos, vs, @{sort exhaustive})
       
   173     |> (if define_foundationally then
       
   174       let
       
   175         val exhaustives = map2 (fn name => fn T => Free (name, exhaustiveT T)) exhaustivesN (Ts @ Us)
       
   176         val eqs = mk_equations thy descr vs tycos exhaustives (Ts, Us)
       
   177       in
       
   178         Function.add_function
       
   179           (map (fn (name, T) =>
       
   180               Syntax.no_syn (Binding.conceal (Binding.name name), SOME (exhaustiveT T)))
       
   181                 (exhaustivesN ~~ (Ts @ Us)))
       
   182             (map (pair (apfst Binding.conceal Attrib.empty_binding)) eqs)
       
   183           Function_Common.default_config pat_completeness_auto
       
   184         #> snd
       
   185         #> Local_Theory.restore
       
   186         #> (fn lthy => Function.prove_termination NONE (termination_tac lthy) lthy)
       
   187         #> snd
       
   188       end
       
   189     else
       
   190       fold_map (fn (name, T) => Local_Theory.define
       
   191           ((Binding.conceal (Binding.name name), NoSyn),
       
   192             (apfst Binding.conceal Attrib.empty_binding, mk_undefined (exhaustiveT T)))
       
   193         #> apfst fst) (exhaustivesN ~~ (Ts @ Us))
       
   194       #> (fn (exhaustives, lthy) =>
       
   195         let
       
   196           val eqs_t = mk_equations thy descr vs tycos exhaustives (Ts, Us)
       
   197           val eqs = map (fn eq => Goal.prove lthy ["f", "i"] [] eq
       
   198             (fn _ => Skip_Proof.cheat_tac (ProofContext.theory_of lthy))) eqs_t
       
   199         in
       
   200           fold (fn (name, eq) => Local_Theory.note
       
   201           ((Binding.conceal (Binding.qualify true prfx
       
   202              (Binding.qualify true name (Binding.name "simps"))),
       
   203              Code.add_default_eqn_attrib :: map (Attrib.internal o K)
       
   204                [Simplifier.simp_add, Nitpick_Simps.add]), [eq]) #> snd) (exhaustivesN ~~ eqs) lthy
       
   205         end))
       
   206     |> Class.prove_instantiation_exit (K (Class.intro_classes_tac []))
       
   207   end handle FUNCTION_TYPE =>
       
   208     (Datatype_Aux.message config
       
   209       "Creation of exhaustivevalue generators failed because the datatype contains a function type";
       
   210     thy)
       
   211 
       
   212 (** building and compiling generator expressions **)
       
   213 
       
   214 structure Counterexample = Proof_Data
       
   215 (
       
   216   type T = unit -> int -> term list option
       
   217   (* FIXME avoid user error with non-user text *)
       
   218   fun init _ () = error "Counterexample"
       
   219 );
       
   220 val put_counterexample = Counterexample.put;
       
   221 
       
   222 structure Counterexample_Batch = Proof_Data
       
   223 (
       
   224   type T = unit -> (int -> term list option) list
       
   225   (* FIXME avoid user error with non-user text *)
       
   226   fun init _ () = error "Counterexample"
       
   227 );
       
   228 val put_counterexample_batch = Counterexample_Batch.put;
       
   229 
       
   230 val target = "Quickcheck";
       
   231 
       
   232 fun mk_smart_generator_expr ctxt t =
       
   233   let
       
   234     val thy = ProofContext.theory_of ctxt
       
   235     val ((vnames, Ts), t') = apfst split_list (strip_abs t)
       
   236     val ([depth_name], ctxt') = Variable.variant_fixes ["depth"] ctxt
       
   237     val (names, ctxt'') = Variable.variant_fixes vnames ctxt'
       
   238     val (term_names, ctxt''') = Variable.variant_fixes (map (prefix "t_") vnames) ctxt''
       
   239     val depth = Free (depth_name, @{typ code_numeral})
       
   240     val frees = map2 (curry Free) names Ts
       
   241     val term_vars = map (fn n => Free (n, @{typ "unit => term"})) term_names 
       
   242     fun strip_imp (Const(@{const_name HOL.implies},_) $ A $ B) = apfst (cons A) (strip_imp B)
       
   243       | strip_imp A = ([], A)
       
   244     val (assms, concl) = strip_imp (subst_bounds (rev frees, t'))
       
   245     val terms = HOLogic.mk_list @{typ term} (map (fn v => v $ @{term "()"}) term_vars)
       
   246     fun mk_exhaustive_closure (free as Free (_, T), term_var) t =
       
   247       if Sign.of_sort thy (T, @{sort enum}) then
       
   248         Const (@{const_name "Quickcheck_Exhaustive.check_all_class.check_all"}, check_allT T)
       
   249           $ (HOLogic.split_const (T, @{typ "unit => term"}, @{typ "term list option"}) 
       
   250             $ lambda free (lambda term_var t))
       
   251       else
       
   252         Const (@{const_name "Quickcheck_Exhaustive.exhaustive_class.exhaustive"}, exhaustiveT T)
       
   253           $ (HOLogic.split_const (T, @{typ "unit => term"}, @{typ "term list option"}) 
       
   254             $ lambda free (lambda term_var t)) $ depth
       
   255     fun lookup v = the (AList.lookup (op =) (names ~~ (frees ~~ term_vars)) v)
       
   256     val none_t = @{term "None :: term list option"}
       
   257     fun mk_safe_if (cond, then_t, else_t) =
       
   258       @{term "Quickcheck_Exhaustive.catch_match :: term list option => term list option => term list option"} $
       
   259         (@{term "If :: bool => term list option => term list option => term list option"}
       
   260         $ cond $ then_t $ else_t) $ none_t;
       
   261     fun mk_test_term bound_vars assms =
       
   262       let
       
   263         fun vars_of t = subtract (op =) bound_vars (Term.add_free_names t [])
       
   264         val (vars, check) =
       
   265           case assms of [] =>
       
   266             (vars_of concl, (concl, none_t, @{term "Some :: term list => term list option"} $ terms))
       
   267           | assm :: assms =>
       
   268             (vars_of assm, (assm, mk_test_term (union (op =) (vars_of assm) bound_vars) assms, none_t))
       
   269       in
       
   270         fold_rev mk_exhaustive_closure (map lookup vars) (mk_safe_if check)
       
   271       end
       
   272   in lambda depth (mk_test_term [] assms) end
       
   273 
       
   274 fun mk_generator_expr ctxt t =
       
   275   let
       
   276     val Ts = (map snd o fst o strip_abs) t;
       
   277     val thy = ProofContext.theory_of ctxt
       
   278     val bound_max = length Ts - 1;
       
   279     val bounds = map_index (fn (i, ty) =>
       
   280       (2 * (bound_max - i) + 1, 2 * (bound_max - i), 2 * i, ty)) Ts;
       
   281     val result = list_comb (t, map (fn (i, _, _, _) => Bound i) bounds);
       
   282     val terms = HOLogic.mk_list @{typ term} (map (fn (_, i, _, _) => Bound i $ @{term "()"}) bounds);
       
   283     val check =
       
   284       @{term "Quickcheck_Exhaustive.catch_match :: term list option => term list option => term list option"} $
       
   285         (@{term "If :: bool => term list option => term list option => term list option"}
       
   286         $ result $ @{term "None :: term list option"} $ (@{term "Some :: term list => term list option"} $ terms))
       
   287       $ @{term "None :: term list option"};
       
   288     fun mk_exhaustive_closure (_, _, i, T) t =
       
   289       Const (@{const_name "Quickcheck_Exhaustive.exhaustive_class.exhaustive"}, exhaustiveT T)
       
   290         $ (HOLogic.split_const (T, @{typ "unit => term"}, @{typ "term list option"}) 
       
   291         $ absdummy (T, absdummy (@{typ "unit => term"}, t))) $ Bound i
       
   292   in Abs ("d", @{typ code_numeral}, fold_rev mk_exhaustive_closure bounds check) end
       
   293 
       
   294 (** post-processing of function terms **)
       
   295 
       
   296 fun dest_fun_upd (Const (@{const_name fun_upd}, _) $ t0 $ t1 $ t2) = (t0, (t1, t2))
       
   297   | dest_fun_upd t = raise TERM ("dest_fun_upd", [t])
       
   298 
       
   299 fun mk_fun_upd T1 T2 (t1, t2) t = 
       
   300   Const (@{const_name fun_upd}, (T1 --> T2) --> T1 --> T2 --> T1 --> T2) $ t $ t1 $ t2
       
   301 
       
   302 fun dest_fun_upds t =
       
   303   case try dest_fun_upd t of
       
   304     NONE =>
       
   305       (case t of
       
   306         Abs (_, _, _) => ([], t) 
       
   307       | _ => raise TERM ("dest_fun_upds", [t]))
       
   308   | SOME (t0, (t1, t2)) => apfst (cons (t1, t2)) (dest_fun_upds t0)
       
   309 
       
   310 fun make_fun_upds T1 T2 (tps, t) = fold_rev (mk_fun_upd T1 T2) tps t
       
   311 
       
   312 fun make_set T1 [] = Const (@{const_abbrev Set.empty}, T1 --> @{typ bool})
       
   313   | make_set T1 ((_, @{const False}) :: tps) = make_set T1 tps
       
   314   | make_set T1 ((t1, @{const True}) :: tps) =
       
   315     Const (@{const_name insert}, T1 --> (T1 --> @{typ bool}) --> T1 --> @{typ bool})
       
   316       $ t1 $ (make_set T1 tps)
       
   317   | make_set T1 ((_, t) :: tps) = raise TERM ("make_set", [t])
       
   318 
       
   319 fun make_coset T [] = Const (@{const_abbrev UNIV}, T --> @{typ bool})
       
   320   | make_coset T tps = 
       
   321     let
       
   322       val U = T --> @{typ bool}
       
   323       fun invert @{const False} = @{const True}
       
   324         | invert @{const True} = @{const False}
       
   325     in
       
   326       Const (@{const_name "Groups.minus_class.minus"}, U --> U --> U)
       
   327         $ Const (@{const_abbrev UNIV}, U) $ make_set T (map (apsnd invert) tps)
       
   328     end
       
   329 
       
   330 fun make_map T1 T2 [] = Const (@{const_abbrev Map.empty}, T1 --> T2)
       
   331   | make_map T1 T2 ((_, Const (@{const_name None}, _)) :: tps) = make_map T1 T2 tps
       
   332   | make_map T1 T2 ((t1, t2) :: tps) = mk_fun_upd T1 T2 (t1, t2) (make_map T1 T2 tps)
       
   333   
       
   334 fun post_process_term t =
       
   335   let
       
   336     fun map_Abs f t =
       
   337       case t of Abs (x, T, t') => Abs (x, T, f t') | _ => raise TERM ("map_Abs", [t]) 
       
   338     fun process_args t = case strip_comb t of
       
   339       (c as Const (_, _), ts) => list_comb (c, map post_process_term ts) 
       
   340   in
       
   341     case fastype_of t of
       
   342       Type (@{type_name fun}, [T1, T2]) =>
       
   343         (case try dest_fun_upds t of
       
   344           SOME (tps, t) =>
       
   345             (map (pairself post_process_term) tps, map_Abs post_process_term t)
       
   346             |> (case T2 of
       
   347               @{typ bool} => 
       
   348                 (case t of
       
   349                    Abs(_, _, @{const True}) => fst #> rev #> make_set T1
       
   350                  | Abs(_, _, @{const False}) => fst #> rev #> make_coset T1
       
   351                  | Abs(_, _, Const (@{const_name undefined}, _)) => fst #> rev #> make_set T1
       
   352                  | _ => raise TERM ("post_process_term", [t]))
       
   353             | Type (@{type_name option}, _) =>
       
   354                 (case t of
       
   355                   Abs(_, _, Const(@{const_name None}, _)) => fst #> make_map T1 T2
       
   356                 | Abs(_, _, Const (@{const_name undefined}, _)) => fst #> make_map T1 T2
       
   357                 | _ => make_fun_upds T1 T2) 
       
   358             | _ => make_fun_upds T1 T2)
       
   359         | NONE => process_args t)
       
   360     | _ => process_args t
       
   361   end
       
   362 
       
   363 (** generator compiliation **)
       
   364 
       
   365 fun compile_generator_expr ctxt t =
       
   366   let
       
   367     val thy = ProofContext.theory_of ctxt
       
   368     val t' =
       
   369       (if Config.get ctxt smart_quantifier then mk_smart_generator_expr else mk_generator_expr)
       
   370         ctxt t;
       
   371     val compile = Code_Runtime.dynamic_value_strict
       
   372       (Counterexample.get, put_counterexample, "Exhaustive_Generators.put_counterexample")
       
   373       thy (SOME target) (fn proc => fn g => g #> (Option.map o map) proc) t' [];
       
   374   in
       
   375     fn size => rpair NONE (compile size |> 
       
   376       (if Config.get ctxt quickcheck_pretty then Option.map (map post_process_term) else I))
       
   377   end;
       
   378 
       
   379 fun compile_generator_exprs ctxt ts =
       
   380   let
       
   381     val thy = ProofContext.theory_of ctxt
       
   382     val mk_generator_expr =
       
   383       if Config.get ctxt smart_quantifier then mk_smart_generator_expr else mk_generator_expr
       
   384     val ts' = map (mk_generator_expr ctxt) ts;
       
   385     val compiles = Code_Runtime.dynamic_value_strict
       
   386       (Counterexample_Batch.get, put_counterexample_batch,
       
   387         "Exhaustive_Generators.put_counterexample_batch")
       
   388       thy (SOME target) (fn proc => map (fn g => g #> (Option.map o map) proc))
       
   389       (HOLogic.mk_list @{typ "code_numeral => term list option"} ts') [];
       
   390   in
       
   391     map (fn compile => fn size => compile size |> Option.map (map post_process_term)) compiles
       
   392   end;
       
   393   
       
   394   
       
   395 (** setup **)
       
   396 
       
   397 val setup =
       
   398   Datatype.interpretation
       
   399     (Quickcheck_Generators.ensure_sort_datatype (@{sort exhaustive}, instantiate_exhaustive_datatype))
       
   400   #> setup_smart_quantifier
       
   401   #> setup_quickcheck_pretty
       
   402   #> Context.theory_map (Quickcheck.add_generator ("exhaustive", compile_generator_expr))
       
   403   #> Context.theory_map (Quickcheck.add_batch_generator ("exhaustive", compile_generator_exprs));
       
   404 
       
   405 end;