src/HOL/Library/simps_case_conv.ML
author wenzelm
Mon Dec 28 17:43:30 2015 +0100 (2015-12-28)
changeset 61952 546958347e05
parent 61813 b84688dd7f6b
child 62969 9f394a16c557
permissions -rw-r--r--
prefer symbols for "Union", "Inter";
     1 (*  Title:      HOL/Library/simps_case_conv.ML
     2     Author:     Lars Noschinski, TU Muenchen
     3     Author:     Gerwin Klein, NICTA
     4 
     5 Convert function specifications between the representation as a list
     6 of equations (with patterns on the lhs) and a single equation (with a
     7 nested case expression on the rhs).
     8 *)
     9 
    10 signature SIMPS_CASE_CONV =
    11 sig
    12   val to_case: Proof.context -> thm list -> thm
    13   val gen_to_simps: Proof.context -> thm list -> thm -> thm list
    14   val to_simps: Proof.context -> thm -> thm list
    15 end
    16 
    17 structure Simps_Case_Conv: SIMPS_CASE_CONV =
    18 struct
    19 
    20 (* Collects all type constructors in a type *)
    21 fun collect_Tcons (Type (name,Ts)) = name :: maps collect_Tcons Ts
    22   | collect_Tcons (TFree _) = []
    23   | collect_Tcons (TVar _) = []
    24 
    25 fun get_type_infos ctxt =
    26     maps collect_Tcons
    27     #> distinct (op =)
    28     #> map_filter (Ctr_Sugar.ctr_sugar_of ctxt)
    29 
    30 fun get_split_ths ctxt = get_type_infos ctxt #> map #split
    31 
    32 val strip_eq = Thm.prop_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq
    33 
    34 
    35 local
    36 
    37   fun transpose [] = []
    38     | transpose ([] :: xss) = transpose xss
    39     | transpose xss = map hd xss :: transpose (map tl xss);
    40 
    41   fun same_fun single_ctrs (ts as _ $ _ :: _) =
    42       let
    43         val (fs, argss) = map strip_comb ts |> split_list
    44         val f = hd fs
    45         fun is_single_ctr (Const (name, _)) = member (op =) single_ctrs name
    46           | is_single_ctr _ = false
    47       in if not (is_single_ctr f) andalso forall (fn x => f = x) fs then SOME (f, argss) else NONE end
    48     | same_fun _ _ = NONE
    49 
    50   (* pats must be non-empty *)
    51   fun split_pat single_ctrs pats ctxt =
    52       case same_fun single_ctrs pats of
    53         NONE =>
    54           let
    55             val (name, ctxt') = yield_singleton Variable.variant_fixes "x" ctxt
    56             val var = Free (name, fastype_of (hd pats))
    57           in (((var, [var]), map single pats), ctxt') end
    58       | SOME (f, argss) =>
    59           let
    60             val (((def_pats, def_frees), case_patss), ctxt') =
    61               split_pats single_ctrs argss ctxt
    62             val def_pat = list_comb (f, def_pats)
    63           in (((def_pat, flat def_frees), case_patss), ctxt') end
    64   and
    65       split_pats single_ctrs patss ctxt =
    66         let
    67           val (splitted, ctxt') = fold_map (split_pat single_ctrs) (transpose patss) ctxt
    68           val r = splitted |> split_list |> apfst split_list |> apsnd (transpose #> map flat)
    69         in (r, ctxt') end
    70 
    71 (*
    72   Takes a list lhss of left hand sides (which are lists of patterns)
    73   and a list rhss of right hand sides. Returns
    74     - a single equation with a (nested) case-expression on the rhs
    75     - a list of all split-thms needed to split the rhs
    76   Patterns which have the same outer context in all lhss remain
    77   on the lhs of the computed equation.
    78 *)
    79 fun build_case_t fun_t lhss rhss ctxt =
    80   let
    81     val single_ctrs =
    82       get_type_infos ctxt (map fastype_of (flat lhss))
    83       |> map_filter (fn ti => case #ctrs ti of [Const (name, _)] => SOME name | _ => NONE)
    84     val (((def_pats, def_frees), case_patss), ctxt') =
    85       split_pats single_ctrs lhss ctxt
    86     val pattern = map HOLogic.mk_tuple case_patss
    87     val case_arg = HOLogic.mk_tuple (flat def_frees)
    88     val cases = Case_Translation.make_case ctxt' Case_Translation.Warning Name.context
    89       case_arg (pattern ~~ rhss)
    90     val split_thms = get_split_ths ctxt' [fastype_of case_arg]
    91     val t = (list_comb (fun_t, def_pats), cases)
    92       |> HOLogic.mk_eq
    93       |> HOLogic.mk_Trueprop
    94   in ((t, split_thms), ctxt') end
    95 
    96 fun tac ctxt {splits, intros, defs} =
    97   let val ctxt' = Classical.addSIs (ctxt, intros) in
    98     REPEAT_DETERM1 (FIRSTGOAL (split_tac ctxt splits))
    99     THEN Local_Defs.unfold_tac ctxt defs
   100     THEN safe_tac ctxt'
   101   end
   102 
   103 fun import [] ctxt = ([], ctxt)
   104   | import (thm :: thms) ctxt =
   105     let
   106       val fun_ct = strip_eq #> fst #> strip_comb #> fst #> Logic.mk_term
   107         #> Thm.cterm_of ctxt
   108       val ct = fun_ct thm
   109       val cts = map fun_ct thms
   110       val pairs = map (fn s => (s,ct)) cts
   111       val thms' = map (fn (th,p) => Thm.instantiate (Thm.match p) th) (thms ~~ pairs)
   112     in Variable.import true (thm :: thms') ctxt |> apfst snd end
   113 
   114 in
   115 
   116 (*
   117   For a list
   118     f p_11 ... p_1n = t1
   119     f p_21 ... p_2n = t2
   120     ...
   121     f p_mn ... p_mn = tm
   122   of theorems, prove a single theorem
   123     f x1 ... xn = t
   124   where t is a (nested) case expression. f must not be a function
   125   application. Moreover, the terms p_11, ..., p_mn must be non-overlapping
   126   datatype patterns. The patterns must be exhausting up to common constructor
   127   contexts.
   128 *)
   129 fun to_case ctxt ths =
   130   let
   131     val (iths, ctxt') = import ths ctxt
   132     val fun_t = hd iths |> strip_eq |> fst |> head_of
   133     val eqs = map (strip_eq #> apfst (snd o strip_comb)) iths
   134 
   135     fun hide_rhs ((pat, rhs), name) lthy =
   136       let
   137         val frees = fold Term.add_frees pat []
   138         val abs_rhs = fold absfree frees rhs
   139         val ((f,def), lthy') = Local_Defs.add_def
   140           ((Binding.name name, Mixfix.NoSyn), abs_rhs) lthy
   141       in ((list_comb (f, map Free (rev frees)), def), lthy') end
   142 
   143     val ((def_ts, def_thms), ctxt2) =
   144       let val names = Name.invent (Variable.names_of ctxt') "rhs" (length eqs)
   145       in fold_map hide_rhs (eqs ~~ names) ctxt' |> apfst split_list end
   146 
   147     val ((t, split_thms), ctxt3) = build_case_t fun_t (map fst eqs) def_ts ctxt2
   148 
   149     val th = Goal.prove ctxt3 [] [] t (fn {context=ctxt, ...} =>
   150           tac ctxt {splits=split_thms, intros=ths, defs=def_thms})
   151   in th
   152     |> singleton (Proof_Context.export ctxt3 ctxt)
   153     |> Goal.norm_result ctxt
   154   end
   155 
   156 end
   157 
   158 local
   159 
   160 fun was_split t =
   161   let
   162     val is_free_eq_imp = is_Free o fst o HOLogic.dest_eq o fst o HOLogic.dest_imp
   163     val get_conjs = HOLogic.dest_conj o HOLogic.dest_Trueprop
   164     fun dest_alls (Const (@{const_name All}, _) $ Abs (_, _, t)) = dest_alls t
   165       | dest_alls t = t
   166   in forall (is_free_eq_imp o dest_alls) (get_conjs t) end
   167   handle TERM _ => false
   168 
   169 fun apply_split ctxt split thm = Seq.of_list
   170   let val ((_,thm'), ctxt') = Variable.import false [thm] ctxt in
   171     (Variable.export ctxt' ctxt) (filter (was_split o Thm.prop_of) (thm' RL [split]))
   172   end
   173 
   174 fun forward_tac rules t = Seq.of_list ([t] RL rules)
   175 
   176 val refl_imp = refl RSN (2, mp)
   177 
   178 val get_rules_once_split =
   179   REPEAT (forward_tac [conjunct1, conjunct2])
   180     THEN REPEAT (forward_tac [spec])
   181     THEN (forward_tac [refl_imp])
   182 
   183 fun do_split ctxt split =
   184   case try op RS (split, iffD1) of
   185     NONE => raise TERM ("malformed split rule", [Thm.prop_of split])
   186   | SOME split' =>
   187       let val split_rhs = Thm.concl_of (hd (snd (fst (Variable.import false [split'] ctxt))))
   188       in if was_split split_rhs
   189          then DETERM (apply_split ctxt split') THEN get_rules_once_split
   190          else raise TERM ("malformed split rule", [split_rhs])
   191       end
   192 
   193 val atomize_meta_eq = forward_tac [meta_eq_to_obj_eq]
   194 
   195 in
   196 
   197 fun gen_to_simps ctxt splitthms thm =
   198   let val splitthms' = filter (fn t => not (Thm.eq_thm (t, Drule.dummy_thm))) splitthms
   199   in
   200     Seq.list_of ((TRY atomize_meta_eq THEN (REPEAT (FIRST (map (do_split ctxt) splitthms')))) thm)
   201   end
   202 
   203 fun to_simps ctxt thm =
   204   let
   205     val T = thm |> strip_eq |> fst |> strip_comb |> fst |> fastype_of
   206     val splitthms = get_split_ths ctxt [T]
   207   in gen_to_simps ctxt splitthms thm end
   208 
   209 
   210 end
   211 
   212 fun case_of_simps_cmd (bind, thms_ref) lthy =
   213   let
   214     val bind' = apsnd (map (Attrib.check_src lthy)) bind
   215     val thm = Attrib.eval_thms lthy thms_ref |> to_case lthy
   216   in
   217     Local_Theory.note (bind', [thm]) lthy |> snd
   218   end
   219 
   220 fun simps_of_case_cmd ((bind, thm_ref), splits_ref) lthy =
   221   let
   222     val bind' = apsnd (map (Attrib.check_src lthy)) bind
   223     val thm = singleton (Attrib.eval_thms lthy) thm_ref
   224     val simps = if null splits_ref
   225       then to_simps lthy thm
   226       else gen_to_simps lthy (Attrib.eval_thms lthy splits_ref) thm
   227   in
   228     Local_Theory.note (bind', simps) lthy |> snd
   229   end
   230 
   231 val _ =
   232   Outer_Syntax.local_theory @{command_keyword case_of_simps}
   233     "turn a list of equations into a case expression"
   234     (Parse_Spec.opt_thm_name ":"  -- Parse.xthms1 >> case_of_simps_cmd)
   235 
   236 val parse_splits = @{keyword "("} |-- Parse.reserved "splits" |-- @{keyword ":"} |--
   237   Parse.xthms1 --| @{keyword ")"}
   238 
   239 val _ =
   240   Outer_Syntax.local_theory @{command_keyword simps_of_case}
   241     "perform case split on rule"
   242     (Parse_Spec.opt_thm_name ":"  -- Parse.xthm --
   243       Scan.optional parse_splits [] >> simps_of_case_cmd)
   244 
   245 end
   246