src/HOL/Tools/Function/fun.ML
author wenzelm
Sat Mar 22 18:19:57 2014 +0100 (2014-03-22)
changeset 56254 a2dd9200854d
parent 54407 e95831757903
child 58826 2ed2eaabe3df
permissions -rw-r--r--
more antiquotations;
krauss@33098
     1
(*  Title:      HOL/Tools/Function/fun.ML
krauss@33098
     2
    Author:     Alexander Krauss, TU Muenchen
krauss@33098
     3
krauss@41114
     4
Command "fun": Function definitions with pattern splitting/completion
krauss@41114
     5
and automated termination proofs.
krauss@33098
     6
*)
krauss@33098
     7
krauss@33098
     8
signature FUNCTION_FUN =
krauss@33098
     9
sig
wenzelm@44237
    10
  val fun_config : Function_Common.function_config
krauss@36523
    11
  val add_fun : (binding * typ option * mixfix) list ->
krauss@36523
    12
    (Attrib.binding * term) list -> Function_Common.function_config ->
krauss@36523
    13
    local_theory -> Proof.context
krauss@36523
    14
  val add_fun_cmd : (binding * string option * mixfix) list ->
krauss@36523
    15
    (Attrib.binding * string) list -> Function_Common.function_config ->
wenzelm@44239
    16
    bool -> local_theory -> Proof.context
krauss@33098
    17
krauss@34232
    18
  val setup : theory -> theory
krauss@33098
    19
end
krauss@33098
    20
krauss@33098
    21
structure Function_Fun : FUNCTION_FUN =
krauss@33098
    22
struct
krauss@33098
    23
krauss@33099
    24
open Function_Lib
krauss@33099
    25
open Function_Common
krauss@33098
    26
krauss@33098
    27
krauss@33098
    28
fun check_pats ctxt geq =
krauss@34232
    29
  let
krauss@34232
    30
    fun err str = error (cat_lines ["Malformed definition:",
krauss@34232
    31
      str ^ " not allowed in sequential mode.",
krauss@34232
    32
      Syntax.string_of_term ctxt geq])
wenzelm@42361
    33
    val thy = Proof_Context.theory_of ctxt
krauss@34232
    34
krauss@34232
    35
    fun check_constr_pattern (Bound _) = ()
krauss@34232
    36
      | check_constr_pattern t =
krauss@34232
    37
      let
krauss@34232
    38
        val (hd, args) = strip_comb t
krauss@34232
    39
      in
blanchet@54407
    40
        (case hd of
blanchet@54407
    41
          Const (hd_s, hd_T) =>
blanchet@54407
    42
          (case body_type hd_T of
blanchet@54407
    43
            Type (Tname, _) =>
blanchet@54407
    44
            (case Ctr_Sugar.ctr_sugar_of ctxt Tname of
blanchet@54407
    45
              SOME {ctrs, ...} => exists (fn Const (s, _) => s = hd_s) ctrs
blanchet@54407
    46
            | NONE => false)
blanchet@54407
    47
          | _ => false)
blanchet@54407
    48
        | _ => false) orelse err "Non-constructor pattern";
blanchet@54407
    49
        map check_constr_pattern args;
blanchet@54407
    50
        ()
krauss@34232
    51
      end
krauss@34232
    52
krauss@39276
    53
    val (_, qs, gs, args, _) = split_def ctxt (K true) geq
krauss@34232
    54
krauss@34232
    55
    val _ = if not (null gs) then err "Conditional equations" else ()
krauss@34232
    56
    val _ = map check_constr_pattern args
krauss@34232
    57
krauss@34232
    58
    (* just count occurrences to check linearity *)
krauss@34232
    59
    val _ = if fold (fold_aterms (fn Bound _ => Integer.add 1 | _ => I)) args 0 > length qs
krauss@34232
    60
      then err "Nonlinear patterns" else ()
krauss@34232
    61
  in
krauss@34232
    62
    ()
krauss@34232
    63
  end
krauss@34232
    64
krauss@33098
    65
fun mk_catchall fixes arity_of =
krauss@34232
    66
  let
krauss@34232
    67
    fun mk_eqn ((fname, fT), _) =
krauss@34232
    68
      let
krauss@34232
    69
        val n = arity_of fname
krauss@34232
    70
        val (argTs, rT) = chop n (binder_types fT)
krauss@34232
    71
          |> apsnd (fn Ts => Ts ---> body_type fT)
krauss@34232
    72
wenzelm@43329
    73
        val qs = map Free (Name.invent Name.context "a" n ~~ argTs)
krauss@34232
    74
      in
krauss@34232
    75
        HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
wenzelm@56254
    76
          Const (@{const_name undefined}, rT))
krauss@34232
    77
        |> HOLogic.mk_Trueprop
krauss@34232
    78
        |> fold_rev Logic.all qs
krauss@34232
    79
      end
krauss@34232
    80
  in
krauss@34232
    81
    map mk_eqn fixes
krauss@34232
    82
  end
krauss@33098
    83
krauss@33098
    84
fun add_catchall ctxt fixes spec =
krauss@39276
    85
  let val fqgars = map (split_def ctxt (K true)) spec
krauss@33098
    86
      val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
krauss@33098
    87
                     |> AList.lookup (op =) #> the
krauss@33098
    88
  in
krauss@33098
    89
    spec @ mk_catchall fixes arity_of
krauss@33098
    90
  end
krauss@33098
    91
krauss@48099
    92
fun further_checks ctxt origs tss =
krauss@34232
    93
  let
krauss@48099
    94
    fun fail_redundant t =
krauss@48099
    95
      error (cat_lines ["Equation is redundant (covered by preceding clauses):", Syntax.string_of_term ctxt t])
krauss@42947
    96
    fun warn_missing strs =
wenzelm@43277
    97
      warning (cat_lines ("Missing patterns in function definition:" :: strs))
krauss@42947
    98
krauss@42947
    99
    val (tss', added) = chop (length origs) tss
krauss@33098
   100
krauss@42947
   101
    val _ = case chop 3 (flat added) of
krauss@42947
   102
       ([], []) => ()
krauss@42947
   103
     | (eqs, []) => warn_missing (map (Syntax.string_of_term ctxt) eqs)
krauss@42947
   104
     | (eqs, rest) => warn_missing (map (Syntax.string_of_term ctxt) eqs
krauss@42947
   105
         @ ["(" ^ string_of_int (length rest) ^ " more)"])
krauss@42947
   106
krauss@42947
   107
    val _ = (origs ~~ tss')
krauss@48099
   108
      |> map (fn (t, ts) => if null ts then fail_redundant t else ())
krauss@34232
   109
  in
krauss@42947
   110
    ()
krauss@34232
   111
  end
krauss@33098
   112
krauss@33099
   113
fun sequential_preproc (config as FunctionConfig {sequential, ...}) ctxt fixes spec =
krauss@34232
   114
  if sequential then
krauss@34232
   115
    let
krauss@34232
   116
      val (bnds, eqss) = split_list spec
krauss@34232
   117
krauss@34232
   118
      val eqs = map the_single eqss
krauss@33098
   119
krauss@34232
   120
      val feqs = eqs
krauss@34232
   121
        |> tap (check_defs ctxt fixes) (* Standard checks *)
krauss@34232
   122
        |> tap (map (check_pats ctxt)) (* More checks for sequential mode *)
krauss@34232
   123
krauss@34232
   124
      val compleqs = add_catchall ctxt fixes feqs (* Completion *)
krauss@33098
   125
krauss@42947
   126
      val spliteqs = Function_Split.split_all_equations ctxt compleqs
krauss@48099
   127
        |> tap (further_checks ctxt feqs)
krauss@34232
   128
krauss@34232
   129
      fun restore_spec thms =
krauss@34232
   130
        bnds ~~ take (length bnds) (unflat spliteqs thms)
krauss@33098
   131
krauss@34232
   132
      val spliteqs' = flat (take (length bnds) spliteqs)
krauss@34232
   133
      val fnames = map (fst o fst) fixes
krauss@34232
   134
      val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs'
krauss@33098
   135
krauss@34232
   136
      fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
krauss@34232
   137
        |> map (map snd)
krauss@33098
   138
krauss@33098
   139
krauss@34232
   140
      val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding
krauss@33098
   141
krauss@34232
   142
      (* using theorem names for case name currently disabled *)
krauss@34232
   143
      val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) 
krauss@34232
   144
        (bnds' ~~ spliteqs) |> flat
krauss@34232
   145
    in
krauss@34232
   146
      (flat spliteqs, restore_spec, sort, case_names)
krauss@34232
   147
    end
krauss@34232
   148
  else
krauss@34232
   149
    Function_Common.empty_preproc check_defs config ctxt fixes spec
krauss@33098
   150
krauss@33098
   151
val setup =
krauss@33099
   152
  Context.theory_map (Function_Common.set_preproc sequential_preproc)
krauss@33098
   153
krauss@33098
   154
krauss@41417
   155
val fun_config = FunctionConfig { sequential=true, default=NONE,
krauss@41846
   156
  domintros=false, partials=false }
krauss@33098
   157
wenzelm@44239
   158
fun gen_add_fun add lthy =
krauss@36523
   159
  let
krauss@36523
   160
    fun pat_completeness_auto ctxt =
krauss@36523
   161
      Pat_Completeness.pat_completeness_tac ctxt 1
wenzelm@42793
   162
      THEN auto_tac ctxt
krauss@36523
   163
    fun prove_termination lthy =
krauss@36523
   164
      Function.prove_termination NONE
wenzelm@49965
   165
        (Function_Common.get_termination_prover lthy) lthy
krauss@36523
   166
  in
krauss@36523
   167
    lthy
wenzelm@44239
   168
    |> add pat_completeness_auto |> snd
krauss@36547
   169
    |> prove_termination |> snd
krauss@36523
   170
  end
krauss@33098
   171
wenzelm@44239
   172
fun add_fun a b c = gen_add_fun (Function.add_function a b c)
wenzelm@44239
   173
fun add_fun_cmd a b c int = gen_add_fun (fn tac => Function.add_function_cmd a b c tac int)
krauss@33098
   174
krauss@33098
   175
krauss@33098
   176
krauss@33098
   177
val _ =
wenzelm@46961
   178
  Outer_Syntax.local_theory' @{command_spec "fun"}
wenzelm@46961
   179
    "define general recursive functions (short version)"
wenzelm@46961
   180
    (function_parser fun_config
wenzelm@46961
   181
      >> (fn ((config, fixes), statements) => add_fun_cmd fixes statements config))
krauss@33098
   182
krauss@33098
   183
end