tuned: more antiquotations, more abstract access to datatype typ;
authorwenzelm
Wed, 07 Aug 2024 13:25:51 +0200
changeset 80658 46eb1135f9bd
parent 80657 c6dca9d3af4e
child 80659 2191ad2d684e
tuned: more antiquotations, more abstract access to datatype typ;
src/HOL/Tools/Old_Datatype/old_datatype_aux.ML
src/HOL/Tools/Old_Datatype/old_datatype_codegen.ML
src/HOL/Tools/Old_Datatype/old_datatype_data.ML
src/HOL/Tools/Old_Datatype/old_datatype_prop.ML
src/HOL/Tools/Old_Datatype/old_primrec.ML
src/HOL/Tools/Old_Datatype/old_rep_datatype.ML
--- a/src/HOL/Tools/Old_Datatype/old_datatype_aux.ML	Wed Aug 07 12:50:22 2024 +0200
+++ b/src/HOL/Tools/Old_Datatype/old_datatype_aux.ML	Wed Aug 07 13:25:51 2024 +0200
@@ -157,7 +157,7 @@
   let
     val goal = Thm.term_of cgoal;
     val params = Logic.strip_params goal;
-    val (_, Type (tname, _)) = hd (rev params);
+    val tname = dest_Type_name (#2 (hd (rev params)));
     val exhaustion = Thm.lift_rule cgoal (exh_thm_of tname);
     val prem' = hd (Thm.prems_of exhaustion);
     val _ $ (_ $ lhs $ _) = hd (rev (Logic.strip_assums_hyp prem'));
@@ -228,25 +228,32 @@
 fun mk_fun_dtyp [] U = U
   | mk_fun_dtyp (T :: Ts) U = DtType ("fun", [T, mk_fun_dtyp Ts U]);
 
-fun name_of_typ (Type (s, Ts)) =
-      let val s' = Long_Name.base_name s in
-        space_implode "_"
-          (filter_out (equal "") (map name_of_typ Ts) @
-            [if Symbol_Pos.is_identifier s' then s' else "x"])
-      end
-  | name_of_typ _ = "";
+fun name_of_typ ty =
+  if is_Type ty then
+    let
+      val name = Long_Name.base_name (dest_Type_name ty)
+      val Ts = dest_Type_args ty
+    in
+      space_implode "_"
+        (filter_out (equal "") (map name_of_typ Ts) @
+          [if Symbol_Pos.is_identifier name then name else "x"])
+    end
+  else "";
 
 fun dtyp_of_typ _ (TFree a) = DtTFree a
-  | dtyp_of_typ _ (TVar _) = error "Illegal schematic type variable(s)"
-  | dtyp_of_typ new_dts (Type (tname, Ts)) =
-      (case AList.lookup (op =) new_dts tname of
-        NONE => DtType (tname, map (dtyp_of_typ new_dts) Ts)
-      | SOME vs =>
-          if map (try dest_TFree) Ts = map SOME vs then
-            DtRec (find_index (curry op = tname o fst) new_dts)
-          else error ("Illegal occurrence of recursive type " ^ quote tname));
+  | dtyp_of_typ new_dts T =
+      if is_TVar T then error "Illegal schematic type variable(s)"
+      else
+        let val (tname, Ts) = dest_Type T in
+          (case AList.lookup (op =) new_dts tname of
+            NONE => DtType (tname, map (dtyp_of_typ new_dts) Ts)
+          | SOME vs =>
+              if map (try dest_TFree) Ts = map SOME vs then
+                DtRec (find_index (curry op = tname o fst) new_dts)
+              else error ("Illegal occurrence of recursive type " ^ quote tname))
+        end;
 
-fun typ_of_dtyp descr (DtTFree a) = TFree a
+fun typ_of_dtyp _ (DtTFree a) = TFree a
   | typ_of_dtyp descr (DtRec i) =
       let val (s, ds, _) = the (AList.lookup (op =) descr i)
       in Type (s, map (typ_of_dtyp descr) ds) end
--- a/src/HOL/Tools/Old_Datatype/old_datatype_codegen.ML	Wed Aug 07 12:50:22 2024 +0200
+++ b/src/HOL/Tools/Old_Datatype/old_datatype_codegen.ML	Wed Aug 07 13:25:51 2024 +0200
@@ -15,7 +15,7 @@
   let
     val ctxt = Proof_Context.init_global thy
     val SOME {ctrs, injects, distincts, case_thms, ...} = Ctr_Sugar.ctr_sugar_of ctxt fcT_name
-    val Type (_, As) = body_type (fastype_of (hd ctrs))
+    val As = dest_Type_args (body_type (fastype_of (hd ctrs)))
   in
     Ctr_Sugar_Code.add_ctr_code fcT_name As (map dest_Const ctrs) injects distincts case_thms thy
   end;
--- a/src/HOL/Tools/Old_Datatype/old_datatype_data.ML	Wed Aug 07 12:50:22 2024 +0200
+++ b/src/HOL/Tools/Old_Datatype/old_datatype_data.ML	Wed Aug 07 13:25:51 2024 +0200
@@ -185,8 +185,9 @@
 
 fun all_distincts thy Ts =
   let
-    fun add_tycos (Type (tyco, Ts)) = insert (op =) tyco #> fold add_tycos Ts
-      | add_tycos _ = I;
+    fun add_tycos T =
+      if is_Type T
+      then insert (op =) (dest_Type_name T) #> fold add_tycos (dest_Type_args T) else I;
     val tycos = fold add_tycos Ts [];
   in map_filter (Option.map #distinct o get_info thy) tycos end;
 
--- a/src/HOL/Tools/Old_Datatype/old_datatype_prop.ML	Wed Aug 07 12:50:22 2024 +0200
+++ b/src/HOL/Tools/Old_Datatype/old_datatype_prop.ML	Wed Aug 07 13:25:51 2024 +0200
@@ -107,9 +107,7 @@
       if length descr' = 1 then ["P"]
       else map (fn i => "P" ^ string_of_int i) (1 upto length descr');
 
-    fun make_pred i T =
-      let val T' = T --> HOLogic.boolT
-      in Free (nth pnames i, T') end;
+    fun make_pred i T = Free (nth pnames i, T --> \<^Type>\<open>bool\<close>);
 
     fun make_ind_prem k T (cname, cargs) =
       let
@@ -161,12 +159,12 @@
         fold_rev (Logic.all o Free) frees
           (Logic.mk_implies (HOLogic.mk_Trueprop
             (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))),
-              HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT))))
+              HOLogic.mk_Trueprop (Free ("P", \<^Type>\<open>bool\<close>))))
       end;
 
     fun make_casedist ((_, (_, _, constrs))) T =
       let val prems = map (make_casedist_prem T) constrs
-      in Logic.list_implies (prems, HOLogic.mk_Trueprop (Free ("P", HOLogic.boolT))) end;
+      in Logic.list_implies (prems, HOLogic.mk_Trueprop (Free ("P", \<^Type>\<open>bool\<close>))) end;
 
   in
     map2 make_casedist (hd descr)
@@ -298,7 +296,7 @@
     val used' = fold Term.add_tfree_namesT recTs [];
     val newTs = take (length (hd descr)) recTs;
     val T' = TFree (singleton (Name.variant_list used') "'t", \<^sort>\<open>type\<close>);
-    val P = Free ("P", T' --> HOLogic.boolT);
+    val P = Free ("P", T' --> \<^Type>\<open>bool\<close>);
 
     fun make_split (((_, (_, _, constrs)), T), comb_t) =
       let
@@ -337,7 +335,7 @@
 
     fun mk_case_cong comb =
       let
-        val Type ("fun", [T, _]) = fastype_of comb;
+        val \<^Type>\<open>fun T _\<close> = fastype_of comb;
         val M = Free ("M", T);
         val M' = Free ("M'", T);
       in
@@ -367,7 +365,7 @@
 
     fun mk_case_cong ((comb, comb'), (_, (_, _, constrs))) =
       let
-        val Type ("fun", [T, _]) = fastype_of comb;
+        val \<^Type>\<open>fun T _\<close> = fastype_of comb;
         val (_, fs) = strip_comb comb;
         val (_, gs) = strip_comb comb';
         val used = ["M", "M'"] @ map (fst o dest_Free) (fs @ gs);
--- a/src/HOL/Tools/Old_Datatype/old_primrec.ML	Wed Aug 07 12:50:22 2024 +0200
+++ b/src/HOL/Tools/Old_Datatype/old_primrec.ML	Wed Aug 07 13:25:51 2024 +0200
@@ -144,7 +144,7 @@
       (case AList.lookup (op =) eqns cname of
         NONE => (warning ("No equation for constructor " ^ quote cname ^
           "\nin definition of function " ^ quote fname);
-            (fnames', fnss', (Const (\<^const_name>\<open>undefined\<close>, dummyT)) :: fns))
+            (fnames', fnss', \<^Const>\<open>undefined dummyT\<close> :: fns))
       | SOME (ls, cargs', rs, rhs, eq) =>
           let
             val recs = filter (Old_Datatype_Aux.is_rec_type o snd) (cargs' ~~ cargs);
@@ -183,9 +183,9 @@
   (case AList.lookup (op =) fns i of
     NONE =>
       let
-        val dummy_fns = map (fn (_, cargs) => Const (\<^const_name>\<open>undefined\<close>,
-          replicate (length cargs + length (filter Old_Datatype_Aux.is_rec_type cargs))
-            dummyT ---> HOLogic.unitT)) constrs;
+        val dummy_fns = map (fn (_, cargs) => \<^Const>\<open>undefined
+          \<open>replicate (length cargs + length (filter Old_Datatype_Aux.is_rec_type cargs))
+            dummyT ---> HOLogic.unitT\<close>\<close>) constrs;
         val _ = warning ("No function definition for datatype " ^ quote tname)
       in
         (dummy_fns @ fs, defs)
--- a/src/HOL/Tools/Old_Datatype/old_rep_datatype.ML	Wed Aug 07 12:50:22 2024 +0200
+++ b/src/HOL/Tools/Old_Datatype/old_rep_datatype.ML	Wed Aug 07 13:25:51 2024 +0200
@@ -40,11 +40,10 @@
 
     fun prove_casedist_thm (i, (T, t)) =
       let
-        val dummyPs = map (fn (Var (_, Type (_, [T', T'']))) =>
-          Abs ("z", T', Const (\<^const_name>\<open>True\<close>, T''))) induct_Ps;
+        val dummyPs = map (fn Var (_, \<^Type>\<open>fun A _\<close>) => Abs ("z", A, \<^Const>\<open>True\<close>)) induct_Ps;
         val P =
           Abs ("z", T, HOLogic.imp $ HOLogic.mk_eq (Var (("a", maxidx + 1), T), Bound 0) $
-            Var (("P", 0), HOLogic.boolT));
+            Var (("P", 0), \<^Type>\<open>bool\<close>));
         val insts = take i dummyPs @ (P :: drop (i + 1) dummyPs);
       in
         Goal.prove_sorry_global thy []
@@ -102,7 +101,7 @@
     val (rec_result_Ts, reccomb_fn_Ts) = Old_Datatype_Prop.make_primrec_Ts descr used;
 
     val rec_set_Ts =
-      map (fn (T1, T2) => (reccomb_fn_Ts @ [T1, T2]) ---> HOLogic.boolT) (recTs ~~ rec_result_Ts);
+      map (fn (T1, T2) => (reccomb_fn_Ts @ [T1, T2]) ---> \<^Type>\<open>bool\<close>) (recTs ~~ rec_result_Ts);
 
     val rec_fns =
       map (uncurry (Old_Datatype_Aux.mk_Free "f")) (reccomb_fn_Ts ~~ (1 upto length reccomb_fn_Ts));
@@ -204,8 +203,8 @@
       let
         val rec_unique_ts =
           map (fn (((set_t, T1), T2), i) =>
-            Const (\<^const_name>\<open>Ex1\<close>, (T2 --> HOLogic.boolT) --> HOLogic.boolT) $
-              absfree ("y", T2) (set_t $ Old_Datatype_Aux.mk_Free "x" T1 i $ Free ("y", T2)))
+            \<^Const>\<open>Ex1 T2 for
+              \<open>absfree ("y", T2) (set_t $ Old_Datatype_Aux.mk_Free "x" T1 i $ Free ("y", T2))\<close>\<close>)
                 (rec_sets ~~ recTs ~~ rec_result_Ts ~~ (1 upto length recTs));
         val insts =
           map (fn ((i, T), t) => absfree ("x" ^ string_of_int i, T) t)
@@ -248,8 +247,7 @@
             (fn ((((name, comb), set), T), T') =>
               (Binding.name (Thm.def_name (Long_Name.base_name name)),
                 Logic.mk_equals (comb, fold_rev lambda rec_fns (absfree ("x", T)
-                 (Const (\<^const_name>\<open>The\<close>, (T' --> HOLogic.boolT) --> T') $ absfree ("y", T')
-                   (set $ Free ("x", T) $ Free ("y", T')))))))
+                 \<^Const>\<open>The T' for \<open>absfree ("y", T') (set $ Free ("x", T) $ Free ("y", T'))\<close>\<close>))))
             (reccomb_names ~~ reccombs ~~ rec_sets ~~ recTs ~~ rec_result_Ts))
       ||> Sign.parent_path;
 
@@ -303,43 +301,45 @@
         let
           val Ts = map (Old_Datatype_Aux.typ_of_dtyp descr') cargs;
           val Ts' = map mk_dummyT (filter Old_Datatype_Aux.is_rec_type cargs)
-        in Const (\<^const_name>\<open>undefined\<close>, Ts @ Ts' ---> T') end) constrs) descr';
+        in \<^Const>\<open>undefined \<open>Ts @ Ts' ---> T'\<close>\<close> end) constrs) descr';
 
     val case_names0 = map (fn s => Sign.full_bname thy1 ("case_" ^ s)) new_type_names;
 
     (* define case combinators via primrec combinators *)
 
-    fun def_case ((((i, (_, _, constrs)), T as Type (Tcon, _)), name), recname) (defs, thy) =
-      if is_some (Ctr_Sugar.ctr_sugar_of ctxt Tcon) then
-        (defs, thy)
-      else
-        let
-          val (fns1, fns2) = split_list (map (fn ((_, cargs), j) =>
-            let
-              val Ts = map (Old_Datatype_Aux.typ_of_dtyp descr') cargs;
-              val Ts' = Ts @ map mk_dummyT (filter Old_Datatype_Aux.is_rec_type cargs);
-              val frees' = map2 (Old_Datatype_Aux.mk_Free "x") Ts' (1 upto length Ts');
-              val frees = take (length cargs) frees';
-              val free = Old_Datatype_Aux.mk_Free "f" (Ts ---> T') j;
-            in
-              (free, fold_rev (absfree o dest_Free) frees' (list_comb (free, frees)))
-            end) (constrs ~~ (1 upto length constrs)));
+    fun def_case ((((i, (_, _, constrs)), T), name), recname) (defs, thy) =
+      let val Tcon = dest_Type_name T in
+        if is_some (Ctr_Sugar.ctr_sugar_of ctxt Tcon) then
+          (defs, thy)
+        else
+          let
+            val (fns1, fns2) = split_list (map (fn ((_, cargs), j) =>
+              let
+                val Ts = map (Old_Datatype_Aux.typ_of_dtyp descr') cargs;
+                val Ts' = Ts @ map mk_dummyT (filter Old_Datatype_Aux.is_rec_type cargs);
+                val frees' = map2 (Old_Datatype_Aux.mk_Free "x") Ts' (1 upto length Ts');
+                val frees = take (length cargs) frees';
+                val free = Old_Datatype_Aux.mk_Free "f" (Ts ---> T') j;
+              in
+                (free, fold_rev (absfree o dest_Free) frees' (list_comb (free, frees)))
+              end) (constrs ~~ (1 upto length constrs)));
 
-          val caseT = map (snd o dest_Free) fns1 @ [T] ---> T';
-          val fns = flat (take i case_dummy_fns) @ fns2 @ flat (drop (i + 1) case_dummy_fns);
-          val reccomb = Const (recname, (map fastype_of fns) @ [T] ---> T');
-          val decl = ((Binding.name (Long_Name.base_name name), caseT), NoSyn);
-          val def =
-            (Binding.name (Thm.def_name (Long_Name.base_name name)),
-              Logic.mk_equals (Const (name, caseT),
-                fold_rev lambda fns1
-                  (list_comb (reccomb,
-                    flat (take i case_dummy_fns) @ fns2 @ flat (drop (i + 1) case_dummy_fns)))));
-          val (def_thm, thy') =
-            thy
-            |> Sign.declare_const_global decl |> snd
-            |> Global_Theory.add_def def;
-        in (defs @ [def_thm], thy') end;
+            val caseT = map (snd o dest_Free) fns1 @ [T] ---> T';
+            val fns = flat (take i case_dummy_fns) @ fns2 @ flat (drop (i + 1) case_dummy_fns);
+            val reccomb = Const (recname, (map fastype_of fns) @ [T] ---> T');
+            val decl = ((Binding.name (Long_Name.base_name name), caseT), NoSyn);
+            val def =
+              (Binding.name (Thm.def_name (Long_Name.base_name name)),
+                Logic.mk_equals (Const (name, caseT),
+                  fold_rev lambda fns1
+                    (list_comb (reccomb,
+                      flat (take i case_dummy_fns) @ fns2 @ flat (drop (i + 1) case_dummy_fns)))));
+            val (def_thm, thy') =
+              thy
+              |> Sign.declare_const_global decl |> snd
+              |> Global_Theory.add_def def;
+          in (defs @ [def_thm], thy') end
+      end;
 
     val (case_defs, thy2) =
       fold def_case (hd descr ~~ newTs ~~ case_names0 ~~ take (length newTs) reccomb_names)
@@ -350,8 +350,8 @@
         EVERY [rewrite_goals_tac ctxt (case_defs @ map mk_meta_eq primrec_thms),
           resolve_tac ctxt [refl] 1]);
 
-    fun prove_cases (Type (Tcon, _)) ts =
-      (case Ctr_Sugar.ctr_sugar_of ctxt Tcon of
+    fun prove_cases T ts =
+      (case Ctr_Sugar.ctr_sugar_of ctxt (dest_Type_name T) of
         SOME {case_thms, ...} => case_thms
       | NONE => map prove_case ts);
 
@@ -455,8 +455,8 @@
   let
     fun prove_case_cong ((t, nchotomy), case_rewrites) =
       let
-        val Const (\<^const_name>\<open>Pure.imp\<close>, _) $ tm $ _ = t;
-        val Const (\<^const_name>\<open>Trueprop\<close>, _) $ (Const (\<^const_name>\<open>HOL.eq\<close>, _) $ _ $ Ma) = tm;
+        val \<^Const_>\<open>Pure.imp for tm _\<close> = t;
+        val \<^Const_>\<open>Trueprop for \<^Const_>\<open>HOL.eq _ for _ Ma\<close>\<close> = tm;
         val nchotomy' = nchotomy RS spec;
         val [v] = Term.add_var_names (Thm.concl_of nchotomy') [];
       in