src/HOL/Tools/Function/fun.ML
author krauss
Fri Feb 25 16:59:48 2011 +0100 (2011-02-25)
changeset 41846 b368a7aee46a
parent 41417 211dbd42f95d
child 42361 23f352990944
permissions -rw-r--r--
removed support for tail-recursion from function package (now implemented by partial_function)
krauss@33098
     1
(*  Title:      HOL/Tools/Function/fun.ML
krauss@33098
     2
    Author:     Alexander Krauss, TU Muenchen
krauss@33098
     3
krauss@41114
     4
Command "fun": Function definitions with pattern splitting/completion
krauss@41114
     5
and automated termination proofs.
krauss@33098
     6
*)
krauss@33098
     7
krauss@33098
     8
signature FUNCTION_FUN =
krauss@33098
     9
sig
krauss@36523
    10
  val add_fun : (binding * typ option * mixfix) list ->
krauss@36523
    11
    (Attrib.binding * term) list -> Function_Common.function_config ->
krauss@36523
    12
    local_theory -> Proof.context
krauss@36523
    13
  val add_fun_cmd : (binding * string option * mixfix) list ->
krauss@36523
    14
    (Attrib.binding * string) list -> Function_Common.function_config ->
krauss@36523
    15
    local_theory -> Proof.context
krauss@33098
    16
krauss@34232
    17
  val setup : theory -> theory
krauss@33098
    18
end
krauss@33098
    19
krauss@33098
    20
structure Function_Fun : FUNCTION_FUN =
krauss@33098
    21
struct
krauss@33098
    22
krauss@33099
    23
open Function_Lib
krauss@33099
    24
open Function_Common
krauss@33098
    25
krauss@33098
    26
krauss@33098
    27
fun check_pats ctxt geq =
krauss@34232
    28
  let
krauss@34232
    29
    fun err str = error (cat_lines ["Malformed definition:",
krauss@34232
    30
      str ^ " not allowed in sequential mode.",
krauss@34232
    31
      Syntax.string_of_term ctxt geq])
krauss@34232
    32
    val thy = ProofContext.theory_of ctxt
krauss@34232
    33
krauss@34232
    34
    fun check_constr_pattern (Bound _) = ()
krauss@34232
    35
      | check_constr_pattern t =
krauss@34232
    36
      let
krauss@34232
    37
        val (hd, args) = strip_comb t
krauss@34232
    38
      in
krauss@34232
    39
        (((case Datatype.info_of_constr thy (dest_Const hd) of
krauss@34232
    40
             SOME _ => ()
krauss@34232
    41
           | NONE => err "Non-constructor pattern")
krauss@34232
    42
          handle TERM ("dest_Const", _) => err "Non-constructor patterns");
krauss@34232
    43
         map check_constr_pattern args;
krauss@34232
    44
         ())
krauss@34232
    45
      end
krauss@34232
    46
krauss@39276
    47
    val (_, qs, gs, args, _) = split_def ctxt (K true) geq
krauss@34232
    48
krauss@34232
    49
    val _ = if not (null gs) then err "Conditional equations" else ()
krauss@34232
    50
    val _ = map check_constr_pattern args
krauss@34232
    51
krauss@34232
    52
    (* just count occurrences to check linearity *)
krauss@34232
    53
    val _ = if fold (fold_aterms (fn Bound _ => Integer.add 1 | _ => I)) args 0 > length qs
krauss@34232
    54
      then err "Nonlinear patterns" else ()
krauss@34232
    55
  in
krauss@34232
    56
    ()
krauss@34232
    57
  end
krauss@34232
    58
krauss@33098
    59
fun mk_catchall fixes arity_of =
krauss@34232
    60
  let
krauss@34232
    61
    fun mk_eqn ((fname, fT), _) =
krauss@34232
    62
      let
krauss@34232
    63
        val n = arity_of fname
krauss@34232
    64
        val (argTs, rT) = chop n (binder_types fT)
krauss@34232
    65
          |> apsnd (fn Ts => Ts ---> body_type fT)
krauss@34232
    66
krauss@34232
    67
        val qs = map Free (Name.invent_list [] "a" n ~~ argTs)
krauss@34232
    68
      in
krauss@34232
    69
        HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
krauss@34232
    70
          Const ("HOL.undefined", rT))
krauss@34232
    71
        |> HOLogic.mk_Trueprop
krauss@34232
    72
        |> fold_rev Logic.all qs
krauss@34232
    73
      end
krauss@34232
    74
  in
krauss@34232
    75
    map mk_eqn fixes
krauss@34232
    76
  end
krauss@33098
    77
krauss@33098
    78
fun add_catchall ctxt fixes spec =
krauss@39276
    79
  let val fqgars = map (split_def ctxt (K true)) spec
krauss@33098
    80
      val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
krauss@33098
    81
                     |> AList.lookup (op =) #> the
krauss@33098
    82
  in
krauss@33098
    83
    spec @ mk_catchall fixes arity_of
krauss@33098
    84
  end
krauss@33098
    85
krauss@33098
    86
fun warn_if_redundant ctxt origs tss =
krauss@34232
    87
  let
krauss@34232
    88
    fun msg t = "Ignoring redundant equation: " ^ quote (Syntax.string_of_term ctxt t)
krauss@33098
    89
krauss@34232
    90
    val (tss', _) = chop (length origs) tss
krauss@34232
    91
    fun check (t, []) = (warning (msg t); [])
krauss@34232
    92
      | check (t, s) = s
krauss@34232
    93
  in
krauss@34232
    94
    (map check (origs ~~ tss'); tss)
krauss@34232
    95
  end
krauss@33098
    96
krauss@33099
    97
fun sequential_preproc (config as FunctionConfig {sequential, ...}) ctxt fixes spec =
krauss@34232
    98
  if sequential then
krauss@34232
    99
    let
krauss@34232
   100
      val (bnds, eqss) = split_list spec
krauss@34232
   101
krauss@34232
   102
      val eqs = map the_single eqss
krauss@33098
   103
krauss@34232
   104
      val feqs = eqs
krauss@34232
   105
        |> tap (check_defs ctxt fixes) (* Standard checks *)
krauss@34232
   106
        |> tap (map (check_pats ctxt)) (* More checks for sequential mode *)
krauss@34232
   107
krauss@34232
   108
      val compleqs = add_catchall ctxt fixes feqs (* Completion *)
krauss@33098
   109
krauss@34232
   110
      val spliteqs = warn_if_redundant ctxt feqs
krauss@34232
   111
        (Function_Split.split_all_equations ctxt compleqs)
krauss@34232
   112
krauss@34232
   113
      fun restore_spec thms =
krauss@34232
   114
        bnds ~~ take (length bnds) (unflat spliteqs thms)
krauss@33098
   115
krauss@34232
   116
      val spliteqs' = flat (take (length bnds) spliteqs)
krauss@34232
   117
      val fnames = map (fst o fst) fixes
krauss@34232
   118
      val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs'
krauss@33098
   119
krauss@34232
   120
      fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
krauss@34232
   121
        |> map (map snd)
krauss@33098
   122
krauss@33098
   123
krauss@34232
   124
      val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding
krauss@33098
   125
krauss@34232
   126
      (* using theorem names for case name currently disabled *)
krauss@34232
   127
      val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) 
krauss@34232
   128
        (bnds' ~~ spliteqs) |> flat
krauss@34232
   129
    in
krauss@34232
   130
      (flat spliteqs, restore_spec, sort, case_names)
krauss@34232
   131
    end
krauss@34232
   132
  else
krauss@34232
   133
    Function_Common.empty_preproc check_defs config ctxt fixes spec
krauss@33098
   134
krauss@33098
   135
val setup =
krauss@33099
   136
  Context.theory_map (Function_Common.set_preproc sequential_preproc)
krauss@33098
   137
krauss@33098
   138
krauss@41417
   139
val fun_config = FunctionConfig { sequential=true, default=NONE,
krauss@41846
   140
  domintros=false, partials=false }
krauss@33098
   141
krauss@36523
   142
fun gen_add_fun add fixes statements config lthy =
krauss@36523
   143
  let
krauss@36523
   144
    fun pat_completeness_auto ctxt =
krauss@36523
   145
      Pat_Completeness.pat_completeness_tac ctxt 1
krauss@36523
   146
      THEN auto_tac (clasimpset_of ctxt)
krauss@36523
   147
    fun prove_termination lthy =
krauss@36523
   148
      Function.prove_termination NONE
krauss@36523
   149
        (Function_Common.get_termination_prover lthy lthy) lthy
krauss@36523
   150
  in
krauss@36523
   151
    lthy
krauss@36523
   152
    |> add fixes statements config pat_completeness_auto |> snd
krauss@36523
   153
    |> Local_Theory.restore
krauss@36547
   154
    |> prove_termination |> snd
krauss@36523
   155
  end
krauss@33098
   156
krauss@36523
   157
val add_fun = gen_add_fun Function.add_function
krauss@36523
   158
val add_fun_cmd = gen_add_fun Function.add_function_cmd
krauss@33098
   159
krauss@33098
   160
krauss@33098
   161
krauss@33098
   162
val _ =
wenzelm@36960
   163
  Outer_Syntax.local_theory "fun" "define general recursive functions (short version)"
wenzelm@36960
   164
  Keyword.thy_decl
krauss@33099
   165
  (function_parser fun_config
wenzelm@36960
   166
     >> (fn ((config, fixes), statements) => add_fun_cmd fixes statements config))
krauss@33098
   167
krauss@33098
   168
end