src/HOL/Tools/smallvalue_generators.ML
author haftmann
Wed Dec 08 13:34:50 2010 +0100 (2010-12-08)
changeset 41075 4bed56dc95fb
parent 40913 99a4ef20704d
child 41085 a549ff1d4070
permissions -rw-r--r--
primitive definitions of bot/top/inf/sup for bool and fun are named with canonical suffix `_def` rather than `_eq`
     1 (*  Title:      HOL/Tools/smallvalue_generators.ML
     2     Author:     Lukas Bulwahn, TU Muenchen
     3 
     4 Generators for small values for various types.
     5 *)
     6 
     7 signature SMALLVALUE_GENERATORS =
     8 sig
     9   val compile_generator_expr:
    10     Proof.context -> term -> int -> term list option * Quickcheck.report option
    11   val put_counterexample: (unit -> int -> term list option)
    12     -> Proof.context -> Proof.context
    13   val smart_quantifier : bool Config.T;
    14   val setup: theory -> theory
    15 end;
    16 
    17 structure Smallvalue_Generators : SMALLVALUE_GENERATORS =
    18 struct
    19 
    20 (* static options *)
    21 
    22 val define_foundationally = false
    23 
    24 (* dynamic options *)
    25 
    26 val (smart_quantifier, setup_smart_quantifier) =
    27   Attrib.config_bool "quickcheck_smart_quantifier" (K true)
    28 
    29 (** general term functions **)
    30 
    31 fun mk_measure f =
    32   let
    33     val Type ("fun", [T, @{typ nat}]) = fastype_of f 
    34   in
    35     Const (@{const_name Wellfounded.measure},
    36       (T --> @{typ nat}) --> HOLogic.mk_prodT (T, T) --> @{typ bool})
    37     $ f
    38   end
    39 
    40 fun mk_sumcases rT f (Type (@{type_name Sum_Type.sum}, [TL, TR])) =
    41   let
    42     val lt = mk_sumcases rT f TL
    43     val rt = mk_sumcases rT f TR
    44   in
    45     SumTree.mk_sumcase TL TR rT lt rt
    46   end
    47   | mk_sumcases _ f T = f T
    48 
    49 fun mk_undefined T = Const(@{const_name undefined}, T)
    50   
    51 
    52 (** abstract syntax **)
    53 
    54 fun termifyT T = HOLogic.mk_prodT (T, @{typ "unit => Code_Evaluation.term"});
    55 
    56 val size = @{term "i :: code_numeral"}
    57 val size_pred = @{term "(i :: code_numeral) - 1"}
    58 val size_ge_zero = @{term "(i :: code_numeral) > 0"}
    59 fun test_function T = Free ("f", termifyT T --> @{typ "term list option"})
    60 
    61 fun mk_none_continuation (x, y) =
    62   let
    63     val (T as Type(@{type_name "option"}, [T'])) = fastype_of x
    64   in
    65     Const (@{const_name "Smallcheck.orelse"}, T --> T --> T)
    66       $ x $ y
    67   end
    68 
    69 (** datatypes **)
    70 
    71 (* constructing smallvalue generator instances on datatypes *)
    72 
    73 exception FUNCTION_TYPE;
    74 
    75 val smallN = "small";
    76 
    77 fun smallT T = (T --> @{typ "Code_Evaluation.term list option"}) --> @{typ code_numeral}
    78   --> @{typ "Code_Evaluation.term list option"}
    79 
    80 val full_smallN = "full_small";
    81 
    82 fun full_smallT T = (termifyT T --> @{typ "Code_Evaluation.term list option"})
    83   --> @{typ code_numeral} --> @{typ "Code_Evaluation.term list option"}
    84  
    85 fun mk_equations thy descr vs tycos smalls (Ts, Us) =
    86   let
    87     fun mk_small_call T =
    88       let
    89         val small = Const (@{const_name "Smallcheck.full_small_class.full_small"}, full_smallT T)        
    90       in
    91         (T, (fn t => small $
    92           (HOLogic.split_const (T, @{typ "unit => Code_Evaluation.term"}, @{typ "Code_Evaluation.term list option"})
    93           $ absdummy (T, absdummy (@{typ "unit => Code_Evaluation.term"}, t))) $ size_pred))
    94       end
    95     fun mk_small_aux_call fTs (k, _) (tyco, Ts) =
    96       let
    97         val T = Type (tyco, Ts)
    98         val _ = if not (null fTs) then raise FUNCTION_TYPE else ()
    99         val small = nth smalls k
   100       in
   101        (T, (fn t => small $
   102           (HOLogic.split_const (T, @{typ "unit => Code_Evaluation.term"}, @{typ "Code_Evaluation.term list option"})
   103             $ absdummy (T, absdummy (@{typ "unit => Code_Evaluation.term"}, t))) $ size_pred))  
   104       end
   105     fun mk_consexpr simpleT (c, xs) =
   106       let
   107         val (Ts, fns) = split_list xs
   108         val constr = Const (c, Ts ---> simpleT)
   109         val bounds = map (fn x => Bound (2 * x + 1)) (((length xs) - 1) downto 0)
   110         val term_bounds = map (fn x => Bound (2 * x)) (((length xs) - 1) downto 0)
   111         val Eval_App = Const ("Code_Evaluation.App", HOLogic.termT --> HOLogic.termT --> HOLogic.termT)
   112         val Eval_Const = Const ("Code_Evaluation.Const", HOLogic.literalT --> @{typ typerep} --> HOLogic.termT)
   113         val term = fold (fn u => fn t => Eval_App $ t $ (u $ @{term "()"}))
   114           bounds (Eval_Const $ HOLogic.mk_literal c $ HOLogic.mk_typerep (Ts ---> simpleT))
   115         val start_term = test_function simpleT $ 
   116         (HOLogic.pair_const simpleT @{typ "unit => Code_Evaluation.term"}
   117           $ (list_comb (constr, bounds)) $ absdummy (@{typ unit}, term))
   118       in fold_rev (fn f => fn t => f t) fns start_term end
   119     fun mk_rhs exprs =
   120         @{term "If :: bool => term list option => term list option => term list option"}
   121             $ size_ge_zero $ (foldr1 mk_none_continuation exprs) $ @{term "None :: term list option"}
   122     val rhss =
   123       Datatype_Aux.interpret_construction descr vs
   124         { atyp = mk_small_call, dtyp = mk_small_aux_call }
   125       |> (map o apfst) Type
   126       |> map (fn (T, cs) => map (mk_consexpr T) cs)
   127       |> map mk_rhs
   128     val lhss = map2 (fn t => fn T => t $ test_function T $ size) smalls (Ts @ Us);
   129     val eqs = map (HOLogic.mk_Trueprop o HOLogic.mk_eq) (lhss ~~ rhss)
   130   in
   131     eqs
   132   end
   133 
   134 (* foundational definition with the function package *)
   135 
   136 val less_int_pred = @{lemma "i > 0 ==> Code_Numeral.nat_of ((i :: code_numeral) - 1) < Code_Numeral.nat_of i" by auto}
   137 
   138 fun mk_single_measure T = HOLogic.mk_comp (@{term "Code_Numeral.nat_of"},
   139     Const (@{const_name "Product_Type.snd"}, T --> @{typ "code_numeral"}))
   140 
   141 fun mk_termination_measure T =
   142   let
   143     val T' = fst (HOLogic.dest_prodT (HOLogic.dest_setT T))
   144   in
   145     mk_measure (mk_sumcases @{typ nat} mk_single_measure T')
   146   end
   147 
   148 fun termination_tac ctxt = 
   149   Function_Relation.relation_tac ctxt mk_termination_measure 1
   150   THEN rtac @{thm wf_measure} 1
   151   THEN (REPEAT_DETERM (Simplifier.asm_full_simp_tac 
   152     (HOL_basic_ss addsimps [@{thm in_measure}, @{thm o_def}, @{thm snd_conv},
   153      @{thm nat_mono_iff}, less_int_pred] @ @{thms sum.cases}) 1))
   154 
   155 fun pat_completeness_auto ctxt =
   156   Pat_Completeness.pat_completeness_tac ctxt 1
   157   THEN auto_tac (clasimpset_of ctxt)    
   158 
   159 
   160 (* creating the instances *)
   161 
   162 fun instantiate_smallvalue_datatype config descr vs tycos prfx (names, auxnames) (Ts, Us) thy =
   163   let
   164     val _ = Datatype_Aux.message config "Creating smallvalue generators ...";
   165     val smallsN = map (prefix (full_smallN ^ "_")) (names @ auxnames);
   166   in
   167     thy
   168     |> Class.instantiation (tycos, vs, @{sort full_small})
   169     |> (if define_foundationally then
   170       let
   171         val smalls = map2 (fn name => fn T => Free (name, full_smallT T)) smallsN (Ts @ Us)
   172         val eqs = mk_equations thy descr vs tycos smalls (Ts, Us)
   173       in
   174         Function.add_function
   175           (map (fn (name, T) =>
   176               Syntax.no_syn (Binding.conceal (Binding.name name), SOME (full_smallT T)))
   177                 (smallsN ~~ (Ts @ Us)))
   178             (map (pair (apfst Binding.conceal Attrib.empty_binding)) eqs)
   179           Function_Common.default_config pat_completeness_auto
   180         #> snd
   181         #> Local_Theory.restore
   182         #> (fn lthy => Function.prove_termination NONE (termination_tac lthy) lthy)
   183         #> snd
   184       end
   185     else
   186       fold_map (fn (name, T) => Local_Theory.define
   187           ((Binding.conceal (Binding.name name), NoSyn),
   188             (apfst Binding.conceal Attrib.empty_binding, mk_undefined (full_smallT T)))
   189         #> apfst fst) (smallsN ~~ (Ts @ Us))
   190       #> (fn (smalls, lthy) =>
   191         let
   192           val eqs_t = mk_equations thy descr vs tycos smalls (Ts, Us)
   193           val eqs = map (fn eq => Goal.prove lthy ["f", "i"] [] eq
   194             (fn _ => Skip_Proof.cheat_tac (ProofContext.theory_of lthy))) eqs_t
   195         in
   196           fold (fn (name, eq) => Local_Theory.note
   197           ((Binding.conceal (Binding.qualify true prfx
   198              (Binding.qualify true name (Binding.name "simps"))),
   199              Code.add_default_eqn_attrib :: map (Attrib.internal o K)
   200                [Simplifier.simp_add, Nitpick_Simps.add]), [eq]) #> snd) (smallsN ~~ eqs) lthy
   201         end))
   202     |> Class.prove_instantiation_exit (K (Class.intro_classes_tac []))
   203   end handle FUNCTION_TYPE =>
   204     (Datatype_Aux.message config
   205       "Creation of smallvalue generators failed because the datatype contains a function type";
   206     thy)
   207 
   208 (** building and compiling generator expressions **)
   209 
   210 structure Counterexample = Proof_Data (
   211   type T = unit -> int -> term list option
   212   fun init _ () = error "Counterexample"
   213 );
   214 val put_counterexample = Counterexample.put;
   215 
   216 val target = "Quickcheck";
   217 
   218 fun mk_smart_generator_expr ctxt t =
   219   let
   220     val ((vnames, Ts), t') = apfst split_list (strip_abs t)
   221     val ([depth_name], ctxt') = Variable.variant_fixes ["depth"] ctxt
   222     val (names, ctxt'') = Variable.variant_fixes vnames ctxt'
   223     val (term_names, ctxt''') = Variable.variant_fixes (map (prefix "t_") vnames) ctxt''
   224     val depth = Free (depth_name, @{typ code_numeral})
   225     val frees = map2 (curry Free) names Ts
   226     val term_vars = map (fn n => Free (n, @{typ "unit => term"})) term_names 
   227     fun strip_imp (Const(@{const_name HOL.implies},_) $ A $ B) = apfst (cons A) (strip_imp B)
   228       | strip_imp A = ([], A)
   229     val (assms, concl) = strip_imp (subst_bounds (rev frees, t'))
   230     val terms = HOLogic.mk_list @{typ term} (map (fn v => v $ @{term "()"}) term_vars)
   231     fun mk_small_closure (free as Free (_, T), term_var) t =
   232       Const (@{const_name "Smallcheck.full_small_class.full_small"}, full_smallT T)
   233         $ (HOLogic.split_const (T, @{typ "unit => term"}, @{typ "term list option"}) 
   234           $ lambda free (lambda term_var t)) $ depth
   235     fun lookup v = the (AList.lookup (op =) (names ~~ (frees ~~ term_vars)) v)
   236     val none_t = @{term "None :: term list option"}
   237     fun mk_safe_if (cond, then_t, else_t) =
   238       @{term "Smallcheck.catch_match :: term list option => term list option => term list option"} $
   239         (@{term "If :: bool => term list option => term list option => term list option"}
   240         $ cond $ then_t $ else_t) $ none_t;
   241     fun mk_test_term bound_vars assms =
   242       let
   243         fun vars_of t = subtract (op =) bound_vars (Term.add_free_names t [])
   244         val (vars, check) =
   245           case assms of [] =>
   246             (vars_of concl, (concl, none_t, @{term "Some :: term list => term list option"} $ terms))
   247           | assm :: assms =>
   248             (vars_of assm, (assm, mk_test_term (union (op =) (vars_of assm) bound_vars) assms, none_t))
   249       in
   250         fold_rev mk_small_closure (map lookup vars) (mk_safe_if check)
   251       end
   252   in lambda depth (mk_test_term [] assms) end
   253 
   254 fun mk_generator_expr ctxt t =
   255   let
   256     val Ts = (map snd o fst o strip_abs) t;
   257     val thy = ProofContext.theory_of ctxt
   258     val bound_max = length Ts - 1;
   259     val bounds = map_index (fn (i, ty) =>
   260       (2 * (bound_max - i) + 1, 2 * (bound_max - i), 2 * i, ty)) Ts;
   261     val result = list_comb (t, map (fn (i, _, _, _) => Bound i) bounds);
   262     val terms = HOLogic.mk_list @{typ term} (map (fn (_, i, _, _) => Bound i $ @{term "()"}) bounds);
   263     val check =
   264       @{term "Smallcheck.catch_match :: term list option => term list option => term list option"} $
   265         (@{term "If :: bool => term list option => term list option => term list option"}
   266         $ result $ @{term "None :: term list option"} $ (@{term "Some :: term list => term list option"} $ terms))
   267       $ @{term "None :: term list option"};
   268     fun mk_small_closure (_, _, i, T) t =
   269       Const (@{const_name "Smallcheck.full_small_class.full_small"}, full_smallT T)
   270         $ (HOLogic.split_const (T, @{typ "unit => term"}, @{typ "term list option"}) 
   271         $ absdummy (T, absdummy (@{typ "unit => term"}, t))) $ Bound i
   272   in Abs ("d", @{typ code_numeral}, fold_rev mk_small_closure bounds check) end
   273 
   274 fun compile_generator_expr ctxt t =
   275   let
   276     val thy = ProofContext.theory_of ctxt
   277     val t' =
   278       (if Config.get ctxt smart_quantifier then mk_smart_generator_expr else mk_generator_expr)
   279         ctxt t;
   280     val compile = Code_Runtime.dynamic_value_strict
   281       (Counterexample.get, put_counterexample, "Smallvalue_Generators.put_counterexample")
   282       thy (SOME target) (fn proc => fn g => g #> (Option.map o map) proc) t' [];
   283   in fn size => rpair NONE (compile size) end;
   284 
   285 (** setup **)
   286 
   287 val setup =
   288   Datatype.interpretation
   289     (Quickcheck_Generators.ensure_sort_datatype (@{sort full_small}, instantiate_smallvalue_datatype))
   290   #> setup_smart_quantifier
   291   #> Context.theory_map
   292     (Quickcheck.add_generator ("exhaustive", compile_generator_expr));
   293 
   294 end;