misc tuning and clarification;
authorwenzelm
Fri, 29 Oct 2021 19:43:32 +0200
changeset 74627 c690435c47ee
parent 74626 9a1f4a7ddf9e
child 74628 cd030003efa8
misc tuning and clarification;
src/HOL/Analysis/metric_arith.ML
--- a/src/HOL/Analysis/metric_arith.ML	Fri Oct 29 19:17:24 2021 +0200
+++ b/src/HOL/Analysis/metric_arith.ML	Fri Oct 29 19:43:32 2021 +0200
@@ -13,26 +13,21 @@
 structure Metric_Arith : METRIC_ARITH =
 struct
 
-fun default d x = case x of SOME y => SOME y | NONE => d
-
 (* apply f to both cterms in ct_pair, merge results *)
 fun app_union_ct_pair f ct_pair = uncurry (union (op aconvc)) (apply2 f ct_pair)
 
 val trace = Attrib.setup_config_bool \<^binding>\<open>metric_trace\<close> (K false)
 
 fun trace_tac ctxt msg =
-  if Config.get ctxt trace then print_tac ctxt msg
-  else all_tac
+  if Config.get ctxt trace then print_tac ctxt msg else all_tac
 
 fun argo_trace_ctxt ctxt =
   if Config.get ctxt trace
   then Config.map (Argo_Tactic.trace) (K "basic") ctxt
   else ctxt
 
-fun IF_UNSOLVED' tac i = IF_UNSOLVED (tac i)
-fun REPEAT' tac i = REPEAT (tac i)
-
-fun free_in v ct = Cterms.defined (Cterms.build (Drule.add_frees_cterm ct)) v
+fun free_in v t =
+  Term.exists_subterm (fn u => u aconv Thm.term_of v) (Thm.term_of t);
 
 (* build a cterm set with elements cts of type ty *)
 fun mk_ct_set ctxt ty =
@@ -76,89 +71,68 @@
     K (trace_tac ctxt ("Simplified using " ^ @{make_string} refl_sym_simps))
   end
 
-fun is_exists ct =
-  case Thm.term_of ct of
-    Const (\<^const_name>\<open>HOL.Ex\<close>,_)$_ => true
-  | Const (\<^const_name>\<open>Trueprop\<close>,_)$_ => is_exists (Thm.dest_arg ct)
-  | _ => false
+fun is_exists \<^Const_>\<open>Ex _ for _\<close> = true
+  | is_exists \<^Const_>\<open>Trueprop for t\<close> = is_exists t
+  | is_exists _ = false
 
-fun is_forall ct =
-  case Thm.term_of ct of
-    Const (\<^const_name>\<open>HOL.All\<close>,_)$_ => true
-  | Const (\<^const_name>\<open>Trueprop\<close>,_)$_ => is_forall (Thm.dest_arg ct)
-  | _ => false
+fun is_forall \<^Const_>\<open>All _ for _\<close> = true
+  | is_forall \<^Const_>\<open>Trueprop for t\<close> = is_forall t
+  | is_forall _ = false
 
-fun dist_ty mty = mty --> mty --> \<^typ>\<open>real\<close>
 
 (* find all free points in ct of type metric_ty *)
 fun find_points ctxt metric_ty ct =
   let
     fun find ct =
-      (if Thm.typ_of_cterm ct = metric_ty then [ct] else []) @ (
-        case Thm.term_of ct of
-          _ $ _ =>
-          app_union_ct_pair find (Thm.dest_comb ct)
-        | Abs (_, _, _) =>
-          (* ensure the point doesn't contain the bound variable *)
-          let val (var, bod) = Thm.dest_abs_global ct in
-            filter (free_in var #> not) (find bod)
-          end
-        | _ => [])
-    val points = find ct
+      (if Thm.typ_of_cterm ct = metric_ty then [ct] else []) @
+      (case Thm.term_of ct of
+        _ $ _ => app_union_ct_pair find (Thm.dest_comb ct)
+      | Abs _ =>
+          (*ensure the point doesn't contain the bound variable*)
+          let val (x, body) = Thm.dest_abs_global ct
+          in filter_out (free_in x) (find body) end
+      | _ => [])
   in
-    case points of
+    (case find ct of
       [] =>
-      (* if no point can be found, invent one *)
-      let
-        val free_name = Term.variant_frees (Thm.term_of ct) [("x", metric_ty)]
-      in
-        map (Free #> Thm.cterm_of ctxt) free_name
-      end
-    | _ => points
+        (*if no point can be found, invent one*)
+        let val x = singleton (Term.variant_frees (Thm.term_of ct)) ("x", metric_ty)
+        in [Thm.cterm_of ctxt (Free x)] end
+    | points => points)
   end
 
 (* find all cterms "dist x y" in ct, where x and y have type metric_ty *)
-fun find_dist metric_ty ct =
+fun find_dist metric_ty =
   let
-    val dty = dist_ty metric_ty
     fun find ct =
-      case Thm.term_of ct of
-        Const (\<^const_name>\<open>dist\<close>, ty) $ _ $ _ =>
-        if ty = dty then [ct] else []
-      | _ $ _ =>
-        app_union_ct_pair find (Thm.dest_comb ct)
-      | Abs (_, _, _) =>
-        let val (var, bod) = Thm.dest_abs_global ct in
-          filter (free_in var #> not) (find bod)
-        end
-      | _ => []
-  in
-    find ct
-  end
+      (case Thm.term_of ct of
+        \<^Const_>\<open>dist T for _ _\<close> => if T = metric_ty then [ct] else []
+      | _ $ _ => app_union_ct_pair find (Thm.dest_comb ct)
+      | Abs _ =>
+          let val (x, body) = Thm.dest_abs_global ct
+          in filter_out (free_in x) (find body) end
+      | _ => [])
+  in find end
 
 (* find all "x=y", where x has type metric_ty *)
-fun find_eq metric_ty ct =
+fun find_eq metric_ty =
   let
     fun find ct =
-      case Thm.term_of ct of
-        Const (\<^const_name>\<open>HOL.eq\<close>, ty) $ _ $ _ =>
-          if fst (dest_funT ty) = metric_ty
-          then [ct]
+      (case Thm.term_of ct of
+        \<^Const_>\<open>HOL.eq T for _ _\<close> =>
+          if T = metric_ty then [ct]
           else app_union_ct_pair find (Thm.dest_binop ct)
       | _ $ _ => app_union_ct_pair find (Thm.dest_comb ct)
-      | Abs (_, _, _) =>
-        let val (var, bod) = Thm.dest_abs_global ct in
-          filter (free_in var #> not) (find bod)
-        end
-      | _ => []
-  in
-    find ct
-  end
+      | Abs _ =>
+          let val (x, body) = Thm.dest_abs_global ct
+          in filter_out (free_in x) (find body) end
+      | _ => [])
+  in find end
 
 (* rewrite ct of the form "dist x y" using maxdist_thm *)
 fun maxdist_conv ctxt fset_ct ct =
   let
-    val (xct, yct) = Thm.dest_binop ct
+    val (x, y) = Thm.dest_binop ct
     val solve_prems =
       rule_by_tactic ctxt (ALLGOALS (simp_tac (put_simpset HOL_ss ctxt
         addsimps @{thms finite.emptyI finite_insert empty_iff insert_iff})))
@@ -177,7 +151,7 @@
       Simplifier.rewrite (put_simpset HOL_ss ctxt addsimps @{thms real_abs_dist})
     val maxdist_thm =
       @{thm maxdist_thm} |>
-      infer_instantiate' ctxt [SOME fset_ct, SOME xct, SOME yct] |>
+      infer_instantiate' ctxt [SOME fset_ct, SOME x, SOME y] |>
       solve_prems
   in
     ((Conv.rewr_conv maxdist_thm) then_conv
@@ -195,7 +169,7 @@
 (* rewrite ct of the form "x=y" using metric_eq_thm *)
 fun metric_eq_conv ctxt fset_ct ct =
   let
-    val (xct, yct) = Thm.dest_binop ct
+    val (x, y) = Thm.dest_binop ct
     val solve_prems =
       rule_by_tactic ctxt (ALLGOALS (simp_tac (put_simpset HOL_ss ctxt
         addsimps @{thms empty_iff insert_iff})))
@@ -206,11 +180,11 @@
       Simplifier.rewrite (put_simpset HOL_ss ctxt addsimps @{thms dist_commute dist_self})
     val metric_eq_thm =
       @{thm metric_eq_thm} |>
-      infer_instantiate' ctxt [SOME xct, SOME fset_ct, SOME yct] |>
+      infer_instantiate' ctxt [SOME x, SOME fset_ct, SOME y] |>
       solve_prems
   in
     ((Conv.rewr_conv metric_eq_thm) then_conv
-    (* convert \<forall>x\<in>{x\<^sub>1,...,x\<^sub>n}. P x to P x\<^sub>1 \<and> ... \<and> P x\<^sub>n *)
+    (*convert \<forall>x\<in>{x\<^sub>1,...,x\<^sub>n}. P x to P x\<^sub>1 \<and> ... \<and> P x\<^sub>n*)
     ball_simp then_conv
     dist_refl_sym_simp) ct
   end
@@ -227,9 +201,9 @@
   let
     val points = find_points ctxt metric_ty goal
     val fset_ct = mk_ct_set ctxt metric_ty points
-    (* embed all subterms of the form "dist x y" in (\<real>\<^sup>n,dist\<^sub>\<infinity>) *)
+    (*embed all subterms of the form "dist x y" in (\<real>\<^sup>n,dist\<^sub>\<infinity>)*)
     val eq1 = map (maxdist_conv ctxt fset_ct) (find_dist metric_ty goal)
-    (* replace point equality by equality of components in \<real>\<^sup>n *)
+    (*replace point equality by equality of components in \<real>\<^sup>n*)
     val eq2 = map (metric_eq_conv ctxt fset_ct) (find_eq metric_ty goal)
   in
     (K (trace_tac ctxt "Embedding into \<real>\<^sup>n") THEN'
@@ -245,85 +219,81 @@
 
 fun basic_metric_arith_tac ctxt metric_ty =
   HEADGOAL (dist_refl_sym_tac ctxt THEN'
-  IF_UNSOLVED' (embedding_tac ctxt metric_ty) THEN'
-  IF_UNSOLVED' (pre_arith_tac ctxt) THEN'
-  IF_UNSOLVED' (lin_real_arith_tac ctxt metric_ty))
+  IF_UNSOLVED o (embedding_tac ctxt metric_ty) THEN'
+  IF_UNSOLVED o (pre_arith_tac ctxt) THEN'
+  IF_UNSOLVED o (lin_real_arith_tac ctxt metric_ty))
 
 (* tries to infer the metric space from ct from dist terms,
    if no dist terms are present, equality terms will be used *)
-fun guess_metric ctxt ct =
-let
-  fun find_dist ct =
-    case Thm.term_of ct of
-      Const (\<^const_name>\<open>dist\<close>, ty) $ _ $ _  => SOME (fst (dest_funT ty))
-    | _ $ _ =>
-      let val (s, t) = Thm.dest_comb ct in
-        default (find_dist t) (find_dist s)
-      end
-    | Abs (_, _, _) => find_dist (snd (Thm.dest_abs_global ct))
-    | _ => NONE
-  fun find_eq ct =
-    case Thm.term_of ct of
-      Const (\<^const_name>\<open>HOL.eq\<close>, ty) $ x $ _ =>
-      let val (l, r) = Thm.dest_binop ct in
-        if Sign.of_sort (Proof_Context.theory_of ctxt) (type_of x, \<^sort>\<open>metric_space\<close>)
-        then SOME (fst (dest_funT ty))
-        else default (find_dist r) (find_eq l)
-      end
-    | _ $ _ =>
-      let val (s, t) = Thm.dest_comb ct in
-        default (find_eq t) (find_eq s)
-      end
-    | Abs (_, _, _) => find_eq (snd (Thm.dest_abs_global ct))
-    | _ => NONE
-  in
-    case default (find_eq ct) (find_dist ct) of
-      SOME ty => ty
-    | NONE => error "No Metric Space was found"
-  end
+fun guess_metric ctxt tm =
+  let
+    val thy = Proof_Context.theory_of ctxt
+    fun find_dist t =
+      (case t of
+        \<^Const_>\<open>dist T for _ _\<close>  => SOME T
+      | t1 $ t2 => (case find_dist t1 of NONE => find_dist t2 | some => some)
+      | Abs _ => find_dist (#2 (Term.dest_abs_global t))
+      | _ => NONE)
+    fun find_eq t =
+      (case t of
+        \<^Const_>\<open>HOL.eq T for l r\<close> =>
+          if Sign.of_sort thy (T, \<^sort>\<open>metric_space\<close>) then SOME T
+          else (case find_eq l of NONE => find_dist r (* FIXME find_eq!? *) | some => some)
+      | t1 $ t2 => (case find_eq t1 of NONE => find_eq t2 | some => some)
+      | Abs _ => find_eq (#2 (Term.dest_abs_global t))
+      | _ => NONE)
+    in
+      (case find_dist tm of
+        SOME ty => ty
+      | NONE =>
+          case find_eq tm of
+            SOME ty => ty
+          | NONE => error "No Metric Space was found")
+    end
 
-(* eliminate \<exists> by proving the goal for a single witness from the metric space *)
-fun elim_exists ctxt goal =
+(* solve \<exists> by proving the goal for a single witness from the metric space *)
+fun exists_tac ctxt st =
   let
-    val ct = Thm.cprem_of goal 1
-    val metric_ty = guess_metric ctxt ct
-    val points = find_points ctxt metric_ty ct
+    val goal = Thm.cprem_of st 1
+    val metric_ty = guess_metric ctxt (Thm.term_of goal)
+    val points = find_points ctxt metric_ty goal
 
-    fun try_point ctxt pt =
-      let val ex_rule = infer_instantiate' ctxt [NONE, SOME pt] @{thm exI}
+    fun try_point_tac ctxt pt =
+      let
+        val ex_rule =
+          \<^instantiate>\<open>'a = \<open>Thm.ctyp_of_cterm pt\<close> and x = pt in
+            lemma (schematic) \<open>P x \<Longrightarrow> \<exists>x::'a. P x\<close> by (rule exI)\<close>
       in
         HEADGOAL (resolve_tac ctxt [ex_rule] ORELSE'
-        (* variable doesn't occur in body *)
+        (*variable doesn't occur in body*)
         resolve_tac ctxt @{thms exI}) THEN
         trace_tac ctxt ("Removed existential quantifier, try " ^ @{make_string} pt) THEN
-        try_points ctxt
+        try_points_tac ctxt
       end
-    and try_points ctxt goal = (
-      if is_exists (Thm.cprem_of goal 1) then
-        FIRST (map (try_point ctxt) points)
-      else if is_forall (Thm.cprem_of goal 1) then
+    and try_points_tac ctxt st = (
+      if is_exists (Thm.major_prem_of st) then
+        FIRST (map (try_point_tac ctxt) points)
+      else if is_forall (Thm.major_prem_of st) then
         HEADGOAL (resolve_tac ctxt @{thms HOL.allI} THEN'
         Subgoal.FOCUS (fn {context = ctxt', ...} =>
           trace_tac ctxt "Removed universal quantifier" THEN
-          try_points ctxt') ctxt)
-      else basic_metric_arith_tac ctxt metric_ty) goal
-  in
-    try_points ctxt goal
-  end
+          try_points_tac ctxt') ctxt)
+      else basic_metric_arith_tac ctxt metric_ty) st
+  in try_points_tac ctxt st end
 
 fun metric_arith_tac ctxt =
-  (* unfold common definitions to get rid of sets *)
+  (*unfold common definitions to get rid of sets*)
   unfold_tac ctxt THEN'
-  (* remove all meta-level connectives *)
-  IF_UNSOLVED' (Object_Logic.full_atomize_tac ctxt) THEN'
-  (* convert goal to prenex form *)
-  IF_UNSOLVED' (prenex_tac ctxt) THEN'
-  (* and NNF to ? *)
-  IF_UNSOLVED' (nnf_tac ctxt) THEN'
-  (* turn all universally quantified variables into free variables, by focusing the subgoal *)
-  REPEAT' (resolve_tac ctxt @{thms HOL.allI}) THEN'
-  IF_UNSOLVED' (SUBPROOF (fn {context=ctxt', ...} =>
+  (*remove all meta-level connectives*)
+  IF_UNSOLVED o (Object_Logic.full_atomize_tac ctxt) THEN'
+  (*convert goal to prenex form*)
+  IF_UNSOLVED o (prenex_tac ctxt) THEN'
+  (*and NNF to ?*)
+  IF_UNSOLVED o (nnf_tac ctxt) THEN'
+  (*turn all universally quantified variables into free variables, by focusing the subgoal*)
+  REPEAT o (resolve_tac ctxt @{thms HOL.allI}) THEN'
+  IF_UNSOLVED o SUBPROOF (fn {context = ctxt', ...} =>
     trace_tac ctxt' "Focused on subgoal" THEN
-    elim_exists ctxt') ctxt)
+    exists_tac ctxt') ctxt
 
 end