Theory Reflective_Field

theory Reflective_Field
imports Commutative_Ring
(*  Title:      HOL/Decision_Procs/Reflective_Field.thy
    Author:     Stefan Berghofer

Reducing equalities in fields to equalities in rings.
*)

theory Reflective_Field
imports "~~/src/HOL/Decision_Procs/Commutative_Ring"
begin

datatype fexpr =
    FCnst int
  | FVar nat
  | FAdd fexpr fexpr
  | FSub fexpr fexpr
  | FMul fexpr fexpr
  | FNeg fexpr
  | FDiv fexpr fexpr
  | FPow fexpr nat

fun (in field) nth_el :: "'a list ⇒ nat ⇒ 'a" where
  "nth_el [] n = 𝟬"
| "nth_el (x # xs) 0 = x"
| "nth_el (x # xs) (Suc n) = nth_el xs n"

lemma (in field) nth_el_Cons:
  "nth_el (x # xs) n = (if n = 0 then x else nth_el xs (n - 1))"
  by (cases n) simp_all

lemma (in field) nth_el_closed [simp]:
  "in_carrier xs ⟹ nth_el xs n ∈ carrier R"
  by (induct xs n rule: nth_el.induct) (simp_all add: in_carrier_def)

primrec (in field) feval :: "'a list ⇒ fexpr ⇒ 'a"
where
  "feval xs (FCnst c) = «c»"
| "feval xs (FVar n) = nth_el xs n"
| "feval xs (FAdd a b) = feval xs a ⊕ feval xs b"
| "feval xs (FSub a b) = feval xs a ⊖ feval xs b"
| "feval xs (FMul a b) = feval xs a ⊗ feval xs b"
| "feval xs (FNeg a) = ⊖ feval xs a"
| "feval xs (FDiv a b) = feval xs a ⊘ feval xs b"
| "feval xs (FPow a n) = feval xs a (^) n"

lemma (in field) feval_Cnst:
  "feval xs (FCnst 0) = 𝟬"
  "feval xs (FCnst 1) = 𝟭"
  "feval xs (FCnst (numeral n)) = «numeral n»"
  by simp_all

datatype pexpr =
    PExpr1 pexpr1
  | PExpr2 pexpr2
and pexpr1 =
    PCnst int
  | PVar nat
  | PAdd pexpr pexpr
  | PSub pexpr pexpr
  | PNeg pexpr
and pexpr2 =
    PMul pexpr pexpr
  | PPow pexpr nat

lemma pexpr_cases [case_names PCnst PVar PAdd PSub PNeg PMul PPow]:
  assumes
    "⋀c. e = PExpr1 (PCnst c) ⟹ P"
    "⋀n. e = PExpr1 (PVar n) ⟹ P"
    "⋀e1 e2. e = PExpr1 (PAdd e1 e2) ⟹ P"
    "⋀e1 e2. e = PExpr1 (PSub e1 e2) ⟹ P"
    "⋀e'. e = PExpr1 (PNeg e') ⟹ P"
    "⋀e1 e2. e = PExpr2 (PMul e1 e2) ⟹ P"
    "⋀e' n. e = PExpr2 (PPow e' n) ⟹ P"
  shows P
proof (cases e)
  case (PExpr1 e')
  then show ?thesis
    apply (cases e')
    apply simp_all
    apply (erule assms)+
    done
next
  case (PExpr2 e')
  then show ?thesis
    apply (cases e')
    apply simp_all
    apply (erule assms)+
    done
qed

lemmas pexpr_cases2 = pexpr_cases [case_product pexpr_cases]

fun (in field) peval :: "'a list ⇒ pexpr ⇒ 'a"
where
  "peval xs (PExpr1 (PCnst c)) = «c»"
| "peval xs (PExpr1 (PVar n)) = nth_el xs n"
| "peval xs (PExpr1 (PAdd a b)) = peval xs a ⊕ peval xs b"
| "peval xs (PExpr1 (PSub a b)) = peval xs a ⊖ peval xs b"
| "peval xs (PExpr1 (PNeg a)) = ⊖ peval xs a"
| "peval xs (PExpr2 (PMul a b)) = peval xs a ⊗ peval xs b"
| "peval xs (PExpr2 (PPow a n)) = peval xs a (^) n"

lemma (in field) peval_Cnst:
  "peval xs (PExpr1 (PCnst 0)) = 𝟬"
  "peval xs (PExpr1 (PCnst 1)) = 𝟭"
  "peval xs (PExpr1 (PCnst (numeral n))) = «numeral n»"
  "peval xs (PExpr1 (PCnst (- numeral n))) = ⊖ «numeral n»"
  by simp_all

lemma (in field) peval_closed [simp]:
  "in_carrier xs ⟹ peval xs e ∈ carrier R"
  "in_carrier xs ⟹ peval xs (PExpr1 e1) ∈ carrier R"
  "in_carrier xs ⟹ peval xs (PExpr2 e2) ∈ carrier R"
  by (induct e and e1 and e2) simp_all

definition npepow :: "pexpr ⇒ nat ⇒ pexpr"
where
  "npepow e n =
     (if n = 0 then PExpr1 (PCnst 1)
      else if n = 1 then e
      else (case e of
          PExpr1 (PCnst c) ⇒ PExpr1 (PCnst (c ^ n))
        | _ ⇒ PExpr2 (PPow e n)))"

lemma (in field) npepow_correct:
  "in_carrier xs ⟹ peval xs (npepow e n) = peval xs (PExpr2 (PPow e n))"
  by (cases e rule: pexpr_cases)
    (simp_all add: npepow_def)

fun npemul :: "pexpr ⇒ pexpr ⇒ pexpr"
where
  "npemul x y = (case x of
       PExpr1 (PCnst c) ⇒
         if c = 0 then x
         else if c = 1 then y else
           (case y of
              PExpr1 (PCnst d) ⇒ PExpr1 (PCnst (c * d))
            | _ ⇒ PExpr2 (PMul x y))
     | PExpr2 (PPow e1 n) ⇒
         (case y of
            PExpr2 (PPow e2 m) ⇒
              if n = m then npepow (npemul e1 e2) n
              else PExpr2 (PMul x y)
          | PExpr1 (PCnst d) ⇒
              if d = 0 then y
              else if d = 1 then x
              else PExpr2 (PMul x y)
          | _ ⇒ PExpr2 (PMul x y))
     | _ ⇒ (case y of
         PExpr1 (PCnst d) ⇒
           if d = 0 then y
           else if d = 1 then x
           else PExpr2 (PMul x y)
       | _ ⇒ PExpr2 (PMul x y)))"

lemma (in field) npemul_correct:
  "in_carrier xs ⟹ peval xs (npemul e1 e2) = peval xs (PExpr2 (PMul e1 e2))"
proof (induct e1 e2 rule: npemul.induct)
  case (1 x y)
  then show ?case
  proof (cases x y rule: pexpr_cases2)
    case (PPow_PPow e n e' m)
    then show ?thesis
    by (simp add: 1 npepow_correct nat_pow_distr
      npemul.simps [of "PExpr2 (PPow e n)" "PExpr2 (PPow e' m)"]
      del: npemul.simps)
  qed simp_all
qed

declare npemul.simps [simp del]

definition npeadd :: "pexpr ⇒ pexpr ⇒ pexpr"
where
  "npeadd x y = (case x of
       PExpr1 (PCnst c) ⇒
         if c = 0 then y else
           (case y of
              PExpr1 (PCnst d) ⇒ PExpr1 (PCnst (c + d))
            | _ ⇒ PExpr1 (PAdd x y))
     | _ ⇒ (case y of
         PExpr1 (PCnst d) ⇒
           if d = 0 then x
           else PExpr1 (PAdd x y)
       | _ ⇒ PExpr1 (PAdd x y)))"

lemma (in field) npeadd_correct:
  "in_carrier xs ⟹ peval xs (npeadd e1 e2) = peval xs (PExpr1 (PAdd e1 e2))"
  by (cases e1 e2 rule: pexpr_cases2) (simp_all add: npeadd_def)

definition npesub :: "pexpr ⇒ pexpr ⇒ pexpr"
where
  "npesub x y = (case y of
       PExpr1 (PCnst d) ⇒
         if d = 0 then x else
           (case x of
              PExpr1 (PCnst c) ⇒ PExpr1 (PCnst (c - d))
            | _ ⇒ PExpr1 (PSub x y))
     | _ ⇒ (case x of
         PExpr1 (PCnst c) ⇒
           if c = 0 then PExpr1 (PNeg y)
           else PExpr1 (PSub x y)
       | _ ⇒ PExpr1 (PSub x y)))"

lemma (in field) npesub_correct:
  "in_carrier xs ⟹ peval xs (npesub e1 e2) = peval xs (PExpr1 (PSub e1 e2))"
  by (cases e1 e2 rule: pexpr_cases2) (simp_all add: npesub_def)

definition npeneg :: "pexpr ⇒ pexpr"
where
  "npeneg e = (case e of
       PExpr1 (PCnst c) ⇒ PExpr1 (PCnst (- c))
     | _ ⇒ PExpr1 (PNeg e))"

lemma (in field) npeneg_correct:
  "peval xs (npeneg e) = peval xs (PExpr1 (PNeg e))"
  by (cases e rule: pexpr_cases) (simp_all add: npeneg_def)

lemma option_pair_cases [case_names None Some]:
  assumes
    "x = None ⟹ P"
    "⋀p q. x = Some (p, q) ⟹ P"
  shows P
proof (cases x)
  case None
  then show ?thesis by (rule assms)
next
  case (Some r)
  then show ?thesis
    apply (cases r)
    apply simp
    by (rule assms)
qed

fun isin :: "pexpr ⇒ nat ⇒ pexpr ⇒ nat ⇒ (nat * pexpr) option"
where
  "isin e n (PExpr2 (PMul e1 e2)) m =
     (case isin e n e1 m of
        Some (k, e3) ⇒
          if k = 0 then Some (0, npemul e3 (npepow e2 m))
          else (case isin e k e2 m of
              Some (l, e4) ⇒ Some (l, npemul e3 e4)
            | None ⇒ Some (k, npemul e3 (npepow e2 m)))
      | None ⇒ (case isin e n e2 m of
          Some (k, e3) ⇒ Some (k, npemul (npepow e1 m) e3)
        | None ⇒ None))"
| "isin e n (PExpr2 (PPow e' k)) m =
     (if k = 0 then None else isin e n e' (k * m))"
| "isin (PExpr1 e) n (PExpr1 e') m =
     (if e = e' then
        if n >= m then Some (n - m, PExpr1 (PCnst 1))
        else Some (0, npepow (PExpr1 e) (m - n))
      else None)"
| "isin (PExpr2 e) n (PExpr1 e') m = None"

lemma (in field) isin_correct:
  assumes "in_carrier xs"
  and "isin e n e' m = Some (p, e'')"
  shows
    "peval xs (PExpr2 (PPow e' m)) =
     peval xs (PExpr2 (PMul (PExpr2 (PPow e (n - p))) e''))"
    "p ≤ n"
  using assms
  by (induct e n e' m arbitrary: p e'' rule: isin.induct)
    (force
       simp add:
         nat_pow_distr nat_pow_pow nat_pow_mult m_ac
         npemul_correct npepow_correct
       split: option.split_asm prod.split_asm if_split_asm)+

lemma (in field) isin_correct':
  "in_carrier xs ⟹ isin e n e' 1 = Some (p, e'') ⟹
   peval xs e' = peval xs e (^) (n - p) ⊗ peval xs e''"
  "in_carrier xs ⟹ isin e n e' 1 = Some (p, e'') ⟹ p ≤ n"
  using isin_correct [where m=1]
  by simp_all

fun split_aux :: "pexpr ⇒ nat ⇒ pexpr ⇒ pexpr × pexpr × pexpr"
where
  "split_aux (PExpr2 (PMul e1 e2)) n e =
     (let
        (left1, common1, right1) = split_aux e1 n e;
        (left2, common2, right2) = split_aux e2 n right1
      in (npemul left1 left2, npemul common1 common2, right2))"
| "split_aux (PExpr2 (PPow e' m)) n e =
     (if m = 0 then (PExpr1 (PCnst 1), PExpr1 (PCnst 1), e)
      else split_aux e' (m * n) e)"
| "split_aux (PExpr1 e') n e =
     (case isin (PExpr1 e') n e 1 of
        Some (m, e'') ⇒
          (if m = 0 then (PExpr1 (PCnst 1), npepow (PExpr1 e') n, e'')
           else (npepow (PExpr1 e') m, npepow (PExpr1 e') (n - m), e''))
      | None ⇒ (npepow (PExpr1 e') n, PExpr1 (PCnst 1), e))"

hide_const Left Right

abbreviation Left :: "pexpr ⇒ pexpr ⇒ pexpr" where
  "Left e1 e2 ≡ fst (split_aux e1 (Suc 0) e2)"

abbreviation Common :: "pexpr ⇒ pexpr ⇒ pexpr" where
  "Common e1 e2 ≡ fst (snd (split_aux e1 (Suc 0) e2))"

abbreviation Right :: "pexpr ⇒ pexpr ⇒ pexpr" where
  "Right e1 e2 ≡ snd (snd (split_aux e1 (Suc 0) e2))"

lemma split_aux_induct [case_names 1 2 3]:
  assumes I1: "⋀e1 e2 n e. P e1 n e ⟹ P e2 n (snd (snd (split_aux e1 n e))) ⟹
    P (PExpr2 (PMul e1 e2)) n e"
  and I2: "⋀e' m n e. (m ≠ 0 ⟹ P e' (m * n) e) ⟹ P (PExpr2 (PPow e' m)) n e"
  and I3: "⋀e' n e. P (PExpr1 e') n e"
  shows "P x y z"
proof (induct x y z rule: split_aux.induct)
  case 1
  from 1(1) 1(2) [OF refl prod.collapse prod.collapse]
  show ?case by (rule I1)
next
  case 2
  then show ?case by (rule I2)
next
  case 3
  then show ?case by (rule I3)
qed

lemma (in field) split_aux_correct:
  "in_carrier xs ⟹
   peval xs (PExpr2 (PPow e1 n)) =
   peval xs (PExpr2 (PMul (fst (split_aux e1 n e2)) (fst (snd (split_aux e1 n e2)))))"
  "in_carrier xs ⟹
   peval xs e2 =
   peval xs (PExpr2 (PMul (snd (snd (split_aux e1 n e2))) (fst (snd (split_aux e1 n e2)))))"
  by (induct e1 n e2 rule: split_aux_induct)
    (auto simp add: split_beta
       nat_pow_distr nat_pow_pow nat_pow_mult m_ac
       npemul_correct npepow_correct isin_correct'
       split: option.split)

lemma (in field) split_aux_correct':
  "in_carrier xs ⟹
   peval xs e1 = peval xs (Left e1 e2) ⊗ peval xs (Common e1 e2)"
  "in_carrier xs ⟹
   peval xs e2 = peval xs (Right e1 e2) ⊗ peval xs (Common e1 e2)"
  using split_aux_correct [where n=1]
  by simp_all

fun fnorm :: "fexpr ⇒ pexpr × pexpr × pexpr list"
where
  "fnorm (FCnst c) = (PExpr1 (PCnst c), PExpr1 (PCnst 1), [])"
| "fnorm (FVar n) = (PExpr1 (PVar n), PExpr1 (PCnst 1), [])"
| "fnorm (FAdd e1 e2) =
     (let
        (xn, xd, xc) = fnorm e1;
        (yn, yd, yc) = fnorm e2;
        (left, common, right) = split_aux xd 1 yd
      in
        (npeadd (npemul xn right) (npemul yn left),
         npemul left (npemul right common),
         List.union xc yc))"
| "fnorm (FSub e1 e2) =
     (let
        (xn, xd, xc) = fnorm e1;
        (yn, yd, yc) = fnorm e2;
        (left, common, right) = split_aux xd 1 yd
      in
        (npesub (npemul xn right) (npemul yn left),
         npemul left (npemul right common),
         List.union xc yc))"
| "fnorm (FMul e1 e2) =
     (let
        (xn, xd, xc) = fnorm e1;
        (yn, yd, yc) = fnorm e2;
        (left1, common1, right1) = split_aux xn 1 yd;
        (left2, common2, right2) = split_aux yn 1 xd
      in
        (npemul left1 left2,
         npemul right2 right1,
         List.union xc yc))"
| "fnorm (FNeg e) =
     (let (n, d, c) = fnorm e
      in (npeneg n, d, c))"
| "fnorm (FDiv e1 e2) =
     (let
        (xn, xd, xc) = fnorm e1;
        (yn, yd, yc) = fnorm e2;
        (left1, common1, right1) = split_aux xn 1 yn;
        (left2, common2, right2) = split_aux xd 1 yd
      in
        (npemul left1 right2,
         npemul left2 right1,
         List.insert yn (List.union xc yc)))"
| "fnorm (FPow e m) =
     (let (n, d, c) = fnorm e
      in (npepow n m, npepow d m, c))"

abbreviation Num :: "fexpr ⇒ pexpr" where
  "Num e ≡ fst (fnorm e)"

abbreviation Denom :: "fexpr ⇒ pexpr" where
  "Denom e ≡ fst (snd (fnorm e))"

abbreviation Cond :: "fexpr ⇒ pexpr list" where
  "Cond e ≡ snd (snd (fnorm e))"

primrec (in field) nonzero :: "'a list ⇒ pexpr list ⇒ bool"
where
  "nonzero xs [] = True"
| "nonzero xs (p # ps) = (peval xs p ≠ 𝟬 ∧ nonzero xs ps)"

lemma (in field) nonzero_singleton:
  "nonzero xs [p] = (peval xs p ≠ 𝟬)"
  by simp

lemma (in field) nonzero_append:
  "nonzero xs (ps @ qs) = (nonzero xs ps ∧ nonzero xs qs)"
  by (induct ps) simp_all

lemma (in field) nonzero_idempotent:
  "p ∈ set ps ⟹ (peval xs p ≠ 𝟬 ∧ nonzero xs ps) = nonzero xs ps"
  by (induct ps) auto

lemma (in field) nonzero_insert:
  "nonzero xs (List.insert p ps) = (peval xs p ≠ 𝟬 ∧ nonzero xs ps)"
  by (simp add: List.insert_def nonzero_idempotent)

lemma (in field) nonzero_union:
  "nonzero xs (List.union ps qs) = (nonzero xs ps ∧ nonzero xs qs)"
  by (induct ps rule: rev_induct)
    (auto simp add: List.union_def nonzero_insert nonzero_append)

lemma (in field) fnorm_correct:
  assumes "in_carrier xs"
  and "nonzero xs (Cond e)"
  shows "feval xs e = peval xs (Num e) ⊘ peval xs (Denom e)"
  and "peval xs (Denom e) ≠ 𝟬"
  using assms
proof (induct e)
  case (FCnst c) {
    case 1
    show ?case by simp
  next
    case 2
    show ?case by simp
  }
next
  case (FVar n) {
    case 1
    then show ?case by simp
  next
    case 2
    show ?case by simp
  }
next
  case (FAdd e1 e2)
  note split = split_aux_correct' [where xs=xs and
    e1="Denom e1" and e2="Denom e2"]
  {
    case 1
    let ?left = "peval xs (Left (Denom e1) (Denom e2))"
    let ?common = "peval xs (Common (Denom e1) (Denom e2))"
    let ?right = "peval xs (Right (Denom e1) (Denom e2))"
    from 1 FAdd
    have "feval xs (FAdd e1 e2) =
      (?common ⊗ (peval xs (Num e1) ⊗ ?right ⊕ peval xs (Num e2) ⊗ ?left)) ⊘
      (?common ⊗ (?left ⊗ (?right ⊗ ?common)))"
      by (simp add: split_beta split nonzero_union
        add_frac_eq r_distr m_ac)
    also from 1 FAdd have "… =
      peval xs (Num (FAdd e1 e2)) ⊘ peval xs (Denom (FAdd e1 e2))"
      by (simp add: split_beta split nonzero_union npeadd_correct npemul_correct integral_iff)
    finally show ?case .
  next
    case 2
    with FAdd show ?case
      by (simp add: split_beta split nonzero_union npemul_correct integral_iff)
  }
next
  case (FSub e1 e2)
  note split = split_aux_correct' [where xs=xs and
    e1="Denom e1" and e2="Denom e2"]
  {
    case 1
    let ?left = "peval xs (Left (Denom e1) (Denom e2))"
    let ?common = "peval xs (Common (Denom e1) (Denom e2))"
    let ?right = "peval xs (Right (Denom e1) (Denom e2))"
    from 1 FSub
    have "feval xs (FSub e1 e2) =
      (?common ⊗ (peval xs (Num e1) ⊗ ?right ⊖ peval xs (Num e2) ⊗ ?left)) ⊘
      (?common ⊗ (?left ⊗ (?right ⊗ ?common)))"
      by (simp add: split_beta split nonzero_union
        diff_frac_eq r_diff_distr m_ac)
    also from 1 FSub have "… =
      peval xs (Num (FSub e1 e2)) ⊘ peval xs (Denom (FSub e1 e2))"
      by (simp add: split_beta split nonzero_union npesub_correct npemul_correct integral_iff)
    finally show ?case .
  next
    case 2
    with FSub show ?case
      by (simp add: split_beta split nonzero_union npemul_correct integral_iff)
  }
next
  case (FMul e1 e2)
  note split =
    split_aux_correct' [where xs=xs and
      e1="Num e1" and e2="Denom e2"]
    split_aux_correct' [where xs=xs and
      e1="Num e2" and e2="Denom e1"]
  {
    case 1
    let ?left1 = "peval xs (Left (Num e1) (Denom e2))"
    let ?common1 = "peval xs (Common (Num e1) (Denom e2))"
    let ?right1 = "peval xs (Right (Num e1) (Denom e2))"
    let ?left2 = "peval xs (Left (Num e2) (Denom e1))"
    let ?common2 = "peval xs (Common (Num e2) (Denom e1))"
    let ?right2 = "peval xs (Right (Num e2) (Denom e1))"
    from 1 FMul
    have "feval xs (FMul e1 e2) =
      ((?common1 ⊗ ?common2) ⊗ (?left1 ⊗ ?left2)) ⊘
      ((?common1 ⊗ ?common2) ⊗ (?right2 ⊗ ?right1))"
      by (simp add: split_beta split nonzero_union
        nonzero_divide_divide_eq_left m_ac)
    also from 1 FMul have "… =
      peval xs (Num (FMul e1 e2)) ⊘ peval xs (Denom (FMul e1 e2))"
      by (simp add: split_beta split nonzero_union npemul_correct integral_iff)
    finally show ?case .
  next
    case 2
    with FMul show ?case
      by (simp add: split_beta split nonzero_union npemul_correct integral_iff)
  }
next
  case (FNeg e)
  {
    case 1
    with FNeg show ?case
      by (simp add: split_beta npeneg_correct)
  next
    case 2
    with FNeg show ?case
      by (simp add: split_beta)
  }
next
  case (FDiv e1 e2)
  note split =
    split_aux_correct' [where xs=xs and
      e1="Num e1" and e2="Num e2"]
    split_aux_correct' [where xs=xs and
      e1="Denom e1" and e2="Denom e2"]
  {
    case 1
    let ?left1 = "peval xs (Left (Num e1) (Num e2))"
    let ?common1 = "peval xs (Common (Num e1) (Num e2))"
    let ?right1 = "peval xs (Right (Num e1) (Num e2))"
    let ?left2 = "peval xs (Left (Denom e1) (Denom e2))"
    let ?common2 = "peval xs (Common (Denom e1) (Denom e2))"
    let ?right2 = "peval xs (Right (Denom e1) (Denom e2))"
    from 1 FDiv
    have "feval xs (FDiv e1 e2) =
      ((?common1 ⊗ ?common2) ⊗ (?left1 ⊗ ?right2)) ⊘
      ((?common1 ⊗ ?common2) ⊗ (?left2 ⊗ ?right1))"
      by (simp add: split_beta split nonzero_union nonzero_insert
        nonzero_divide_divide_eq m_ac)
    also from 1 FDiv have "… =
      peval xs (Num (FDiv e1 e2)) ⊘ peval xs (Denom (FDiv e1 e2))"
      by (simp add: split_beta split nonzero_union nonzero_insert npemul_correct integral_iff)
    finally show ?case .
  next
    case 2
    with FDiv show ?case
      by (simp add: split_beta split nonzero_union nonzero_insert npemul_correct integral_iff)
  }
next
  case (FPow e n)
  {
    case 1
    with FPow show ?case
      by (simp add: split_beta nonzero_power_divide npepow_correct)
  next
    case 2
    with FPow show ?case
      by (simp add: split_beta npepow_correct)
  }
qed

lemma (in field) feval_eq0:
  assumes "in_carrier xs"
  and "fnorm e = (n, d, c)"
  and "nonzero xs c"
  and "peval xs n = 𝟬"
  shows "feval xs e = 𝟬"
  using assms fnorm_correct [of xs e]
  by simp

lemma (in field) fexpr_in_carrier:
  assumes "in_carrier xs"
  and "nonzero xs (Cond e)"
  shows "feval xs e ∈ carrier R"
  using assms
proof (induct e)
  case (FDiv e1 e2)
  then have "feval xs e1 ∈ carrier R" "feval xs e2 ∈ carrier R"
    "peval xs (Num e2) ≠ 𝟬" "nonzero xs (Cond e2)"
    by (simp_all add: nonzero_union nonzero_insert split: prod.split_asm)
  from ‹in_carrier xs› ‹nonzero xs (Cond e2)›
  have "feval xs e2 = peval xs (Num e2) ⊘ peval xs (Denom e2)"
    by (rule fnorm_correct)
  moreover from ‹in_carrier xs› ‹nonzero xs (Cond e2)›
  have "peval xs (Denom e2) ≠ 𝟬" by (rule fnorm_correct)
  ultimately have "feval xs e2 ≠ 𝟬" using ‹peval xs (Num e2) ≠ 𝟬› ‹in_carrier xs›
    by (simp add: divide_eq_0_iff)
  with ‹feval xs e1 ∈ carrier R› ‹feval xs e2 ∈ carrier R›
  show ?case by simp
qed (simp_all add: nonzero_union split: prod.split_asm)

lemma (in field) feval_eq:
  assumes "in_carrier xs"
  and "fnorm (FSub e e') = (n, d, c)"
  and "nonzero xs c"
  shows "(feval xs e = feval xs e') = (peval xs n = 𝟬)"
proof -
  from assms have "nonzero xs (Cond e)" "nonzero xs (Cond e')"
    by (auto simp add: nonzero_union split: prod.split_asm)
  with assms fnorm_correct [of xs "FSub e e'"]
  have "feval xs e ⊖ feval xs e' = peval xs n ⊘ peval xs d"
    "peval xs d ≠ 𝟬"
    by simp_all
  show ?thesis
  proof
    assume "feval xs e = feval xs e'"
    with ‹feval xs e ⊖ feval xs e' = peval xs n ⊘ peval xs d›
      ‹in_carrier xs› ‹nonzero xs (Cond e')›
    have "peval xs n ⊘ peval xs d = 𝟬"
      by (simp add: fexpr_in_carrier minus_eq r_neg)
    with ‹peval xs d ≠ 𝟬› ‹in_carrier xs›
    show "peval xs n = 𝟬"
      by (simp add: divide_eq_0_iff)
  next
    assume "peval xs n = 𝟬"
    with ‹feval xs e ⊖ feval xs e' = peval xs n ⊘ peval xs d› ‹peval xs d ≠ 𝟬›
      ‹nonzero xs (Cond e)› ‹nonzero xs (Cond e')› ‹in_carrier xs›
    show "feval xs e = feval xs e'"
      by (simp add: eq_diff0 fexpr_in_carrier)
  qed
qed

ML ‹
val term_of_nat = HOLogic.mk_number @{typ nat} o @{code integer_of_nat};

val term_of_int = HOLogic.mk_number @{typ int} o @{code integer_of_int};

fun term_of_pexpr (@{code PExpr1} x) = @{term PExpr1} $ term_of_pexpr1 x
  | term_of_pexpr (@{code PExpr2} x) = @{term PExpr2} $ term_of_pexpr2 x
and term_of_pexpr1 (@{code PCnst} k) = @{term PCnst} $ term_of_int k
  | term_of_pexpr1 (@{code PVar} n) = @{term PVar} $ term_of_nat n
  | term_of_pexpr1 (@{code PAdd} (x, y)) = @{term PAdd} $ term_of_pexpr x $ term_of_pexpr y
  | term_of_pexpr1 (@{code PSub} (x, y)) = @{term PSub} $ term_of_pexpr x $ term_of_pexpr y
  | term_of_pexpr1 (@{code PNeg} x) = @{term PNeg} $ term_of_pexpr x
and term_of_pexpr2 (@{code PMul} (x, y)) = @{term PMul} $ term_of_pexpr x $ term_of_pexpr y
  | term_of_pexpr2 (@{code PPow} (x, n)) = @{term PPow} $ term_of_pexpr x $ term_of_nat n

fun term_of_result (x, (y, zs)) =
  HOLogic.mk_prod (term_of_pexpr x, HOLogic.mk_prod
    (term_of_pexpr y, HOLogic.mk_list @{typ pexpr} (map term_of_pexpr zs)));

local

fun fnorm (ctxt, ct, t) = Thm.mk_binop @{cterm "Pure.eq :: pexpr × pexpr × pexpr list ⇒ pexpr × pexpr × pexpr list ⇒ prop"}
  ct (Thm.cterm_of ctxt t);

val (_, raw_fnorm_oracle) = Context.>>> (Context.map_theory_result
  (Thm.add_oracle (@{binding fnorm}, fnorm)));

fun fnorm_oracle ctxt ct t = raw_fnorm_oracle (ctxt, ct, t);

in

val cv = @{computation_conv "pexpr × pexpr × pexpr list"
  terms: fnorm nat_of_integer Code_Target_Nat.natural
    "0::nat" "1::nat" "2::nat" "3::nat"
    "0::int" "1::int" "2::int" "3::int" "-1::int"
  datatypes: fexpr int integer num}
  (fn ctxt => fn result => fn ct => fnorm_oracle ctxt ct (term_of_result result))

end
›

ML ‹
signature FIELD_TAC =
sig
  structure Field_Simps:
  sig
    type T
    val get: Context.generic -> T
    val put: T -> Context.generic -> Context.generic
    val map: (T -> T) -> Context.generic -> Context.generic
  end
  val eq_field_simps:
    (term * (thm list * thm list * thm list * thm * thm)) *
    (term * (thm list * thm list * thm list * thm * thm)) -> bool
  val field_tac: bool -> Proof.context -> int -> tactic
end

structure Field_Tac : FIELD_TAC =
struct

open Ring_Tac;

fun field_struct (Const (@{const_name Ring.ring.add}, _) $ R $ _ $ _) = SOME R
  | field_struct (Const (@{const_name Ring.a_minus}, _) $ R $ _ $ _) = SOME R
  | field_struct (Const (@{const_name Group.monoid.mult}, _) $ R $ _ $ _) = SOME R
  | field_struct (Const (@{const_name Ring.a_inv}, _) $ R $ _) = SOME R
  | field_struct (Const (@{const_name Group.pow}, _) $ R $ _ $ _) = SOME R
  | field_struct (Const (@{const_name Algebra_Aux.m_div}, _) $ R $ _ $ _) = SOME R
  | field_struct (Const (@{const_name Ring.ring.zero}, _) $ R) = SOME R
  | field_struct (Const (@{const_name Group.monoid.one}, _) $ R) = SOME R
  | field_struct (Const (@{const_name Algebra_Aux.of_integer}, _) $ R $ _) = SOME R
  | field_struct _ = NONE;

fun reif_fexpr vs (Const (@{const_name Ring.ring.add}, _) $ _ $ a $ b) =
      @{const FAdd} $ reif_fexpr vs a $ reif_fexpr vs b
  | reif_fexpr vs (Const (@{const_name Ring.a_minus}, _) $ _ $ a $ b) =
      @{const FSub} $ reif_fexpr vs a $ reif_fexpr vs b
  | reif_fexpr vs (Const (@{const_name Group.monoid.mult}, _) $ _ $ a $ b) =
      @{const FMul} $ reif_fexpr vs a $ reif_fexpr vs b
  | reif_fexpr vs (Const (@{const_name Ring.a_inv}, _) $ _ $ a) =
      @{const FNeg} $ reif_fexpr vs a
  | reif_fexpr vs (Const (@{const_name Group.pow}, _) $ _ $ a $ n) =
      @{const FPow} $ reif_fexpr vs a $ n
  | reif_fexpr vs (Const (@{const_name Algebra_Aux.m_div}, _) $ _ $ a $ b) =
      @{const FDiv} $ reif_fexpr vs a $ reif_fexpr vs b
  | reif_fexpr vs (Free x) =
      @{const FVar} $ HOLogic.mk_number HOLogic.natT (find_index (equal x) vs)
  | reif_fexpr vs (Const (@{const_name Ring.ring.zero}, _) $ _) =
      @{term "FCnst 0"}
  | reif_fexpr vs (Const (@{const_name Group.monoid.one}, _) $ _) =
      @{term "FCnst 1"}
  | reif_fexpr vs (Const (@{const_name Algebra_Aux.of_integer}, _) $ _ $ n) =
      @{const FCnst} $ n
  | reif_fexpr _ _ = error "reif_fexpr: bad expression";

fun reif_fexpr' vs (Const (@{const_name Groups.plus}, _) $ a $ b) =
      @{const FAdd} $ reif_fexpr' vs a $ reif_fexpr' vs b
  | reif_fexpr' vs (Const (@{const_name Groups.minus}, _) $ a $ b) =
      @{const FSub} $ reif_fexpr' vs a $ reif_fexpr' vs b
  | reif_fexpr' vs (Const (@{const_name Groups.times}, _) $ a $ b) =
      @{const FMul} $ reif_fexpr' vs a $ reif_fexpr' vs b
  | reif_fexpr' vs (Const (@{const_name Groups.uminus}, _) $ a) =
      @{const FNeg} $ reif_fexpr' vs a
  | reif_fexpr' vs (Const (@{const_name Power.power}, _) $ a $ n) =
      @{const FPow} $ reif_fexpr' vs a $ n
  | reif_fexpr' vs (Const (@{const_name divide}, _) $ a $ b) =
      @{const FDiv} $ reif_fexpr' vs a $ reif_fexpr' vs b
  | reif_fexpr' vs (Free x) =
      @{const FVar} $ HOLogic.mk_number HOLogic.natT (find_index (equal x) vs)
  | reif_fexpr' vs (Const (@{const_name zero_class.zero}, _)) =
      @{term "FCnst 0"}
  | reif_fexpr' vs (Const (@{const_name one_class.one}, _)) =
      @{term "FCnst 1"}
  | reif_fexpr' vs (Const (@{const_name numeral}, _) $ b) =
      @{const FCnst} $ (@{const numeral (int)} $ b)
  | reif_fexpr' _ _ = error "reif_fexpr: bad expression";

fun eq_field_simps
  ((t, (ths1, ths2, ths3, th4, th)),
   (t', (ths1', ths2', ths3', th4', th'))) =
    t aconv t' andalso
    eq_list Thm.eq_thm (ths1, ths1') andalso
    eq_list Thm.eq_thm (ths2, ths2') andalso
    eq_list Thm.eq_thm (ths3, ths3') andalso
    Thm.eq_thm (th4, th4') andalso
    Thm.eq_thm (th, th');

structure Field_Simps = Generic_Data
(struct
  type T = (term * (thm list * thm list * thm list * thm * thm)) Net.net
  val empty = Net.empty
  val extend = I
  val merge = Net.merge eq_field_simps
end);

fun get_field_simps ctxt optcT t =
  (case get_matching_rules ctxt (Field_Simps.get (Context.Proof ctxt)) t of
     SOME (ths1, ths2, ths3, th4, th) =>
       let val tr =
         Thm.transfer (Proof_Context.theory_of ctxt) #>
         (case optcT of NONE => I | SOME cT => inst [cT] [] #> norm)
       in (map tr ths1, map tr ths2, map tr ths3, tr th4, tr th) end
   | NONE => error "get_field_simps: lookup failed");

fun nth_el_conv (_, _, _, nth_el_Cons, _) =
  let
    val a = type_of_eqn nth_el_Cons;
    val If_conv_a = If_conv a;

    fun conv ys n = (case strip_app ys of
      (@{const_name Cons}, [x, xs]) =>
        transitive'
          (inst [] [x, xs, n] nth_el_Cons)
          (If_conv_a (args2 nat_eq_conv)
             Thm.reflexive
             (cong2' conv Thm.reflexive (args2 nat_minus_conv))))
  in conv end;

fun feval_conv (rls as
      ([feval_simps_1, feval_simps_2, feval_simps_3,
        feval_simps_4, feval_simps_5, feval_simps_6,
        feval_simps_7, feval_simps_8, feval_simps_9,
        feval_simps_10, feval_simps_11],
       _, _, _, _)) =
  let
    val nth_el_conv' = nth_el_conv rls;

    fun conv xs x = (case strip_app x of
        (@{const_name FCnst}, [c]) => (case strip_app c of
            (@{const_name zero_class.zero}, _) => inst [] [xs] feval_simps_9
          | (@{const_name one_class.one}, _) => inst [] [xs] feval_simps_10
          | (@{const_name numeral}, [n]) => inst [] [xs, n] feval_simps_11
          | _ => inst [] [xs, c] feval_simps_1)
      | (@{const_name FVar}, [n]) =>
          transitive' (inst [] [xs, n] feval_simps_2) (args2 nth_el_conv')
      | (@{const_name FAdd}, [a, b]) =>
          transitive' (inst [] [xs, a, b] feval_simps_3)
            (cong2 (args2 conv) (args2 conv))
      | (@{const_name FSub}, [a, b]) =>
          transitive' (inst [] [xs, a, b] feval_simps_4)
            (cong2 (args2 conv) (args2 conv))
      | (@{const_name FMul}, [a, b]) =>
          transitive' (inst [] [xs, a, b] feval_simps_5)
            (cong2 (args2 conv) (args2 conv))
      | (@{const_name FNeg}, [a]) =>
          transitive' (inst [] [xs, a] feval_simps_6)
            (cong1 (args2 conv))
      | (@{const_name FDiv}, [a, b]) =>
          transitive' (inst [] [xs, a, b] feval_simps_7)
            (cong2 (args2 conv) (args2 conv))
      | (@{const_name FPow}, [a, n]) =>
          transitive' (inst [] [xs, a, n] feval_simps_8)
            (cong2 (args2 conv) Thm.reflexive))
  in conv end;

fun peval_conv (rls as
      (_,
       [peval_simps_1, peval_simps_2, peval_simps_3,
        peval_simps_4, peval_simps_5, peval_simps_6,
        peval_simps_7, peval_simps_8, peval_simps_9,
        peval_simps_10, peval_simps_11],
       _, _, _)) =
  let
    val nth_el_conv' = nth_el_conv rls;

    fun conv xs x = (case strip_app x of
        (@{const_name PExpr1}, [e]) => (case strip_app e of
            (@{const_name PCnst}, [c]) => (case strip_numeral c of
                (@{const_name zero_class.zero}, _) => inst [] [xs] peval_simps_8
              | (@{const_name one_class.one}, _) => inst [] [xs] peval_simps_9
              | (@{const_name numeral}, [n]) => inst [] [xs, n] peval_simps_10
              | (@{const_name uminus}, [n]) => inst [] [xs, n] peval_simps_11
              | _ => inst [] [xs, c] peval_simps_1)
          | (@{const_name PVar}, [n]) =>
              transitive' (inst [] [xs, n] peval_simps_2) (args2 nth_el_conv')
          | (@{const_name PAdd}, [a, b]) =>
              transitive' (inst [] [xs, a, b] peval_simps_3)
                (cong2 (args2 conv) (args2 conv))
          | (@{const_name PSub}, [a, b]) =>
              transitive' (inst [] [xs, a, b] peval_simps_4)
                (cong2 (args2 conv) (args2 conv))
          | (@{const_name PNeg}, [a]) =>
              transitive' (inst [] [xs, a] peval_simps_5)
                (cong1 (args2 conv)))
      | (@{const_name PExpr2}, [e]) => (case strip_app e of
            (@{const_name PMul}, [a, b]) =>
              transitive' (inst [] [xs, a, b] peval_simps_6)
                (cong2 (args2 conv) (args2 conv))
          | (@{const_name PPow}, [a, n]) =>
              transitive' (inst [] [xs, a, n] peval_simps_7)
                (cong2 (args2 conv) Thm.reflexive)))
  in conv end;

fun nonzero_conv (rls as
      (_, _,
       [nonzero_Nil, nonzero_Cons, nonzero_singleton],
       _, _)) =
  let
    val peval_conv' = peval_conv rls;

    fun conv xs qs = (case strip_app qs of
        (@{const_name Nil}, []) => inst [] [xs] nonzero_Nil
      | (@{const_name Cons}, [p, ps]) => (case Thm.term_of ps of
            Const (@{const_name Nil}, _) =>
              transitive' (inst [] [xs, p] nonzero_singleton)
                (cong1 (cong2 (args2 peval_conv') Thm.reflexive))
          | _ => transitive' (inst [] [xs, p, ps] nonzero_Cons)
              (cong2 (cong1 (cong2 (args2 peval_conv') Thm.reflexive)) (args2 conv))))
  in conv end;

fun field_tac in_prem ctxt =
  SUBGOAL (fn (g, i) =>
    let
      val (prems, concl) = Logic.strip_horn g;
      fun find_eq s = (case s of
          (_ $ (Const (@{const_name HOL.eq}, Type (_, [T, _])) $ t $ u)) =>
            (case (field_struct t, field_struct u) of
               (SOME R, _) => SOME ((t, u), R, T, NONE, mk_in_carrier ctxt R [], reif_fexpr)
             | (_, SOME R) => SOME ((t, u), R, T, NONE, mk_in_carrier ctxt R [], reif_fexpr)
             | _ =>
                 if Sign.of_sort (Proof_Context.theory_of ctxt) (T, @{sort field})
                 then SOME ((t, u), mk_ring T, T, SOME T, K @{thm in_carrier_trivial}, reif_fexpr')
                 else NONE)
        | _ => NONE);
      val ((t, u), R, T, optT, mkic, reif) =
        (case get_first find_eq
           (if in_prem then prems else [concl]) of
           SOME q => q
         | NONE => error "cannot determine field");
      val rls as (_, _, _, _, feval_eq) =
        get_field_simps ctxt (Option.map (Thm.ctyp_of ctxt) optT) R;
      val xs = [] |> Term.add_frees t |> Term.add_frees u |> filter (equal T o snd);
      val cxs = Thm.cterm_of ctxt (HOLogic.mk_list T (map Free xs));
      val ce = Thm.cterm_of ctxt (reif xs t);
      val ce' = Thm.cterm_of ctxt (reif xs u);
      val fnorm = cv ctxt
        (Thm.apply @{cterm fnorm} (Thm.apply (Thm.apply @{cterm FSub} ce) ce'));
      val (_, [n, dc]) = strip_app (Thm.rhs_of fnorm);
      val (_, [_, c]) = strip_app dc;
      val th =
        Conv.fconv_rule (Conv.concl_conv 1 (Conv.arg_conv
          (binop_conv
             (binop_conv
                (K (feval_conv rls cxs ce)) (K (feval_conv rls cxs ce')))
             (Conv.arg1_conv (K (peval_conv rls cxs n))))))
        ([mkic xs,
          mk_obj_eq fnorm,
          mk_obj_eq (nonzero_conv rls cxs c) RS @{thm iffD2}] MRS
         feval_eq);
      val th' = Drule.rotate_prems 1
        (th RS (if in_prem then @{thm iffD1} else @{thm iffD2}));
    in
      if in_prem then
        dresolve_tac ctxt [th'] 1 THEN defer_tac 1
      else
        resolve_tac ctxt [th'] 1
    end);

end
›

context field begin

local_setup ‹
Local_Theory.declaration {syntax = false, pervasive = false}
  (fn phi => Field_Tac.Field_Simps.map (Ring_Tac.insert_rules Field_Tac.eq_field_simps
    (Morphism.term phi @{term R},
     (Morphism.fact phi @{thms feval.simps [meta] feval_Cnst [meta]},
      Morphism.fact phi @{thms peval.simps [meta] peval_Cnst [meta]},
      Morphism.fact phi @{thms nonzero.simps [meta] nonzero_singleton [meta]},
      singleton (Morphism.fact phi) @{thm nth_el_Cons [meta]},
      singleton (Morphism.fact phi) @{thm feval_eq}))))
›

end

method_setup field = ‹
  Scan.lift (Args.mode "prems") -- Attrib.thms >> (fn (in_prem, thms) => fn ctxt =>
    SIMPLE_METHOD' (Field_Tac.field_tac in_prem ctxt THEN' Ring_Tac.ring_tac in_prem thms ctxt))
› "reduce equations over fields to equations over rings"

end