src/HOL/Tools/Function/fundef_datatype.ML
changeset 33270 320a1d67b9ae
parent 33269 3b7e2dbbd684
parent 33220 11a1af478dac
child 33271 7be66dee1a5a
equal deleted inserted replaced
33269:3b7e2dbbd684 33270:320a1d67b9ae
     1 (*  Title:      HOL/Tools/Function/fundef_datatype.ML
       
     2     Author:     Alexander Krauss, TU Muenchen
       
     3 
       
     4 A package for general recursive function definitions.
       
     5 A tactic to prove completeness of datatype patterns.
       
     6 *)
       
     7 
       
     8 signature FUNDEF_DATATYPE =
       
     9 sig
       
    10     val add_fun : FundefCommon.fundef_config ->
       
    11       (binding * typ option * mixfix) list -> (Attrib.binding * term) list ->
       
    12       bool -> local_theory -> Proof.context
       
    13     val add_fun_cmd : FundefCommon.fundef_config ->
       
    14       (binding * string option * mixfix) list -> (Attrib.binding * string) list ->
       
    15       bool -> local_theory -> Proof.context
       
    16 
       
    17     val setup : theory -> theory
       
    18 end
       
    19 
       
    20 structure FundefDatatype : FUNDEF_DATATYPE =
       
    21 struct
       
    22 
       
    23 open FundefLib
       
    24 open FundefCommon
       
    25 
       
    26 
       
    27 fun check_pats ctxt geq =
       
    28     let 
       
    29       fun err str = error (cat_lines ["Malformed definition:",
       
    30                                       str ^ " not allowed in sequential mode.",
       
    31                                       Syntax.string_of_term ctxt geq])
       
    32       val thy = ProofContext.theory_of ctxt
       
    33                 
       
    34       fun check_constr_pattern (Bound _) = ()
       
    35         | check_constr_pattern t =
       
    36           let
       
    37             val (hd, args) = strip_comb t
       
    38           in
       
    39             (((case Datatype.info_of_constr thy (dest_Const hd) of
       
    40                  SOME _ => ()
       
    41                | NONE => err "Non-constructor pattern")
       
    42               handle TERM ("dest_Const", _) => err "Non-constructor patterns");
       
    43              map check_constr_pattern args; 
       
    44              ())
       
    45           end
       
    46           
       
    47       val (fname, qs, gs, args, rhs) = split_def ctxt geq 
       
    48                                        
       
    49       val _ = if not (null gs) then err "Conditional equations" else ()
       
    50       val _ = map check_constr_pattern args
       
    51                   
       
    52                   (* just count occurrences to check linearity *)
       
    53       val _ = if fold (fold_aterms (fn Bound _ => Integer.add 1 | _ => I)) args 0 > length qs
       
    54               then err "Nonlinear patterns" else ()
       
    55     in
       
    56       ()
       
    57     end
       
    58     
       
    59 val by_pat_completeness_auto =
       
    60     Proof.global_future_terminal_proof
       
    61       (Method.Basic Pat_Completeness.pat_completeness,
       
    62        SOME (Method.Source_i (Args.src (("HOL.auto", []), Position.none))))
       
    63 
       
    64 fun termination_by method int =
       
    65     Fundef.termination_proof NONE
       
    66     #> Proof.global_future_terminal_proof (Method.Basic method, NONE) int
       
    67 
       
    68 fun mk_catchall fixes arity_of =
       
    69     let
       
    70       fun mk_eqn ((fname, fT), _) =
       
    71           let 
       
    72             val n = arity_of fname
       
    73             val (argTs, rT) = chop n (binder_types fT)
       
    74                                    |> apsnd (fn Ts => Ts ---> body_type fT) 
       
    75                               
       
    76             val qs = map Free (Name.invent_list [] "a" n ~~ argTs)
       
    77           in
       
    78             HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
       
    79                           Const ("HOL.undefined", rT))
       
    80               |> HOLogic.mk_Trueprop
       
    81               |> fold_rev Logic.all qs
       
    82           end
       
    83     in
       
    84       map mk_eqn fixes
       
    85     end
       
    86 
       
    87 fun add_catchall ctxt fixes spec =
       
    88   let val fqgars = map (split_def ctxt) spec
       
    89       val arity_of = map (fn (fname,_,_,args,_) => (fname, length args)) fqgars
       
    90                      |> AList.lookup (op =) #> the
       
    91   in
       
    92     spec @ mk_catchall fixes arity_of
       
    93   end
       
    94 
       
    95 fun warn_if_redundant ctxt origs tss =
       
    96     let
       
    97         fun msg t = "Ignoring redundant equation: " ^ quote (Syntax.string_of_term ctxt t)
       
    98                     
       
    99         val (tss', _) = chop (length origs) tss
       
   100         fun check (t, []) = (warning (msg t); [])
       
   101           | check (t, s) = s
       
   102     in
       
   103         (map check (origs ~~ tss'); tss)
       
   104     end
       
   105 
       
   106 
       
   107 fun sequential_preproc (config as FundefConfig {sequential, ...}) ctxt fixes spec =
       
   108       if sequential then
       
   109         let
       
   110           val (bnds, eqss) = split_list spec
       
   111                             
       
   112           val eqs = map the_single eqss
       
   113                     
       
   114           val feqs = eqs
       
   115                       |> tap (check_defs ctxt fixes) (* Standard checks *)
       
   116                       |> tap (map (check_pats ctxt))    (* More checks for sequential mode *)
       
   117 
       
   118           val compleqs = add_catchall ctxt fixes feqs   (* Completion *)
       
   119 
       
   120           val spliteqs = warn_if_redundant ctxt feqs
       
   121                            (FundefSplit.split_all_equations ctxt compleqs)
       
   122 
       
   123           fun restore_spec thms =
       
   124               bnds ~~ Library.take (length bnds, Library.unflat spliteqs thms)
       
   125               
       
   126           val spliteqs' = flat (Library.take (length bnds, spliteqs))
       
   127           val fnames = map (fst o fst) fixes
       
   128           val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs'
       
   129 
       
   130           fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
       
   131                                        |> map (map snd)
       
   132 
       
   133 
       
   134           val bnds' = bnds @ replicate (length spliteqs - length bnds) Attrib.empty_binding
       
   135 
       
   136           (* using theorem names for case name currently disabled *)
       
   137           val case_names = map_index (fn (i, (_, es)) => mk_case_names i "" (length es)) 
       
   138                                      (bnds' ~~ spliteqs)
       
   139                            |> flat
       
   140         in
       
   141           (flat spliteqs, restore_spec, sort, case_names)
       
   142         end
       
   143       else
       
   144         FundefCommon.empty_preproc check_defs config ctxt fixes spec
       
   145 
       
   146 val setup =
       
   147   Context.theory_map (FundefCommon.set_preproc sequential_preproc)
       
   148 
       
   149 
       
   150 val fun_config = FundefConfig { sequential=true, default="%x. undefined" (*FIXME dynamic scoping*), 
       
   151                                 domintros=false, tailrec=false }
       
   152 
       
   153 fun gen_fun add config fixes statements int lthy =
       
   154   let val group = serial_string () in
       
   155     lthy
       
   156       |> LocalTheory.set_group group
       
   157       |> add fixes statements config
       
   158       |> by_pat_completeness_auto int
       
   159       |> LocalTheory.restore
       
   160       |> LocalTheory.set_group group
       
   161       |> termination_by (FundefCommon.get_termination_prover lthy) int
       
   162   end;
       
   163 
       
   164 val add_fun = gen_fun Fundef.add_fundef
       
   165 val add_fun_cmd = gen_fun Fundef.add_fundef_cmd
       
   166 
       
   167 
       
   168 
       
   169 local structure P = OuterParse and K = OuterKeyword in
       
   170 
       
   171 val _ =
       
   172   OuterSyntax.local_theory' "fun" "define general recursive functions (short version)" K.thy_decl
       
   173   (fundef_parser fun_config
       
   174      >> (fn ((config, fixes), statements) => add_fun_cmd config fixes statements));
       
   175 
       
   176 end
       
   177 
       
   178 end