src/HOL/Library/simps_case_conv.ML
author wenzelm
Tue Dec 31 14:29:16 2013 +0100 (2013-12-31)
changeset 54883 dd04a8b654fc
parent 54401 f6950854855d
child 55997 9dc5ce83202c
permissions -rw-r--r--
proper context for norm_hhf and derived operations;
clarified tool context in some boundary cases;
     1 (*  Title:      HOL/Library/simps_case_conv.ML
     2     Author:     Lars Noschinski, TU Muenchen
     3     Author:     Gerwin Klein, NICTA
     4 
     5 Convert function specifications between the representation as a list
     6 of equations (with patterns on the lhs) and a single equation (with a
     7 nested case expression on the rhs).
     8 *)
     9 
    10 signature SIMPS_CASE_CONV =
    11 sig
    12   val to_case: Proof.context -> thm list -> thm
    13   val gen_to_simps: Proof.context -> thm list -> thm -> thm list
    14   val to_simps: Proof.context -> thm -> thm list
    15 end
    16 
    17 structure Simps_Case_Conv: SIMPS_CASE_CONV =
    18 struct
    19 
    20 (* Collects all type constructors in a type *)
    21 fun collect_Tcons (Type (name,Ts)) = name :: maps collect_Tcons Ts
    22   | collect_Tcons (TFree _) = []
    23   | collect_Tcons (TVar _) = []
    24 
    25 fun get_split_ths ctxt = collect_Tcons
    26     #> distinct (op =)
    27     #> map_filter (Ctr_Sugar.ctr_sugar_of ctxt)
    28     #> map #split
    29 
    30 val strip_eq = prop_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq
    31 
    32 
    33 local
    34 
    35   fun transpose [] = []
    36     | transpose ([] :: xss) = transpose xss
    37     | transpose xss = map hd xss :: transpose (map tl xss);
    38 
    39   fun same_fun (ts as _ $ _ :: _) =
    40       let
    41         val (fs, argss) = map strip_comb ts |> split_list
    42         val f = hd fs
    43       in if forall (fn x => f = x) fs then SOME (f, argss) else NONE end
    44     | same_fun _ = NONE
    45 
    46   (* pats must be non-empty *)
    47   fun split_pat pats ctxt =
    48       case same_fun pats of
    49         NONE =>
    50           let
    51             val (name, ctxt') = yield_singleton Variable.variant_fixes "x" ctxt
    52             val var = Free (name, fastype_of (hd pats))
    53           in (((var, [var]), map single pats), ctxt') end
    54       | SOME (f, argss) =>
    55           let
    56             val (((def_pats, def_frees), case_patss), ctxt') =
    57               split_pats argss ctxt
    58             val def_pat = list_comb (f, def_pats)
    59           in (((def_pat, flat def_frees), case_patss), ctxt') end
    60   and
    61       split_pats patss ctxt =
    62         let
    63           val (splitted, ctxt') = fold_map split_pat (transpose patss) ctxt
    64           val r = splitted |> split_list |> apfst split_list |> apsnd (transpose #> map flat)
    65         in (r, ctxt') end
    66 
    67 (*
    68   Takes a list lhss of left hand sides (which are lists of patterns)
    69   and a list rhss of right hand sides. Returns
    70     - a single equation with a (nested) case-expression on the rhs
    71     - a list of all split-thms needed to split the rhs
    72   Patterns which have the same outer context in all lhss remain
    73   on the lhs of the computed equation.
    74 *)
    75 fun build_case_t fun_t lhss rhss ctxt =
    76   let
    77     val (((def_pats, def_frees), case_patss), ctxt') =
    78       split_pats lhss ctxt
    79     val pattern = map HOLogic.mk_tuple case_patss
    80     val case_arg = HOLogic.mk_tuple (flat def_frees)
    81     val cases = Case_Translation.make_case ctxt' Case_Translation.Warning Name.context
    82       case_arg (pattern ~~ rhss)
    83     val split_thms = get_split_ths ctxt' (fastype_of case_arg)
    84     val t = (list_comb (fun_t, def_pats), cases)
    85       |> HOLogic.mk_eq
    86       |> HOLogic.mk_Trueprop
    87   in ((t, split_thms), ctxt') end
    88 
    89 fun tac ctxt {splits, intros, defs} =
    90   let val ctxt' = Classical.addSIs (ctxt, intros) in
    91     REPEAT_DETERM1 (FIRSTGOAL (split_tac splits))
    92     THEN Local_Defs.unfold_tac ctxt defs
    93     THEN safe_tac ctxt'
    94   end
    95 
    96 fun import [] ctxt = ([], ctxt)
    97   | import (thm :: thms) ctxt =
    98     let
    99       val fun_ct = strip_eq #> fst #> strip_comb #> fst #> Logic.mk_term
   100         #> Thm.cterm_of (Proof_Context.theory_of ctxt)
   101       val ct = fun_ct thm
   102       val cts = map fun_ct thms
   103       val pairs = map (fn s => (s,ct)) cts
   104       val thms' = map (fn (th,p) => Thm.instantiate (Thm.match p) th) (thms ~~ pairs)
   105     in Variable.import true (thm :: thms') ctxt |> apfst snd end
   106 
   107 in
   108 
   109 (*
   110   For a list
   111     f p_11 ... p_1n = t1
   112     f p_21 ... p_2n = t2
   113     ...
   114     f p_mn ... p_mn = tm
   115   of theorems, prove a single theorem
   116     f x1 ... xn = t
   117   where t is a (nested) case expression. f must not be a function
   118   application. Moreover, the terms p_11, ..., p_mn must be non-overlapping
   119   datatype patterns. The patterns must be exhausting up to common constructor
   120   contexts.
   121 *)
   122 fun to_case ctxt ths =
   123   let
   124     val (iths, ctxt') = import ths ctxt
   125     val fun_t = hd iths |> strip_eq |> fst |> head_of
   126     val eqs = map (strip_eq #> apfst (snd o strip_comb)) iths
   127 
   128     fun hide_rhs ((pat, rhs), name) lthy = let
   129         val frees = fold Term.add_frees pat []
   130         val abs_rhs = fold absfree frees rhs
   131         val ((f,def), lthy') = Local_Defs.add_def
   132           ((Binding.name name, Mixfix.NoSyn), abs_rhs) lthy
   133       in ((list_comb (f, map Free (rev frees)), def), lthy') end
   134 
   135     val ((def_ts, def_thms), ctxt2) = let
   136         val nctxt = Variable.names_of ctxt'
   137         val names = Name.invent nctxt "rhs" (length eqs)
   138       in fold_map hide_rhs (eqs ~~ names) ctxt' |> apfst split_list end
   139 
   140     val ((t, split_thms), ctxt3) = build_case_t fun_t (map fst eqs) def_ts ctxt2
   141 
   142     val th = Goal.prove ctxt3 [] [] t (fn {context=ctxt, ...} =>
   143           tac ctxt {splits=split_thms, intros=ths, defs=def_thms})
   144   in th
   145     |> singleton (Proof_Context.export ctxt3 ctxt)
   146     |> Goal.norm_result ctxt
   147   end
   148 
   149 end
   150 
   151 local
   152 
   153 fun was_split t =
   154   let
   155     val is_free_eq_imp = is_Free o fst o HOLogic.dest_eq
   156               o fst o HOLogic.dest_imp
   157     val get_conjs = HOLogic.dest_conj o HOLogic.dest_Trueprop
   158     fun dest_alls (Const ("HOL.All", _) $ Abs (_, _, t)) = dest_alls t
   159       | dest_alls t = t
   160   in forall (is_free_eq_imp o dest_alls) (get_conjs t) end
   161         handle TERM _ => false
   162 
   163 fun apply_split ctxt split thm = Seq.of_list
   164   let val ((_,thm'), ctxt') = Variable.import false [thm] ctxt in
   165     (Variable.export ctxt' ctxt) (filter (was_split o prop_of) (thm' RL [split]))
   166   end
   167 
   168 fun forward_tac rules t = Seq.of_list ([t] RL rules)
   169 
   170 val refl_imp = refl RSN (2, mp)
   171 
   172 val get_rules_once_split =
   173   REPEAT (forward_tac [conjunct1, conjunct2])
   174     THEN REPEAT (forward_tac [spec])
   175     THEN (forward_tac [refl_imp])
   176 
   177 fun do_split ctxt split =
   178   let
   179     val split' = split RS iffD1;
   180     val split_rhs = concl_of (hd (snd (fst (Variable.import false [split'] ctxt))))
   181   in if was_split split_rhs
   182      then DETERM (apply_split ctxt split') THEN get_rules_once_split
   183      else raise TERM ("malformed split rule", [split_rhs])
   184   end
   185 
   186 val atomize_meta_eq = forward_tac [meta_eq_to_obj_eq]
   187 
   188 in
   189 
   190 fun gen_to_simps ctxt splitthms thm =
   191   Seq.list_of ((TRY atomize_meta_eq
   192                  THEN (REPEAT (FIRST (map (do_split ctxt) splitthms)))) thm)
   193 
   194 fun to_simps ctxt thm =
   195   let
   196     val T = thm |> strip_eq |> fst |> strip_comb |> fst |> fastype_of
   197     val splitthms = get_split_ths ctxt T
   198   in gen_to_simps ctxt splitthms thm end
   199 
   200 
   201 end
   202 
   203 fun case_of_simps_cmd (bind, thms_ref) lthy =
   204   let
   205     val thy = Proof_Context.theory_of lthy
   206     val bind' = apsnd (map (Attrib.intern_src thy)) bind
   207     val thm = (Attrib.eval_thms lthy) thms_ref |> to_case lthy
   208   in
   209     Local_Theory.note (bind', [thm]) lthy |> snd
   210   end
   211 
   212 fun simps_of_case_cmd ((bind, thm_ref), splits_ref) lthy =
   213   let
   214     val thy = Proof_Context.theory_of lthy
   215     val bind' = apsnd (map (Attrib.intern_src thy)) bind
   216     val thm = singleton (Attrib.eval_thms lthy) thm_ref
   217     val simps = if null splits_ref
   218       then to_simps lthy thm
   219       else gen_to_simps lthy (Attrib.eval_thms lthy splits_ref) thm
   220   in
   221     Local_Theory.note (bind', simps) lthy |> snd
   222   end
   223 
   224 val _ =
   225   Outer_Syntax.local_theory @{command_spec "case_of_simps"}
   226     "turns a list of equations into a case expression"
   227     (Parse_Spec.opt_thm_name ":"  -- Parse_Spec.xthms1 >> case_of_simps_cmd)
   228 
   229 val parse_splits = @{keyword "("} |-- Parse.reserved "splits" |-- @{keyword ":"} |--
   230   Parse_Spec.xthms1 --| @{keyword ")"}
   231 
   232 val _ =
   233   Outer_Syntax.local_theory @{command_spec "simps_of_case"}
   234     "perform case split on rule"
   235     (Parse_Spec.opt_thm_name ":"  -- Parse_Spec.xthm --
   236       Scan.optional parse_splits [] >> simps_of_case_cmd)
   237 
   238 end
   239