src/HOL/Tools/Function/partial_function.ML
author haftmann
Tue Oct 13 09:21:15 2015 +0200 (2015-10-13)
changeset 61424 c3658c18b7bc
parent 61121 efe8b18306b7
child 61841 4d3527b94f2a
permissions -rw-r--r--
prod_case as canonical name for product type eliminator
     1 (*  Title:      HOL/Tools/Function/partial_function.ML
     2     Author:     Alexander Krauss, TU Muenchen
     3 
     4 Partial function definitions based on least fixed points in ccpos.
     5 *)
     6 
     7 signature PARTIAL_FUNCTION =
     8 sig
     9   val init: string -> term -> term -> thm -> thm -> thm option -> declaration
    10   val mono_tac: Proof.context -> int -> tactic
    11   val add_partial_function: string -> (binding * typ option * mixfix) list ->
    12     Attrib.binding * term -> local_theory -> local_theory
    13   val add_partial_function_cmd: string -> (binding * string option * mixfix) list ->
    14     Attrib.binding * string -> local_theory -> local_theory
    15 end;
    16 
    17 structure Partial_Function: PARTIAL_FUNCTION =
    18 struct
    19 
    20 (*** Context Data ***)
    21 
    22 datatype setup_data = Setup_Data of
    23  {fixp: term,
    24   mono: term,
    25   fixp_eq: thm,
    26   fixp_induct: thm,
    27   fixp_induct_user: thm option};
    28 
    29 fun transform_setup_data phi (Setup_Data {fixp, mono, fixp_eq, fixp_induct, fixp_induct_user}) =
    30   let
    31     val term = Morphism.term phi;
    32     val thm = Morphism.thm phi;
    33   in
    34     Setup_Data
    35      {fixp = term fixp, mono = term mono, fixp_eq = thm fixp_eq,
    36       fixp_induct = thm fixp_induct, fixp_induct_user = Option.map thm fixp_induct_user}
    37   end;
    38 
    39 structure Modes = Generic_Data
    40 (
    41   type T = setup_data Symtab.table;
    42   val empty = Symtab.empty;
    43   val extend = I;
    44   fun merge data = Symtab.merge (K true) data;
    45 )
    46 
    47 fun init mode fixp mono fixp_eq fixp_induct fixp_induct_user phi =
    48   let
    49     val data' =
    50       Setup_Data
    51        {fixp = fixp, mono = mono, fixp_eq = fixp_eq,
    52         fixp_induct = fixp_induct, fixp_induct_user = fixp_induct_user}
    53       |> transform_setup_data (phi $> Morphism.trim_context_morphism);
    54   in Modes.map (Symtab.update (mode, data')) end;
    55 
    56 val known_modes = Symtab.keys o Modes.get o Context.Proof;
    57 
    58 fun lookup_mode ctxt =
    59   Symtab.lookup (Modes.get (Context.Proof ctxt))
    60   #> Option.map (transform_setup_data (Morphism.transfer_morphism (Proof_Context.theory_of ctxt)));
    61 
    62 
    63 (*** Automated monotonicity proofs ***)
    64 
    65 fun strip_cases ctac = ctac #> Seq.map snd;
    66 
    67 (*rewrite conclusion with k-th assumtion*)
    68 fun rewrite_with_asm_tac ctxt k =
    69   Subgoal.FOCUS (fn {context = ctxt', prems, ...} =>
    70     Local_Defs.unfold_tac ctxt' [nth prems k]) ctxt;
    71 
    72 fun dest_case ctxt t =
    73   case strip_comb t of
    74     (Const (case_comb, _), args) =>
    75       (case Ctr_Sugar.ctr_sugar_of_case ctxt case_comb of
    76          NONE => NONE
    77        | SOME {case_thms, ...} =>
    78            let
    79              val lhs = Thm.prop_of (hd case_thms)
    80                |> HOLogic.dest_Trueprop |> HOLogic.dest_eq |> fst;
    81              val arity = length (snd (strip_comb lhs));
    82              val conv = funpow (length args - arity) Conv.fun_conv
    83                (Conv.rewrs_conv (map mk_meta_eq case_thms));
    84            in
    85              SOME (nth args (arity - 1), conv)
    86            end)
    87   | _ => NONE;
    88 
    89 (*split on case expressions*)
    90 val split_cases_tac = Subgoal.FOCUS_PARAMS (fn {context = ctxt, ...} =>
    91   SUBGOAL (fn (t, i) => case t of
    92     _ $ (_ $ Abs (_, _, body)) =>
    93       (case dest_case ctxt body of
    94          NONE => no_tac
    95        | SOME (arg, conv) =>
    96            let open Conv in
    97               if Term.is_open arg then no_tac
    98               else ((DETERM o strip_cases o Induct.cases_tac ctxt false [[SOME arg]] NONE [])
    99                 THEN_ALL_NEW (rewrite_with_asm_tac ctxt 0)
   100                 THEN_ALL_NEW eresolve_tac ctxt @{thms thin_rl}
   101                 THEN_ALL_NEW (CONVERSION
   102                   (params_conv ~1 (fn ctxt' =>
   103                     arg_conv (arg_conv (abs_conv (K conv) ctxt'))) ctxt))) i
   104            end)
   105   | _ => no_tac) 1);
   106 
   107 (*monotonicity proof: apply rules + split case expressions*)
   108 fun mono_tac ctxt =
   109   K (Local_Defs.unfold_tac ctxt [@{thm curry_def}])
   110   THEN' (TRY o REPEAT_ALL_NEW
   111    (resolve_tac ctxt (rev (Named_Theorems.get ctxt @{named_theorems partial_function_mono}))
   112      ORELSE' split_cases_tac ctxt));
   113 
   114 
   115 (*** Auxiliary functions ***)
   116 
   117 (*Returns t $ u, but instantiates the type of t to make the
   118 application type correct*)
   119 fun apply_inst ctxt t u =
   120   let
   121     val thy = Proof_Context.theory_of ctxt;
   122     val T = domain_type (fastype_of t);
   123     val T' = fastype_of u;
   124     val subst = Sign.typ_match thy (T, T') Vartab.empty
   125       handle Type.TYPE_MATCH => raise TYPE ("apply_inst", [T, T'], [t, u])
   126   in
   127     map_types (Envir.norm_type subst) t $ u
   128   end;
   129 
   130 fun head_conv cv ct =
   131   if can Thm.dest_comb ct then Conv.fun_conv (head_conv cv) ct else cv ct;
   132 
   133 
   134 (*** currying transformation ***)
   135 
   136 fun curry_const (A, B, C) =
   137   Const (@{const_name Product_Type.curry},
   138     [HOLogic.mk_prodT (A, B) --> C, A, B] ---> C);
   139 
   140 fun mk_curry f =
   141   case fastype_of f of
   142     Type ("fun", [Type (_, [S, T]), U]) =>
   143       curry_const (S, T, U) $ f
   144   | T => raise TYPE ("mk_curry", [T], [f]);
   145 
   146 (* iterated versions. Nonstandard left-nested tuples arise naturally
   147 from "split o split o split"*)
   148 fun curry_n arity = funpow (arity - 1) mk_curry;
   149 fun uncurry_n arity = funpow (arity - 1) HOLogic.mk_case_prod;
   150 
   151 val curry_uncurry_ss =
   152   simpset_of (put_simpset HOL_basic_ss @{context}
   153     addsimps [@{thm Product_Type.curry_case_prod}, @{thm Product_Type.case_prod_curry}])
   154 
   155 val split_conv_ss =
   156   simpset_of (put_simpset HOL_basic_ss @{context}
   157     addsimps [@{thm Product_Type.split_conv}]);
   158 
   159 val curry_K_ss =
   160   simpset_of (put_simpset HOL_basic_ss @{context}
   161     addsimps [@{thm Product_Type.curry_K}]);
   162 
   163 (* instantiate generic fixpoint induction and eliminate the canonical assumptions;
   164   curry induction predicate *)
   165 fun specialize_fixp_induct ctxt args fT fT_uc F curry uncurry mono_thm f_def rule =
   166   let
   167     val ([P], ctxt') = Variable.variant_fixes ["P"] ctxt
   168     val P_inst = Abs ("f", fT_uc, Free (P, fT --> HOLogic.boolT) $ (curry $ Bound 0))
   169   in
   170     (* FIXME ctxt vs. ctxt' (!?) *)
   171     rule
   172     |> infer_instantiate' ctxt
   173       ((map o Option.map) (Thm.cterm_of ctxt) [SOME uncurry, NONE, SOME curry, NONE, SOME P_inst])
   174     |> Tactic.rule_by_tactic ctxt
   175       (Simplifier.simp_tac (put_simpset curry_uncurry_ss ctxt) 3 (* discharge U (C f) = f *)
   176        THEN Simplifier.simp_tac (put_simpset curry_K_ss ctxt) 4 (* simplify bot case *)
   177        THEN Simplifier.full_simp_tac (put_simpset curry_uncurry_ss ctxt) 5) (* simplify induction step *)
   178     |> (fn thm => thm OF [mono_thm, f_def])
   179     |> Conv.fconv_rule (Conv.concl_conv ~1    (* simplify conclusion *)
   180          (Raw_Simplifier.rewrite ctxt false [mk_meta_eq @{thm Product_Type.curry_case_prod}]))
   181     |> singleton (Variable.export ctxt' ctxt)
   182   end
   183 
   184 fun mk_curried_induct args ctxt inst_rule =
   185   let
   186     val cert = Thm.cterm_of ctxt
   187     val ([P], ctxt') = Variable.variant_fixes ["P"] ctxt
   188 
   189     val split_paired_all_conv =
   190       Conv.every_conv (replicate (length args - 1) (Conv.rewr_conv @{thm split_paired_all}))
   191 
   192     val split_params_conv =
   193       Conv.params_conv ~1 (fn ctxt' =>
   194         Conv.implies_conv split_paired_all_conv Conv.all_conv)
   195 
   196     val (P_var, x_var) =
   197        Thm.prop_of inst_rule |> Logic.strip_imp_concl |> HOLogic.dest_Trueprop
   198       |> strip_comb |> apsnd hd
   199       |> apply2 dest_Var
   200     val P_rangeT = range_type (snd P_var)
   201     val PT = map (snd o dest_Free) args ---> P_rangeT
   202     val x_inst = cert (foldl1 HOLogic.mk_prod args)
   203     val P_inst = cert (uncurry_n (length args) (Free (P, PT)))
   204 
   205     val inst_rule' = inst_rule
   206       |> Tactic.rule_by_tactic ctxt
   207         (Simplifier.simp_tac (put_simpset curry_uncurry_ss ctxt) 4
   208          THEN Simplifier.simp_tac (put_simpset curry_uncurry_ss ctxt) 3
   209          THEN CONVERSION (split_params_conv ctxt
   210            then_conv (Conv.forall_conv (K split_paired_all_conv) ctxt)) 3)
   211       |> Thm.instantiate ([], [(P_var, P_inst), (x_var, x_inst)])
   212       |> Simplifier.full_simplify (put_simpset split_conv_ss ctxt)
   213       |> singleton (Variable.export ctxt' ctxt)
   214   in
   215     inst_rule'
   216   end;
   217 
   218 
   219 (*** partial_function definition ***)
   220 
   221 fun gen_add_partial_function prep mode fixes_raw eqn_raw lthy =
   222   let
   223     val setup_data = the (lookup_mode lthy mode)
   224       handle Option.Option => error (cat_lines ["Unknown mode " ^ quote mode ^ ".",
   225         "Known modes are " ^ commas_quote (known_modes lthy) ^ "."]);
   226     val Setup_Data {fixp, mono, fixp_eq, fixp_induct, fixp_induct_user} = setup_data;
   227 
   228     val ((fixes, [(eq_abinding, eqn)]), _) = prep fixes_raw [eqn_raw] lthy;
   229     val ((_, plain_eqn), args_ctxt) = Variable.focus NONE eqn lthy;
   230 
   231     val ((f_binding, fT), mixfix) = the_single fixes;
   232     val fname = Binding.name_of f_binding;
   233 
   234     val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop plain_eqn);
   235     val (head, args) = strip_comb lhs;
   236     val argnames = map (fst o dest_Free) args;
   237     val F = fold_rev lambda (head :: args) rhs;
   238 
   239     val arity = length args;
   240     val (aTs, bTs) = chop arity (binder_types fT);
   241 
   242     val tupleT = foldl1 HOLogic.mk_prodT aTs;
   243     val fT_uc = tupleT :: bTs ---> body_type fT;
   244     val f_uc = Var ((fname, 0), fT_uc);
   245     val x_uc = Var (("x", 0), tupleT);
   246     val uncurry = lambda head (uncurry_n arity head);
   247     val curry = lambda f_uc (curry_n arity f_uc);
   248 
   249     val F_uc =
   250       lambda f_uc (uncurry_n arity (F $ curry_n arity f_uc));
   251 
   252     val mono_goal = apply_inst lthy mono (lambda f_uc (F_uc $ f_uc $ x_uc))
   253       |> HOLogic.mk_Trueprop
   254       |> Logic.all x_uc;
   255 
   256     val mono_thm = Goal.prove_internal lthy [] (Thm.cterm_of lthy mono_goal)
   257         (K (mono_tac lthy 1))
   258     val inst_mono_thm = Thm.forall_elim (Thm.cterm_of lthy x_uc) mono_thm
   259 
   260     val f_def_rhs = curry_n arity (apply_inst lthy fixp F_uc);
   261     val f_def_binding =
   262       if Config.get lthy Function_Lib.function_defs then (Binding.name (Thm.def_name fname), [])
   263       else Attrib.empty_binding;
   264     val ((f, (_, f_def)), lthy') = Local_Theory.define
   265       ((f_binding, mixfix), (f_def_binding, f_def_rhs)) lthy;
   266 
   267     val eqn = HOLogic.mk_eq (list_comb (f, args),
   268         Term.betapplys (F, f :: args))
   269       |> HOLogic.mk_Trueprop;
   270 
   271     val unfold =
   272       (infer_instantiate' lthy' (map (SOME o Thm.cterm_of lthy') [uncurry, F, curry]) fixp_eq
   273         OF [inst_mono_thm, f_def])
   274       |> Tactic.rule_by_tactic lthy' (Simplifier.simp_tac (put_simpset curry_uncurry_ss lthy') 1);
   275 
   276     val specialized_fixp_induct =
   277       specialize_fixp_induct lthy' args fT fT_uc F curry uncurry inst_mono_thm f_def fixp_induct
   278       |> Drule.rename_bvars' (map SOME (fname :: fname :: argnames));
   279 
   280     val mk_raw_induct =
   281       infer_instantiate' args_ctxt
   282         ((map o Option.map) (Thm.cterm_of args_ctxt) [SOME uncurry, NONE, SOME curry])
   283       #> mk_curried_induct args args_ctxt
   284       #> singleton (Variable.export args_ctxt lthy')
   285       #> (fn thm => infer_instantiate' lthy'
   286           [SOME (Thm.cterm_of lthy' F)] thm OF [inst_mono_thm, f_def])
   287       #> Drule.rename_bvars' (map SOME (fname :: argnames @ argnames))
   288 
   289     val raw_induct = Option.map mk_raw_induct fixp_induct_user
   290     val rec_rule = let open Conv in
   291       Goal.prove lthy' (map (fst o dest_Free) args) [] eqn (fn _ =>
   292         CONVERSION ((arg_conv o arg1_conv o head_conv o rewr_conv) (mk_meta_eq unfold)) 1
   293         THEN resolve_tac lthy' @{thms refl} 1) end;
   294   in
   295     lthy'
   296     |> Local_Theory.note (eq_abinding, [rec_rule])
   297     |-> (fn (_, rec') =>
   298       Spec_Rules.add Spec_Rules.Equational ([f], rec')
   299       #> Local_Theory.note ((Binding.qualify true fname (Binding.name "simps"), []), rec') #> snd)
   300     |> (Local_Theory.note ((Binding.qualify true fname (Binding.name "mono"), []), [mono_thm]) #> snd)
   301     |> (case raw_induct of NONE => I | SOME thm =>
   302          Local_Theory.note ((Binding.qualify true fname (Binding.name "raw_induct"), []), [thm]) #> snd)
   303     |> (Local_Theory.note ((Binding.qualify true fname (Binding.name "fixp_induct"), []), [specialized_fixp_induct]) #> snd)
   304   end;
   305 
   306 val add_partial_function = gen_add_partial_function Specification.check_spec;
   307 val add_partial_function_cmd = gen_add_partial_function Specification.read_spec;
   308 
   309 val mode = @{keyword "("} |-- Parse.xname --| @{keyword ")"};
   310 
   311 val _ =
   312   Outer_Syntax.local_theory @{command_keyword partial_function} "define partial function"
   313     ((mode -- (Parse.fixes -- (Parse.where_ |-- Parse_Spec.spec)))
   314       >> (fn (mode, (fixes, spec)) => add_partial_function_cmd mode fixes spec));
   315 
   316 end;