src/HOL/Tools/Function/pattern_split.ML
author blanchet
Wed, 08 Jun 2011 16:20:18 +0200
changeset 43293 a80cdc4b27a3
parent 42483 39eefaef816a
child 54406 a2d18fea844a
permissions -rw-r--r--
made "query" type systes a bit more sound -- local facts, e.g. the negated conjecture, may make invalid the infinity check, e.g. if we are proving that there exists two values of an infinite type, we can use the negated conjecture that there is only one value to derive unsound proofs unless the type is properly encoded

(*  Title:      HOL/Tools/Function/pattern_split.ML
    Author:     Alexander Krauss, TU Muenchen

Fairly ad-hoc pattern splitting.
*)

signature FUNCTION_SPLIT =
sig
  val split_some_equations :
      Proof.context -> (bool * term) list -> term list list

  val split_all_equations :
      Proof.context -> term list -> term list list
end

structure Function_Split : FUNCTION_SPLIT =
struct

open Function_Lib

fun new_var ctxt vs T =
  let
    val [v] = Variable.variant_frees ctxt vs [("v", T)]
  in
    (Free v :: vs, Free v)
  end

fun saturate ctxt vs t =
  fold (fn T => fn (vs, t) => new_var ctxt vs T |> apsnd (curry op $ t))
    (binder_types (fastype_of t)) (vs, t)


(* This is copied from "pat_completeness.ML" *)
fun inst_constrs_of thy (T as Type (name, _)) =
  map (fn (Cn,CT) =>
    Envir.subst_term_types (Sign.typ_match thy (body_type CT, T) Vartab.empty) (Const (Cn, CT)))
    (the (Datatype.get_constrs thy name))
  | inst_constrs_of thy T = raise TYPE ("inst_constrs_of", [T], [])


fun join ((vs1,sub1), (vs2,sub2)) = (merge (op aconv) (vs1,vs2), sub1 @ sub2)
fun join_product (xs, ys) = map_product (curry join) xs ys

exception DISJ

fun pattern_subtract_subst ctxt vs t t' =
  let
    exception DISJ
    fun pattern_subtract_subst_aux vs _ (Free v2) = []
      | pattern_subtract_subst_aux vs (v as (Free (_, T))) t' =
          let
            fun foo constr =
              let
                val (vs', t) = saturate ctxt vs constr
                val substs = pattern_subtract_subst ctxt vs' t t'
              in
                map (fn (vs, subst) => (vs, (v,t)::subst)) substs
              end
          in
            maps foo (inst_constrs_of (Proof_Context.theory_of ctxt) T)
          end
     | pattern_subtract_subst_aux vs t t' =
         let
           val (C, ps) = strip_comb t
           val (C', qs) = strip_comb t'
         in
           if C = C'
           then flat (map2 (pattern_subtract_subst_aux vs) ps qs)
           else raise DISJ
         end
  in
    pattern_subtract_subst_aux vs t t'
    handle DISJ => [(vs, [])]
  end

(* p - q *)
fun pattern_subtract ctxt eq2 eq1 =
  let
    val thy = Proof_Context.theory_of ctxt

    val (vs, feq1 as (_ $ (_ $ lhs1 $ _))) = dest_all_all eq1
    val (_,  _ $ (_ $ lhs2 $ _)) = dest_all_all eq2

    val substs = pattern_subtract_subst ctxt vs lhs1 lhs2

    fun instantiate (vs', sigma) =
      let
        val t = Pattern.rewrite_term thy sigma [] feq1
        val xs = fold_aterms
          (fn x as Free (a, _) =>
              if not (Variable.is_fixed ctxt a) andalso member (op =) vs' x
              then insert (op =) x else I
            | _ => I) t [];
      in fold Logic.all xs t end
  in
    map instantiate substs
  end

(* ps - p' *)
fun pattern_subtract_from_many ctxt p'=
  maps (pattern_subtract ctxt p')

(* in reverse order *)
fun pattern_subtract_many ctxt ps' =
  fold_rev (pattern_subtract_from_many ctxt) ps'

fun split_some_equations ctxt eqns =
  let
    fun split_aux prev [] = []
      | split_aux prev ((true, eq) :: es) =
          pattern_subtract_many ctxt prev [eq] :: split_aux (eq :: prev) es
      | split_aux prev ((false, eq) :: es) =
          [eq] :: split_aux (eq :: prev) es
  in
    split_aux [] eqns
  end

fun split_all_equations ctxt =
  split_some_equations ctxt o map (pair true)


end