src/HOL/Tools/Function/fun.ML
author krauss
Sat Jan 02 23:18:58 2010 +0100 (2010-01-02)
changeset 34232 36a2a3029fd3
parent 33957 e9afca2118d4
child 36519 46bf776a81e0
permissions -rw-r--r--
new year's resolution: reindented code in function package
krauss@33098
     1
(*  Title:      HOL/Tools/Function/fun.ML
krauss@33098
     2
    Author:     Alexander Krauss, TU Muenchen
krauss@33098
     3
krauss@33098
     4
Sequential mode for function definitions
krauss@33098
     5
Command "fun" for fully automated function definitions
krauss@33098
     6
*)
krauss@33098
     7
krauss@33098
     8
signature FUNCTION_FUN =
krauss@33098
     9
sig
krauss@34232
    10
  val add_fun : Function_Common.function_config ->
krauss@34232
    11
    (binding * typ option * mixfix) list -> (Attrib.binding * term) list ->
krauss@34232
    12
    bool -> local_theory -> Proof.context
krauss@34232
    13
  val add_fun_cmd : Function_Common.function_config ->
krauss@34232
    14
    (binding * string option * mixfix) list -> (Attrib.binding * string) list ->
krauss@34232
    15
    bool -> local_theory -> Proof.context
krauss@33098
    16
krauss@34232
    17
  val setup : theory -> theory
krauss@33098
    18
end
krauss@33098
    19
krauss@33098
    20
structure Function_Fun : FUNCTION_FUN =
krauss@33098
    21
struct
krauss@33098
    22
krauss@33099
    23
open Function_Lib
krauss@33099
    24
open Function_Common
krauss@33098
    25
krauss@33098
    26
krauss@33098
    27
fun check_pats ctxt geq =
krauss@34232
    28
  let
krauss@34232
    29
    fun err str = error (cat_lines ["Malformed definition:",
krauss@34232
    30
      str ^ " not allowed in sequential mode.",
krauss@34232
    31
      Syntax.string_of_term ctxt geq])
krauss@34232
    32
    val thy = ProofContext.theory_of ctxt
krauss@34232
    33
krauss@34232
    34
    fun check_constr_pattern (Bound _) = ()
krauss@34232
    35
      | check_constr_pattern t =
krauss@34232
    36
      let
krauss@34232
    37
        val (hd, args) = strip_comb t
krauss@34232
    38
      in
krauss@34232
    39
        (((case Datatype.info_of_constr thy (dest_Const hd) of
krauss@34232
    40
             SOME _ => ()
krauss@34232
    41
           | NONE => err "Non-constructor pattern")
krauss@34232
    42
          handle TERM ("dest_Const", _) => err "Non-constructor patterns");
krauss@34232
    43
         map check_constr_pattern args;
krauss@34232
    44
         ())
krauss@34232
    45
      end
krauss@34232
    46
krauss@34232
    47
    val (_, qs, gs, args, _) = split_def ctxt geq
krauss@34232
    48
krauss@34232
    49
    val _ = if not (null gs) then err "Conditional equations" else ()
krauss@34232
    50
    val _ = map check_constr_pattern args
krauss@34232
    51
krauss@34232
    52
    (* just count occurrences to check linearity *)
krauss@34232
    53
    val _ = if fold (fold_aterms (fn Bound _ => Integer.add 1 | _ => I)) args 0 > length qs
krauss@34232
    54
      then err "Nonlinear patterns" else ()
krauss@34232
    55
  in
krauss@34232
    56
    ()
krauss@34232
    57
  end
krauss@34232
    58
krauss@33098
    59
val by_pat_completeness_auto =
krauss@34232
    60
  Proof.global_future_terminal_proof
krauss@34232
    61
    (Method.Basic Pat_Completeness.pat_completeness,
krauss@34232
    62
     SOME (Method.Source_i (Args.src (("HOL.auto", []), Position.none))))
krauss@33098
    63
krauss@33098
    64
fun termination_by method int =
krauss@34232
    65
  Function.termination_proof NONE
krauss@34232
    66
  #> Proof.global_future_terminal_proof (Method.Basic method, NONE) int
krauss@33098
    67
krauss@33098
    68
fun mk_catchall fixes arity_of =
krauss@34232
    69
  let
krauss@34232
    70
    fun mk_eqn ((fname, fT), _) =
krauss@34232
    71
      let
krauss@34232
    72
        val n = arity_of fname
krauss@34232
    73
        val (argTs, rT) = chop n (binder_types fT)
krauss@34232
    74
          |> apsnd (fn Ts => Ts ---> body_type fT)
krauss@34232
    75
krauss@34232
    76
        val qs = map Free (Name.invent_list [] "a" n ~~ argTs)
krauss@34232
    77
      in
krauss@34232
    78
        HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
krauss@34232
    79
          Const ("HOL.undefined", rT))
krauss@34232
    80
        |> HOLogic.mk_Trueprop
krauss@34232
    81
        |> fold_rev Logic.all qs
krauss@34232
    82
      end
krauss@34232
    83
  in
krauss@34232
    84
    map mk_eqn fixes
krauss@34232
    85
  end
krauss@33098
    86
krauss@33098
    87
fun add_catchall ctxt fixes spec =
krauss@33098
    88
  let val fqgars = map (split_def ctxt) spec
krauss@33098
    89
      val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
krauss@33098
    90
                     |> AList.lookup (op =) #> the
krauss@33098
    91
  in
krauss@33098
    92
    spec @ mk_catchall fixes arity_of
krauss@33098
    93
  end
krauss@33098
    94
krauss@33098
    95
fun warn_if_redundant ctxt origs tss =
krauss@34232
    96
  let
krauss@34232
    97
    fun msg t = "Ignoring redundant equation: " ^ quote (Syntax.string_of_term ctxt t)
krauss@33098
    98
krauss@34232
    99
    val (tss', _) = chop (length origs) tss
krauss@34232
   100
    fun check (t, []) = (warning (msg t); [])
krauss@34232
   101
      | check (t, s) = s
krauss@34232
   102
  in
krauss@34232
   103
    (map check (origs ~~ tss'); tss)
krauss@34232
   104
  end
krauss@33098
   105
krauss@33099
   106
fun sequential_preproc (config as FunctionConfig {sequential, ...}) ctxt fixes spec =
krauss@34232
   107
  if sequential then
krauss@34232
   108
    let
krauss@34232
   109
      val (bnds, eqss) = split_list spec
krauss@34232
   110
krauss@34232
   111
      val eqs = map the_single eqss
krauss@33098
   112
krauss@34232
   113
      val feqs = eqs
krauss@34232
   114
        |> tap (check_defs ctxt fixes) (* Standard checks *)
krauss@34232
   115
        |> tap (map (check_pats ctxt)) (* More checks for sequential mode *)
krauss@34232
   116
krauss@34232
   117
      val compleqs = add_catchall ctxt fixes feqs (* Completion *)
krauss@33098
   118
krauss@34232
   119
      val spliteqs = warn_if_redundant ctxt feqs
krauss@34232
   120
        (Function_Split.split_all_equations ctxt compleqs)
krauss@34232
   121
krauss@34232
   122
      fun restore_spec thms =
krauss@34232
   123
        bnds ~~ take (length bnds) (unflat spliteqs thms)
krauss@33098
   124
krauss@34232
   125
      val spliteqs' = flat (take (length bnds) spliteqs)
krauss@34232
   126
      val fnames = map (fst o fst) fixes
krauss@34232
   127
      val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs'
krauss@33098
   128
krauss@34232
   129
      fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
krauss@34232
   130
        |> map (map snd)
krauss@33098
   131
krauss@33098
   132
krauss@34232
   133
      val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding
krauss@33098
   134
krauss@34232
   135
      (* using theorem names for case name currently disabled *)
krauss@34232
   136
      val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) 
krauss@34232
   137
        (bnds' ~~ spliteqs) |> flat
krauss@34232
   138
    in
krauss@34232
   139
      (flat spliteqs, restore_spec, sort, case_names)
krauss@34232
   140
    end
krauss@34232
   141
  else
krauss@34232
   142
    Function_Common.empty_preproc check_defs config ctxt fixes spec
krauss@33098
   143
krauss@33098
   144
val setup =
krauss@33099
   145
  Context.theory_map (Function_Common.set_preproc sequential_preproc)
krauss@33098
   146
krauss@33098
   147
krauss@33099
   148
val fun_config = FunctionConfig { sequential=true, default="%x. undefined" (*FIXME dynamic scoping*), 
krauss@33101
   149
  domintros=false, partials=false, tailrec=false }
krauss@33098
   150
krauss@33098
   151
fun gen_fun add config fixes statements int lthy =
wenzelm@33726
   152
  lthy
krauss@34232
   153
  |> add fixes statements config
krauss@34232
   154
  |> by_pat_completeness_auto int
krauss@34232
   155
  |> Local_Theory.restore
krauss@34232
   156
  |> termination_by (Function_Common.get_termination_prover lthy) int
krauss@33098
   157
krauss@33099
   158
val add_fun = gen_fun Function.add_function
krauss@33099
   159
val add_fun_cmd = gen_fun Function.add_function_cmd
krauss@33098
   160
krauss@33098
   161
krauss@33098
   162
krauss@33098
   163
local structure P = OuterParse and K = OuterKeyword in
krauss@33098
   164
krauss@33098
   165
val _ =
krauss@33098
   166
  OuterSyntax.local_theory' "fun" "define general recursive functions (short version)" K.thy_decl
krauss@33099
   167
  (function_parser fun_config
krauss@33098
   168
     >> (fn ((config, fixes), statements) => add_fun_cmd config fixes statements));
krauss@33098
   169
krauss@33098
   170
end
krauss@33098
   171
krauss@33098
   172
end