src/HOL/Library/simps_case_conv.ML
changeset 53429 9d9945941eab
parent 53426 92db671e0ac6
child 53433 3b356b7f7cad
equal deleted inserted replaced
53428:3083c611ec40 53429:9d9945941eab
    30 val strip_eq = prop_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq
    30 val strip_eq = prop_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq
    31 
    31 
    32 
    32 
    33 local
    33 local
    34 
    34 
    35 (*Creates free variables for a list of types*)
    35   fun transpose [] = []
    36 fun mk_Frees Ts ctxt =
    36     | transpose ([] :: xss) = transpose xss
    37   let
    37     | transpose xss = map hd xss :: transpose (map tl xss);
    38     val (names,ctxt') = Variable.variant_fixes (replicate (length Ts) "x") ctxt
    38 
    39     val ts = map Free (names ~~ Ts)
    39   fun same_fun (ts as _ $ _ :: _) =
    40   in (ts, ctxt') end
    40       let
       
    41         val (fs, argss) = map strip_comb ts |> split_list
       
    42         val f = hd fs
       
    43       in if forall (fn x => f = x) fs then SOME (f, argss) else NONE end
       
    44     | same_fun _ = NONE
       
    45 
       
    46   (* pats must be non-empty *)
       
    47   fun split_pat pats ctxt =
       
    48       case same_fun pats of
       
    49         NONE =>
       
    50           let
       
    51             val (name, ctxt') = yield_singleton Variable.variant_fixes "x" ctxt
       
    52             val var = Free (name, fastype_of (hd pats))
       
    53           in (((var, [var]), map single pats), ctxt') end
       
    54       | SOME (f, argss) =>
       
    55           let
       
    56             val (((def_pats, def_frees), case_patss), ctxt') =
       
    57               split_pats argss ctxt
       
    58             val def_pat = list_comb (f, def_pats)
       
    59           in (((def_pat, flat def_frees), case_patss), ctxt') end
       
    60   and
       
    61       split_pats patss ctxt =
       
    62         let
       
    63           val (splitted, ctxt') = fold_map split_pat (transpose patss) ctxt
       
    64           val r = splitted |> split_list |> apfst split_list |> apsnd (transpose #> map flat)
       
    65         in (r, ctxt') end
       
    66 
       
    67 (*
       
    68   Takes a list lhss of left hand sides (which are lists of patterns)
       
    69   and a list rhss of right hand sides. Returns
       
    70     - a single equation with a (nested) case-expression on the rhs
       
    71     - a list of all split-thms needed to split the rhs
       
    72   Patterns which have the same outer context in all lhss remain
       
    73   on the lhs of the computed equation.
       
    74 *)
       
    75 fun build_case_t fun_t lhss rhss ctxt =
       
    76   let
       
    77     val (((def_pats, def_frees), case_patss), ctxt') =
       
    78       split_pats lhss ctxt
       
    79     val pattern = map HOLogic.mk_tuple case_patss
       
    80     val case_arg = HOLogic.mk_tuple (flat def_frees)
       
    81     val cases = Case_Translation.make_case ctxt' Case_Translation.Warning Name.context
       
    82       case_arg (pattern ~~ rhss)
       
    83     val split_thms = get_split_ths (Proof_Context.theory_of ctxt') (fastype_of case_arg)
       
    84     val t = (list_comb (fun_t, def_pats), cases)
       
    85       |> HOLogic.mk_eq
       
    86       |> HOLogic.mk_Trueprop
       
    87   in ((t, split_thms), ctxt') end
    41 
    88 
    42 fun tac ctxt {splits, intros, defs} =
    89 fun tac ctxt {splits, intros, defs} =
    43   let val ctxt' = Classical.addSIs (ctxt, intros) in
    90   let val ctxt' = Classical.addSIs (ctxt, intros) in
    44     REPEAT_DETERM1 (FIRSTGOAL (split_tac splits))
    91     REPEAT_DETERM1 (FIRSTGOAL (split_tac splits))
    45     THEN Local_Defs.unfold_tac ctxt defs
    92     THEN Local_Defs.unfold_tac ctxt defs
    65     f p_21 ... p_2n = t2
   112     f p_21 ... p_2n = t2
    66     ...
   113     ...
    67     f p_mn ... p_mn = tm
   114     f p_mn ... p_mn = tm
    68   of theorems, prove a single theorem
   115   of theorems, prove a single theorem
    69     f x1 ... xn = t
   116     f x1 ... xn = t
    70   where t is a (nested) case expression. The terms p_11, ..., p_mn must
   117   where t is a (nested) case expression. f must not be a function
    71   be exhaustive, non-overlapping datatype patterns. f must not be a function
   118   application. Moreover, the terms p_11, ..., p_mn must be non-overlapping
    72   application.
   119   datatype patterns. The patterns must be exhausting up to common constructor
       
   120   contexts.
    73 *)
   121 *)
    74 fun to_case ctxt ths =
   122 fun to_case ctxt ths =
    75   let
   123   let
    76     val (iths, ctxt') = import ths ctxt
   124     val (iths, ctxt') = import ths ctxt
    77     val (fun_t, arg_ts) = hd iths |> strip_eq |> fst |> strip_comb
   125     val fun_t = hd iths |> strip_eq |> fst |> head_of
    78     val eqs = map (strip_eq #> apfst (snd o strip_comb)) iths
   126     val eqs = map (strip_eq #> apfst (snd o strip_comb)) iths
    79     val (arg_Frees, ctxt'') = mk_Frees (map fastype_of arg_ts) ctxt'
       
    80 
   127 
    81     fun hide_rhs ((pat, rhs), name) lthy = let
   128     fun hide_rhs ((pat, rhs), name) lthy = let
    82         val frees = fold Term.add_frees pat []
   129         val frees = fold Term.add_frees pat []
    83         val abs_rhs = fold absfree frees rhs
   130         val abs_rhs = fold absfree frees rhs
    84         val ((f,def), lthy') = Local_Defs.add_def
   131         val ((f,def), lthy') = Local_Defs.add_def
    85           ((Binding.name name, Mixfix.NoSyn), abs_rhs) lthy
   132           ((Binding.name name, Mixfix.NoSyn), abs_rhs) lthy
    86       in ((list_comb (f, map Free (rev frees)), def), lthy') end
   133       in ((list_comb (f, map Free (rev frees)), def), lthy') end
    87 
   134 
    88     val ((def_ts, def_thms), ctxt3) = let
   135     val ((def_ts, def_thms), ctxt2) = let
    89         val nctxt = Variable.names_of ctxt''
   136         val nctxt = Variable.names_of ctxt'
    90         val names = Name.invent nctxt "rhs" (length eqs)
   137         val names = Name.invent nctxt "rhs" (length eqs)
    91       in fold_map hide_rhs (eqs ~~ names) ctxt'' |> apfst split_list end
   138       in fold_map hide_rhs (eqs ~~ names) ctxt' |> apfst split_list end
    92 
   139 
    93     val (cases, split_thms) =
   140     val ((t, split_thms), ctxt3) = build_case_t fun_t (map fst eqs) def_ts ctxt2
    94       let
   141 
    95         val pattern = map (fst #> HOLogic.mk_tuple) eqs
       
    96         val case_arg = HOLogic.mk_tuple arg_Frees
       
    97         val cases = Case_Translation.make_case ctxt Case_Translation.Warning Name.context
       
    98           case_arg (pattern ~~ def_ts)
       
    99         val split_thms = get_split_ths (Proof_Context.theory_of ctxt3) (fastype_of case_arg)
       
   100       in (cases, split_thms) end
       
   101 
       
   102     val t = (list_comb (fun_t, arg_Frees), cases)
       
   103       |> HOLogic.mk_eq
       
   104       |> HOLogic.mk_Trueprop
       
   105     val th = Goal.prove ctxt3 [] [] t (fn {context=ctxt, ...} =>
   142     val th = Goal.prove ctxt3 [] [] t (fn {context=ctxt, ...} =>
   106           tac ctxt {splits=split_thms, intros=ths, defs=def_thms})
   143           tac ctxt {splits=split_thms, intros=ths, defs=def_thms})
   107   in th
   144   in th
   108     |> singleton (Proof_Context.export ctxt3 ctxt)
   145     |> singleton (Proof_Context.export ctxt3 ctxt)
   109     |> Goal.norm_result
   146     |> Goal.norm_result