src/HOL/Tools/function_package/fundef_datatype.ML
author krauss
Tue Nov 07 11:53:55 2006 +0100 (2006-11-07)
changeset 21211 5370cfbf3070
parent 21051 c49467a9c1e1
child 21240 8e75fb38522c
permissions -rw-r--r--
Preparations for making "lexicographic_order" part of "fun"
     1 (*  Title:      HOL/Tools/function_package/fundef_datatype.ML
     2     ID:         $Id$
     3     Author:     Alexander Krauss, TU Muenchen
     4 
     5 A package for general recursive function definitions.
     6 A tactic to prove completeness of datatype patterns.
     7 *)
     8 
     9 signature FUNDEF_DATATYPE =
    10 sig
    11     val pat_complete_tac: int -> tactic
    12 
    13     val pat_completeness : method
    14     val setup : theory -> theory
    15 end
    16 
    17 structure FundefDatatype : FUNDEF_DATATYPE =
    18 struct
    19 
    20 open FundefLib
    21 open FundefCommon
    22 
    23 fun mk_argvar i T = Free ("_av" ^ (string_of_int i), T)
    24 fun mk_patvar i T = Free ("_pv" ^ (string_of_int i), T)
    25 
    26 fun inst_free var inst thm =
    27     forall_elim inst (forall_intr var thm)
    28 
    29 
    30 fun inst_case_thm thy x P thm =
    31     let
    32         val [Pv, xv] = term_vars (prop_of thm)
    33     in
    34         cterm_instantiate [(cterm_of thy xv, cterm_of thy x), (cterm_of thy Pv, cterm_of thy P)] thm
    35     end
    36 
    37 
    38 fun invent_vars constr i =
    39     let
    40         val Ts = binder_types (fastype_of constr)
    41         val j = i + length Ts
    42         val is = i upto (j - 1)
    43         val avs = map2 mk_argvar is Ts
    44         val pvs = map2 mk_patvar is Ts
    45     in
    46         (avs, pvs, j)
    47     end
    48 
    49 
    50 fun filter_pats thy cons pvars [] = []
    51   | filter_pats thy cons pvars (([], thm) :: pts) = raise Match
    52   | filter_pats thy cons pvars ((pat :: pats, thm) :: pts) =
    53     case pat of
    54         Free _ => let val inst = list_comb (cons, pvars)
    55                  in (inst :: pats, inst_free (cterm_of thy pat) (cterm_of thy inst) thm)
    56                     :: (filter_pats thy cons pvars pts) end
    57       | _ => if fst (strip_comb pat) = cons
    58              then (pat :: pats, thm) :: (filter_pats thy cons pvars pts)
    59              else filter_pats thy cons pvars pts
    60 
    61 
    62 fun inst_constrs_of thy (T as Type (name, _)) =
    63         map (fn (Cn,CT) => Envir.subst_TVars (Sign.typ_match thy (body_type CT, T) Vartab.empty) (Const (Cn, CT)))
    64             (the (DatatypePackage.get_datatype_constrs thy name))
    65   | inst_constrs_of thy _ = raise Match
    66 
    67 
    68 fun transform_pat thy avars c_assum ([] , thm) = raise Match
    69   | transform_pat thy avars c_assum (pat :: pats, thm) =
    70     let
    71         val (_, subps) = strip_comb pat
    72         val eqs = map (cterm_of thy o HOLogic.mk_Trueprop o HOLogic.mk_eq) (avars ~~ subps)
    73         val a_eqs = map assume eqs
    74         val c_eq_pat = simplify (HOL_basic_ss addsimps a_eqs) c_assum
    75     in
    76         (subps @ pats, fold_rev implies_intr eqs
    77                                 (implies_elim thm c_eq_pat))
    78     end
    79 
    80 
    81 exception COMPLETENESS
    82 
    83 fun constr_case thy P idx (v :: vs) pats cons =
    84     let
    85         val (avars, pvars, newidx) = invent_vars cons idx
    86         val c_hyp = cterm_of thy (HOLogic.mk_Trueprop (HOLogic.mk_eq (v, list_comb (cons, avars))))
    87         val c_assum = assume c_hyp
    88         val newpats = map (transform_pat thy avars c_assum) (filter_pats thy cons pvars pats)
    89     in
    90         o_alg thy P newidx (avars @ vs) newpats
    91               |> implies_intr c_hyp
    92               |> fold_rev (forall_intr o cterm_of thy) avars
    93     end
    94   | constr_case _ _ _ _ _ _ = raise Match
    95 and o_alg thy P idx [] (([], Pthm) :: _)  = Pthm
    96   | o_alg thy P idx (v :: vs) [] = raise COMPLETENESS
    97   | o_alg thy P idx (v :: vs) pts =
    98     if forall (is_Free o hd o fst) pts (* Var case *)
    99     then o_alg thy P idx vs (map (fn (pv :: pats, thm) =>
   100                                (pats, refl RS (inst_free (cterm_of thy pv) (cterm_of thy v) thm))) pts)
   101     else (* Cons case *)
   102          let
   103              val T = fastype_of v
   104              val (tname, _) = dest_Type T
   105              val {exhaustion=case_thm, ...} = DatatypePackage.the_datatype thy tname
   106              val constrs = inst_constrs_of thy T
   107              val c_cases = map (constr_case thy P idx (v :: vs) pts) constrs
   108          in
   109              inst_case_thm thy v P case_thm
   110                            |> fold (curry op COMP) c_cases
   111          end
   112   | o_alg _ _ _ _ _ = raise Match
   113 
   114 
   115 fun prove_completeness thy x P qss pats =
   116     let
   117         fun mk_assum qs pat = Logic.mk_implies (HOLogic.mk_Trueprop (HOLogic.mk_eq (x,pat)),
   118                                                 HOLogic.mk_Trueprop P)
   119                                                |> fold_rev mk_forall qs
   120                                                |> cterm_of thy
   121 
   122         val hyps = map2 mk_assum qss pats
   123 
   124         fun inst_hyps hyp qs = fold (forall_elim o cterm_of thy) qs (assume hyp)
   125 
   126         val assums = map2 inst_hyps hyps qss
   127     in
   128         o_alg thy P 2 [x] (map2 (pair o single) pats assums)
   129               |> fold_rev implies_intr hyps
   130     end
   131 
   132 
   133 
   134 fun pat_complete_tac i thm =
   135     let
   136       val thy = theory_of_thm thm
   137 
   138         val subgoal = nth (prems_of thm) (i - 1)   (* FIXME SUBGOAL tactical *)
   139 
   140         val ([P, x], subgf) = dest_all_all subgoal
   141 
   142         val assums = Logic.strip_imp_prems subgf
   143 
   144         fun pat_of assum =
   145             let
   146                 val (qs, imp) = dest_all_all assum
   147             in
   148                 case Logic.dest_implies imp of
   149                     (_ $ (_ $ _ $ pat), _) => (qs, pat)
   150                   | _ => raise COMPLETENESS
   151             end
   152 
   153         val (qss, pats) = split_list (map pat_of assums)
   154 
   155         val complete_thm = prove_completeness thy x P qss pats
   156                                               |> forall_intr (cterm_of thy x)
   157                                               |> forall_intr (cterm_of thy P)
   158     in
   159         Seq.single (Drule.compose_single(complete_thm, i, thm))
   160     end
   161     handle COMPLETENESS => Seq.empty
   162 
   163 
   164 val pat_completeness = Method.SIMPLE_METHOD (pat_complete_tac 1)
   165 
   166 val by_pat_completeness_simp =
   167     Proof.global_terminal_proof
   168       (Method.Basic (K pat_completeness),
   169        SOME (Method.Source (Args.src (("simp_all", []), Position.none))))
   170          (* FIXME avoid dynamic scoping of method name! *)
   171 
   172 fun termination_by_lexicographic_order name =
   173     FundefPackage.setup_termination_proof (SOME name)
   174     #> Proof.global_terminal_proof (Method.Basic (K LexicographicOrder.lexicographic_order), NONE)
   175 
   176 val setup =
   177     Method.add_methods [("pat_completeness", Method.no_args pat_completeness, "Completeness prover for datatype patterns")]
   178 
   179 
   180 
   181 
   182 local structure P = OuterParse and K = OuterKeyword in
   183 
   184 
   185 fun or_list1 s = P.enum1 "|" s
   186 val otherwise = P.$$$ "(" |-- P.$$$ "otherwise" --| P.$$$ ")"
   187 val statement_ow = P.and_list1 (P.opt_thm_name ":" -- Scan.repeat1 (P.prop -- Scan.optional (otherwise >> K true) false))
   188 val statements_ow = or_list1 statement_ow
   189 
   190 
   191 fun fun_cmd fixes statements lthy =
   192     lthy
   193       |> FundefPackage.add_fundef fixes statements FundefCommon.fun_config
   194       ||> by_pat_completeness_simp
   195       (*|-> termination_by_lexicographic_order*) |> snd
   196 
   197 
   198 val funP =
   199   OuterSyntax.command "fun" "define general recursive functions (short version)" K.thy_decl
   200   ((P.opt_locale_target -- P.fixes --| P.$$$ "where" -- statements_ow)
   201      >> (fn ((target, fixes), statements) =>
   202             (Toplevel.local_theory target (fun_cmd fixes statements))));
   203 
   204 val _ = OuterSyntax.add_parsers [funP];
   205 end
   206 
   207 end