src/HOL/Tools/function_package/fundef_datatype.ML
author wenzelm
Sat Oct 06 16:50:04 2007 +0200 (2007-10-06)
changeset 24867 e5b55d7be9bb
parent 24466 619f78b717cb
child 24920 2a45e400fdad
permissions -rw-r--r--
simplified interfaces for outer syntax;
     1 (*  Title:      HOL/Tools/function_package/fundef_datatype.ML
     2     ID:         $Id$
     3     Author:     Alexander Krauss, TU Muenchen
     4 
     5 A package for general recursive function definitions.
     6 A tactic to prove completeness of datatype patterns.
     7 *)
     8 
     9 signature FUNDEF_DATATYPE =
    10 sig
    11     val pat_complete_tac: int -> tactic
    12 
    13     val pat_completeness : method
    14     val setup : theory -> theory
    15 end
    16 
    17 structure FundefDatatype: FUNDEF_DATATYPE =
    18 struct
    19 
    20 open FundefLib
    21 open FundefCommon
    22 
    23 
    24 fun check_pats ctxt geq =
    25     let 
    26       fun err str = error (cat_lines ["Malformed definition:",
    27                                       str ^ " not allowed in sequential mode.",
    28                                       ProofContext.string_of_term ctxt geq])
    29       val thy = ProofContext.theory_of ctxt
    30                 
    31       fun check_constr_pattern (Bound _) = ()
    32         | check_constr_pattern t =
    33           let
    34             val (hd, args) = strip_comb t
    35           in
    36             (((case DatatypePackage.datatype_of_constr thy (fst (dest_Const hd)) of
    37                  SOME _ => ()
    38                | NONE => err "Non-constructor pattern")
    39               handle TERM ("dest_Const", _) => err "Non-constructor patterns");
    40              map check_constr_pattern args; 
    41              ())
    42           end
    43           
    44       val (fname, qs, gs, args, rhs) = split_def ctxt geq 
    45                                        
    46       val _ = if not (null gs) then err "Conditional equations" else ()
    47       val _ = map check_constr_pattern args
    48                   
    49                   (* just count occurrences to check linearity *)
    50       val _ = if fold (fold_aterms (fn Bound _ => curry (op +) 1 | _ => I)) args 0 > length qs
    51               then err "Nonlinear patterns" else ()
    52     in
    53       ()
    54     end
    55     
    56 
    57 fun mk_argvar i T = Free ("_av" ^ (string_of_int i), T)
    58 fun mk_patvar i T = Free ("_pv" ^ (string_of_int i), T)
    59 
    60 fun inst_free var inst thm =
    61     forall_elim inst (forall_intr var thm)
    62 
    63 
    64 fun inst_case_thm thy x P thm =
    65     let
    66         val [Pv, xv] = term_vars (prop_of thm)
    67     in
    68         cterm_instantiate [(cterm_of thy xv, cterm_of thy x), (cterm_of thy Pv, cterm_of thy P)] thm
    69     end
    70 
    71 
    72 fun invent_vars constr i =
    73     let
    74         val Ts = binder_types (fastype_of constr)
    75         val j = i + length Ts
    76         val is = i upto (j - 1)
    77         val avs = map2 mk_argvar is Ts
    78         val pvs = map2 mk_patvar is Ts
    79     in
    80         (avs, pvs, j)
    81     end
    82 
    83 
    84 fun filter_pats thy cons pvars [] = []
    85   | filter_pats thy cons pvars (([], thm) :: pts) = raise Match
    86   | filter_pats thy cons pvars ((pat :: pats, thm) :: pts) =
    87     case pat of
    88         Free _ => let val inst = list_comb (cons, pvars)
    89                  in (inst :: pats, inst_free (cterm_of thy pat) (cterm_of thy inst) thm)
    90                     :: (filter_pats thy cons pvars pts) end
    91       | _ => if fst (strip_comb pat) = cons
    92              then (pat :: pats, thm) :: (filter_pats thy cons pvars pts)
    93              else filter_pats thy cons pvars pts
    94 
    95 
    96 fun inst_constrs_of thy (T as Type (name, _)) =
    97         map (fn (Cn,CT) => Envir.subst_TVars (Sign.typ_match thy (body_type CT, T) Vartab.empty) (Const (Cn, CT)))
    98             (the (DatatypePackage.get_datatype_constrs thy name))
    99   | inst_constrs_of thy _ = raise Match
   100 
   101 
   102 fun transform_pat thy avars c_assum ([] , thm) = raise Match
   103   | transform_pat thy avars c_assum (pat :: pats, thm) =
   104     let
   105         val (_, subps) = strip_comb pat
   106         val eqs = map (cterm_of thy o HOLogic.mk_Trueprop o HOLogic.mk_eq) (avars ~~ subps)
   107         val a_eqs = map assume eqs
   108         val c_eq_pat = simplify (HOL_basic_ss addsimps a_eqs) c_assum
   109     in
   110         (subps @ pats, fold_rev implies_intr eqs
   111                                 (implies_elim thm c_eq_pat))
   112     end
   113 
   114 
   115 exception COMPLETENESS
   116 
   117 fun constr_case thy P idx (v :: vs) pats cons =
   118     let
   119         val (avars, pvars, newidx) = invent_vars cons idx
   120         val c_hyp = cterm_of thy (HOLogic.mk_Trueprop (HOLogic.mk_eq (v, list_comb (cons, avars))))
   121         val c_assum = assume c_hyp
   122         val newpats = map (transform_pat thy avars c_assum) (filter_pats thy cons pvars pats)
   123     in
   124         o_alg thy P newidx (avars @ vs) newpats
   125               |> implies_intr c_hyp
   126               |> fold_rev (forall_intr o cterm_of thy) avars
   127     end
   128   | constr_case _ _ _ _ _ _ = raise Match
   129 and o_alg thy P idx [] (([], Pthm) :: _)  = Pthm
   130   | o_alg thy P idx (v :: vs) [] = raise COMPLETENESS
   131   | o_alg thy P idx (v :: vs) pts =
   132     if forall (is_Free o hd o fst) pts (* Var case *)
   133     then o_alg thy P idx vs (map (fn (pv :: pats, thm) =>
   134                                (pats, refl RS (inst_free (cterm_of thy pv) (cterm_of thy v) thm))) pts)
   135     else (* Cons case *)
   136          let
   137              val T = fastype_of v
   138              val (tname, _) = dest_Type T
   139              val {exhaustion=case_thm, ...} = DatatypePackage.the_datatype thy tname
   140              val constrs = inst_constrs_of thy T
   141              val c_cases = map (constr_case thy P idx (v :: vs) pts) constrs
   142          in
   143              inst_case_thm thy v P case_thm
   144                            |> fold (curry op COMP) c_cases
   145          end
   146   | o_alg _ _ _ _ _ = raise Match
   147 
   148 
   149 fun prove_completeness thy x P qss pats =
   150     let
   151         fun mk_assum qs pat = Logic.mk_implies (HOLogic.mk_Trueprop (HOLogic.mk_eq (x,pat)),
   152                                                 HOLogic.mk_Trueprop P)
   153                                                |> fold_rev mk_forall qs
   154                                                |> cterm_of thy
   155 
   156         val hyps = map2 mk_assum qss pats
   157 
   158         fun inst_hyps hyp qs = fold (forall_elim o cterm_of thy) qs (assume hyp)
   159 
   160         val assums = map2 inst_hyps hyps qss
   161     in
   162         o_alg thy P 2 [x] (map2 (pair o single) pats assums)
   163               |> fold_rev implies_intr hyps
   164     end
   165 
   166 
   167 
   168 fun pat_complete_tac i thm =
   169     let
   170       val thy = theory_of_thm thm
   171 
   172         val subgoal = nth (prems_of thm) (i - 1)   (* FIXME SUBGOAL tactical *)
   173 
   174         val ([P, x], subgf) = dest_all_all subgoal
   175 
   176         val assums = Logic.strip_imp_prems subgf
   177 
   178         fun pat_of assum =
   179             let
   180                 val (qs, imp) = dest_all_all assum
   181             in
   182                 case Logic.dest_implies imp of
   183                     (_ $ (_ $ _ $ pat), _) => (qs, pat)
   184                   | _ => raise COMPLETENESS
   185             end
   186 
   187         val (qss, pats) = split_list (map pat_of assums)
   188 
   189         val complete_thm = prove_completeness thy x P qss pats
   190                                               |> forall_intr (cterm_of thy x)
   191                                               |> forall_intr (cterm_of thy P)
   192     in
   193         Seq.single (Drule.compose_single(complete_thm, i, thm))
   194     end
   195     handle COMPLETENESS => Seq.empty
   196 
   197 
   198 val pat_completeness = Method.SIMPLE_METHOD' pat_complete_tac
   199 
   200 val by_pat_completeness_simp =
   201     Proof.global_terminal_proof
   202       (Method.Basic (K pat_completeness, Position.none),
   203        SOME (Method.Source_i (Args.src (("HOL.auto", []), Position.none))))
   204 
   205 val termination_by_lexicographic_order =
   206     FundefPackage.setup_termination_proof NONE
   207     #> Proof.global_terminal_proof
   208       (Method.Basic (LexicographicOrder.lexicographic_order [], Position.none), NONE)
   209 
   210 fun mk_catchall fixes arities =
   211     let
   212       fun mk_eqn ((fname, fT), _) =
   213           let 
   214             val n = the (Symtab.lookup arities fname)
   215             val (argTs, rT) = chop n (binder_types fT)
   216                                    |> apsnd (fn Ts => Ts ---> body_type fT) 
   217                               
   218             val qs = map Free (Name.invent_list [] "a" n ~~ argTs)
   219           in
   220             HOLogic.mk_eq(list_comb (Free (fname, fT), qs),
   221                           Const ("HOL.undefined", rT))
   222               |> HOLogic.mk_Trueprop
   223               |> fold_rev mk_forall qs
   224           end
   225     in
   226       map mk_eqn fixes
   227     end
   228 
   229 fun add_catchall ctxt fixes spec =
   230     let 
   231       val catchalls = mk_catchall fixes (mk_arities (map (split_def ctxt) (map snd spec)))
   232     in
   233       spec @ map (pair true) catchalls
   234     end
   235 
   236 fun warn_if_redundant ctxt origs tss =
   237     let
   238         fun msg t = "Ignoring redundant equation: " ^ quote (ProofContext.string_of_term ctxt t)
   239                     
   240         val (tss', _) = chop (length origs) tss
   241         fun check ((_, t), []) = (Output.warning (msg t); [])
   242           | check ((_, t), s) = s
   243     in
   244         (map check (origs ~~ tss'); tss)
   245     end
   246 
   247 
   248 fun sequential_preproc (config as FundefConfig {sequential, ...}) flags ctxt fixes spec =
   249     let
   250       val enabled = sequential orelse exists I flags
   251     in 
   252       if enabled then
   253         let
   254           val flags' = if sequential then map (K true) flags else flags
   255 
   256           val (nas, eqss) = split_list spec
   257                             
   258           val eqs = map the_single eqss
   259                     
   260           val feqs = eqs
   261                            |> tap (check_defs ctxt fixes) (* Standard checks *)
   262                            |> tap (map (check_pats ctxt))    (* More checks for sequential mode *)
   263                            |> curry op ~~ flags'
   264 
   265     val compleqs = add_catchall ctxt fixes feqs   (* Completion *)
   266 
   267     val spliteqs = warn_if_redundant ctxt feqs
   268              (FundefSplit.split_some_equations ctxt compleqs)
   269 
   270           fun restore_spec thms =
   271               nas ~~ Library.take (length nas, Library.unflat spliteqs thms)
   272               
   273           val spliteqs' = flat (Library.take (length nas, spliteqs))
   274           val fnames = map (fst o fst) fixes
   275           val indices = map (fn eq => find_index (curry op = (fname_of eq)) fnames) spliteqs'
   276 
   277           fun sort xs = partition_list (fn i => fn (j,_) => i = j) 0 (length fnames - 1) (indices ~~ xs)
   278                                        |> map (map snd)
   279 
   280         in
   281           (flat spliteqs, restore_spec, sort)
   282         end
   283       else
   284         FundefCommon.empty_preproc check_defs config flags ctxt fixes spec
   285     end
   286 
   287 val setup =
   288     Method.add_methods [("pat_completeness", Method.no_args pat_completeness, 
   289                          "Completeness prover for datatype patterns")]
   290     #> Context.theory_map (FundefCommon.set_preproc sequential_preproc)
   291 
   292 
   293 val fun_config = FundefConfig { sequential=true, default="%x. arbitrary", 
   294                                 target=NONE, domintros=false, tailrec=false }
   295 
   296 
   297 local structure P = OuterParse and K = OuterKeyword in
   298 
   299 fun fun_cmd config fixes statements flags lthy =
   300     lthy
   301       |> FundefPackage.add_fundef fixes statements config flags
   302       |> by_pat_completeness_simp
   303       |> termination_by_lexicographic_order
   304 
   305 val _ =
   306   OuterSyntax.command "fun" "define general recursive functions (short version)" K.thy_decl
   307   (fundef_parser fun_config
   308      >> (fn ((config, fixes), (flags, statements)) =>
   309             (Toplevel.local_theory (target_of config) (fun_cmd config fixes statements flags))));
   310 
   311 end
   312 
   313 end