src/HOL/Tools/Function/fun.ML
author wenzelm
Sun Oct 25 19:21:34 2009 +0100 (2009-10-25)
changeset 33171 292970b42770
parent 33101 8846318b52d0
child 33671 4b0f2599ed48
permissions -rw-r--r--
name space groups are identified by serial, not serial_string;
     1 (*  Title:      HOL/Tools/Function/fun.ML
     2     Author:     Alexander Krauss, TU Muenchen
     3 
     4 Sequential mode for function definitions
     5 Command "fun" for fully automated function definitions
     6 *)
     7 
     8 signature FUNCTION_FUN =
     9 sig
    10     val add_fun : Function_Common.function_config ->
    11       (binding * typ option * mixfix) list -> (Attrib.binding * term) list ->
    12       bool -> local_theory -> Proof.context
    13     val add_fun_cmd : Function_Common.function_config ->
    14       (binding * string option * mixfix) list -> (Attrib.binding * string) list ->
    15       bool -> local_theory -> Proof.context
    16 
    17     val setup : theory -> theory
    18 end
    19 
    20 structure Function_Fun : FUNCTION_FUN =
    21 struct
    22 
    23 open Function_Lib
    24 open Function_Common
    25 
    26 
    27 fun check_pats ctxt geq =
    28     let 
    29       fun err str = error (cat_lines ["Malformed definition:",
    30                                       str ^ " not allowed in sequential mode.",
    31                                       Syntax.string_of_term ctxt geq])
    32       val thy = ProofContext.theory_of ctxt
    33                 
    34       fun check_constr_pattern (Bound _) = ()
    35         | check_constr_pattern t =
    36           let
    37             val (hd, args) = strip_comb t
    38           in
    39             (((case Datatype.info_of_constr thy (dest_Const hd) of
    40                  SOME _ => ()
    41                | NONE => err "Non-constructor pattern")
    42               handle TERM ("dest_Const", _) => err "Non-constructor patterns");
    43              map check_constr_pattern args; 
    44              ())
    45           end
    46           
    47       val (fname, qs, gs, args, rhs) = split_def ctxt geq 
    48                                        
    49       val _ = if not (null gs) then err "Conditional equations" else ()
    50       val _ = map check_constr_pattern args
    51                   
    52                   (* just count occurrences to check linearity *)
    53       val _ = if fold (fold_aterms (fn Bound _ => Integer.add 1 | _ => I)) args 0 > length qs
    54               then err "Nonlinear patterns" else ()
    55     in
    56       ()
    57     end
    58     
    59 val by_pat_completeness_auto =
    60     Proof.global_future_terminal_proof
    61       (Method.Basic Pat_Completeness.pat_completeness,
    62        SOME (Method.Source_i (Args.src (("HOL.auto", []), Position.none))))
    63 
    64 fun termination_by method int =
    65     Function.termination_proof NONE
    66     #> Proof.global_future_terminal_proof (Method.Basic method, NONE) int
    67 
    68 fun mk_catchall fixes arity_of =
    69     let
    70       fun mk_eqn ((fname, fT), _) =
    71           let 
    72             val n = arity_of fname
    73             val (argTs, rT) = chop n (binder_types fT)
    74                                    |> apsnd (fn Ts => Ts ---> body_type fT) 
    75                               
    76             val qs = map Free (Name.invent_list [] "a" n ~~ argTs)
    77           in
    78             HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
    79                           Const ("HOL.undefined", rT))
    80               |> HOLogic.mk_Trueprop
    81               |> fold_rev Logic.all qs
    82           end
    83     in
    84       map mk_eqn fixes
    85     end
    86 
    87 fun add_catchall ctxt fixes spec =
    88   let val fqgars = map (split_def ctxt) spec
    89       val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
    90                      |> AList.lookup (op =) #> the
    91   in
    92     spec @ mk_catchall fixes arity_of
    93   end
    94 
    95 fun warn_if_redundant ctxt origs tss =
    96     let
    97         fun msg t = "Ignoring redundant equation: " ^ quote (Syntax.string_of_term ctxt t)
    98                     
    99         val (tss', _) = chop (length origs) tss
   100         fun check (t, []) = (warning (msg t); [])
   101           | check (t, s) = s
   102     in
   103         (map check (origs ~~ tss'); tss)
   104     end
   105 
   106 
   107 fun sequential_preproc (config as FunctionConfig {sequential, ...}) ctxt fixes spec =
   108       if sequential then
   109         let
   110           val (bnds, eqss) = split_list spec
   111                             
   112           val eqs = map the_single eqss
   113                     
   114           val feqs = eqs
   115                       |> tap (check_defs ctxt fixes) (* Standard checks *)
   116                       |> tap (map (check_pats ctxt))    (* More checks for sequential mode *)
   117 
   118           val compleqs = add_catchall ctxt fixes feqs   (* Completion *)
   119 
   120           val spliteqs = warn_if_redundant ctxt feqs
   121                            (Function_Split.split_all_equations ctxt compleqs)
   122 
   123           fun restore_spec thms =
   124               bnds ~~ Library.take (length bnds, Library.unflat spliteqs thms)
   125               
   126           val spliteqs' = flat (Library.take (length bnds, spliteqs))
   127           val fnames = map (fst o fst) fixes
   128           val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs'
   129 
   130           fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
   131                                        |> map (map snd)
   132 
   133 
   134           val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding
   135 
   136           (* using theorem names for case name currently disabled *)
   137           val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) 
   138                                      (bnds' ~~ spliteqs)
   139                            |> flat
   140         in
   141           (flat spliteqs, restore_spec, sort, case_names)
   142         end
   143       else
   144         Function_Common.empty_preproc check_defs config ctxt fixes spec
   145 
   146 val setup =
   147   Context.theory_map (Function_Common.set_preproc sequential_preproc)
   148 
   149 
   150 val fun_config = FunctionConfig { sequential=true, default="%x. undefined" (*FIXME dynamic scoping*), 
   151   domintros=false, partials=false, tailrec=false }
   152 
   153 fun gen_fun add config fixes statements int lthy =
   154   let val group = serial () in
   155     lthy
   156       |> LocalTheory.set_group group
   157       |> add fixes statements config
   158       |> by_pat_completeness_auto int
   159       |> LocalTheory.restore
   160       |> LocalTheory.set_group group
   161       |> termination_by (Function_Common.get_termination_prover lthy) int
   162   end;
   163 
   164 val add_fun = gen_fun Function.add_function
   165 val add_fun_cmd = gen_fun Function.add_function_cmd
   166 
   167 
   168 
   169 local structure P = OuterParse and K = OuterKeyword in
   170 
   171 val _ =
   172   OuterSyntax.local_theory' "fun" "define general recursive functions (short version)" K.thy_decl
   173   (function_parser fun_config
   174      >> (fn ((config, fixes), statements) => add_fun_cmd config fixes statements));
   175 
   176 end
   177 
   178 end