src/HOL/Tools/Function/fun.ML
author wenzelm
Sun Oct 25 19:21:34 2009 +0100 (2009-10-25)
changeset 33171 292970b42770
parent 33101 8846318b52d0
child 33671 4b0f2599ed48
permissions -rw-r--r--
name space groups are identified by serial, not serial_string;
krauss@33098
     1
(*  Title:      HOL/Tools/Function/fun.ML
krauss@33098
     2
    Author:     Alexander Krauss, TU Muenchen
krauss@33098
     3
krauss@33098
     4
Sequential mode for function definitions
krauss@33098
     5
Command "fun" for fully automated function definitions
krauss@33098
     6
*)
krauss@33098
     7
krauss@33098
     8
signature FUNCTION_FUN =
krauss@33098
     9
sig
krauss@33099
    10
    val add_fun : Function_Common.function_config ->
krauss@33098
    11
      (binding * typ option * mixfix) list -> (Attrib.binding * term) list ->
krauss@33098
    12
      bool -> local_theory -> Proof.context
krauss@33099
    13
    val add_fun_cmd : Function_Common.function_config ->
krauss@33098
    14
      (binding * string option * mixfix) list -> (Attrib.binding * string) list ->
krauss@33098
    15
      bool -> local_theory -> Proof.context
krauss@33098
    16
krauss@33098
    17
    val setup : theory -> theory
krauss@33098
    18
end
krauss@33098
    19
krauss@33098
    20
structure Function_Fun : FUNCTION_FUN =
krauss@33098
    21
struct
krauss@33098
    22
krauss@33099
    23
open Function_Lib
krauss@33099
    24
open Function_Common
krauss@33098
    25
krauss@33098
    26
krauss@33098
    27
fun check_pats ctxt geq =
krauss@33098
    28
    let 
krauss@33098
    29
      fun err str = error (cat_lines ["Malformed definition:",
krauss@33098
    30
                                      str ^ " not allowed in sequential mode.",
krauss@33098
    31
                                      Syntax.string_of_term ctxt geq])
krauss@33098
    32
      val thy = ProofContext.theory_of ctxt
krauss@33098
    33
                
krauss@33098
    34
      fun check_constr_pattern (Bound _) = ()
krauss@33098
    35
        | check_constr_pattern t =
krauss@33098
    36
          let
krauss@33098
    37
            val (hd, args) = strip_comb t
krauss@33098
    38
          in
krauss@33098
    39
            (((case Datatype.info_of_constr thy (dest_Const hd) of
krauss@33098
    40
                 SOME _ => ()
krauss@33098
    41
               | NONE => err "Non-constructor pattern")
krauss@33098
    42
              handle TERM ("dest_Const", _) => err "Non-constructor patterns");
krauss@33098
    43
             map check_constr_pattern args; 
krauss@33098
    44
             ())
krauss@33098
    45
          end
krauss@33098
    46
          
krauss@33098
    47
      val (fname, qs, gs, args, rhs) = split_def ctxt geq 
krauss@33098
    48
                                       
krauss@33098
    49
      val _ = if not (null gs) then err "Conditional equations" else ()
krauss@33098
    50
      val _ = map check_constr_pattern args
krauss@33098
    51
                  
krauss@33098
    52
                  (* just count occurrences to check linearity *)
krauss@33098
    53
      val _ = if fold (fold_aterms (fn Bound _ => Integer.add 1 | _ => I)) args 0 > length qs
krauss@33098
    54
              then err "Nonlinear patterns" else ()
krauss@33098
    55
    in
krauss@33098
    56
      ()
krauss@33098
    57
    end
krauss@33098
    58
    
krauss@33098
    59
val by_pat_completeness_auto =
krauss@33098
    60
    Proof.global_future_terminal_proof
krauss@33098
    61
      (Method.Basic Pat_Completeness.pat_completeness,
krauss@33098
    62
       SOME (Method.Source_i (Args.src (("HOL.auto", []), Position.none))))
krauss@33098
    63
krauss@33098
    64
fun termination_by method int =
krauss@33099
    65
    Function.termination_proof NONE
krauss@33098
    66
    #> Proof.global_future_terminal_proof (Method.Basic method, NONE) int
krauss@33098
    67
krauss@33098
    68
fun mk_catchall fixes arity_of =
krauss@33098
    69
    let
krauss@33098
    70
      fun mk_eqn ((fname, fT), _) =
krauss@33098
    71
          let 
krauss@33098
    72
            val n = arity_of fname
krauss@33098
    73
            val (argTs, rT) = chop n (binder_types fT)
krauss@33098
    74
                                   |> apsnd (fn Ts => Ts ---> body_type fT) 
krauss@33098
    75
                              
krauss@33098
    76
            val qs = map Free (Name.invent_list [] "a" n ~~ argTs)
krauss@33098
    77
          in
krauss@33098
    78
            HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
krauss@33098
    79
                          Const ("HOL.undefined", rT))
krauss@33098
    80
              |> HOLogic.mk_Trueprop
krauss@33098
    81
              |> fold_rev Logic.all qs
krauss@33098
    82
          end
krauss@33098
    83
    in
krauss@33098
    84
      map mk_eqn fixes
krauss@33098
    85
    end
krauss@33098
    86
krauss@33098
    87
fun add_catchall ctxt fixes spec =
krauss@33098
    88
  let val fqgars = map (split_def ctxt) spec
krauss@33098
    89
      val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
krauss@33098
    90
                     |> AList.lookup (op =) #> the
krauss@33098
    91
  in
krauss@33098
    92
    spec @ mk_catchall fixes arity_of
krauss@33098
    93
  end
krauss@33098
    94
krauss@33098
    95
fun warn_if_redundant ctxt origs tss =
krauss@33098
    96
    let
krauss@33098
    97
        fun msg t = "Ignoring redundant equation: " ^ quote (Syntax.string_of_term ctxt t)
krauss@33098
    98
                    
krauss@33098
    99
        val (tss', _) = chop (length origs) tss
krauss@33098
   100
        fun check (t, []) = (warning (msg t); [])
krauss@33098
   101
          | check (t, s) = s
krauss@33098
   102
    in
krauss@33098
   103
        (map check (origs ~~ tss'); tss)
krauss@33098
   104
    end
krauss@33098
   105
krauss@33098
   106
krauss@33099
   107
fun sequential_preproc (config as FunctionConfig {sequential, ...}) ctxt fixes spec =
krauss@33098
   108
      if sequential then
krauss@33098
   109
        let
krauss@33098
   110
          val (bnds, eqss) = split_list spec
krauss@33098
   111
                            
krauss@33098
   112
          val eqs = map the_single eqss
krauss@33098
   113
                    
krauss@33098
   114
          val feqs = eqs
krauss@33098
   115
                      |> tap (check_defs ctxt fixes) (* Standard checks *)
krauss@33098
   116
                      |> tap (map (check_pats ctxt))    (* More checks for sequential mode *)
krauss@33098
   117
krauss@33098
   118
          val compleqs = add_catchall ctxt fixes feqs   (* Completion *)
krauss@33098
   119
krauss@33098
   120
          val spliteqs = warn_if_redundant ctxt feqs
krauss@33099
   121
                           (Function_Split.split_all_equations ctxt compleqs)
krauss@33098
   122
krauss@33098
   123
          fun restore_spec thms =
krauss@33098
   124
              bnds ~~ Library.take (length bnds, Library.unflat spliteqs thms)
krauss@33098
   125
              
krauss@33098
   126
          val spliteqs' = flat (Library.take (length bnds, spliteqs))
krauss@33098
   127
          val fnames = map (fst o fst) fixes
krauss@33098
   128
          val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs'
krauss@33098
   129
krauss@33098
   130
          fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
krauss@33098
   131
                                       |> map (map snd)
krauss@33098
   132
krauss@33098
   133
krauss@33098
   134
          val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding
krauss@33098
   135
krauss@33098
   136
          (* using theorem names for case name currently disabled *)
krauss@33098
   137
          val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) 
krauss@33098
   138
                                     (bnds' ~~ spliteqs)
krauss@33098
   139
                           |> flat
krauss@33098
   140
        in
krauss@33098
   141
          (flat spliteqs, restore_spec, sort, case_names)
krauss@33098
   142
        end
krauss@33098
   143
      else
krauss@33099
   144
        Function_Common.empty_preproc check_defs config ctxt fixes spec
krauss@33098
   145
krauss@33098
   146
val setup =
krauss@33099
   147
  Context.theory_map (Function_Common.set_preproc sequential_preproc)
krauss@33098
   148
krauss@33098
   149
krauss@33099
   150
val fun_config = FunctionConfig { sequential=true, default="%x. undefined" (*FIXME dynamic scoping*), 
krauss@33101
   151
  domintros=false, partials=false, tailrec=false }
krauss@33098
   152
krauss@33098
   153
fun gen_fun add config fixes statements int lthy =
wenzelm@33171
   154
  let val group = serial () in
krauss@33098
   155
    lthy
krauss@33098
   156
      |> LocalTheory.set_group group
krauss@33098
   157
      |> add fixes statements config
krauss@33098
   158
      |> by_pat_completeness_auto int
krauss@33098
   159
      |> LocalTheory.restore
krauss@33098
   160
      |> LocalTheory.set_group group
krauss@33099
   161
      |> termination_by (Function_Common.get_termination_prover lthy) int
krauss@33098
   162
  end;
krauss@33098
   163
krauss@33099
   164
val add_fun = gen_fun Function.add_function
krauss@33099
   165
val add_fun_cmd = gen_fun Function.add_function_cmd
krauss@33098
   166
krauss@33098
   167
krauss@33098
   168
krauss@33098
   169
local structure P = OuterParse and K = OuterKeyword in
krauss@33098
   170
krauss@33098
   171
val _ =
krauss@33098
   172
  OuterSyntax.local_theory' "fun" "define general recursive functions (short version)" K.thy_decl
krauss@33099
   173
  (function_parser fun_config
krauss@33098
   174
     >> (fn ((config, fixes), statements) => add_fun_cmd config fixes statements));
krauss@33098
   175
krauss@33098
   176
end
krauss@33098
   177
krauss@33098
   178
end