src/HOL/Library/simps_case_conv.ML
author wenzelm
Wed, 22 Jun 2016 10:40:53 +0200
changeset 63344 c9910404cc8a
parent 62969 9f394a16c557
child 63345 70b2313f9c52
permissions -rw-r--r--
tuned signature; tuned;

(*  Title:      HOL/Library/simps_case_conv.ML
    Author:     Lars Noschinski, TU Muenchen
    Author:     Gerwin Klein, NICTA

Convert function specifications between the representation as a list
of equations (with patterns on the lhs) and a single equation (with a
nested case expression on the rhs).
*)

signature SIMPS_CASE_CONV =
sig
  val to_case: Proof.context -> thm list -> thm
  val gen_to_simps: Proof.context -> thm list -> thm -> thm list
  val to_simps: Proof.context -> thm -> thm list
end

structure Simps_Case_Conv: SIMPS_CASE_CONV =
struct

(* Collects all type constructors in a type *)
fun collect_Tcons (Type (name,Ts)) = name :: maps collect_Tcons Ts
  | collect_Tcons (TFree _) = []
  | collect_Tcons (TVar _) = []

fun get_type_infos ctxt =
    maps collect_Tcons
    #> distinct (op =)
    #> map_filter (Ctr_Sugar.ctr_sugar_of ctxt)

fun get_split_ths ctxt = get_type_infos ctxt #> map #split

val strip_eq = Thm.prop_of #> HOLogic.dest_Trueprop #> HOLogic.dest_eq


local

  fun transpose [] = []
    | transpose ([] :: xss) = transpose xss
    | transpose xss = map hd xss :: transpose (map tl xss);

  fun same_fun single_ctrs (ts as _ $ _ :: _) =
      let
        val (fs, argss) = map strip_comb ts |> split_list
        val f = hd fs
        fun is_single_ctr (Const (name, _)) = member (op =) single_ctrs name
          | is_single_ctr _ = false
      in if not (is_single_ctr f) andalso forall (fn x => f = x) fs then SOME (f, argss) else NONE end
    | same_fun _ _ = NONE

  (* pats must be non-empty *)
  fun split_pat single_ctrs pats ctxt =
      case same_fun single_ctrs pats of
        NONE =>
          let
            val (name, ctxt') = yield_singleton Variable.variant_fixes "x" ctxt
            val var = Free (name, fastype_of (hd pats))
          in (((var, [var]), map single pats), ctxt') end
      | SOME (f, argss) =>
          let
            val (((def_pats, def_frees), case_patss), ctxt') =
              split_pats single_ctrs argss ctxt
            val def_pat = list_comb (f, def_pats)
          in (((def_pat, flat def_frees), case_patss), ctxt') end
  and
      split_pats single_ctrs patss ctxt =
        let
          val (splitted, ctxt') = fold_map (split_pat single_ctrs) (transpose patss) ctxt
          val r = splitted |> split_list |> apfst split_list |> apsnd (transpose #> map flat)
        in (r, ctxt') end

(*
  Takes a list lhss of left hand sides (which are lists of patterns)
  and a list rhss of right hand sides. Returns
    - a single equation with a (nested) case-expression on the rhs
    - a list of all split-thms needed to split the rhs
  Patterns which have the same outer context in all lhss remain
  on the lhs of the computed equation.
*)
fun build_case_t fun_t lhss rhss ctxt =
  let
    val single_ctrs =
      get_type_infos ctxt (map fastype_of (flat lhss))
      |> map_filter (fn ti => case #ctrs ti of [Const (name, _)] => SOME name | _ => NONE)
    val (((def_pats, def_frees), case_patss), ctxt') =
      split_pats single_ctrs lhss ctxt
    val pattern = map HOLogic.mk_tuple case_patss
    val case_arg = HOLogic.mk_tuple (flat def_frees)
    val cases = Case_Translation.make_case ctxt' Case_Translation.Warning Name.context
      case_arg (pattern ~~ rhss)
    val split_thms = get_split_ths ctxt' [fastype_of case_arg]
    val t = (list_comb (fun_t, def_pats), cases)
      |> HOLogic.mk_eq
      |> HOLogic.mk_Trueprop
  in ((t, split_thms), ctxt') end

fun tac ctxt {splits, intros, defs} =
  let val ctxt' = Classical.addSIs (ctxt, intros) in
    REPEAT_DETERM1 (FIRSTGOAL (split_tac ctxt splits))
    THEN Local_Defs.unfold_tac ctxt defs
    THEN safe_tac ctxt'
  end

fun import [] ctxt = ([], ctxt)
  | import (thm :: thms) ctxt =
    let
      val fun_ct = strip_eq #> fst #> strip_comb #> fst #> Logic.mk_term
        #> Thm.cterm_of ctxt
      val ct = fun_ct thm
      val cts = map fun_ct thms
      val pairs = map (fn s => (s,ct)) cts
      val thms' = map (fn (th,p) => Thm.instantiate (Thm.match p) th) (thms ~~ pairs)
    in Variable.import true (thm :: thms') ctxt |> apfst snd end

in

(*
  For a list
    f p_11 ... p_1n = t1
    f p_21 ... p_2n = t2
    ...
    f p_mn ... p_mn = tm
  of theorems, prove a single theorem
    f x1 ... xn = t
  where t is a (nested) case expression. f must not be a function
  application. Moreover, the terms p_11, ..., p_mn must be non-overlapping
  datatype patterns. The patterns must be exhausting up to common constructor
  contexts.
*)
fun to_case ctxt ths =
  let
    val (iths, ctxt') = import ths ctxt
    val fun_t = hd iths |> strip_eq |> fst |> head_of
    val eqs = map (strip_eq #> apfst (snd o strip_comb)) iths

    fun hide_rhs ((pat, rhs), name) lthy =
      let
        val frees = fold Term.add_frees pat []
        val abs_rhs = fold absfree frees rhs
        val ([(f, (_, def))], lthy') = lthy
          |> Local_Defs.define [((Binding.name name, Mixfix.NoSyn), (Thm.empty_binding, abs_rhs))]
      in ((list_comb (f, map Free (rev frees)), def), lthy') end

    val ((def_ts, def_thms), ctxt2) =
      let val names = Name.invent (Variable.names_of ctxt') "rhs" (length eqs)
      in fold_map hide_rhs (eqs ~~ names) ctxt' |> apfst split_list end

    val ((t, split_thms), ctxt3) = build_case_t fun_t (map fst eqs) def_ts ctxt2

    val th = Goal.prove ctxt3 [] [] t (fn {context=ctxt, ...} =>
          tac ctxt {splits=split_thms, intros=ths, defs=def_thms})
  in th
    |> singleton (Proof_Context.export ctxt3 ctxt)
    |> Goal.norm_result ctxt
  end

end

local

fun was_split t =
  let
    val is_free_eq_imp = is_Free o fst o HOLogic.dest_eq o fst o HOLogic.dest_imp
    val get_conjs = HOLogic.dest_conj o HOLogic.dest_Trueprop
    fun dest_alls (Const (@{const_name All}, _) $ Abs (_, _, t)) = dest_alls t
      | dest_alls t = t
  in forall (is_free_eq_imp o dest_alls) (get_conjs t) end
  handle TERM _ => false

fun apply_split ctxt split thm = Seq.of_list
  let val ((_,thm'), ctxt') = Variable.import false [thm] ctxt in
    (Variable.export ctxt' ctxt) (filter (was_split o Thm.prop_of) (thm' RL [split]))
  end

fun forward_tac rules t = Seq.of_list ([t] RL rules)

val refl_imp = refl RSN (2, mp)

val get_rules_once_split =
  REPEAT (forward_tac [conjunct1, conjunct2])
    THEN REPEAT (forward_tac [spec])
    THEN (forward_tac [refl_imp])

fun do_split ctxt split =
  case try op RS (split, iffD1) of
    NONE => raise TERM ("malformed split rule", [Thm.prop_of split])
  | SOME split' =>
      let val split_rhs = Thm.concl_of (hd (snd (fst (Variable.import false [split'] ctxt))))
      in if was_split split_rhs
         then DETERM (apply_split ctxt split') THEN get_rules_once_split
         else raise TERM ("malformed split rule", [split_rhs])
      end

val atomize_meta_eq = forward_tac [meta_eq_to_obj_eq]

in

fun gen_to_simps ctxt splitthms thm =
  let val splitthms' = filter (fn t => not (Thm.eq_thm (t, Drule.dummy_thm))) splitthms
  in
    Seq.list_of ((TRY atomize_meta_eq THEN (REPEAT (FIRST (map (do_split ctxt) splitthms')))) thm)
  end

fun to_simps ctxt thm =
  let
    val T = thm |> strip_eq |> fst |> strip_comb |> fst |> fastype_of
    val splitthms = get_split_ths ctxt [T]
  in gen_to_simps ctxt splitthms thm end


end

fun case_of_simps_cmd (bind, thms_ref) lthy =
  let
    val bind' = apsnd (map (Attrib.check_src lthy)) bind
    val thm = Attrib.eval_thms lthy thms_ref |> to_case lthy
  in
    Local_Theory.note (bind', [thm]) lthy |> snd
  end

fun simps_of_case_cmd ((bind, thm_ref), splits_ref) lthy =
  let
    val bind' = apsnd (map (Attrib.check_src lthy)) bind
    val thm = singleton (Attrib.eval_thms lthy) thm_ref
    val simps = if null splits_ref
      then to_simps lthy thm
      else gen_to_simps lthy (Attrib.eval_thms lthy splits_ref) thm
  in
    Local_Theory.note (bind', simps) lthy |> snd
  end

val _ =
  Outer_Syntax.local_theory @{command_keyword case_of_simps}
    "turn a list of equations into a case expression"
    (Parse_Spec.opt_thm_name ":"  -- Parse.thms1 >> case_of_simps_cmd)

val parse_splits = @{keyword "("} |-- Parse.reserved "splits" |-- @{keyword ":"} |--
  Parse.thms1 --| @{keyword ")"}

val _ =
  Outer_Syntax.local_theory @{command_keyword simps_of_case}
    "perform case split on rule"
    (Parse_Spec.opt_thm_name ":"  -- Parse.thm --
      Scan.optional parse_splits [] >> simps_of_case_cmd)

end