src/HOL/Library/simps_case_conv.ML
changeset 69568 de09a7261120
parent 63352 4eaf35781b23
child 69593 3dda49e08b9d
     1.1 --- a/src/HOL/Library/simps_case_conv.ML	Sun Dec 30 10:30:41 2018 +0100
     1.2 +++ b/src/HOL/Library/simps_case_conv.ML	Tue Jan 01 17:04:53 2019 +0100
     1.3 @@ -31,88 +31,6 @@
     1.4  
     1.5  val strip_eq = Thm.prop_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq
     1.6  
     1.7 -
     1.8 -local
     1.9 -
    1.10 -  fun transpose [] = []
    1.11 -    | transpose ([] :: xss) = transpose xss
    1.12 -    | transpose xss = map hd xss :: transpose (map tl xss);
    1.13 -
    1.14 -  fun same_fun single_ctrs (ts as _ $ _ :: _) =
    1.15 -      let
    1.16 -        val (fs, argss) = map strip_comb ts |> split_list
    1.17 -        val f = hd fs
    1.18 -        fun is_single_ctr (Const (name, _)) = member (op =) single_ctrs name
    1.19 -          | is_single_ctr _ = false
    1.20 -      in if not (is_single_ctr f) andalso forall (fn x => f = x) fs then SOME (f, argss) else NONE end
    1.21 -    | same_fun _ _ = NONE
    1.22 -
    1.23 -  (* pats must be non-empty *)
    1.24 -  fun split_pat single_ctrs pats ctxt =
    1.25 -      case same_fun single_ctrs pats of
    1.26 -        NONE =>
    1.27 -          let
    1.28 -            val (name, ctxt') = yield_singleton Variable.variant_fixes "x" ctxt
    1.29 -            val var = Free (name, fastype_of (hd pats))
    1.30 -          in (((var, [var]), map single pats), ctxt') end
    1.31 -      | SOME (f, argss) =>
    1.32 -          let
    1.33 -            val (((def_pats, def_frees), case_patss), ctxt') =
    1.34 -              split_pats single_ctrs argss ctxt
    1.35 -            val def_pat = list_comb (f, def_pats)
    1.36 -          in (((def_pat, flat def_frees), case_patss), ctxt') end
    1.37 -  and
    1.38 -      split_pats single_ctrs patss ctxt =
    1.39 -        let
    1.40 -          val (splitted, ctxt') = fold_map (split_pat single_ctrs) (transpose patss) ctxt
    1.41 -          val r = splitted |> split_list |> apfst split_list |> apsnd (transpose #> map flat)
    1.42 -        in (r, ctxt') end
    1.43 -
    1.44 -(*
    1.45 -  Takes a list lhss of left hand sides (which are lists of patterns)
    1.46 -  and a list rhss of right hand sides. Returns
    1.47 -    - a single equation with a (nested) case-expression on the rhs
    1.48 -    - a list of all split-thms needed to split the rhs
    1.49 -  Patterns which have the same outer context in all lhss remain
    1.50 -  on the lhs of the computed equation.
    1.51 -*)
    1.52 -fun build_case_t fun_t lhss rhss ctxt =
    1.53 -  let
    1.54 -    val single_ctrs =
    1.55 -      get_type_infos ctxt (map fastype_of (flat lhss))
    1.56 -      |> map_filter (fn ti => case #ctrs ti of [Const (name, _)] => SOME name | _ => NONE)
    1.57 -    val (((def_pats, def_frees), case_patss), ctxt') =
    1.58 -      split_pats single_ctrs lhss ctxt
    1.59 -    val pattern = map HOLogic.mk_tuple case_patss
    1.60 -    val case_arg = HOLogic.mk_tuple (flat def_frees)
    1.61 -    val cases = Case_Translation.make_case ctxt' Case_Translation.Warning Name.context
    1.62 -      case_arg (pattern ~~ rhss)
    1.63 -    val split_thms = get_split_ths ctxt' [fastype_of case_arg]
    1.64 -    val t = (list_comb (fun_t, def_pats), cases)
    1.65 -      |> HOLogic.mk_eq
    1.66 -      |> HOLogic.mk_Trueprop
    1.67 -  in ((t, split_thms), ctxt') end
    1.68 -
    1.69 -fun tac ctxt {splits, intros, defs} =
    1.70 -  let val ctxt' = Classical.addSIs (ctxt, intros) in
    1.71 -    REPEAT_DETERM1 (FIRSTGOAL (split_tac ctxt splits))
    1.72 -    THEN Local_Defs.unfold_tac ctxt defs
    1.73 -    THEN safe_tac ctxt'
    1.74 -  end
    1.75 -
    1.76 -fun import [] ctxt = ([], ctxt)
    1.77 -  | import (thm :: thms) ctxt =
    1.78 -    let
    1.79 -      val fun_ct = strip_eq #> fst #> strip_comb #> fst #> Logic.mk_term
    1.80 -        #> Thm.cterm_of ctxt
    1.81 -      val ct = fun_ct thm
    1.82 -      val cts = map fun_ct thms
    1.83 -      val pairs = map (fn s => (s,ct)) cts
    1.84 -      val thms' = map (fn (th,p) => Thm.instantiate (Thm.match p) th) (thms ~~ pairs)
    1.85 -    in Variable.import true (thm :: thms') ctxt |> apfst snd end
    1.86 -
    1.87 -in
    1.88 -
    1.89  (*
    1.90    For a list
    1.91      f p_11 ... p_1n = t1
    1.92 @@ -122,39 +40,24 @@
    1.93    of theorems, prove a single theorem
    1.94      f x1 ... xn = t
    1.95    where t is a (nested) case expression. f must not be a function
    1.96 -  application. Moreover, the terms p_11, ..., p_mn must be non-overlapping
    1.97 -  datatype patterns. The patterns must be exhausting up to common constructor
    1.98 -  contexts.
    1.99 +  application.
   1.100  *)
   1.101  fun to_case ctxt ths =
   1.102    let
   1.103 -    val (iths, ctxt') = import ths ctxt
   1.104 -    val fun_t = hd iths |> strip_eq |> fst |> head_of
   1.105 -    val eqs = map (strip_eq #> apfst (snd o strip_comb)) iths
   1.106 -
   1.107 -    fun hide_rhs ((pat, rhs), name) lthy =
   1.108 +    fun ctr_count (ctr, T) = 
   1.109        let
   1.110 -        val frees = fold Term.add_frees pat []
   1.111 -        val abs_rhs = fold absfree frees rhs
   1.112 -        val ([(f, (_, def))], lthy') = lthy
   1.113 -          |> Local_Defs.define [((Binding.name name, NoSyn), (Binding.empty_atts, abs_rhs))]
   1.114 -      in ((list_comb (f, map Free (rev frees)), def), lthy') end
   1.115 -
   1.116 -    val ((def_ts, def_thms), ctxt2) =
   1.117 -      let val names = Name.invent (Variable.names_of ctxt') "rhs" (length eqs)
   1.118 -      in fold_map hide_rhs (eqs ~~ names) ctxt' |> apfst split_list end
   1.119 -
   1.120 -    val ((t, split_thms), ctxt3) = build_case_t fun_t (map fst eqs) def_ts ctxt2
   1.121 -
   1.122 -    val th = Goal.prove ctxt3 [] [] t (fn {context=ctxt, ...} =>
   1.123 -          tac ctxt {splits=split_thms, intros=ths, defs=def_thms})
   1.124 -  in th
   1.125 -    |> singleton (Proof_Context.export ctxt3 ctxt)
   1.126 -    |> Goal.norm_result ctxt
   1.127 +        val tyco = body_type T |> dest_Type |> fst
   1.128 +        val info = Ctr_Sugar.ctr_sugar_of ctxt tyco
   1.129 +        val _ = if is_none info then error ("Pattern match on non-constructor constant " ^ ctr) else ()
   1.130 +      in
   1.131 +        info |> the |> #ctrs |> length
   1.132 +      end
   1.133 +    val thms = Case_Converter.to_case ctxt Case_Converter.keep_constructor_context ctr_count ths
   1.134 +  in
   1.135 +    case thms of SOME thms => hd thms
   1.136 +      | _ => error ("Conversion to case expression failed.")
   1.137    end
   1.138  
   1.139 -end
   1.140 -
   1.141  local
   1.142  
   1.143  fun was_split t =