src/HOL/Tools/Function/partial_function.ML
changeset 40107 374f3ef9f940
child 40169 11ea439d947f
equal deleted inserted replaced
40106:c58951943cba 40107:374f3ef9f940
       
     1 (*  Title:      HOL/Tools/Function/partial_function.ML
       
     2     Author:     Alexander Krauss, TU Muenchen
       
     3 
       
     4 Partial function definitions based on least fixed points in ccpos.
       
     5 *)
       
     6 
       
     7 signature PARTIAL_FUNCTION =
       
     8 sig
       
     9   val setup: theory -> theory
       
    10   val init: term -> term -> thm -> declaration
       
    11 
       
    12   val add_partial_function: string -> (binding * typ option * mixfix) list ->
       
    13     Attrib.binding * term -> local_theory -> local_theory
       
    14 
       
    15   val add_partial_function_cmd: string -> (binding * string option * mixfix) list ->
       
    16     Attrib.binding * string -> local_theory -> local_theory
       
    17 end;
       
    18 
       
    19 
       
    20 structure Partial_Function: PARTIAL_FUNCTION =
       
    21 struct
       
    22 
       
    23 (*** Context Data ***)
       
    24 
       
    25 structure Modes = Generic_Data
       
    26 (
       
    27   type T = ((term * term) * thm) Symtab.table;
       
    28   val empty = Symtab.empty;
       
    29   val extend = I;
       
    30   fun merge (a, b) = Symtab.merge (K true) (a, b);
       
    31 )
       
    32 
       
    33 fun init fixp mono fixp_eq phi =
       
    34   let
       
    35     val term = Morphism.term phi;
       
    36     val data' = ((term fixp, term mono), Morphism.thm phi fixp_eq);
       
    37     val mode = (* extract mode identifier from morphism prefix! *)
       
    38       Binding.prefix_of (Morphism.binding phi (Binding.empty))
       
    39       |> map fst |> space_implode ".";
       
    40   in
       
    41     if mode = "" then I
       
    42     else Modes.map (Symtab.update (mode, data'))
       
    43   end
       
    44 
       
    45 val known_modes = Symtab.keys o Modes.get o Context.Proof;
       
    46 val lookup_mode = Symtab.lookup o Modes.get o Context.Proof;
       
    47 
       
    48 
       
    49 structure Mono_Rules = Named_Thms
       
    50 (
       
    51   val name = "partial_function_mono";
       
    52   val description = "monotonicity rules for partial function definitions";
       
    53 );
       
    54 
       
    55 
       
    56 (*** Automated monotonicity proofs ***)
       
    57 
       
    58 fun strip_cases ctac = ctac #> Seq.map snd;
       
    59 
       
    60 (*rewrite conclusion with k-th assumtion*)
       
    61 fun rewrite_with_asm_tac ctxt k =
       
    62   Subgoal.FOCUS (fn {context=ctxt', prems, ...} =>
       
    63     Local_Defs.unfold_tac ctxt' [nth prems k]) ctxt;
       
    64 
       
    65 fun dest_case thy t =
       
    66   case strip_comb t of
       
    67     (Const (case_comb, _), args) =>
       
    68       (case Datatype.info_of_case thy case_comb of
       
    69          NONE => NONE
       
    70        | SOME {case_rewrites, ...} =>
       
    71            let
       
    72              val lhs = prop_of (hd case_rewrites)
       
    73                |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |> fst;
       
    74              val arity = length (snd (strip_comb lhs));
       
    75              val conv = funpow (length args - arity) Conv.fun_conv
       
    76                (Conv.rewrs_conv (map mk_meta_eq case_rewrites));
       
    77            in
       
    78              SOME (nth args (arity - 1), conv)
       
    79            end)
       
    80   | _ => NONE;
       
    81 
       
    82 (*split on case expressions*)
       
    83 val split_cases_tac = Subgoal.FOCUS_PARAMS (fn {context=ctxt, ...} =>
       
    84   SUBGOAL (fn (t, i) => case t of
       
    85     _ $ (_ $ Abs (_, _, body)) =>
       
    86       (case dest_case (ProofContext.theory_of ctxt) body of
       
    87          NONE => no_tac
       
    88        | SOME (arg, conv) =>
       
    89            let open Conv in
       
    90               if not (null (loose_bnos arg)) then no_tac
       
    91               else ((DETERM o strip_cases o Induct.cases_tac ctxt false [[SOME arg]] NONE [])
       
    92                 THEN_ALL_NEW (rewrite_with_asm_tac ctxt 0)
       
    93                 THEN_ALL_NEW etac @{thm thin_rl}
       
    94                 THEN_ALL_NEW (CONVERSION
       
    95                   (params_conv ~1 (fn ctxt' =>
       
    96                     arg_conv (arg_conv (abs_conv (K conv) ctxt'))) ctxt))) i
       
    97            end)
       
    98   | _ => no_tac) 1);
       
    99 
       
   100 (*monotonicity proof: apply rules + split case expressions*)
       
   101 fun mono_tac ctxt =
       
   102   K (Local_Defs.unfold_tac ctxt [@{thm curry_def}])
       
   103   THEN' (TRY o REPEAT_ALL_NEW
       
   104    (resolve_tac (Mono_Rules.get ctxt)
       
   105      ORELSE' split_cases_tac ctxt));
       
   106 
       
   107 
       
   108 (*** Auxiliary functions ***)
       
   109 
       
   110 (*positional instantiation with computed type substitution.
       
   111   internal version of  attribute "[of s t u]".*)
       
   112 fun cterm_instantiate' cts thm =
       
   113   let
       
   114     val thy = Thm.theory_of_thm thm;
       
   115     val vs = rev (Term.add_vars (prop_of thm) [])
       
   116       |> map (Thm.cterm_of thy o Var);
       
   117   in
       
   118     cterm_instantiate (zip_options vs cts) thm
       
   119   end;
       
   120 
       
   121 (*Returns t $ u, but instantiates the type of t to make the
       
   122 application type correct*)
       
   123 fun apply_inst ctxt t u =
       
   124   let
       
   125     val thy = ProofContext.theory_of ctxt;
       
   126     val T = domain_type (fastype_of t);
       
   127     val T' = fastype_of u;
       
   128     val subst = Type.typ_match (Sign.tsig_of thy) (T, T') Vartab.empty
       
   129       handle Type.TYPE_MATCH => raise TYPE ("apply_inst", [T, T'], [t, u])
       
   130   in
       
   131     map_types (Envir.norm_type subst) t $ u
       
   132   end;
       
   133 
       
   134 fun head_conv cv ct =
       
   135   if can Thm.dest_comb ct then Conv.fun_conv (head_conv cv) ct else cv ct;
       
   136 
       
   137 
       
   138 (*** currying transformation ***)
       
   139 
       
   140 fun curry_const (A, B, C) =
       
   141   Const (@{const_name Product_Type.curry},
       
   142     [HOLogic.mk_prodT (A, B) --> C, A, B] ---> C);
       
   143 
       
   144 fun mk_curry f =
       
   145   case fastype_of f of
       
   146     Type ("fun", [Type (_, [S, T]), U]) =>
       
   147       curry_const (S, T, U) $ f
       
   148   | T => raise TYPE ("mk_curry", [T], [f]);
       
   149 
       
   150 (* iterated versions. Nonstandard left-nested tuples arise naturally
       
   151 from "split o split o split"*)
       
   152 fun curry_n arity = funpow (arity - 1) mk_curry;
       
   153 fun uncurry_n arity = funpow (arity - 1) HOLogic.mk_split;
       
   154 
       
   155 val curry_uncurry_ss = HOL_basic_ss addsimps
       
   156   [@{thm Product_Type.curry_split}, @{thm Product_Type.split_curry}]
       
   157 
       
   158 
       
   159 (*** partial_function definition ***)
       
   160 
       
   161 fun gen_add_partial_function prep mode fixes_raw eqn_raw lthy =
       
   162   let
       
   163     val ((fixp, mono), fixp_eq) = the (lookup_mode lthy mode)
       
   164       handle Option.Option => error (cat_lines ["Unknown mode " ^ quote mode ^ ".",
       
   165         "Known modes are " ^ commas_quote (known_modes lthy) ^ "."]);
       
   166 
       
   167     val ((fixes, [(eq_abinding, eqn)]), _) = prep fixes_raw [eqn_raw] lthy;
       
   168     val (_, _, plain_eqn) = Function_Lib.dest_all_all_ctx lthy eqn;
       
   169 
       
   170     val ((f_binding, fT), mixfix) = the_single fixes;
       
   171     val fname = Binding.name_of f_binding;
       
   172 
       
   173     val cert = cterm_of (ProofContext.theory_of lthy);
       
   174     val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop plain_eqn);
       
   175     val (head, args) = strip_comb lhs;
       
   176     val F = fold_rev lambda (head :: args) rhs;
       
   177 
       
   178     val arity = length args;
       
   179     val (aTs, bTs) = chop arity (binder_types fT);
       
   180 
       
   181     val tupleT = foldl1 HOLogic.mk_prodT aTs;
       
   182     val fT_uc = tupleT :: bTs ---> body_type fT;
       
   183     val f_uc = Var ((fname, 0), fT_uc);
       
   184     val x_uc = Var (("x", 0), tupleT);
       
   185     val uncurry = lambda head (uncurry_n arity head);
       
   186     val curry = lambda f_uc (curry_n arity f_uc);
       
   187 
       
   188     val F_uc =
       
   189       lambda f_uc (uncurry_n arity (F $ curry_n arity f_uc));
       
   190 
       
   191     val mono_goal = apply_inst lthy mono (lambda f_uc (F_uc $ f_uc $ x_uc))
       
   192       |> HOLogic.mk_Trueprop
       
   193       |> Logic.all x_uc;
       
   194 
       
   195     val mono_thm = Goal.prove_internal [] (cert mono_goal)
       
   196         (K (mono_tac lthy 1))
       
   197       |> Thm.forall_elim (cert x_uc);
       
   198 
       
   199     val f_def_rhs = curry_n arity (apply_inst lthy fixp F_uc);
       
   200     val f_def_binding = Binding.conceal (Binding.name (Thm.def_name fname));
       
   201     val ((f, (_, f_def)), lthy') = Local_Theory.define
       
   202       ((f_binding, mixfix), ((f_def_binding, []), f_def_rhs)) lthy;
       
   203 
       
   204     val eqn = HOLogic.mk_eq (list_comb (f, args),
       
   205         Term.betapplys (F, f :: args))
       
   206       |> HOLogic.mk_Trueprop;
       
   207 
       
   208     val unfold =
       
   209       (cterm_instantiate' (map (SOME o cert) [uncurry, F, curry]) fixp_eq
       
   210         OF [mono_thm, f_def])
       
   211       |> Tactic.rule_by_tactic lthy (Simplifier.simp_tac curry_uncurry_ss 1);
       
   212 
       
   213     val rec_rule = let open Conv in
       
   214       Goal.prove lthy' (map (fst o dest_Free) args) [] eqn (fn _ =>
       
   215         CONVERSION ((arg_conv o arg1_conv o head_conv o rewr_conv) (mk_meta_eq unfold)) 1
       
   216         THEN rtac @{thm refl} 1) end;
       
   217   in
       
   218     lthy'
       
   219     |> Local_Theory.note (eq_abinding, [rec_rule])
       
   220     |-> (fn (_, rec') =>
       
   221       Local_Theory.note ((Binding.qualify true fname (Binding.name "rec"), []), rec'))
       
   222     |> snd
       
   223   end;
       
   224 
       
   225 val add_partial_function = gen_add_partial_function Specification.check_spec;
       
   226 val add_partial_function_cmd = gen_add_partial_function Specification.read_spec;
       
   227 
       
   228 val mode = Parse.$$$ "(" |-- Parse.xname --| Parse.$$$ ")";
       
   229 
       
   230 val _ = Outer_Syntax.local_theory
       
   231   "partial_function" "define partial function" Keyword.thy_goal
       
   232   ((mode -- (Parse.fixes -- (Parse.where_ |-- Parse_Spec.spec)))
       
   233      >> (fn (mode, (fixes, spec)) => add_partial_function_cmd mode fixes spec));
       
   234 
       
   235 
       
   236 val setup = Mono_Rules.setup;
       
   237 
       
   238 end