src/HOL/Tools/semiring_normalizer.ML
changeset 61153 3d5e01b427cb
parent 61075 f6b0d827240e
child 61694 6571c78c9667
equal deleted inserted replaced
61152:13b2fd801692 61153:3d5e01b427cb
     2     Author:     Amine Chaieb, TU Muenchen
     2     Author:     Amine Chaieb, TU Muenchen
     3 
     3 
     4 Normalization of expressions in semirings.
     4 Normalization of expressions in semirings.
     5 *)
     5 *)
     6 
     6 
     7 signature SEMIRING_NORMALIZER = 
     7 signature SEMIRING_NORMALIZER =
     8 sig
     8 sig
     9   type entry
     9   type entry
    10   val match: Proof.context -> cterm -> entry option
    10   val match: Proof.context -> cterm -> entry option
    11   val the_semiring: Proof.context -> thm -> cterm list * thm list
    11   val the_semiring: Proof.context -> thm -> cterm list * thm list
    12   val the_ring: Proof.context -> thm -> cterm list * thm list
    12   val the_ring: Proof.context -> thm -> cterm list * thm list
    13   val the_field: Proof.context -> thm -> cterm list * thm list
    13   val the_field: Proof.context -> thm -> cterm list * thm list
    14   val the_idom: Proof.context -> thm -> thm list
    14   val the_idom: Proof.context -> thm -> thm list
    15   val the_ideal: Proof.context -> thm -> thm list
    15   val the_ideal: Proof.context -> thm -> thm list
    16   val declare: thm -> {semiring: cterm list * thm list, ring: cterm list * thm list,
    16   val declare: thm -> {semiring: term list * thm list, ring: term list * thm list,
    17     field: cterm list * thm list, idom: thm list, ideal: thm list} -> declaration
    17     field: term list * thm list, idom: thm list, ideal: thm list} ->
       
    18     local_theory -> local_theory
    18 
    19 
    19   val semiring_normalize_conv: Proof.context -> conv
    20   val semiring_normalize_conv: Proof.context -> conv
    20   val semiring_normalize_ord_conv: Proof.context -> (cterm -> cterm -> bool) -> conv
    21   val semiring_normalize_ord_conv: Proof.context -> (cterm -> cterm -> bool) -> conv
    21   val semiring_normalize_wrapper: Proof.context -> entry -> conv
    22   val semiring_normalize_wrapper: Proof.context -> entry -> conv
    22   val semiring_normalize_ord_wrapper: Proof.context -> entry
    23   val semiring_normalize_ord_wrapper: Proof.context -> entry
    38        main: Proof.context -> conv,
    39        main: Proof.context -> conv,
    39        pow: Proof.context -> conv,
    40        pow: Proof.context -> conv,
    40        sub: Proof.context -> conv}
    41        sub: Proof.context -> conv}
    41 end
    42 end
    42 
    43 
    43 structure Semiring_Normalizer: SEMIRING_NORMALIZER = 
    44 structure Semiring_Normalizer: SEMIRING_NORMALIZER =
    44 struct
    45 struct
    45 
    46 
    46 (** data **)
    47 (** data **)
    47 
    48 
    48 type entry =
    49 type entry =
    74 val the_ideal = #ideal oo the_rules
    75 val the_ideal = #ideal oo the_rules
    75 
    76 
    76 fun match ctxt tm =
    77 fun match ctxt tm =
    77   let
    78   let
    78     fun match_inst
    79     fun match_inst
    79         ({vars, semiring = (sr_ops, sr_rules), 
    80         ({vars, semiring = (sr_ops, sr_rules),
    80           ring = (r_ops, r_rules), field = (f_ops, f_rules), idom, ideal},
    81           ring = (r_ops, r_rules), field = (f_ops, f_rules), idom, ideal},
    81          fns) pat =
    82          fns) pat =
    82        let
    83        let
    83         fun h instT =
    84         fun h instT =
    84           let
    85           let
    90             val ring' = (map substT_cterm r_ops, map substT r_rules);
    91             val ring' = (map substT_cterm r_ops, map substT r_rules);
    91             val field' = (map substT_cterm f_ops, map substT f_rules);
    92             val field' = (map substT_cterm f_ops, map substT f_rules);
    92             val idom' = map substT idom;
    93             val idom' = map substT idom;
    93             val ideal' = map substT ideal;
    94             val ideal' = map substT ideal;
    94 
    95 
    95             val result = ({vars = vars', semiring = semiring', 
    96             val result = ({vars = vars', semiring = semiring',
    96                            ring = ring', field = field', idom = idom', ideal = ideal'}, fns);
    97                            ring = ring', field = field', idom = idom', ideal = ideal'}, fns);
    97           in SOME result end
    98           in SOME result end
    98       in (case try Thm.match (pat, tm) of
    99       in (case try Thm.match (pat, tm) of
    99            NONE => NONE
   100            NONE => NONE
   100          | SOME (instT, _) => h instT)
   101          | SOME (instT, _) => h instT)
   103     fun match_struct (_,
   104     fun match_struct (_,
   104         entry as ({semiring = (sr_ops, _), ring = (r_ops, _), field = (f_ops, _), ...}, _): entry) =
   105         entry as ({semiring = (sr_ops, _), ring = (r_ops, _), field = (f_ops, _), ...}, _): entry) =
   105       get_first (match_inst entry) (sr_ops @ r_ops @ f_ops);
   106       get_first (match_inst entry) (sr_ops @ r_ops @ f_ops);
   106   in get_first match_struct (Data.get (Context.Proof ctxt)) end;
   107   in get_first match_struct (Data.get (Context.Proof ctxt)) end;
   107 
   108 
   108   
   109 
   109 (* extra-logical functions *)
   110 (* extra-logical functions *)
   110 
   111 
   111 val semiring_norm_ss =
   112 val semiring_norm_ss =
   112   simpset_of (put_simpset HOL_basic_ss @{context} addsimps @{thms semiring_norm});
   113   simpset_of (put_simpset HOL_basic_ss @{context} addsimps @{thms semiring_norm});
   113 
   114 
   135      | Const (@{const_name Fields.inverse},_)$t => can HOLogic.dest_number t
   136      | Const (@{const_name Fields.inverse},_)$t => can HOLogic.dest_number t
   136      | t => can HOLogic.dest_number t
   137      | t => can HOLogic.dest_number t
   137     fun dest_const ct = ((case Thm.term_of ct of
   138     fun dest_const ct = ((case Thm.term_of ct of
   138        Const (@{const_name Rings.divide},_) $ a $ b=>
   139        Const (@{const_name Rings.divide},_) $ a $ b=>
   139         Rat.rat_of_quotient (snd (HOLogic.dest_number a), snd (HOLogic.dest_number b))
   140         Rat.rat_of_quotient (snd (HOLogic.dest_number a), snd (HOLogic.dest_number b))
   140      | Const (@{const_name Fields.inverse},_)$t => 
   141      | Const (@{const_name Fields.inverse},_)$t =>
   141                    Rat.inv (Rat.rat_of_int (snd (HOLogic.dest_number t)))
   142                    Rat.inv (Rat.rat_of_int (snd (HOLogic.dest_number t)))
   142      | t => Rat.rat_of_int (snd (HOLogic.dest_number t))) 
   143      | t => Rat.rat_of_int (snd (HOLogic.dest_number t)))
   143        handle TERM _ => error "ring_dest_const")
   144        handle TERM _ => error "ring_dest_const")
   144     fun mk_const cT x =
   145     fun mk_const cT x =
   145       let val (a, b) = Rat.quotient_of_rat x
   146       let val (a, b) = Rat.quotient_of_rat x
   146       in if b = 1 then Numeral.mk_cnumber cT a
   147       in if b = 1 then Numeral.mk_cnumber cT a
   147         else Thm.apply
   148         else Thm.apply
   163 val ringN = "ring";
   164 val ringN = "ring";
   164 val fieldN = "field";
   165 val fieldN = "field";
   165 val idomN = "idom";
   166 val idomN = "idom";
   166 
   167 
   167 fun declare raw_key
   168 fun declare raw_key
   168     {semiring = raw_semiring, ring = raw_ring, field = raw_field, idom = raw_idom, ideal = raw_ideal}
   169     {semiring = raw_semiring0, ring = raw_ring0, field = raw_field0, idom = raw_idom, ideal = raw_ideal}
   169     phi context =
   170     lthy =
   170   let
   171   let
   171     val ctxt = Context.proof_of context;
   172     val ctxt' = fold Variable.auto_fixes (fst raw_semiring0 @ fst raw_ring0 @ fst raw_field0) lthy;
   172     val key = Morphism.thm phi raw_key;
   173     val prepare_ops = apfst (Variable.export_terms ctxt' lthy #> map (Thm.cterm_of lthy));
   173     fun morphism_ops_rules (ops, rules) = (map (Morphism.cterm phi) ops, Morphism.fact phi rules);
   174     val raw_semiring = prepare_ops raw_semiring0;
   174     val (sr_ops, sr_rules) = morphism_ops_rules raw_semiring;
   175     val raw_ring = prepare_ops raw_ring0;
   175     val (r_ops, r_rules) = morphism_ops_rules raw_ring;
   176     val raw_field = prepare_ops raw_field0;
   176     val (f_ops, f_rules) = morphism_ops_rules raw_field;
       
   177     val idom = Morphism.fact phi raw_idom;
       
   178     val ideal = Morphism.fact phi raw_ideal;
       
   179 
       
   180     fun check kind name xs n =
       
   181       null xs orelse length xs = n orelse
       
   182       error ("Expected " ^ string_of_int n ^ " " ^ kind ^ " for " ^ name);
       
   183     val check_ops = check "operations";
       
   184     val check_rules = check "rules";
       
   185     val _ =
       
   186       check_ops semiringN sr_ops 5 andalso
       
   187       check_rules semiringN sr_rules 36 andalso
       
   188       check_ops ringN r_ops 2 andalso
       
   189       check_rules ringN r_rules 2 andalso
       
   190       check_ops fieldN f_ops 2 andalso
       
   191       check_rules fieldN f_rules 2 andalso
       
   192       check_rules idomN idom 2;
       
   193 
       
   194     val mk_meta = Local_Defs.meta_rewrite_rule ctxt;
       
   195     val sr_rules' = map mk_meta sr_rules;
       
   196     val r_rules' = map mk_meta r_rules;
       
   197     val f_rules' = map mk_meta f_rules;
       
   198 
       
   199     fun rule i = nth sr_rules' (i - 1);
       
   200     
       
   201     val (cx, cy) = Thm.dest_binop (hd sr_ops);
       
   202     val cz = rule 34 |> Thm.rhs_of |> Thm.dest_arg |> Thm.dest_arg;
       
   203     val cn = rule 36 |> Thm.rhs_of |> Thm.dest_arg |> Thm.dest_arg;
       
   204     val ((clx, crx), (cly, cry)) =
       
   205       rule 13 |> Thm.rhs_of |> Thm.dest_binop |> apply2 Thm.dest_binop;
       
   206     val ((ca, cb), (cc, cd)) =
       
   207       rule 20 |> Thm.lhs_of |> Thm.dest_binop |> apply2 Thm.dest_binop;
       
   208     val cm = rule 1 |> Thm.rhs_of |> Thm.dest_arg;
       
   209     val (cp, cq) = rule 26 |> Thm.lhs_of |> Thm.dest_binop |> apply2 Thm.dest_arg;
       
   210 
       
   211     val vars = [ca, cb, cc, cd, cm, cn, cp, cq, cx, cy, cz, clx, crx, cly, cry];
       
   212 
       
   213     val semiring = (sr_ops, sr_rules');
       
   214     val ring = (r_ops, r_rules');
       
   215     val field = (f_ops, f_rules');
       
   216     val ideal' = map (Thm.symmetric o mk_meta) ideal
       
   217 
       
   218   in
   177   in
   219     context
   178     lthy |> Local_Theory.declaration {syntax = false, pervasive = false} (fn phi => fn context =>
   220     |> Data.map (AList.update Thm.eq_thm (key,
   179       let
   221         ({vars = vars, semiring = semiring, ring = ring, field = field, idom = idom, ideal = ideal'},
   180         val ctxt = Context.proof_of context;
   222           (if null f_ops then semiring_funs else field_funs))))
   181         val key = Morphism.thm phi raw_key;
   223   end
   182         fun transform_ops_rules (ops, rules) =
       
   183           (map (Morphism.cterm phi) ops, Morphism.fact phi rules);
       
   184         val (sr_ops, sr_rules) = transform_ops_rules raw_semiring;
       
   185         val (r_ops, r_rules) = transform_ops_rules raw_ring;
       
   186         val (f_ops, f_rules) = transform_ops_rules raw_field;
       
   187         val idom = Morphism.fact phi raw_idom;
       
   188         val ideal = Morphism.fact phi raw_ideal;
       
   189 
       
   190         fun check kind name xs n =
       
   191           null xs orelse length xs = n orelse
       
   192           error ("Expected " ^ string_of_int n ^ " " ^ kind ^ " for " ^ name);
       
   193         val check_ops = check "operations";
       
   194         val check_rules = check "rules";
       
   195         val _ =
       
   196           check_ops semiringN sr_ops 5 andalso
       
   197           check_rules semiringN sr_rules 36 andalso
       
   198           check_ops ringN r_ops 2 andalso
       
   199           check_rules ringN r_rules 2 andalso
       
   200           check_ops fieldN f_ops 2 andalso
       
   201           check_rules fieldN f_rules 2 andalso
       
   202           check_rules idomN idom 2;
       
   203 
       
   204         val mk_meta = Local_Defs.meta_rewrite_rule ctxt;
       
   205         val sr_rules' = map mk_meta sr_rules;
       
   206         val r_rules' = map mk_meta r_rules;
       
   207         val f_rules' = map mk_meta f_rules;
       
   208 
       
   209         fun rule i = nth sr_rules' (i - 1);
       
   210 
       
   211         val (cx, cy) = Thm.dest_binop (hd sr_ops);
       
   212         val cz = rule 34 |> Thm.rhs_of |> Thm.dest_arg |> Thm.dest_arg;
       
   213         val cn = rule 36 |> Thm.rhs_of |> Thm.dest_arg |> Thm.dest_arg;
       
   214         val ((clx, crx), (cly, cry)) =
       
   215           rule 13 |> Thm.rhs_of |> Thm.dest_binop |> apply2 Thm.dest_binop;
       
   216         val ((ca, cb), (cc, cd)) =
       
   217           rule 20 |> Thm.lhs_of |> Thm.dest_binop |> apply2 Thm.dest_binop;
       
   218         val cm = rule 1 |> Thm.rhs_of |> Thm.dest_arg;
       
   219         val (cp, cq) = rule 26 |> Thm.lhs_of |> Thm.dest_binop |> apply2 Thm.dest_arg;
       
   220 
       
   221         val vars = [ca, cb, cc, cd, cm, cn, cp, cq, cx, cy, cz, clx, crx, cly, cry];
       
   222 
       
   223         val semiring = (sr_ops, sr_rules');
       
   224         val ring = (r_ops, r_rules');
       
   225         val field = (f_ops, f_rules');
       
   226         val ideal' = map (Thm.symmetric o mk_meta) ideal
       
   227       in
       
   228         context
       
   229         |> Data.map (AList.update Thm.eq_thm (key,
       
   230             ({vars = vars, semiring = semiring, ring = ring, field = field, idom = idom, ideal = ideal'},
       
   231               (if null f_ops then semiring_funs else field_funs))))
       
   232       end)
       
   233   end;
   224 
   234 
   225 
   235 
   226 (** auxiliary **)
   236 (** auxiliary **)
   227 
   237 
   228 fun is_comb ct =
   238 fun is_comb ct =
   253   Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps [@{thm numeral_1_eq_1} RS sym]);
   263   Simplifier.rewrite (put_simpset HOL_basic_ss ctxt addsimps [@{thm numeral_1_eq_1} RS sym]);
   254 
   264 
   255 fun zerone_conv ctxt cv =
   265 fun zerone_conv ctxt cv =
   256   zero1_numeral_conv ctxt then_conv cv then_conv numeral01_conv ctxt;
   266   zero1_numeral_conv ctxt then_conv cv then_conv numeral01_conv ctxt;
   257 
   267 
   258 val nat_add_ss = simpset_of 
   268 val nat_add_ss = simpset_of
   259   (put_simpset HOL_basic_ss @{context}
   269   (put_simpset HOL_basic_ss @{context}
   260      addsimps @{thms arith_simps} @ @{thms diff_nat_numeral} @ @{thms rel_simps}
   270      addsimps @{thms arith_simps} @ @{thms diff_nat_numeral} @ @{thms rel_simps}
   261        @ @{thms if_False if_True Nat.add_0 add_Suc add_numeral_left Suc_eq_plus1}
   271        @ @{thms if_False if_True Nat.add_0 add_Suc add_numeral_left Suc_eq_plus1}
   262        @ map (fn th => th RS sym) @{thms numerals});
   272        @ map (fn th => th RS sym) @{thms numerals});
   263 
   273 
   306       in (neg_mul, sub_add, sub_tm, neg_tm, dest_sub, neg_mul |> concl |> Thm.dest_arg,
   316       in (neg_mul, sub_add, sub_tm, neg_tm, dest_sub, neg_mul |> concl |> Thm.dest_arg,
   307           sub_add |> concl |> Thm.dest_arg |> Thm.dest_arg)
   317           sub_add |> concl |> Thm.dest_arg |> Thm.dest_arg)
   308       end
   318       end
   309     | _ => (TrueI, TrueI, true_tm, true_tm, (fn t => (t,t)), true_tm, true_tm));
   319     | _ => (TrueI, TrueI, true_tm, true_tm, (fn t => (t,t)), true_tm, true_tm));
   310 
   320 
   311 val (divide_inverse, divide_tm, inverse_tm) = 
   321 val (divide_inverse, divide_tm, inverse_tm) =
   312   (case (f_ops, f_rules) of 
   322   (case (f_ops, f_rules) of
   313    ([divide_pat, inverse_pat], [div_inv, _]) => 
   323    ([divide_pat, inverse_pat], [div_inv, _]) =>
   314      let val div_tm = funpow 2 Thm.dest_fun divide_pat
   324      let val div_tm = funpow 2 Thm.dest_fun divide_pat
   315          val inv_tm = Thm.dest_fun inverse_pat
   325          val inv_tm = Thm.dest_fun inverse_pat
   316      in (div_inv, div_tm, inv_tm)
   326      in (div_inv, div_tm, inv_tm)
   317      end
   327      end
   318    | _ => (TrueI, true_tm, true_tm));
   328    | _ => (TrueI, true_tm, true_tm));
   418    fun powvar tm =
   428    fun powvar tm =
   419     if is_semiring_constant tm then one_tm
   429     if is_semiring_constant tm then one_tm
   420     else
   430     else
   421      ((let val (lopr,r) = Thm.dest_comb tm
   431      ((let val (lopr,r) = Thm.dest_comb tm
   422            val (opr,l) = Thm.dest_comb lopr
   432            val (opr,l) = Thm.dest_comb lopr
   423        in if opr aconvc pow_tm andalso is_number r then l 
   433        in if opr aconvc pow_tm andalso is_number r then l
   424           else raise CTERM ("monomial_mul_conv",[tm]) end)
   434           else raise CTERM ("monomial_mul_conv",[tm]) end)
   425      handle CTERM _ => tm)   (* FIXME !? *)
   435      handle CTERM _ => tm)   (* FIXME !? *)
   426    fun  vorder x y =
   436    fun  vorder x y =
   427     if x aconvc y then 0
   437     if x aconvc y then 0
   428     else
   438     else
   801          in if opr aconvc pow_tm andalso is_number r
   811          in if opr aconvc pow_tm andalso is_number r
   802             then
   812             then
   803               let val th1 = Drule.fun_cong_rule (Drule.arg_cong_rule opr (polynomial_conv ctxt l)) r
   813               let val th1 = Drule.fun_cong_rule (Drule.arg_cong_rule opr (polynomial_conv ctxt l)) r
   804               in Thm.transitive th1 (polynomial_pow_conv ctxt (concl th1))
   814               in Thm.transitive th1 (polynomial_pow_conv ctxt (concl th1))
   805               end
   815               end
   806          else if opr aconvc divide_tm 
   816          else if opr aconvc divide_tm
   807             then
   817             then
   808               let val th1 = Thm.combination (Drule.arg_cong_rule opr (polynomial_conv ctxt l)) 
   818               let val th1 = Thm.combination (Drule.arg_cong_rule opr (polynomial_conv ctxt l))
   809                                         (polynomial_conv ctxt r)
   819                                         (polynomial_conv ctxt r)
   810                   val th2 = (Conv.rewr_conv divide_inverse then_conv polynomial_mul_conv ctxt)
   820                   val th2 = (Conv.rewr_conv divide_inverse then_conv polynomial_mul_conv ctxt)
   811                               (Thm.rhs_of th1)
   821                               (Thm.rhs_of th1)
   812               in Thm.transitive th1 th2
   822               in Thm.transitive th1 th2
   813               end
   823               end
   844 fun simple_cterm_ord t u = Term_Ord.term_ord (Thm.term_of t, Thm.term_of u) = LESS;
   854 fun simple_cterm_ord t u = Term_Ord.term_ord (Thm.term_of t, Thm.term_of u) = LESS;
   845 
   855 
   846 
   856 
   847 (* various normalizing conversions *)
   857 (* various normalizing conversions *)
   848 
   858 
   849 fun semiring_normalizers_ord_wrapper ctxt ({vars, semiring, ring, field, idom, ideal}, 
   859 fun semiring_normalizers_ord_wrapper ctxt ({vars, semiring, ring, field, idom, ideal},
   850                                      {conv, dest_const, mk_const, is_const}) ord =
   860                                      {conv, dest_const, mk_const, is_const}) ord =
   851   let
   861   let
   852     val pow_conv =
   862     val pow_conv =
   853       Conv.arg_conv (Simplifier.rewrite (put_simpset nat_exp_ss ctxt))
   863       Conv.arg_conv (Simplifier.rewrite (put_simpset nat_exp_ss ctxt))
   854       then_conv Simplifier.rewrite
   864       then_conv Simplifier.rewrite
   860 fun semiring_normalize_ord_wrapper ctxt ({vars, semiring, ring, field, idom, ideal}, {conv, dest_const, mk_const, is_const}) ord =
   870 fun semiring_normalize_ord_wrapper ctxt ({vars, semiring, ring, field, idom, ideal}, {conv, dest_const, mk_const, is_const}) ord =
   861  #main (semiring_normalizers_ord_wrapper ctxt
   871  #main (semiring_normalizers_ord_wrapper ctxt
   862   ({vars = vars, semiring = semiring, ring = ring, field = field, idom = idom, ideal = ideal},
   872   ({vars = vars, semiring = semiring, ring = ring, field = field, idom = idom, ideal = ideal},
   863    {conv = conv, dest_const = dest_const, mk_const = mk_const, is_const = is_const}) ord) ctxt;
   873    {conv = conv, dest_const = dest_const, mk_const = mk_const, is_const = is_const}) ord) ctxt;
   864 
   874 
   865 fun semiring_normalize_wrapper ctxt data = 
   875 fun semiring_normalize_wrapper ctxt data =
   866   semiring_normalize_ord_wrapper ctxt data simple_cterm_ord;
   876   semiring_normalize_ord_wrapper ctxt data simple_cterm_ord;
   867 
   877 
   868 fun semiring_normalize_ord_conv ctxt ord tm =
   878 fun semiring_normalize_ord_conv ctxt ord tm =
   869   (case match ctxt tm of
   879   (case match ctxt tm of
   870     NONE => Thm.reflexive tm
   880     NONE => Thm.reflexive tm
   871   | SOME res => semiring_normalize_ord_wrapper ctxt res ord tm);
   881   | SOME res => semiring_normalize_ord_wrapper ctxt res ord tm);
   872  
   882 
   873 fun semiring_normalize_conv ctxt = semiring_normalize_ord_conv ctxt simple_cterm_ord;
   883 fun semiring_normalize_conv ctxt = semiring_normalize_ord_conv ctxt simple_cterm_ord;
   874 
   884 
   875 end;
   885 end;