src/HOL/Tools/Function/pattern_split.ML
author haftmann
Tue Oct 20 16:13:01 2009 +0200 (2009-10-20)
changeset 33037 b22e44496dc2
parent 32952 aeb1e44fbc19
child 33038 8f9594c31de4
permissions -rw-r--r--
replaced old_style infixes eq_set, subset, union, inter and variants by generic versions
haftmann@31775
     1
(*  Title:      HOL/Tools/Function/pattern_split.ML
krauss@20270
     2
    Author:     Alexander Krauss, TU Muenchen
krauss@20270
     3
wenzelm@20344
     4
A package for general recursive function definitions.
krauss@20270
     5
wenzelm@20344
     6
Automatic splitting of overlapping constructor patterns. This is a preprocessing step which
krauss@20270
     7
turns a specification with overlaps into an overlap-free specification.
krauss@20270
     8
krauss@20270
     9
*)
krauss@20270
    10
wenzelm@20344
    11
signature FUNDEF_SPLIT =
krauss@20270
    12
sig
wenzelm@20289
    13
  val split_some_equations :
krauss@21237
    14
      Proof.context -> (bool * term) list -> term list list
krauss@20270
    15
krauss@20523
    16
  val split_all_equations :
krauss@21237
    17
      Proof.context -> term list -> term list list
krauss@20270
    18
end
krauss@20270
    19
wenzelm@20344
    20
structure FundefSplit : FUNDEF_SPLIT =
krauss@20270
    21
struct
krauss@20270
    22
krauss@21051
    23
open FundefLib
krauss@20270
    24
krauss@20270
    25
(* We use proof context for the variable management *)
krauss@20270
    26
(* FIXME: no __ *)
krauss@20270
    27
wenzelm@20344
    28
fun new_var ctx vs T =
wenzelm@20344
    29
    let
krauss@20270
    30
      val [v] = Variable.variant_frees ctx vs [("v", T)]
krauss@20270
    31
    in
krauss@20270
    32
      (Free v :: vs, Free v)
krauss@20270
    33
    end
krauss@20270
    34
krauss@20270
    35
fun saturate ctx vs t =
krauss@20270
    36
    fold (fn T => fn (vs, t) => new_var ctx vs T |> apsnd (curry op $ t))
krauss@20270
    37
         (binder_types (fastype_of t)) (vs, t)
krauss@21237
    38
         
krauss@21237
    39
         
krauss@20270
    40
(* This is copied from "fundef_datatype.ML" *)
krauss@20270
    41
fun inst_constrs_of thy (T as Type (name, _)) =
wenzelm@32035
    42
    map (fn (Cn,CT) =>
wenzelm@32035
    43
          Envir.subst_term_types (Sign.typ_match thy (body_type CT, T) Vartab.empty) (Const (Cn, CT)))
haftmann@31784
    44
        (the (Datatype.get_constrs thy name))
wenzelm@25402
    45
  | inst_constrs_of thy T = raise TYPE ("inst_constrs_of", [T], [])
krauss@21237
    46
                            
krauss@21237
    47
                            
krauss@21237
    48
                            
krauss@20636
    49
krauss@20636
    50
fun join ((vs1,sub1), (vs2,sub2)) = (merge (op aconv) (vs1,vs2), sub1 @ sub2)
haftmann@25538
    51
fun join_product (xs, ys) = map_product (curry join) xs ys
krauss@20636
    52
krauss@20636
    53
fun join_list [] = []
krauss@20636
    54
  | join_list xs = foldr1 (join_product) xs
krauss@20636
    55
krauss@20636
    56
krauss@20636
    57
exception DISJ
krauss@20636
    58
krauss@20636
    59
fun pattern_subtract_subst ctx vs t t' =
wenzelm@20344
    60
    let
krauss@20636
    61
      exception DISJ
krauss@20636
    62
      fun pattern_subtract_subst_aux vs _ (Free v2) = []
krauss@20636
    63
        | pattern_subtract_subst_aux vs (v as (Free (_, T))) t' =
wenzelm@20344
    64
          let
krauss@20636
    65
            fun foo constr =
krauss@20636
    66
                let
krauss@20636
    67
                  val (vs', t) = saturate ctx vs constr
krauss@20636
    68
                  val substs = pattern_subtract_subst ctx vs' t t'
krauss@20636
    69
                in
krauss@20636
    70
                  map (fn (vs, subst) => (vs, (v,t)::subst)) substs
krauss@20636
    71
                end
krauss@20270
    72
          in
wenzelm@32952
    73
            maps foo (inst_constrs_of (ProofContext.theory_of ctx) T)
krauss@20636
    74
          end
krauss@20636
    75
        | pattern_subtract_subst_aux vs t t' =
krauss@20636
    76
          let
krauss@20636
    77
            val (C, ps) = strip_comb t
krauss@20636
    78
            val (C', qs) = strip_comb t'
krauss@20636
    79
          in
krauss@20636
    80
            if C = C'
krauss@20636
    81
            then flat (map2 (pattern_subtract_subst_aux vs) ps qs)
krauss@20636
    82
            else raise DISJ
krauss@20270
    83
          end
krauss@20270
    84
    in
krauss@20636
    85
      pattern_subtract_subst_aux vs t t'
krauss@20636
    86
      handle DISJ => [(vs, [])]
krauss@20270
    87
    end
krauss@20270
    88
krauss@20270
    89
krauss@20523
    90
(* p - q *)
krauss@20270
    91
fun pattern_subtract ctx eq2 eq1 =
krauss@20270
    92
    let
krauss@20636
    93
      val thy = ProofContext.theory_of ctx
krauss@21237
    94
                
krauss@20523
    95
      val (vs, feq1 as (_ $ (_ $ lhs1 $ _))) = dest_all_all eq1
krauss@20523
    96
      val (_,  _ $ (_ $ lhs2 $ _)) = dest_all_all eq2
krauss@21237
    97
                                     
krauss@20523
    98
      val substs = pattern_subtract_subst ctx vs lhs1 lhs2
krauss@21237
    99
                   
krauss@20523
   100
      fun instantiate (vs', sigma) =
krauss@20523
   101
          let
krauss@20523
   102
            val t = Pattern.rewrite_term thy sigma [] feq1
krauss@20523
   103
          in
haftmann@33037
   104
            fold_rev Logic.all (gen_inter (op =) (map Free (frees_in_term ctx t), vs')) t
krauss@20523
   105
          end
krauss@20270
   106
    in
krauss@20654
   107
      map instantiate substs
krauss@20270
   108
    end
krauss@21237
   109
      
krauss@20270
   110
krauss@20270
   111
(* ps - p' *)
krauss@20270
   112
fun pattern_subtract_from_many ctx p'=
wenzelm@32952
   113
    maps (pattern_subtract ctx p')
krauss@20270
   114
krauss@20270
   115
(* in reverse order *)
krauss@20270
   116
fun pattern_subtract_many ctx ps' =
krauss@20270
   117
    fold_rev (pattern_subtract_from_many ctx) ps'
krauss@20270
   118
krauss@20270
   119
krauss@20270
   120
krauss@20270
   121
fun split_some_equations ctx eqns =
krauss@20270
   122
    let
krauss@20338
   123
      fun split_aux prev [] = []
krauss@20523
   124
        | split_aux prev ((true, eq) :: es) = pattern_subtract_many ctx prev [eq]
krauss@20523
   125
                                              :: split_aux (eq :: prev) es
krauss@20523
   126
        | split_aux prev ((false, eq) :: es) = [eq]
krauss@20523
   127
                                               :: split_aux (eq :: prev) es
krauss@20270
   128
    in
krauss@20270
   129
      split_aux [] eqns
krauss@20270
   130
    end
krauss@21237
   131
    
krauss@20523
   132
fun split_all_equations ctx =
krauss@20523
   133
    split_some_equations ctx o map (pair true)
krauss@20523
   134
krauss@20523
   135
krauss@20523
   136
krauss@20523
   137
krauss@20270
   138
end