src/HOL/Tools/function_package/pattern_split.ML
changeset 21237 b803f9870e97
parent 21051 c49467a9c1e1
child 24584 01e83ffa6c54
equal deleted inserted replaced
21236:890fafbcf8b0 21237:b803f9870e97
    10 *)
    10 *)
    11 
    11 
    12 signature FUNDEF_SPLIT =
    12 signature FUNDEF_SPLIT =
    13 sig
    13 sig
    14   val split_some_equations :
    14   val split_some_equations :
    15     Proof.context -> (bool * term) list -> term list list
    15       Proof.context -> (bool * term) list -> term list list
    16 
    16 
    17   val split_all_equations :
    17   val split_all_equations :
    18     Proof.context -> term list -> term list list
    18       Proof.context -> term list -> term list list
    19 end
    19 end
    20 
    20 
    21 structure FundefSplit : FUNDEF_SPLIT =
    21 structure FundefSplit : FUNDEF_SPLIT =
    22 struct
    22 struct
    23 
    23 
    34     end
    34     end
    35 
    35 
    36 fun saturate ctx vs t =
    36 fun saturate ctx vs t =
    37     fold (fn T => fn (vs, t) => new_var ctx vs T |> apsnd (curry op $ t))
    37     fold (fn T => fn (vs, t) => new_var ctx vs T |> apsnd (curry op $ t))
    38          (binder_types (fastype_of t)) (vs, t)
    38          (binder_types (fastype_of t)) (vs, t)
    39 
    39          
    40 
    40          
    41 (* This is copied from "fundef_datatype.ML" *)
    41 (* This is copied from "fundef_datatype.ML" *)
    42 fun inst_constrs_of thy (T as Type (name, _)) =
    42 fun inst_constrs_of thy (T as Type (name, _)) =
    43         map (fn (Cn,CT) => Envir.subst_TVars (Sign.typ_match thy (body_type CT, T) Vartab.empty) (Const (Cn, CT)))
    43     map (fn (Cn,CT) => Envir.subst_TVars (Sign.typ_match thy (body_type CT, T) Vartab.empty) (Const (Cn, CT)))
    44             (the (DatatypePackage.get_datatype_constrs thy name))
    44         (the (DatatypePackage.get_datatype_constrs thy name))
    45   | inst_constrs_of thy t = (print t; sys_error "inst_constrs_of")
    45   | inst_constrs_of thy t = (print t; sys_error "inst_constrs_of")
    46 
    46                             
    47 
    47                             
    48 
    48                             
    49 
    49 
    50 fun join ((vs1,sub1), (vs2,sub2)) = (merge (op aconv) (vs1,vs2), sub1 @ sub2)
    50 fun join ((vs1,sub1), (vs2,sub2)) = (merge (op aconv) (vs1,vs2), sub1 @ sub2)
    51 fun join_product (xs, ys) = map join (product xs ys)
    51 fun join_product (xs, ys) = map join (product xs ys)
    52 
    52 
    53 fun join_list [] = []
    53 fun join_list [] = []
    89 
    89 
    90 (* p - q *)
    90 (* p - q *)
    91 fun pattern_subtract ctx eq2 eq1 =
    91 fun pattern_subtract ctx eq2 eq1 =
    92     let
    92     let
    93       val thy = ProofContext.theory_of ctx
    93       val thy = ProofContext.theory_of ctx
    94       
    94                 
    95       val (vs, feq1 as (_ $ (_ $ lhs1 $ _))) = dest_all_all eq1
    95       val (vs, feq1 as (_ $ (_ $ lhs1 $ _))) = dest_all_all eq1
    96       val (_,  _ $ (_ $ lhs2 $ _)) = dest_all_all eq2
    96       val (_,  _ $ (_ $ lhs2 $ _)) = dest_all_all eq2
    97 
    97                                      
    98       val substs = pattern_subtract_subst ctx vs lhs1 lhs2
    98       val substs = pattern_subtract_subst ctx vs lhs1 lhs2
    99 
    99                    
   100       fun instantiate (vs', sigma) =
   100       fun instantiate (vs', sigma) =
   101           let
   101           let
   102             val t = Pattern.rewrite_term thy sigma [] feq1
   102             val t = Pattern.rewrite_term thy sigma [] feq1
   103           in
   103           in
   104             fold_rev mk_forall (map Free (frees_in_term ctx t) inter vs') t
   104             fold_rev mk_forall (map Free (frees_in_term ctx t) inter vs') t
   105           end
   105           end
   106     in
   106     in
   107       map instantiate substs
   107       map instantiate substs
   108     end
   108     end
   109 
   109       
   110 
   110 
   111 (* ps - p' *)
   111 (* ps - p' *)
   112 fun pattern_subtract_from_many ctx p'=
   112 fun pattern_subtract_from_many ctx p'=
   113     flat o map (pattern_subtract ctx p')
   113     flat o map (pattern_subtract ctx p')
   114 
   114 
   126         | split_aux prev ((false, eq) :: es) = [eq]
   126         | split_aux prev ((false, eq) :: es) = [eq]
   127                                                :: split_aux (eq :: prev) es
   127                                                :: split_aux (eq :: prev) es
   128     in
   128     in
   129       split_aux [] eqns
   129       split_aux [] eqns
   130     end
   130     end
   131 
   131     
   132 fun split_all_equations ctx =
   132 fun split_all_equations ctx =
   133     split_some_equations ctx o map (pair true)
   133     split_some_equations ctx o map (pair true)
   134 
   134 
   135 
   135 
   136 
   136