src/HOL/Tools/Quickcheck/quickcheck_common.ML
author bulwahn
Mon Dec 05 12:36:20 2011 +0100 (2011-12-05)
changeset 45763 3bb2bdf654f7
parent 45761 90028fd2f1fa
child 45765 cb6ddee6a463
permissions -rw-r--r--
random reporting compilation returns if counterexample is genuine or potentially spurious, and takes genuine_only option as argument
bulwahn@41927
     1
(*  Title:      HOL/Tools/Quickcheck/quickcheck_common.ML
bulwahn@41927
     2
    Author:     Florian Haftmann, Lukas Bulwahn, TU Muenchen
bulwahn@41927
     3
wenzelm@41938
     4
Common functions for quickcheck's generators.
bulwahn@41927
     5
*)
bulwahn@41927
     6
bulwahn@41927
     7
signature QUICKCHECK_COMMON =
bulwahn@41927
     8
sig
bulwahn@42195
     9
  val strip_imp : term -> (term list * term)
bulwahn@45721
    10
  val reflect_bool : bool -> term
bulwahn@42214
    11
  val define_functions : ((term list -> term list) * (Proof.context -> tactic) option)
bulwahn@42214
    12
    -> string -> string list -> string list -> typ list -> Proof.context -> Proof.context 
bulwahn@41927
    13
  val perhaps_constrain: theory -> (typ * sort) list -> (string * sort) list
bulwahn@41927
    14
    -> (string * sort -> string * sort) option
bulwahn@45159
    15
  val instantiate_goals: Proof.context -> (string * typ) list -> (term * term list) list
bulwahn@45159
    16
    -> (typ option * (term * term list)) list list
bulwahn@45763
    17
  val mk_safe_if : term -> term -> term * term * (bool -> term) -> bool -> term
bulwahn@45159
    18
  val collect_results : ('a -> Quickcheck.result) -> 'a list -> Quickcheck.result list -> Quickcheck.result list
bulwahn@45159
    19
  type compile_generator =
bulwahn@45754
    20
    Proof.context -> (term * term list) list -> bool -> int list -> (bool * term list) option * Quickcheck.report option
bulwahn@45420
    21
  val generator_test_goal_terms :
bulwahn@45420
    22
    string * compile_generator -> Proof.context -> bool -> (string * typ) list
bulwahn@45418
    23
    -> (term * term list) list -> Quickcheck.result list
bulwahn@41927
    24
  val ensure_sort_datatype:
bulwahn@42229
    25
    ((sort * sort) * sort) * (Datatype.config -> Datatype.descr -> (string * sort) list
bulwahn@42229
    26
      -> string list -> string -> string list * string list -> typ list * typ list -> theory -> theory)
bulwahn@41927
    27
    -> Datatype.config -> string list -> theory -> theory
bulwahn@42159
    28
  val gen_mk_parametric_generator_expr :
bulwahn@42229
    29
   (((Proof.context -> term * term list -> term) * term) * typ)
bulwahn@42229
    30
     -> Proof.context -> (term * term list) list -> term
bulwahn@45039
    31
  val mk_fun_upd : typ -> typ -> term * term -> term -> term
bulwahn@41935
    32
  val post_process_term : term -> term
bulwahn@45420
    33
  val test_term : string * compile_generator
bulwahn@45420
    34
    -> Proof.context -> bool -> term * term list -> Quickcheck.result
bulwahn@41927
    35
end;
bulwahn@41927
    36
bulwahn@41927
    37
structure Quickcheck_Common : QUICKCHECK_COMMON =
bulwahn@41927
    38
struct
bulwahn@41927
    39
bulwahn@42214
    40
(* static options *)
bulwahn@42214
    41
bulwahn@42214
    42
val define_foundationally = false
bulwahn@42214
    43
bulwahn@42195
    44
(* HOLogic's term functions *)
bulwahn@42195
    45
bulwahn@42195
    46
fun strip_imp (Const(@{const_name HOL.implies},_) $ A $ B) = apfst (cons A) (strip_imp B)
bulwahn@42195
    47
  | strip_imp A = ([], A)
bulwahn@42195
    48
bulwahn@45721
    49
fun reflect_bool b = if b then @{term "True"} else @{term "False"}
bulwahn@45721
    50
bulwahn@42214
    51
fun mk_undefined T = Const(@{const_name undefined}, T)
bulwahn@45159
    52
bulwahn@45159
    53
(* testing functions: testing with increasing sizes (and cardinalities) *)
bulwahn@45159
    54
bulwahn@45159
    55
type compile_generator =
bulwahn@45754
    56
  Proof.context -> (term * term list) list -> bool -> int list -> (bool * term list) option * Quickcheck.report option
bulwahn@45159
    57
bulwahn@45159
    58
fun check_test_term t =
bulwahn@45159
    59
  let
bulwahn@45159
    60
    val _ = (null (Term.add_tvars t []) andalso null (Term.add_tfrees t [])) orelse
bulwahn@45159
    61
      error "Term to be tested contains type variables";
bulwahn@45159
    62
    val _ = null (Term.add_vars t []) orelse
bulwahn@45159
    63
      error "Term to be tested contains schematic variables";
bulwahn@45159
    64
  in () end
bulwahn@45159
    65
bulwahn@45159
    66
fun cpu_time description e =
bulwahn@45159
    67
  let val ({cpu, ...}, result) = Timing.timing e ()
bulwahn@45159
    68
  in (result, (description, Time.toMilliseconds cpu)) end
bulwahn@45159
    69
bulwahn@45420
    70
fun test_term (name, compile) ctxt catch_code_errors (t, eval_terms) =
bulwahn@45159
    71
  let
bulwahn@45757
    72
    val genuine_only = Config.get ctxt Quickcheck.genuine_only
bulwahn@45159
    73
    val _ = check_test_term t
bulwahn@45159
    74
    val names = Term.add_free_names t []
bulwahn@45159
    75
    val current_size = Unsynchronized.ref 0
bulwahn@45159
    76
    val current_result = Unsynchronized.ref Quickcheck.empty_result 
bulwahn@45159
    77
    fun excipit () =
bulwahn@45159
    78
      "Quickcheck ran out of time while testing at size " ^ string_of_int (!current_size)
bulwahn@45754
    79
    val act = if catch_code_errors then try else (fn f => SOME o f) 
bulwahn@45754
    80
    val (test_fun, comp_time) = cpu_time "quickcheck compilation"
bulwahn@45754
    81
        (fn () => act (compile ctxt) [(t, eval_terms)]);
bulwahn@45754
    82
    val _ = Quickcheck.add_timing comp_time current_result
bulwahn@45754
    83
    fun with_size test_fun genuine_only k =
bulwahn@45159
    84
      if k > Config.get ctxt Quickcheck.size then
bulwahn@45159
    85
        NONE
bulwahn@45159
    86
      else
bulwahn@45159
    87
        let
bulwahn@45420
    88
          val _ = Quickcheck.message ctxt
bulwahn@45754
    89
            ("[Quickcheck-" ^ name ^ "] Test data size: " ^ string_of_int k)
bulwahn@45159
    90
          val _ = current_size := k
bulwahn@45159
    91
          val ((result, report), timing) =
bulwahn@45754
    92
            cpu_time ("size " ^ string_of_int k) (fn () => test_fun genuine_only [1, k - 1])
bulwahn@45159
    93
          val _ = Quickcheck.add_timing timing current_result
bulwahn@45159
    94
          val _ = Quickcheck.add_report k report current_result
bulwahn@45159
    95
        in
bulwahn@45753
    96
          case result of
bulwahn@45754
    97
            NONE => with_size test_fun genuine_only (k + 1)
bulwahn@45753
    98
          | SOME (true, ts) => SOME (true, ts)
bulwahn@45755
    99
          | SOME (false, ts) =>
bulwahn@45755
   100
            let
bulwahn@45755
   101
              val (ts1, ts2) = chop (length names) ts
bulwahn@45755
   102
              val (eval_terms', _) = chop (length ts2) eval_terms
bulwahn@45755
   103
              val cex = SOME ((false, names ~~ ts1), eval_terms' ~~ ts2)
bulwahn@45755
   104
            in
bulwahn@45755
   105
              (Output.urgent_message (Pretty.string_of (Quickcheck.pretty_counterex ctxt false cex));
bulwahn@45755
   106
              Output.urgent_message "Quickcheck continues to find a genuine counterexample...";
bulwahn@45755
   107
              with_size test_fun true k)
bulwahn@45755
   108
            end
bulwahn@45159
   109
        end;
bulwahn@45159
   110
  in
bulwahn@45417
   111
    case test_fun of
bulwahn@45420
   112
      NONE => (Quickcheck.message ctxt ("Conjecture is not executable with Quickcheck-" ^ name);
bulwahn@45420
   113
        !current_result)
bulwahn@45417
   114
    | SOME test_fun =>
bulwahn@45159
   115
      let
bulwahn@45159
   116
        val (response, exec_time) =
bulwahn@45754
   117
          cpu_time "quickcheck execution" (fn () => with_size test_fun genuine_only 1)
bulwahn@45159
   118
        val _ = Quickcheck.add_response names eval_terms response current_result
bulwahn@45159
   119
        val _ = Quickcheck.add_timing exec_time current_result
bulwahn@45417
   120
      in !current_result end
bulwahn@45159
   121
  end;
bulwahn@45159
   122
bulwahn@45159
   123
fun validate_terms ctxt ts =
bulwahn@45159
   124
  let
bulwahn@45159
   125
    val _ = map check_test_term ts
bulwahn@45159
   126
    val size = Config.get ctxt Quickcheck.size
bulwahn@45159
   127
    val (test_funs, comp_time) = cpu_time "quickcheck batch compilation"
bulwahn@45159
   128
      (fn () => Quickcheck.mk_batch_validator ctxt ts) 
bulwahn@45159
   129
    fun with_size tester k =
bulwahn@45159
   130
      if k > size then true
bulwahn@45159
   131
      else if tester k then with_size tester (k + 1) else false
bulwahn@45159
   132
    val (results, exec_time) = cpu_time "quickcheck batch execution" (fn () =>
bulwahn@45420
   133
        Option.map (map (fn test_fun =>
bulwahn@45420
   134
          TimeLimit.timeLimit (seconds (Config.get ctxt Quickcheck.timeout))
bulwahn@45159
   135
              (fn () => with_size test_fun 1) ()
bulwahn@45159
   136
             handle TimeLimit.TimeOut => true)) test_funs)
bulwahn@45159
   137
  in
bulwahn@45159
   138
    (results, [comp_time, exec_time])
bulwahn@45159
   139
  end
bulwahn@42214
   140
  
bulwahn@45159
   141
fun test_terms ctxt ts =
bulwahn@45159
   142
  let
bulwahn@45159
   143
    val _ = map check_test_term ts
bulwahn@45159
   144
    val size = Config.get ctxt Quickcheck.size
bulwahn@45159
   145
    val namess = map (fn t => Term.add_free_names t []) ts
bulwahn@45159
   146
    val (test_funs, comp_time) =
bulwahn@45159
   147
      cpu_time "quickcheck batch compilation" (fn () => Quickcheck.mk_batch_tester ctxt ts) 
bulwahn@45159
   148
    fun with_size tester k =
bulwahn@45159
   149
      if k > size then NONE
bulwahn@45159
   150
      else case tester k of SOME ts => SOME ts | NONE => with_size tester (k + 1)
bulwahn@45159
   151
    val (results, exec_time) = cpu_time "quickcheck batch execution" (fn () =>
bulwahn@45420
   152
        Option.map (map (fn test_fun =>
bulwahn@45420
   153
          TimeLimit.timeLimit (seconds (Config.get ctxt Quickcheck.timeout))
bulwahn@45159
   154
              (fn () => with_size test_fun 1) ()
bulwahn@45159
   155
             handle TimeLimit.TimeOut => NONE)) test_funs)
bulwahn@45159
   156
  in
bulwahn@45159
   157
    (Option.map (map2 (fn names => Option.map (fn ts => names ~~ ts)) namess) results,
bulwahn@45159
   158
      [comp_time, exec_time])
bulwahn@45159
   159
  end
bulwahn@45159
   160
bulwahn@45420
   161
fun test_term_with_cardinality (name, compile) ctxt catch_code_errors ts =
bulwahn@45159
   162
  let
bulwahn@45757
   163
    val genuine_only = Config.get ctxt Quickcheck.genuine_only
bulwahn@45159
   164
    val thy = Proof_Context.theory_of ctxt
bulwahn@45159
   165
    val (ts', eval_terms) = split_list ts
bulwahn@45159
   166
    val _ = map check_test_term ts'
bulwahn@45159
   167
    val names = Term.add_free_names (hd ts') []
bulwahn@45159
   168
    val Ts = map snd (Term.add_frees (hd ts') [])
bulwahn@45159
   169
    val current_result = Unsynchronized.ref Quickcheck.empty_result
bulwahn@45754
   170
    fun test_card_size test_fun genuine_only (card, size) =
bulwahn@45159
   171
      (* FIXME: why decrement size by one? *)
bulwahn@45159
   172
      let
bulwahn@45686
   173
        val _ =
bulwahn@45719
   174
          Quickcheck.message ctxt ("[Quickcheck-" ^ name ^ "] Test " ^
bulwahn@45686
   175
            (if size = 0 then "" else "data size: " ^ string_of_int (size - 1) ^ " and ") ^
bulwahn@45686
   176
            "cardinality: " ^ string_of_int card)          
bulwahn@45159
   177
        val (ts, timing) =
bulwahn@45159
   178
          cpu_time ("size " ^ string_of_int size ^ " and card " ^ string_of_int card)
bulwahn@45754
   179
            (fn () => fst (test_fun genuine_only [card, size - 1]))
bulwahn@45159
   180
        val _ = Quickcheck.add_timing timing current_result
bulwahn@45159
   181
      in
bulwahn@45755
   182
        Option.map (pair (card, size)) ts
bulwahn@45159
   183
      end
bulwahn@45159
   184
    val enumeration_card_size =
bulwahn@45159
   185
      if forall (fn T => Sign.of_sort thy (T,  ["Enum.enum"])) Ts then
bulwahn@45159
   186
        (* size does not matter *)
bulwahn@45159
   187
        map (rpair 0) (1 upto (length ts))
bulwahn@45159
   188
      else
bulwahn@45159
   189
        (* size does matter *)
bulwahn@45159
   190
        map_product pair (1 upto (length ts)) (1 upto (Config.get ctxt Quickcheck.size))
bulwahn@45159
   191
        |> sort (fn ((c1, s1), (c2, s2)) => int_ord ((c1 + s1), (c2 + s2)))
bulwahn@45419
   192
    val act = if catch_code_errors then try else (fn f => SOME o f)
bulwahn@45419
   193
    val (test_fun, comp_time) = cpu_time "quickcheck compilation" (fn () => act (compile ctxt) ts)
bulwahn@45417
   194
    val _ = Quickcheck.add_timing comp_time current_result
bulwahn@45159
   195
  in
bulwahn@45417
   196
    case test_fun of
bulwahn@45420
   197
      NONE => (Quickcheck.message ctxt ("Conjecture is not executable with Quickcheck-" ^ name);
bulwahn@45420
   198
        !current_result)
bulwahn@45417
   199
    | SOME test_fun =>
bulwahn@45159
   200
      let
bulwahn@45755
   201
        fun test genuine_only enum = case get_first (test_card_size test_fun genuine_only) enum of
bulwahn@45755
   202
          SOME ((card, _), (true, ts)) =>
bulwahn@45755
   203
            Quickcheck.add_response names (nth eval_terms (card - 1)) (SOME (true, ts)) current_result
bulwahn@45755
   204
        | SOME ((card, size), (false, ts)) =>
bulwahn@45755
   205
           let
bulwahn@45755
   206
              val (ts1, ts2) = chop (length names) ts
bulwahn@45755
   207
              val (eval_terms', _) = chop (length ts2) (nth eval_terms (card - 1))
bulwahn@45755
   208
              val cex = SOME ((false, names ~~ ts1), eval_terms' ~~ ts2)
bulwahn@45755
   209
            in
bulwahn@45755
   210
              (Output.urgent_message (Pretty.string_of (Quickcheck.pretty_counterex ctxt false cex));
bulwahn@45755
   211
              Output.urgent_message "Quickcheck continues to find a genuine counterexample...";
bulwahn@45755
   212
              test true (snd (take_prefix (fn x => not (x = (card, size))) enum)))
bulwahn@45755
   213
            end
bulwahn@45159
   214
        | NONE => ()
bulwahn@45755
   215
      in (test genuine_only enumeration_card_size; !current_result) end
bulwahn@45159
   216
  end
bulwahn@45159
   217
bulwahn@45159
   218
fun get_finite_types ctxt =
bulwahn@45159
   219
  fst (chop (Config.get ctxt Quickcheck.finite_type_size)
bulwahn@45416
   220
    [@{typ "Enum.finite_1"}, @{typ "Enum.finite_2"}, @{typ "Enum.finite_3"},
bulwahn@45416
   221
     @{typ "Enum.finite_4"}, @{typ "Enum.finite_5"}])
bulwahn@45159
   222
bulwahn@45159
   223
exception WELLSORTED of string
bulwahn@45159
   224
bulwahn@45159
   225
fun monomorphic_term thy insts default_T =
bulwahn@45159
   226
  let
bulwahn@45159
   227
    fun subst (T as TFree (v, S)) =
bulwahn@45159
   228
      let
bulwahn@45159
   229
        val T' = AList.lookup (op =) insts v
bulwahn@45159
   230
          |> the_default default_T
bulwahn@45159
   231
      in if Sign.of_sort thy (T', S) then T'
bulwahn@45159
   232
        else raise (WELLSORTED ("For instantiation with default_type " ^
bulwahn@45159
   233
          Syntax.string_of_typ_global thy default_T ^
bulwahn@45159
   234
          ":\n" ^ Syntax.string_of_typ_global thy T' ^
bulwahn@45159
   235
          " to be substituted for variable " ^
bulwahn@45159
   236
          Syntax.string_of_typ_global thy T ^ " does not have sort " ^
bulwahn@45159
   237
          Syntax.string_of_sort_global thy S))
bulwahn@45159
   238
      end
bulwahn@45159
   239
      | subst T = T;
bulwahn@45159
   240
  in (map_types o map_atyps) subst end;
bulwahn@45159
   241
bulwahn@45159
   242
datatype wellsorted_error = Wellsorted_Error of string | Term of term * term list
bulwahn@45159
   243
bulwahn@45440
   244
(* minimalistic preprocessing *)
bulwahn@45440
   245
bulwahn@45440
   246
fun strip_all (Const (@{const_name HOL.All}, _) $ Abs (a, T, t)) = 
bulwahn@45440
   247
  let
bulwahn@45440
   248
    val (a', t') = strip_all t
bulwahn@45440
   249
  in ((a, T) :: a', t') end
bulwahn@45440
   250
  | strip_all t = ([], t);
bulwahn@45440
   251
bulwahn@45440
   252
fun preprocess ctxt t =
bulwahn@45440
   253
  let
bulwahn@45440
   254
    val thy = Proof_Context.theory_of ctxt
bulwahn@45440
   255
    val dest = HOLogic.dest_eq o HOLogic.dest_Trueprop o Thm.prop_of
bulwahn@45440
   256
    val rewrs = map (swap o dest) @{thms all_simps} @
bulwahn@45440
   257
      (map dest [@{thm not_ex}, @{thm not_all}, @{thm imp_conjL}])
bulwahn@45440
   258
    val t' = Pattern.rewrite_term thy rewrs [] (Object_Logic.atomize_term thy t)
bulwahn@45440
   259
    val (vs, body) = strip_all t'
bulwahn@45440
   260
    val vs' = Variable.variant_frees ctxt [t'] vs
bulwahn@45440
   261
  in
bulwahn@45440
   262
    subst_bounds (map Free (rev vs'), body)
bulwahn@45440
   263
  end
bulwahn@45440
   264
bulwahn@45440
   265
(* instantiation of type variables with concrete types *)
bulwahn@45440
   266
 
bulwahn@45159
   267
fun instantiate_goals lthy insts goals =
bulwahn@45159
   268
  let
bulwahn@45159
   269
    fun map_goal_and_eval_terms f (check_goal, eval_terms) = (f check_goal, map f eval_terms)
bulwahn@45159
   270
    val thy = Proof_Context.theory_of lthy
bulwahn@45159
   271
    val default_insts =
bulwahn@45420
   272
      if Config.get lthy Quickcheck.finite_types then get_finite_types else Quickcheck.default_type
bulwahn@45159
   273
    val inst_goals =
bulwahn@45159
   274
      map (fn (check_goal, eval_terms) =>
bulwahn@45159
   275
        if not (null (Term.add_tfree_names check_goal [])) then
bulwahn@45159
   276
          map (fn T =>
bulwahn@45440
   277
            (pair (SOME T) o Term o apfst (preprocess lthy))
bulwahn@45159
   278
              (map_goal_and_eval_terms (monomorphic_term thy insts T) (check_goal, eval_terms))
bulwahn@45420
   279
              handle WELLSORTED s => (SOME T, Wellsorted_Error s)) (default_insts lthy)
bulwahn@45159
   280
        else
bulwahn@45440
   281
          [(NONE, Term (preprocess lthy check_goal, eval_terms))]) goals
bulwahn@45159
   282
    val error_msg =
bulwahn@45159
   283
      cat_lines
bulwahn@45159
   284
        (maps (map_filter (fn (_, Term t) => NONE | (_, Wellsorted_Error s) => SOME s)) inst_goals)
bulwahn@45159
   285
    fun is_wellsorted_term (T, Term t) = SOME (T, t)
bulwahn@45159
   286
      | is_wellsorted_term (_, Wellsorted_Error s) = NONE
bulwahn@45159
   287
    val correct_inst_goals =
bulwahn@45159
   288
      case map (map_filter is_wellsorted_term) inst_goals of
bulwahn@45159
   289
        [[]] => error error_msg
bulwahn@45159
   290
      | xs => xs
bulwahn@45159
   291
    val _ = if Config.get lthy Quickcheck.quiet then () else warning error_msg
bulwahn@45159
   292
  in
bulwahn@45159
   293
    correct_inst_goals
bulwahn@45159
   294
  end
bulwahn@45159
   295
bulwahn@45718
   296
(* compilation of testing functions *)
bulwahn@45718
   297
bulwahn@45763
   298
fun mk_safe_if genuine_only none (cond, then_t, else_t) genuine =
bulwahn@45753
   299
  let
bulwahn@45763
   300
    val T = fastype_of then_t
bulwahn@45754
   301
    val if_t = Const (@{const_name "If"}, @{typ bool} --> T --> T --> T)
bulwahn@45753
   302
  in
bulwahn@45753
   303
    Const (@{const_name "Quickcheck.catch_match"}, T --> T --> T) $ 
bulwahn@45761
   304
      (if_t $ cond $ then_t $ else_t genuine) $
bulwahn@45763
   305
      (if_t $ genuine_only $ none $ else_t false)
bulwahn@45753
   306
  end
bulwahn@45718
   307
bulwahn@45159
   308
fun collect_results f [] results = results
bulwahn@45159
   309
  | collect_results f (t :: ts) results =
bulwahn@45159
   310
    let
bulwahn@45159
   311
      val result = f t
bulwahn@45159
   312
    in
bulwahn@45159
   313
      if Quickcheck.found_counterexample result then
bulwahn@45159
   314
        (result :: results)
bulwahn@45159
   315
      else
bulwahn@45159
   316
        collect_results f ts (result :: results)
bulwahn@45159
   317
    end  
bulwahn@45159
   318
bulwahn@45420
   319
fun generator_test_goal_terms (name, compile) ctxt catch_code_errors insts goals =
bulwahn@45159
   320
  let
bulwahn@45687
   321
    fun add_eval_term t ts = if is_Free t then ts else ts @ [t]
bulwahn@45687
   322
    fun add_equation_eval_terms (t, eval_terms) =
bulwahn@45687
   323
      case try HOLogic.dest_eq (snd (strip_imp t)) of
bulwahn@45687
   324
        SOME (lhs, rhs) => (t, add_eval_term lhs (add_eval_term rhs eval_terms))
bulwahn@45687
   325
      | NONE => (t, eval_terms)
bulwahn@45159
   326
    fun test_term' goal =
bulwahn@45159
   327
      case goal of
bulwahn@45420
   328
        [(NONE, t)] => test_term (name, compile) ctxt catch_code_errors t
bulwahn@45420
   329
      | ts => test_term_with_cardinality (name, compile) ctxt catch_code_errors (map snd ts)
bulwahn@45159
   330
    val goals' = instantiate_goals ctxt insts goals
bulwahn@45159
   331
      |> map (map (apsnd add_equation_eval_terms))
bulwahn@45159
   332
  in
bulwahn@45159
   333
    if Config.get ctxt Quickcheck.finite_types then
bulwahn@45159
   334
      collect_results test_term' goals' []
bulwahn@45159
   335
    else
bulwahn@45420
   336
      collect_results (test_term (name, compile) ctxt catch_code_errors)
bulwahn@45159
   337
        (maps (map snd) goals') []
bulwahn@45159
   338
  end;
bulwahn@45159
   339
bulwahn@42214
   340
(* defining functions *)
bulwahn@42214
   341
bulwahn@42214
   342
fun pat_completeness_auto ctxt =
wenzelm@42793
   343
  Pat_Completeness.pat_completeness_tac ctxt 1 THEN auto_tac ctxt
bulwahn@42214
   344
bulwahn@42214
   345
fun define_functions (mk_equations, termination_tac) prfx argnames names Ts =
bulwahn@42214
   346
  if define_foundationally andalso is_some termination_tac then
bulwahn@42214
   347
    let
bulwahn@42214
   348
      val eqs_t = mk_equations (map2 (fn name => fn T => Free (name, T)) names Ts)
bulwahn@42214
   349
    in
bulwahn@42214
   350
      Function.add_function
wenzelm@42287
   351
        (map (fn (name, T) => (Binding.conceal (Binding.name name), SOME T, NoSyn))
wenzelm@42287
   352
          (names ~~ Ts))
wenzelm@42287
   353
        (map (pair (apfst Binding.conceal Attrib.empty_binding)) eqs_t)
bulwahn@42214
   354
        Function_Common.default_config pat_completeness_auto
bulwahn@42214
   355
      #> snd
bulwahn@42214
   356
      #> (fn lthy => Function.prove_termination NONE (the termination_tac lthy) lthy)
bulwahn@42214
   357
      #> snd
bulwahn@42214
   358
    end
bulwahn@42214
   359
  else
bulwahn@42214
   360
    fold_map (fn (name, T) => Local_Theory.define
bulwahn@42214
   361
        ((Binding.conceal (Binding.name name), NoSyn),
bulwahn@42214
   362
          (apfst Binding.conceal Attrib.empty_binding, mk_undefined T))
bulwahn@42214
   363
      #> apfst fst) (names ~~ Ts)
bulwahn@42214
   364
    #> (fn (consts, lthy) =>
bulwahn@42214
   365
      let
bulwahn@42214
   366
        val eqs_t = mk_equations consts
bulwahn@42214
   367
        val eqs = map (fn eq => Goal.prove lthy argnames [] eq
wenzelm@42361
   368
          (fn _ => Skip_Proof.cheat_tac (Proof_Context.theory_of lthy))) eqs_t
bulwahn@42214
   369
      in
bulwahn@42214
   370
        fold (fn (name, eq) => Local_Theory.note
wenzelm@45592
   371
          ((Binding.conceal
wenzelm@45592
   372
            (Binding.qualify true prfx
wenzelm@45592
   373
              (Binding.qualify true name (Binding.name "simps"))),
wenzelm@45592
   374
             Code.add_default_eqn_attrib :: @{attributes [simp, nitpick_simp]}), [eq]) #> snd)
wenzelm@45592
   375
          (names ~~ eqs) lthy
bulwahn@42214
   376
      end)
bulwahn@42214
   377
bulwahn@41935
   378
(** ensuring sort constraints **)
bulwahn@41935
   379
bulwahn@41927
   380
fun perhaps_constrain thy insts raw_vs =
bulwahn@41927
   381
  let
bulwahn@41927
   382
    fun meet (T, sort) = Sorts.meet_sort (Sign.classes_of thy) 
bulwahn@41927
   383
      (Logic.varifyT_global T, sort);
bulwahn@41927
   384
    val vtab = Vartab.empty
bulwahn@41927
   385
      |> fold (fn (v, sort) => Vartab.update ((v, 0), sort)) raw_vs
bulwahn@41927
   386
      |> fold meet insts;
bulwahn@41927
   387
  in SOME (fn (v, _) => (v, (the o Vartab.lookup vtab) (v, 0)))
bulwahn@41927
   388
  end handle Sorts.CLASS_ERROR _ => NONE;
bulwahn@41927
   389
bulwahn@42229
   390
fun ensure_sort_datatype (((sort_vs, aux_sort), sort), instantiate_datatype) config raw_tycos thy =
bulwahn@41927
   391
  let
bulwahn@41927
   392
    val algebra = Sign.classes_of thy;
bulwahn@42229
   393
    val (descr, raw_vs, tycos, prfx, (names, auxnames), raw_TUs) = Datatype.the_descr thy raw_tycos
bulwahn@42229
   394
    val vs = (map o apsnd) (curry (Sorts.inter_sort algebra) sort_vs) raw_vs
bulwahn@42229
   395
    fun insts_of sort constr  = (map (rpair sort) o flat o maps snd o maps snd)
bulwahn@42229
   396
      (Datatype_Aux.interpret_construction descr vs constr)
bulwahn@42229
   397
    val insts = insts_of sort  { atyp = single, dtyp = (K o K o K) [] }
bulwahn@42229
   398
      @ insts_of aux_sort { atyp = K [], dtyp = K o K }
bulwahn@42229
   399
    val has_inst = exists (fn tyco => can (Sorts.mg_domain algebra tyco) sort) tycos;
bulwahn@41927
   400
  in if has_inst then thy
bulwahn@42229
   401
    else case perhaps_constrain thy insts vs
bulwahn@41927
   402
     of SOME constrain => instantiate_datatype config descr
bulwahn@42229
   403
          (map constrain vs) tycos prfx (names, auxnames)
bulwahn@41927
   404
            ((pairself o map o map_atyps) (fn TFree v => TFree (constrain v)) raw_TUs) thy
bulwahn@41927
   405
      | NONE => thy
bulwahn@41927
   406
  end;
bulwahn@41935
   407
  
bulwahn@42159
   408
(** generic parametric compilation **)
bulwahn@42159
   409
bulwahn@42159
   410
fun gen_mk_parametric_generator_expr ((mk_generator_expr, out_of_bounds), T) ctxt ts =
bulwahn@42159
   411
  let
bulwahn@42159
   412
    val if_t = Const (@{const_name "If"}, @{typ bool} --> T --> T --> T)
bulwahn@45721
   413
    fun mk_if (index, (t, eval_terms)) else_t = if_t $
bulwahn@45721
   414
        (HOLogic.eq_const @{typ code_numeral} $ Bound 0 $ HOLogic.mk_number @{typ code_numeral} index) $
bulwahn@42159
   415
        (mk_generator_expr ctxt (t, eval_terms)) $ else_t
bulwahn@42159
   416
  in
wenzelm@44241
   417
    absdummy @{typ "code_numeral"} (fold_rev mk_if (1 upto (length ts) ~~ ts) out_of_bounds)
bulwahn@42159
   418
  end
bulwahn@42159
   419
bulwahn@41935
   420
(** post-processing of function terms **)
bulwahn@41935
   421
bulwahn@41935
   422
fun dest_fun_upd (Const (@{const_name fun_upd}, _) $ t0 $ t1 $ t2) = (t0, (t1, t2))
bulwahn@41935
   423
  | dest_fun_upd t = raise TERM ("dest_fun_upd", [t])
bulwahn@41935
   424
bulwahn@41935
   425
fun mk_fun_upd T1 T2 (t1, t2) t = 
bulwahn@41935
   426
  Const (@{const_name fun_upd}, (T1 --> T2) --> T1 --> T2 --> T1 --> T2) $ t $ t1 $ t2
bulwahn@41935
   427
bulwahn@41935
   428
fun dest_fun_upds t =
bulwahn@41935
   429
  case try dest_fun_upd t of
bulwahn@41935
   430
    NONE =>
bulwahn@41935
   431
      (case t of
bulwahn@41935
   432
        Abs (_, _, _) => ([], t) 
bulwahn@41935
   433
      | _ => raise TERM ("dest_fun_upds", [t]))
bulwahn@41935
   434
  | SOME (t0, (t1, t2)) => apfst (cons (t1, t2)) (dest_fun_upds t0)
bulwahn@41935
   435
bulwahn@41935
   436
fun make_fun_upds T1 T2 (tps, t) = fold_rev (mk_fun_upd T1 T2) tps t
bulwahn@41935
   437
bulwahn@41935
   438
fun make_set T1 [] = Const (@{const_abbrev Set.empty}, T1 --> @{typ bool})
bulwahn@41935
   439
  | make_set T1 ((_, @{const False}) :: tps) = make_set T1 tps
bulwahn@41935
   440
  | make_set T1 ((t1, @{const True}) :: tps) =
bulwahn@41935
   441
    Const (@{const_name insert}, T1 --> (T1 --> @{typ bool}) --> T1 --> @{typ bool})
bulwahn@41935
   442
      $ t1 $ (make_set T1 tps)
bulwahn@41935
   443
  | make_set T1 ((_, t) :: tps) = raise TERM ("make_set", [t])
bulwahn@41935
   444
bulwahn@41935
   445
fun make_coset T [] = Const (@{const_abbrev UNIV}, T --> @{typ bool})
bulwahn@41935
   446
  | make_coset T tps = 
bulwahn@41935
   447
    let
bulwahn@41935
   448
      val U = T --> @{typ bool}
bulwahn@41935
   449
      fun invert @{const False} = @{const True}
bulwahn@41935
   450
        | invert @{const True} = @{const False}
bulwahn@41935
   451
    in
bulwahn@41935
   452
      Const (@{const_name "Groups.minus_class.minus"}, U --> U --> U)
bulwahn@41935
   453
        $ Const (@{const_abbrev UNIV}, U) $ make_set T (map (apsnd invert) tps)
bulwahn@41935
   454
    end
bulwahn@41935
   455
bulwahn@41935
   456
fun make_map T1 T2 [] = Const (@{const_abbrev Map.empty}, T1 --> T2)
bulwahn@41935
   457
  | make_map T1 T2 ((_, Const (@{const_name None}, _)) :: tps) = make_map T1 T2 tps
bulwahn@41935
   458
  | make_map T1 T2 ((t1, t2) :: tps) = mk_fun_upd T1 T2 (t1, t2) (make_map T1 T2 tps)
bulwahn@41935
   459
  
bulwahn@41935
   460
fun post_process_term t =
bulwahn@41935
   461
  let
bulwahn@41935
   462
    fun map_Abs f t =
bulwahn@41935
   463
      case t of Abs (x, T, t') => Abs (x, T, f t') | _ => raise TERM ("map_Abs", [t]) 
bulwahn@41935
   464
    fun process_args t = case strip_comb t of
bulwahn@42110
   465
      (c as Const (_, _), ts) => list_comb (c, map post_process_term ts)
bulwahn@41935
   466
  in
bulwahn@41935
   467
    case fastype_of t of
bulwahn@41935
   468
      Type (@{type_name fun}, [T1, T2]) =>
bulwahn@41935
   469
        (case try dest_fun_upds t of
bulwahn@41935
   470
          SOME (tps, t) =>
bulwahn@41935
   471
            (map (pairself post_process_term) tps, map_Abs post_process_term t)
bulwahn@41935
   472
            |> (case T2 of
bulwahn@41935
   473
              @{typ bool} => 
bulwahn@41935
   474
                (case t of
bulwahn@42110
   475
                   Abs(_, _, @{const False}) => fst #> rev #> make_set T1
bulwahn@42110
   476
                 | Abs(_, _, @{const True}) => fst #> rev #> make_coset T1
bulwahn@41935
   477
                 | Abs(_, _, Const (@{const_name undefined}, _)) => fst #> rev #> make_set T1
bulwahn@41935
   478
                 | _ => raise TERM ("post_process_term", [t]))
bulwahn@41935
   479
            | Type (@{type_name option}, _) =>
bulwahn@41935
   480
                (case t of
bulwahn@42110
   481
                  Abs(_, _, Const (@{const_name None}, _)) => fst #> make_map T1 T2
bulwahn@41935
   482
                | Abs(_, _, Const (@{const_name undefined}, _)) => fst #> make_map T1 T2
bulwahn@42110
   483
                | _ => make_fun_upds T1 T2)
bulwahn@41935
   484
            | _ => make_fun_upds T1 T2)
bulwahn@41935
   485
        | NONE => process_args t)
bulwahn@41935
   486
    | _ => process_args t
bulwahn@41935
   487
  end
bulwahn@41927
   488
bulwahn@41927
   489
end;