ML antiquotations for type constructors and term constants;
authorwenzelm
Sat, 11 Sep 2021 21:16:23 +0200
changeset 74291 b83fa8f3a271
parent 74290 b2ad24b5a42c
child 74292 39c98371606f
ML antiquotations for type constructors and term constants;
NEWS
etc/symbols
src/Pure/ML/ml_antiquotation.ML
src/Pure/ML/ml_antiquotations1.ML
src/Pure/ML/ml_lex.ML
src/Pure/consts.ML
--- a/NEWS	Sat Sep 11 13:04:32 2021 +0200
+++ b/NEWS	Sat Sep 11 21:16:23 2021 +0200
@@ -251,6 +251,25 @@
 * ML antiquotations \<^tvar>\<open>?'a::sort\<close> and \<^var>\<open>?x::type\<close> inline
 corresponding ML values, notably as arguments for Thm.instantiate etc.
 
+* ML antiquotations for type constructors and term constants:
+
+    \<^Type>\<open>c\<close>
+    \<^Type>\<open>c T \<dots>\<close>       \<comment> \<open>same with type arguments\<close>
+    \<^Const>\<open>c\<close>
+    \<^Const>\<open>c T \<dots>\<close>      \<comment> \<open>same with type arguments\<close>
+    \<^Const>\<open>c for t \<dots>\<close>  \<comment> \<open>same with term arguments\<close>
+    \<^Const_>\<open>c \<dots>\<close>       \<comment> \<open>same for patterns: case, let, fn\<close>
+
+Examples in HOL:
+
+  val natT = \<^Type>\<open>nat\<close>;
+  fun mk_funT (A, B) = \<^Type>\<open>fun A B\<close>;
+  val dest_funT = fn \<^Type>\<open>fun A B\<close> => (A, B);
+  fun mk_conj (A, B) = \<^Const>\<open>conj for A B\<close>;
+  val dest_conj = fn \<^Const_>\<open>conj for A B\<close> => (A, B);
+  fun mk_eq T (t, u) = \<^Const>\<open>HOL.eq T for t u\<close>;
+  val dest_eq = fn \<^Const_>\<open>HOL.eq T for t u\<close> => (T, (t, u));
+
 * The "build" combinators of various data structures help to build
 content from bottom-up, by applying an "add" function the "empty" value.
 For example:
--- a/etc/symbols	Sat Sep 11 13:04:32 2021 +0200
+++ b/etc/symbols	Sat Sep 11 21:16:23 2021 +0200
@@ -489,6 +489,9 @@
 \<^type_syntax>         argument: cartouche
 \<^var>                 argument: cartouche
 \<^oracle_name>         argument: cartouche
+\<^Const>               argument: cartouche
+\<^Const_>              argument: cartouche
+\<^Type>                argument: cartouche
 \<^code>                argument: cartouche
 \<^computation>         argument: cartouche
 \<^computation_conv>    argument: cartouche
--- a/src/Pure/ML/ml_antiquotation.ML	Sat Sep 11 13:04:32 2021 +0200
+++ b/src/Pure/ML/ml_antiquotation.ML	Sat Sep 11 21:16:23 2021 +0200
@@ -72,7 +72,7 @@
     (fn _ => fn src => fn ctxt =>
       let
         val (s, _) = Token.syntax (Scan.lift Args.embedded_input) src ctxt;
-        val tokenize = ML_Lex.tokenize_range Position.no_range;
+        val tokenize = ML_Lex.tokenize_no_range;
         val tokenize_range = ML_Lex.tokenize_range (Input.range_of s);
 
         val (decl, ctxt') = ML_Context.expand_antiquotes (ML_Lex.read_source s) ctxt;
--- a/src/Pure/ML/ml_antiquotations1.ML	Sat Sep 11 13:04:32 2021 +0200
+++ b/src/Pure/ML/ml_antiquotations1.ML	Sat Sep 11 21:16:23 2021 +0200
@@ -192,6 +192,134 @@
         in ML_Syntax.atomic (ML_Syntax.print_term const) end)));
 
 
+(* type/term constructors *)
+
+local
+
+fun read_embedded ctxt src parse =
+  let
+    val keywords = Thy_Header.get_keywords' ctxt;
+    val input = #1 (Token.syntax (Scan.lift Args.embedded_input) src ctxt);
+    val syms = Input.source_explode input;
+  in
+    (case Token.read_body keywords parse syms of
+      SOME res => res
+    | NONE => error ("Bad input" ^ Position.here (Input.pos_of input)))
+  end;
+
+fun ml_sources ctxt srcs =
+  let
+    val (decls, ctxt') = fold_map (ML_Context.expand_antiquotes o ML_Lex.read_source) srcs ctxt;
+    fun decl' ctxt'' = map (fn decl => decl ctxt'') decls;
+  in (decl', ctxt') end
+
+val parse_name = Parse.input Parse.name;
+val parse_args = Scan.repeat (Parse.input Parse.underscore || Parse.embedded_input);
+val parse_for_args =
+  Scan.optional ((Parse.position (Parse.$$$ "for") >> #2) -- Parse.!!! parse_args)
+    (Position.none, []);
+
+fun is_dummy s = Input.string_of s = "_";
+
+val ml = ML_Lex.tokenize_no_range;
+val ml_dummy = ml "_";
+fun ml_parens x = ml "(" @ x @ ml ")";
+fun ml_bracks x = ml "[" @ x @ ml "]";
+fun ml_commas xs = flat (separate (ml ", ") xs);
+val ml_list = ml_bracks o ml_commas;
+val ml_string = ml o ML_Syntax.print_string;
+fun ml_pair (x, y) = ml_parens (ml_commas [x, y]);
+
+fun type_antiquotation binding =
+  ML_Context.add_antiquotation binding true
+    (fn _ => fn src => fn ctxt =>
+      let
+        val (s, type_args) = read_embedded ctxt src (parse_name -- parse_args);
+        val pos = Input.pos_of s;
+
+        val Type (c, Ts) =
+          Proof_Context.read_type_name {proper = true, strict = true} ctxt
+            (Syntax.implode_input s);
+        val n = length Ts;
+        val _ =
+          length type_args = n orelse
+            error ("Type constructor " ^ quote (Proof_Context.markup_type ctxt c) ^
+              " takes " ^ string_of_int n ^ " argument(s)" ^ Position.here pos);
+
+        val (decls1, ctxt1) = ml_sources ctxt type_args;
+        fun decl' ctxt' =
+          let
+            val (ml_args_env, ml_args_body) = split_list (decls1 ctxt');
+            val ml_body = ml_parens (ml "Term.Type " @ ml_pair (ml_string c, ml_list ml_args_body));
+          in (flat ml_args_env, ml_body) end;
+      in (decl', ctxt1) end);
+
+fun const_antiquotation binding pattern =
+  ML_Context.add_antiquotation binding true
+    (fn _ => fn src => fn ctxt =>
+      let
+        val ((s, type_args), (for_pos, term_args)) =
+          read_embedded ctxt src (parse_name -- parse_args -- parse_for_args);
+        val _ = Context_Position.report ctxt for_pos (Markup.keyword_properties Markup.keyword1);
+
+        val Const (c, T) =
+          Proof_Context.read_const {proper = true, strict = true} ctxt
+            (Syntax.implode_input s);
+
+        val consts = Proof_Context.consts_of ctxt;
+        val type_paths = Consts.type_arguments consts c;
+        val type_params = map Term.dest_TVar (Consts.typargs consts (c, T));
+
+        val n = length type_params;
+        val m = length (Term.binder_types T);
+
+        fun err msg =
+          error ("Constant " ^ quote (Proof_Context.markup_const ctxt c) ^ msg ^
+            Position.here (Input.pos_of s));
+        val _ =
+          length type_args <> n andalso err (" takes " ^ string_of_int n ^ " type argument(s)");
+        val _ =
+          length term_args > m andalso Term.is_Type (Term.body_type T) andalso
+            err (" cannot have more than " ^ string_of_int m ^ " type argument(s)");
+
+        val (decls1, ctxt1) = ml_sources ctxt type_args;
+        val (decls2, ctxt2) = ml_sources ctxt1 term_args;
+        fun decl' ctxt' =
+          let
+            val (ml_args_env1, ml_args_body1) = split_list (decls1 ctxt');
+            val (ml_args_env2, ml_args_body2) = split_list (decls2 ctxt');
+
+            val relevant = map is_dummy type_args ~~ type_paths;
+            fun relevant_path is =
+              not pattern orelse
+                let val p = rev is
+                in relevant |> exists (fn (u, q) => not u andalso is_prefix (op =) p q) end;
+
+            val ml_typarg = the o AList.lookup (op =) (type_params ~~ ml_args_body1);
+            fun ml_typ is (Type (d, Us)) =
+                  if relevant_path is then
+                    ml "Term.Type " @
+                    ml_pair (ml_string d, ml_list (map_index (fn (i, U) => ml_typ (i :: is) U) Us))
+                  else ml_dummy
+              | ml_typ is (TVar arg) = if relevant_path is then ml_typarg arg else ml_dummy
+              | ml_typ _ (TFree _) = raise Match;
+
+            fun ml_app [] = ml "Term.Const " @ ml_pair (ml_string c, ml_typ [] T)
+              | ml_app (u :: us) = ml "Term.$ " @ ml_pair (ml_app us, u);
+
+            val ml_env = flat (ml_args_env1 @ ml_args_env2);
+            val ml_body = ml_parens (ml_app (rev ml_args_body2));
+          in (ml_env, ml_body) end;
+      in (decl', ctxt2) end);
+
+val _ = Theory.setup
+ (type_antiquotation \<^binding>\<open>Type\<close> #>
+  const_antiquotation \<^binding>\<open>Const\<close> false #>
+  const_antiquotation \<^binding>\<open>Const_\<close> true);
+
+in end;
+
+
 (* special forms *)
 
 val _ = Theory.setup
--- a/src/Pure/ML/ml_lex.ML	Sat Sep 11 13:04:32 2021 +0200
+++ b/src/Pure/ML/ml_lex.ML	Sat Sep 11 21:16:23 2021 +0200
@@ -28,6 +28,7 @@
       Source.source) Source.source
   val tokenize: string -> token list
   val tokenize_range: Position.range -> string -> token list
+  val tokenize_no_range: string -> token list
   val read_text: Symbol_Pos.text * Position.T -> token Antiquote.antiquote list
   val read: Symbol_Pos.text -> token Antiquote.antiquote list
   val read_range: Position.range -> Symbol_Pos.text -> token Antiquote.antiquote list
@@ -359,6 +360,7 @@
 
 val tokenize = Symbol.explode #> Source.of_list #> source #> Source.exhaust;
 fun tokenize_range range = tokenize #> map (set_range range);
+val tokenize_no_range = tokenize_range Position.no_range;
 
 val read_text = reader {opaque_warning = true} scan_ml_antiq o Symbol_Pos.explode;
 fun read text = read_text (text, Position.none);
--- a/src/Pure/consts.ML	Sat Sep 11 13:04:32 2021 +0200
+++ b/src/Pure/consts.ML	Sat Sep 11 21:16:23 2021 +0200
@@ -19,6 +19,7 @@
   val the_const: T -> string -> string * typ                   (*exception TYPE*)
   val the_abbreviation: T -> string -> typ * term              (*exception TYPE*)
   val type_scheme: T -> string -> typ                          (*exception TYPE*)
+  val type_arguments: T -> string -> int list list             (*exception TYPE*)
   val is_monomorphic: T -> string -> bool                      (*exception TYPE*)
   val the_constraint: T -> string -> typ                       (*exception TYPE*)
   val space_of: T -> Name_Space.T