src/HOL/Tools/SMT2/smt2_normalize.ML
changeset 56103 6689512f3710
parent 56100 0dc5f68a7802
child 56104 fd6e132ee4fb
equal deleted inserted replaced
56102:439dda276b3f 56103:6689512f3710
    56 (** atomization **)
    56 (** atomization **)
    57 
    57 
    58 fun atomize_conv ctxt ct =
    58 fun atomize_conv ctxt ct =
    59   (case Thm.term_of ct of
    59   (case Thm.term_of ct of
    60     @{const "==>"} $ _ $ _ =>
    60     @{const "==>"} $ _ $ _ =>
    61       Conv.binop_conv (atomize_conv ctxt) then_conv
    61       Conv.binop_conv (atomize_conv ctxt) then_conv Conv.rewr_conv @{thm atomize_imp}
    62       Conv.rewr_conv @{thm atomize_imp}
       
    63   | Const (@{const_name "=="}, _) $ _ $ _ =>
    62   | Const (@{const_name "=="}, _) $ _ $ _ =>
    64       Conv.binop_conv (atomize_conv ctxt) then_conv
    63       Conv.binop_conv (atomize_conv ctxt) then_conv Conv.rewr_conv @{thm atomize_eq}
    65       Conv.rewr_conv @{thm atomize_eq}
       
    66   | Const (@{const_name all}, _) $ Abs _ =>
    64   | Const (@{const_name all}, _) $ Abs _ =>
    67       Conv.binder_conv (atomize_conv o snd) ctxt then_conv
    65       Conv.binder_conv (atomize_conv o snd) ctxt then_conv Conv.rewr_conv @{thm atomize_all}
    68       Conv.rewr_conv @{thm atomize_all}
       
    69   | _ => Conv.all_conv) ct
    66   | _ => Conv.all_conv) ct
    70 
    67 
    71 val setup_atomize =
    68 val setup_atomize =
    72   fold SMT2_Builtin.add_builtin_fun_ext'' [@{const_name "==>"},
    69   fold SMT2_Builtin.add_builtin_fun_ext'' [@{const_name "==>"}, @{const_name "=="},
    73     @{const_name "=="}, @{const_name all}, @{const_name Trueprop}]
    70     @{const_name all}, @{const_name Trueprop}]
    74 
    71 
    75 
    72 
    76 (** unfold special quantifiers **)
    73 (** unfold special quantifiers **)
    77 
    74 
    78 local
    75 local
    79   val ex1_def = mk_meta_eq @{lemma
    76   val special_quants = [
    80     "Ex1 = (%P. EX x. P x & (ALL y. P y --> y = x))"
    77     (@{const_name Ex1}, @{thm Ex1_def_raw}),
    81     by (rule ext) (simp only: Ex1_def)}
    78     (@{const_name Ball}, @{thm Ball_def_raw}),
    82 
    79     (@{const_name Bex}, @{thm Bex_def_raw})]
    83   val ball_def = mk_meta_eq @{lemma "Ball = (%A P. ALL x. x : A --> P x)"
       
    84     by (rule ext)+ (rule Ball_def)}
       
    85 
       
    86   val bex_def = mk_meta_eq @{lemma "Bex = (%A P. EX x. x : A & P x)"
       
    87     by (rule ext)+ (rule Bex_def)}
       
    88 
       
    89   val special_quants = [(@{const_name Ex1}, ex1_def),
       
    90     (@{const_name Ball}, ball_def), (@{const_name Bex}, bex_def)]
       
    91   
    80   
    92   fun special_quant (Const (n, _)) = AList.lookup (op =) special_quants n
    81   fun special_quant (Const (n, _)) = AList.lookup (op =) special_quants n
    93     | special_quant _ = NONE
    82     | special_quant _ = NONE
    94 
    83 
    95   fun special_quant_conv _ ct =
    84   fun special_quant_conv _ ct =
    99 in
    88 in
   100 
    89 
   101 fun unfold_special_quants_conv ctxt =
    90 fun unfold_special_quants_conv ctxt =
   102   SMT2_Util.if_exists_conv (is_some o special_quant) (Conv.top_conv special_quant_conv ctxt)
    91   SMT2_Util.if_exists_conv (is_some o special_quant) (Conv.top_conv special_quant_conv ctxt)
   103 
    92 
   104 val setup_unfolded_quants =
    93 val setup_unfolded_quants = fold (SMT2_Builtin.add_builtin_fun_ext'' o fst) special_quants
   105   fold (SMT2_Builtin.add_builtin_fun_ext'' o fst) special_quants
       
   106 
    94 
   107 end
    95 end
   108 
    96 
   109 
    97 
   110 (** trigger inference **)
    98 (** trigger inference **)
   200   fun mk_list (ccons, cnil) f cts = fold_rev (Thm.mk_binop ccons o f) cts cnil
   188   fun mk_list (ccons, cnil) f cts = fold_rev (Thm.mk_binop ccons o f) cts cnil
   201   val mk_pat_list = mk_list (mk_clist @{typ SMT2.pattern})
   189   val mk_pat_list = mk_list (mk_clist @{typ SMT2.pattern})
   202   val mk_mpat_list = mk_list (mk_clist @{typ "SMT2.pattern list"})  
   190   val mk_mpat_list = mk_list (mk_clist @{typ "SMT2.pattern list"})  
   203   fun mk_trigger ctss = mk_mpat_list (mk_pat_list mk_pat) ctss
   191   fun mk_trigger ctss = mk_mpat_list (mk_pat_list mk_pat) ctss
   204 
   192 
   205   val trigger_eq =
   193   val trigger_eq = mk_meta_eq @{lemma "p = SMT2.trigger t p" by (simp add: trigger_def)}
   206     mk_meta_eq @{lemma "p = SMT2.trigger t p" by (simp add: trigger_def)}
       
   207 
   194 
   208   fun insert_trigger_conv [] ct = Conv.all_conv ct
   195   fun insert_trigger_conv [] ct = Conv.all_conv ct
   209     | insert_trigger_conv ctss ct =
   196     | insert_trigger_conv ctss ct =
   210         let val (ctr, cp) = Thm.dest_binop (Thm.rhs_of trigger_eq) ||> rpair ct
   197         let val (ctr, cp) = Thm.dest_binop (Thm.rhs_of trigger_eq) ||> rpair ct
   211         in Thm.instantiate ([], [cp, (ctr, mk_trigger ctss)]) trigger_eq end
   198         in Thm.instantiate ([], [cp, (ctr, mk_trigger ctss)]) trigger_eq end
   280   fun under_trigger_conv cv ct =
   267   fun under_trigger_conv cv ct =
   281     (case Thm.term_of ct of
   268     (case Thm.term_of ct of
   282       @{const SMT2.trigger} $ _ $ _ => Conv.arg_conv cv
   269       @{const SMT2.trigger} $ _ $ _ => Conv.arg_conv cv
   283     | _ => cv) ct
   270     | _ => cv) ct
   284 
   271 
   285   val weight_eq =
   272   val weight_eq = mk_meta_eq @{lemma "p = SMT2.weight i p" by (simp add: weight_def)}
   286     mk_meta_eq @{lemma "p = SMT2.weight i p" by (simp add: weight_def)}
       
   287   fun mk_weight_eq w =
   273   fun mk_weight_eq w =
   288     let val cv = Thm.dest_arg1 (Thm.rhs_of weight_eq)
   274     let val cv = Thm.dest_arg1 (Thm.rhs_of weight_eq)
   289     in
   275     in Thm.instantiate ([], [(cv, Numeral.mk_cnumber @{ctyp int} w)]) weight_eq end
   290       Thm.instantiate ([], [(cv, Numeral.mk_cnumber @{ctyp int} w)]) weight_eq
       
   291     end
       
   292 
   276 
   293   fun add_weight_conv NONE _ = Conv.all_conv
   277   fun add_weight_conv NONE _ = Conv.all_conv
   294     | add_weight_conv (SOME weight) ctxt =
   278     | add_weight_conv (SOME weight) ctxt =
   295         let val cv = Conv.rewr_conv (mk_weight_eq weight)
   279         let val cv = Conv.rewr_conv (mk_weight_eq weight)
   296         in SMT2_Util.under_quant_conv (K (under_trigger_conv cv)) ctxt end
   280         in SMT2_Util.under_quant_conv (K (under_trigger_conv cv)) ctxt end
   346 
   330 
   347 local
   331 local
   348   fun is_case_bool (Const (@{const_name "bool.case_bool"}, _)) = true
   332   fun is_case_bool (Const (@{const_name "bool.case_bool"}, _)) = true
   349     | is_case_bool _ = false
   333     | is_case_bool _ = false
   350 
   334 
   351   val thm = mk_meta_eq @{lemma
       
   352     "case_bool = (%x y P. if P then x else y)" by (rule ext)+ simp}
       
   353 
       
   354   fun unfold_conv _ =
   335   fun unfold_conv _ =
   355     SMT2_Util.if_true_conv (is_case_bool o Term.head_of) (expand_head_conv (Conv.rewr_conv thm))
   336     SMT2_Util.if_true_conv (is_case_bool o Term.head_of)
       
   337       (expand_head_conv (Conv.rewr_conv @{thm case_bool_if}))
   356 in
   338 in
   357 
   339 
   358 fun rewrite_case_bool_conv ctxt =
   340 fun rewrite_case_bool_conv ctxt =
   359   SMT2_Util.if_exists_conv is_case_bool (Conv.top_conv unfold_conv ctxt)
   341   SMT2_Util.if_exists_conv is_case_bool (Conv.top_conv unfold_conv ctxt)
   360 
   342 
   361 val setup_case_bool =
   343 val setup_case_bool = SMT2_Builtin.add_builtin_fun_ext'' @{const_name "bool.case_bool"}
   362   SMT2_Builtin.add_builtin_fun_ext'' @{const_name "bool.case_bool"}
       
   363 
   344 
   364 end
   345 end
   365 
   346 
   366 
   347 
   367 (** unfold abs, min and max **)
   348 (** unfold abs, min and max **)
   368 
   349 
   369 local
   350 local
   370   val abs_def = mk_meta_eq @{lemma "abs = (%a::'a::abs_if. if a < 0 then - a else a)"
   351   val defs = [
   371     by (rule ext) (rule abs_if)}
   352     (@{const_name min}, @{thm min_def_raw}),
   372 
   353     (@{const_name max}, @{thm max_def_raw}),
   373   val min_def = mk_meta_eq @{lemma "min = (%a b. if a <= b then a else b)"
   354     (@{const_name abs}, @{thm abs_if_raw})]
   374     by (rule ext)+ (rule min_def)}
       
   375 
       
   376   val max_def = mk_meta_eq  @{lemma "max = (%a b. if a <= b then b else a)"
       
   377     by (rule ext)+ (rule max_def)}
       
   378 
       
   379   val defs = [(@{const_name min}, min_def), (@{const_name max}, max_def),
       
   380     (@{const_name abs}, abs_def)]
       
   381 
   355 
   382   fun abs_min_max ctxt (Const (n, Type (@{type_name fun}, [T, _]))) =
   356   fun abs_min_max ctxt (Const (n, Type (@{type_name fun}, [T, _]))) =
   383         (case AList.lookup (op =) defs n of
   357         (case AList.lookup (op =) defs n of
   384           NONE => NONE
   358           NONE => NONE
   385         | SOME thm => if SMT2_Builtin.is_builtin_typ_ext ctxt T then SOME thm else NONE)
   359         | SOME thm => if SMT2_Builtin.is_builtin_typ_ext ctxt T then SOME thm else NONE)
   400 
   374 
   401 
   375 
   402 (** embedding of standard natural number operations into integer operations **)
   376 (** embedding of standard natural number operations into integer operations **)
   403 
   377 
   404 local
   378 local
   405   val nat_embedding = @{lemma
   379   val nat_embedding = @{thms nat_int int_nat_nneg int_nat_neg}
   406     "ALL n. nat (int n) = n"
       
   407     "ALL i. i >= 0 --> int (nat i) = i"
       
   408     "ALL i. i < 0 --> int (nat i) = 0"
       
   409     by simp_all}
       
   410 
   380 
   411   val simple_nat_ops = [
   381   val simple_nat_ops = [
   412     @{const less (nat)}, @{const less_eq (nat)},
   382     @{const less (nat)}, @{const less_eq (nat)},
   413     @{const Suc}, @{const plus (nat)}, @{const minus (nat)}]
   383     @{const Suc}, @{const plus (nat)}, @{const minus (nat)}]
   414 
   384 
   427   val is_nat_const = member (op aconv) nat_consts
   397   val is_nat_const = member (op aconv) nat_consts
   428 
   398 
   429   fun is_nat_const' @{const of_nat (int)} = true
   399   fun is_nat_const' @{const of_nat (int)} = true
   430     | is_nat_const' t = is_nat_const t
   400     | is_nat_const' t = is_nat_const t
   431 
   401 
   432   val expands = map mk_meta_eq @{lemma
   402   val expands = map mk_meta_eq @{thms nat_zero_as_int nat_one_as_int nat_numeral_as_int
   433     "0 = nat 0"
   403     nat_less_as_int nat_leq_as_int Suc_as_int nat_plus_as_int nat_minus_as_int nat_times_as_int
   434     "1 = nat 1"
   404     nat_div_as_int nat_mod_as_int}
   435     "(numeral :: num => nat) = (%i. nat (numeral i))"
   405 
   436     "op < = (%a b. int a < int b)"
   406   val ints = map mk_meta_eq @{thms int_0 int_1 int_Suc int_plus int_minus int_mult zdiv_int
   437     "op <= = (%a b. int a <= int b)"
   407     zmod_int}
   438     "Suc = (%a. nat (int a + 1))"
   408   val int_if = mk_meta_eq @{lemma "int (if P then n else m) = (if P then int n else int m)" by simp}
   439     "op + = (%a b. nat (int a + int b))"
       
   440     "op - = (%a b. nat (int a - int b))"
       
   441     "op * = (%a b. nat (int a * int b))"
       
   442     "op div = (%a b. nat (int a div int b))"
       
   443     "op mod = (%a b. nat (int a mod int b))"
       
   444     by (fastforce simp add: nat_mult_distrib nat_div_distrib nat_mod_distrib)+}
       
   445 
       
   446   val ints = map mk_meta_eq @{lemma
       
   447     "int 0 = 0"
       
   448     "int 1 = 1"
       
   449     "int (Suc n) = int n + 1"
       
   450     "int (n + m) = int n + int m"
       
   451     "int (n - m) = int (nat (int n - int m))"
       
   452     "int (n * m) = int n * int m"
       
   453     "int (n div m) = int n div int m"
       
   454     "int (n mod m) = int n mod int m"
       
   455     by (auto simp add: int_mult zdiv_int zmod_int)}
       
   456 
       
   457   val int_if = mk_meta_eq @{lemma
       
   458     "int (if P then n else m) = (if P then int n else int m)"
       
   459     by simp}
       
   460 
   409 
   461   fun mk_number_eq ctxt i lhs =
   410   fun mk_number_eq ctxt i lhs =
   462     let
   411     let
   463       val eq = SMT2_Util.mk_cequals lhs (Numeral.mk_cnumber @{ctyp int} i)
   412       val eq = SMT2_Util.mk_cequals lhs (Numeral.mk_cnumber @{ctyp int} i)
   464       val ctxt' = put_simpset HOL_ss ctxt addsimps @{thms Int.int_numeral}
   413       val ctxt' = put_simpset HOL_ss ctxt addsimps @{thms Int.int_numeral}