author paulson
Fri, 20 Aug 2004 12:20:09 +0200
changeset 15150 c7af682b9ee5
child 15250 217bececa2bd
permissions -rw-r--r--
fix to eliminate excessive case-splits in the recursion equations, by Luca Dixon

(* -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- *) 
(*  Title:      TFL/casesplit.ML
    Author:     Lucas Dixon, University of Edinburgh
    Date:       17 Aug 2004
(* -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- *) 

    A structure that defines a tactic to program case splits. 

    casesplit_free :
      string * Term.type -> int -> Thm.thm -> Thm.thm Seq.seq

    casesplit_name : 
      string -> int -> Thm.thm -> Thm.thm Seq.seq

    These use the induction theorem associated with the recursive data
    type to be split. 

    The structure includes a function to try and recursively split a
    conjecture into a list sub-theorems: 

    splitto : Thm.thm list -> Thm.thm -> Thm.thm
(* -=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=-=- *) 

(* logic-specific *)
signature CASE_SPLIT_DATA =
  val dest_Trueprop : Term.term -> Term.term
  val mk_Trueprop : Term.term -> Term.term
  val read_cterm : -> string -> Thm.cterm

(* for HOL *)
structure CaseSplitData_HOL : CASE_SPLIT_DATA = 
val dest_Trueprop = HOLogic.dest_Trueprop;
val mk_Trueprop = HOLogic.mk_Trueprop;
val read_cterm = HOLogic.read_cterm;

signature CASE_SPLIT =
  (* failure to find a free to split on *)
  exception find_split_exp of string

  (* getting a case split thm from the induction thm *)
  val case_thm_of_ty : -> Term.typ -> Thm.thm
  val cases_thm_of_induct_thm : Thm.thm -> Thm.thm

  (* case split tactics *)
  val casesplit_free :
      string * Term.typ -> int -> Thm.thm -> Thm.thm Seq.seq
  val casesplit_name : string -> int -> Thm.thm -> Thm.thm Seq.seq

  (* finding a free var to split *)
  val find_term_split :
      Term.term * Term.term -> (string * Term.typ) Library.option
  val find_thm_split :
      Thm.thm -> int -> Thm.thm -> (string * Term.typ) Library.option
  val find_thms_split :
      Thm.thm list -> int -> Thm.thm -> (string * Term.typ) Library.option

  (* try to recursively split conjectured thm to given list of thms *)
  val splitto : Thm.thm list -> Thm.thm -> Thm.thm

  (* for use with the recdef package *)
  val derive_init_eqs : ->
      (Thm.thm * int) list -> Term.term list -> (Thm.thm * int) list

functor CaseSplitFUN(Data : CASE_SPLIT_DATA) =

(* beta-eta contract the theorem *)
fun beta_eta_contract thm = 
      val thm2 = equal_elim (Thm.beta_conversion true (Thm.cprop_of thm)) thm
      val thm3 = equal_elim (Thm.eta_conversion (Thm.cprop_of thm2)) thm2
    in thm3 end;

(* make a casethm from an induction thm *)
val cases_thm_of_induct_thm = 
     Seq.hd o (ALLGOALS (fn i => REPEAT (etac Drule.thin_rl i)));

(* get the case_thm (my version) from a type *)
fun case_thm_of_ty sgn ty  = 
      val dtypestab = DatatypePackage.get_datatypes_sg sgn;
      val ty_str = case ty of 
                     Type(ty_str, _) => ty_str
                   | TFree(s,_)  => raise ERROR_MESSAGE 
                                            ("Free type: " ^ s)   
                   | TVar((s,i),_) => raise ERROR_MESSAGE 
                                            ("Free variable: " ^ s)   
      val dt = case (Symtab.lookup (dtypestab,ty_str))
                of Some dt => dt
                 | None => raise ERROR_MESSAGE ("Not a Datatype: " ^ ty_str)
      cases_thm_of_induct_thm (#induction dt)

 val ty = (snd o hd o map Term.dest_Free o Term.term_frees) t;  

(* for use when there are no prems to the subgoal *)
(* does a case split on the given variable *)
fun mk_casesplit_goal_thm sgn (vstr,ty) gt = 
      val x = Free(vstr,ty)
      val abst = Abs(vstr, ty, Term.abstract_over (x, gt));

      val ctermify = Thm.cterm_of sgn;
      val ctypify = Thm.ctyp_of sgn;
      val case_thm = case_thm_of_ty sgn ty;

      val abs_ct = ctermify abst;
      val free_ct = ctermify x;

      val casethm_vars = rev (Term.term_vars (Thm.concl_of case_thm));
      val tsig = Sign.tsig_of sgn;
      val casethm_tvars = Term.term_tvars (Thm.concl_of case_thm);
      val (Pv, Dv, type_insts) = 
          case (Thm.concl_of case_thm) of 
            (_ $ ((Pv as Var(P,Pty)) $ (Dv as Var(D, Dty)))) => 
            (Pv, Dv, 
             Vartab.dest (Type.typ_match tsig (Vartab.empty, (Dty, ty))))
          | _ => raise ERROR_MESSAGE ("not a valid case thm");
      val type_cinsts = map (apsnd ctypify) type_insts;
      val cPv = ctermify (Sign.inst_term_tvars sgn type_insts Pv);
      val cDv = ctermify (Sign.inst_term_tvars sgn type_insts Dv);
            |> Thm.instantiate (type_cinsts, []) 
            |> Thm.instantiate ([], [(cPv, abs_ct), (cDv, free_ct)])))

(* for use when there are no prems to the subgoal *)
(* does a case split on the given variable (Free fv) *)
fun casesplit_free fv i th = 
      val gt = Data.dest_Trueprop (nth_elem( i - 1, Thm.prems_of th));
      val sgn = Thm.sign_of_thm th;
      Tactic.rtac (mk_casesplit_goal_thm sgn fv gt) i th

(* for use when there are no prems to the subgoal *)
(* does a case split on the given variable *)
fun casesplit_name vstr i th = 
      val gt = Data.dest_Trueprop (nth_elem( i - 1, Thm.prems_of th));
      val freets = Term.term_frees gt;
      fun getter x = let val (n,ty) = Term.dest_Free x in 
                       if vstr = n then Some (n,ty) else None end;
      val (n,ty) = case Library.get_first getter freets 
                of Some (n, ty) => (n, ty)
                 | _ => raise ERROR_MESSAGE ("no such variable " ^ vstr);
      val sgn = Thm.sign_of_thm th;
      Tactic.rtac (mk_casesplit_goal_thm sgn (n,ty) gt) i th

(* small example: 
Goal "P (x :: nat) & (C y --> Q (y :: nat))";
by (rtac (thm "conjI") 1);
val th = topthm();
val i = 2;
val vstr = "y";

by (casesplit_name "y" 2);

val th = topthm();
val i = 1;
val th' = casesplit_name "x" i th;

(* the find_XXX_split functions are simply doing a lightwieght (I
think) term matching equivalent to find where to do the next split *)

(* assuming two twems are identical except for a free in one at a
subterm, or constant in another, ie assume that one term is a plit of
another, then gives back the free variable that has been split. *)
exception find_split_exp of string
fun find_term_split (Free v, _ $ _) = Some v
  | find_term_split (Free v, Const _) = Some v
  | find_term_split (Free v, Abs _) = Some v (* do we really want this case? *)
  | find_term_split (a $ b, a2 $ b2) = 
    (case find_term_split (a, a2) of 
       None => find_term_split (b,b2)  
     | vopt => vopt)
  | find_term_split (Abs(_,ty,t1), Abs(_,ty2,t2)) = 
    find_term_split (t1, t2)
  | find_term_split (Const (x,ty), Const(x2,ty2)) = 
    if x = x2 then None else (* keep searching *)
    raise find_split_exp (* stop now *)
            "Terms are not identical upto a free varaible! (Consts)"
  | find_term_split (Bound i, Bound j) =     
    if i = j then None else (* keep searching *)
    raise find_split_exp (* stop now *)
            "Terms are not identical upto a free varaible! (Bound)"
  | find_term_split (a, b) = 
    raise find_split_exp (* stop now *)
            "Terms are not identical upto a free varaible! (Other)";

(* assume that "splitth" is a case split form of subgoal i of "genth",
then look for a free variable to split, breaking the subgoal closer to
splitth. *)
fun find_thm_split splitth i genth =
    find_term_split (Logic.get_goal (Thm.prop_of genth) i, 
                     Thm.concl_of splitth) handle find_split_exp _ => None;

(* as above but searches "splitths" for a theorem that suggest a case split *)
fun find_thms_split splitths i genth =
    Library.get_first (fn sth => find_thm_split sth i genth) splitths;

(* split the subgoal i of "genth" until we get to a member of
splitths. Assumes that genth will be a general form of splitths, that
can be case-split, as needed. Otherwise fails. Note: We assume that
all of "splitths" are aplit to the same level, and thus it doesn't
matter which one we choose to look for the next split. Simply add
search on splitthms and plit variable, to change this.  *)
(* Note: possible efficiency measure: when a case theorem is no longer
useful, drop it? *)
(* Note: This should not be a separate tactic but integrated into the
case split done during recdef's case analysis, this would avoid us
having to (re)search for variables to split. *)
fun splitto splitths genth = 
      val _ = assert (not (null splitths)) "splitto: no given splitths";
      val sgn = Thm.sign_of_thm genth;

      (* check if we are a member of splitths - FIXME: quicker and 
      more flexible with discrim net. *)
      fun solve_by_splitth th split = biresolution false [(false,split)] 1 th;

      fun split th = 
          (case find_thms_split splitths 1 th of 
             None => raise ERROR_MESSAGE "splitto: cannot find variable to split on"
            | Some v => 
               val gt = Data.dest_Trueprop (nth_elem(0, Thm.prems_of th));
               val split_thm = mk_casesplit_goal_thm sgn v gt;
               val (subthms, expf) = IsaND.fixed_subgoal_thms split_thm;
               expf (map recsplitf subthms)

      and recsplitf th = 
          (* note: multiple unifiers! we only take the first element,
             probably fine -- there is probably only one anyway. *)
          (case Library.get_first (Seq.pull o solve_by_splitth th) splitths of
             None => split th
           | Some (solved_th, more) => solved_th)
      recsplitf genth

(* Note: We dont do this if wf conditions fail to be solved, as each
case may have a different wf condition - we could group the conditions
togeather and say that they must be true to solve the general case,
but that would hide from the user which sub-case they were related
to. Probably this is not important, and it would work fine, but I
prefer leaving more fine grain control to the user. *)

(* derive eqs, assuming strict, ie the rules have no assumptions = all
   the well-foundness conditions have been solved. *)
  fun get_related_thms i = 
      mapfilter ((fn (r,x) => if x = i then Some r else None));
  fun solve_eq (th, [], i) = 
      raise ERROR_MESSAGE "derive_init_eqs: missing rules"
    | solve_eq (th, [a], i) = (a, i)
    | solve_eq (th, splitths as (_ :: _), i) = (splitto splitths th,i);
fun derive_init_eqs sgn rules eqs = 
      val eqths = map (Thm.trivial o (Thm.cterm_of sgn) o Data.mk_Trueprop) 
      (rev o map solve_eq)
           (fn (e,i) => 
               (curry (op ::)) (e, (get_related_thms (i - 1) rules), i - 1)) 
           eqths [])
    val (rs_hwfc, unhidefs) = Library.split_list (map hide_prems rules)
    (map2 (op |>) (ths, expfs))


structure CaseSplit = CaseSplitFUN(CaseSplitData_HOL);