explicit check for field sort, to anticipate situation where syntactic checking alone will not be sufficient any longer
authorhaftmann
Mon, 01 Jun 2015 18:59:21 +0200
changeset 60351 5cdf3903a302
parent 60350 9251f82337d6
child 60352 d46de31a50c4
explicit check for field sort, to anticipate situation where syntactic checking alone will not be sufficient any longer
src/HOL/Tools/lin_arith.ML
--- a/src/HOL/Tools/lin_arith.ML	Mon Jun 01 18:59:21 2015 +0200
+++ b/src/HOL/Tools/lin_arith.ML	Mon Jun 01 18:59:21 2015 +0200
@@ -136,7 +136,9 @@
 
    returns either (SOME term, associated multiplicity) or (NONE, constant)
 *)
-fun demult (inj_consts : (string * typ) list) : term * Rat.rat -> term option * Rat.rat =
+fun of_field_sort thy U = Sign.of_sort thy (U, @{sort inverse});
+
+fun demult thy (inj_consts : (string * typ) list) : term * Rat.rat -> term option * Rat.rat =
 let
   fun demult ((mC as Const (@{const_name Groups.times}, _)) $ s $ t, m) =
       (case s of Const (@{const_name Groups.times}, _) $ s1 $ s2 =>
@@ -150,23 +152,26 @@
               (SOME t', m'') => (SOME (mC $ s' $ t'), m'')
             | (NONE,    m'') => (SOME s', m''))
         | (NONE,    m') => demult (t, m')))
-    | demult ((mC as Const (@{const_name Fields.divide}, _)) $ s $ t, m) =
+    | demult (atom as (mC as Const (@{const_name Fields.divide}, T)) $ s $ t, m) =
       (* FIXME: Shouldn't we simplify nested quotients, e.g. '(s/t)/u' could
          become 's/(t*u)', and '(s*t)/u' could become 's*(t/u)' ?   Note that
          if we choose to do so here, the simpset used by arith must be able to
          perform the same simplifications. *)
       (* quotient 's / t', where the denominator t can be NONE *)
       (* Note: will raise Rat.DIVZERO iff m' is Rat.zero *)
-      let val (os',m') = demult (s, m);
+      if of_field_sort thy (domain_type T) then
+        let
+          val (os',m') = demult (s, m);
           val (ot',p) = demult (t, Rat.one)
-      in (case (os',ot') of
+        in (case (os',ot') of
             (SOME s', SOME t') => SOME (mC $ s' $ t')
           | (SOME s', NONE) => SOME s'
           | (NONE, SOME t') =>
                SOME (mC $ Const (@{const_name Groups.one}, domain_type (snd (dest_Const mC))) $ t')
           | (NONE, NONE) => NONE,
           Rat.mult m' (Rat.inv p))
-      end
+        end
+      else (SOME atom, m)
     (* terms that evaluate to numeric constants *)
     | demult (Const (@{const_name Groups.uminus}, _) $ t, m) = demult (t, Rat.neg m)
     | demult (Const (@{const_name Groups.zero}, _), _) = (NONE, Rat.zero)
@@ -188,7 +193,7 @@
     | demult (atom, m) = (SOME atom, m)
 in demult end;
 
-fun decomp0 (inj_consts : (string * typ) list) (rel : string, lhs : term, rhs : term) :
+fun decomp0 thy (inj_consts : (string * typ) list) (rel : string, lhs : term, rhs : term) :
             ((term * Rat.rat) list * Rat.rat * string * (term * Rat.rat) list * Rat.rat) option =
 let
   (* Turns a term 'all' and associated multiplicity 'm' into a list 'p' of
@@ -211,13 +216,15 @@
     | poly (Const (@{const_name Suc}, _) $ t, m, (p, i)) =
         poly (t, m, (p, Rat.add i m))
     | poly (all as Const (@{const_name Groups.times}, _) $ _ $ _, m, pi as (p, i)) =
-        (case demult inj_consts (all, m) of
+        (case demult thy inj_consts (all, m) of
            (NONE,   m') => (p, Rat.add i m')
          | (SOME u, m') => add_atom u m' pi)
-    | poly (all as Const (@{const_name Fields.divide}, _) $ _ $ _, m, pi as (p, i)) =
-        (case demult inj_consts (all, m) of
-           (NONE,   m') => (p, Rat.add i m')
-         | (SOME u, m') => add_atom u m' pi)
+    | poly (all as Const (@{const_name Fields.divide}, T) $ _ $ _, m, pi as (p, i)) =
+        if of_field_sort thy (domain_type T) then 
+          (case demult thy inj_consts (all, m) of
+             (NONE,   m') => (p, Rat.add i m')
+           | (SOME u, m') => add_atom u m' pi)
+        else add_atom all m pi
     | poly (all as Const f $ x, m, pi) =
         if member (op =) inj_consts f then poly (x, m, pi) else add_atom all m pi
     | poly (all, m, pi) =
@@ -240,12 +247,12 @@
       else if member (op =) discrete D then (true, true) else (false, false)
   | allows_lin_arith sg discrete U = (of_lin_arith_sort sg U, false);
 
-fun decomp_typecheck (thy, discrete, inj_consts) (T : typ, xxx) : decomp option =
+fun decomp_typecheck thy (discrete, inj_consts) (T : typ, xxx) : decomp option =
   case T of
     Type ("fun", [U, _]) =>
       (case allows_lin_arith thy discrete U of
         (true, d) =>
-          (case decomp0 inj_consts xxx of
+          (case decomp0 thy inj_consts xxx of
             NONE                   => NONE
           | SOME (p, i, rel, q, j) => SOME (p, i, rel, q, j, d))
       | (false, _) =>
@@ -255,20 +262,20 @@
 fun negate (SOME (x, i, rel, y, j, d)) = SOME (x, i, "~" ^ rel, y, j, d)
   | negate NONE                        = NONE;
 
-fun decomp_negation data
-  ((Const (@{const_name Trueprop}, _)) $ (Const (rel, T) $ lhs $ rhs)) : decomp option =
-      decomp_typecheck data (T, (rel, lhs, rhs))
-  | decomp_negation data ((Const (@{const_name Trueprop}, _)) $
-  (Const (@{const_name Not}, _) $ (Const (rel, T) $ lhs $ rhs))) =
-      negate (decomp_typecheck data (T, (rel, lhs, rhs)))
-  | decomp_negation data _ =
+fun decomp_negation thy data
+      ((Const (@{const_name Trueprop}, _)) $ (Const (rel, T) $ lhs $ rhs)) : decomp option =
+      decomp_typecheck thy data (T, (rel, lhs, rhs))
+  | decomp_negation thy data
+      ((Const (@{const_name Trueprop}, _)) $ (Const (@{const_name Not}, _) $ (Const (rel, T) $ lhs $ rhs))) =
+      negate (decomp_typecheck thy data (T, (rel, lhs, rhs)))
+  | decomp_negation thy data _ =
       NONE;
 
 fun decomp ctxt : term -> decomp option =
   let
     val thy = Proof_Context.theory_of ctxt
     val {discrete, inj_consts, ...} = get_arith_data ctxt
-  in decomp_negation (thy, discrete, inj_consts) end;
+  in decomp_negation thy (discrete, inj_consts) end;
 
 fun domain_is_nat (_ $ (Const (_, T) $ _ $ _)) = nT T
   | domain_is_nat (_ $ (Const (@{const_name Not}, _) $ (Const (_, T) $ _ $ _))) = nT T