support for "lemma";
authorwenzelm
Thu, 28 Oct 2021 18:37:33 +0200
changeset 74606 40f5c6b2e8aa
parent 74605 8b7258c61649
child 74607 7f6178b655a8
support for "lemma"; support for "schematic" mode; clarified error position;
src/Pure/ML/ml_instantiate.ML
src/Pure/ML/ml_thms.ML
--- a/src/Pure/ML/ml_instantiate.ML	Thu Oct 28 13:20:45 2021 +0200
+++ b/src/Pure/ML/ml_instantiate.ML	Thu Oct 28 18:37:33 2021 +0200
@@ -14,6 +14,10 @@
   val instantiate_ctyp: Position.T -> cinsts -> ctyp -> ctyp
   val instantiate_term: insts -> term -> term
   val instantiate_cterm: Position.T -> cinsts -> cterm -> cterm
+  val instantiate_thm: Position.T -> cinsts -> thm -> thm
+  val instantiate_thms: Position.T -> cinsts -> thm list -> thm list
+  val get_thms: Proof.context -> int -> thm list
+  val get_thm: Proof.context -> int -> thm
 end;
 
 structure ML_Instantiate: ML_INSTANTIATE =
@@ -41,14 +45,41 @@
     val inst = Vars.make ((map o apfst o apsnd) instantiateT (#2 insts));
   in Term_Subst.instantiate_beta (instT, inst) end;
 
-fun instantiate_cterm pos (cinsts: cinsts) ct =
+fun make_cinsts (cinsts: cinsts) =
   let
     val cinstT = TVars.make (#1 cinsts);
     val instantiateT = Term_Subst.instantiateT (TVars.map (K Thm.typ_of) cinstT);
     val cinst = Vars.make ((map o apfst o apsnd) instantiateT (#2 cinsts));
-  in Thm.instantiate_beta_cterm (cinstT, cinst) ct end
+  in (cinstT, cinst) end;
+
+fun instantiate_cterm pos cinsts ct =
+  Thm.instantiate_beta_cterm (make_cinsts cinsts) ct
   handle CTERM (msg, args) => Exn.reraise (CTERM (msg ^ Position.here pos, args));
 
+fun instantiate_thm pos cinsts th =
+  Thm.instantiate_beta (make_cinsts cinsts) th
+  handle THM (msg, i, args) => Exn.reraise (THM (msg ^ Position.here pos, i, args));
+
+fun instantiate_thms pos cinsts = map (instantiate_thm pos cinsts);
+
+
+(* context data *)
+
+structure Data = Proof_Data
+(
+  type T = int * thm list Inttab.table;
+  fun init _ = (0, Inttab.empty);
+);
+
+fun put_thms ths ctxt =
+  let
+    val (i, thms) = Data.get ctxt;
+    val ctxt' = ctxt |> Data.put (i + 1, Inttab.update (i, ths) thms);
+  in (i, ctxt') end;
+
+fun get_thms ctxt i = the (Inttab.lookup (#2 (Data.get ctxt)) i);
+fun get_thm ctxt i = the_single (get_thms ctxt i);
+
 
 (* ML antiquotation *)
 
@@ -57,7 +88,7 @@
 val make_keywords =
   Thy_Header.get_keywords'
   #> Keyword.no_major_keywords
-  #> Keyword.add_major_keywords ["typ", "term", "prop", "ctyp", "cterm", "cprop"];
+  #> Keyword.add_major_keywords ["typ", "term", "prop", "ctyp", "cterm", "cprop", "lemma"];
 
 val parse_inst_name = Parse.position (Parse.type_ident >> pair true || Parse.name >> pair false);
 
@@ -77,6 +108,7 @@
 val ml_list = ml_bracks o ml_commas;
 fun ml_pair (x, y) = ml_parens (ml_commas [x, y]);
 val ml_list_pair = ml_list o ListPair.map ml_pair;
+val ml_here = ML_Syntax.atomic o ML_Syntax.print_position;
 
 fun get_tfree envT (a, pos) =
   (case AList.lookup (op =) envT a of
@@ -89,15 +121,19 @@
       (Context_Position.reports ctxt (map (pair pos) (Syntax_Phases.markup_free ctxt x)); (x, T))
   | NONE => error ("No occurrence of variable " ^ quote x ^ Position.here pos));
 
-fun missing_instT envT instT =
+fun missing_instT pos envT instT =
   (case filter_out (fn (a, _) => exists (fn (b, _) => a = b) instT) envT of
     [] => ()
-  | bad => error ("No instantiation for free type variable(s) " ^ commas_quote (map #1 bad)));
+  | bad =>
+      error ("No instantiation for free type variable(s) " ^ commas_quote (map #1 bad) ^
+        Position.here pos));
 
-fun missing_inst env inst =
+fun missing_inst pos env inst =
   (case filter_out (fn (a, _) => exists (fn (b, _) => a = b) inst) env of
     [] => ()
-  | bad => error ("No instantiation for free variable(s) " ^ commas_quote (map #1 bad)));
+  | bad =>
+      error ("No instantiation for free variable(s) " ^ commas_quote (map #1 bad) ^
+        Position.here pos));
 
 fun make_instT (a, pos) T =
   (case try (Term.dest_TVar o Logic.dest_type) T of
@@ -109,25 +145,32 @@
     NONE => error ("Not a free variable " ^ quote a ^ Position.here pos)
   | SOME v => ml (ML_Syntax.print_pair ML_Syntax.print_indexname ML_Syntax.print_typ v));
 
-fun make_env t = (Term.add_tfrees t [], Term.add_frees t []);
+fun make_env ts =
+  let
+    val envT = fold Term.add_tfrees ts [];
+    val env = fold Term.add_frees ts [];
+  in (envT, env) end;
 
-fun prepare_insts ctxt1 ctxt0 (instT, inst) t =
+fun prepare_insts pos {schematic} ctxt1 ctxt0 (instT, inst) ts =
   let
-    val (envT, env) = make_env t;
+    val (envT, env) = make_env ts;
     val freesT = map (Logic.mk_type o TFree o get_tfree envT) instT;
     val frees = map (Free o check_free ctxt1 env) inst;
-    val (t' :: varsT, vars) =
-      Variable.export_terms ctxt1 ctxt0 (t :: freesT @ frees)
-      |> chop (1 + length freesT);
+    val (ts', (varsT, vars)) =
+      Variable.export_terms ctxt1 ctxt0 (ts @ freesT @ frees)
+      |> chop (length ts) ||> chop (length freesT);
+    val ml_insts = (map2 make_instT instT varsT, map2 make_inst inst vars);
+  in
+    if schematic then ()
+    else
+      let val (envT', env') = make_env ts' in
+        missing_instT pos (subtract (eq_fst op =) envT' envT) instT;
+        missing_inst pos (subtract (eq_fst op =) env' env) inst
+      end;
+    (ml_insts, ts')
+  end;
 
-    val (envT', env') = make_env t';
-    val _ = missing_instT (subtract (eq_fst op =) envT' envT) instT;
-    val _ = missing_inst (subtract (eq_fst op =) env' env) inst;
-
-    val ml_insts = (map2 make_instT instT varsT, map2 make_inst inst vars);
-  in (ml_insts, t') end;
-
-fun prepare_ml range (kind, ml1, ml2) ml_val (ml_instT, ml_inst) ctxt =
+fun prepare_ml range kind ml1 ml2 ml_val (ml_instT, ml_inst) ctxt =
   let
     val (ml_name, ctxt') = ML_Context.variant kind ctxt;
     val ml_env = ml ("val " ^ ml_name ^ " = ") @ ml ml1 @ ml_parens (ml ml_val) @ ml ";\n";
@@ -137,40 +180,55 @@
         ml_range range (ML_Context.struct_name ctxt ^ "." ^ ml_name));
   in ((ml_env, ml_body), ctxt') end;
 
-fun prepare_type range (arg, s) insts ctxt =
+fun prepare_type range ((((kind, pos), ml1, ml2), schematic), s) insts ctxt =
   let
     val T = Syntax.read_typ ctxt s;
     val t = Logic.mk_type T;
     val ctxt1 = Proof_Context.augment t ctxt;
-    val (ml_insts, T') = prepare_insts ctxt1 ctxt insts t ||> Logic.dest_type;
-  in prepare_ml range arg (ML_Syntax.print_typ T') ml_insts ctxt end;
+    val (ml_insts, T') =
+      prepare_insts pos schematic ctxt1 ctxt insts [t] ||> (the_single #> Logic.dest_type);
+  in prepare_ml range kind ml1 ml2 (ML_Syntax.print_typ T') ml_insts ctxt end;
 
-fun prepare_term read range (arg, (s, fixes)) insts ctxt =
+fun prepare_term read range ((((kind, pos), ml1, ml2), schematic), (s, fixes)) insts ctxt =
   let
     val ctxt' = #2 (Proof_Context.add_fixes_cmd fixes ctxt);
     val t = read ctxt' s;
     val ctxt1 = Proof_Context.augment t ctxt';
-    val (ml_insts, t') = prepare_insts ctxt1 ctxt insts t;
-  in prepare_ml range arg (ML_Syntax.print_term t') ml_insts ctxt end;
-
-val ml_here = ML_Syntax.atomic o ML_Syntax.print_position;
+    val (ml_insts, t') = prepare_insts pos schematic ctxt1 ctxt insts [t] ||> the_single;
+  in prepare_ml range kind ml1 ml2 (ML_Syntax.print_term t') ml_insts ctxt end;
 
-fun typ_ml (kind, _: Position.T) = (kind, "", "ML_Instantiate.instantiate_typ ");
-fun term_ml (kind, _: Position.T) = (kind, "", "ML_Instantiate.instantiate_term ");
+fun prepare_lemma range ((pos, schematic), make_lemma) insts ctxt =
+  let
+    val (ths, (props, ctxt1)) = make_lemma ctxt
+    val (i, thms_ctxt) = put_thms ths ctxt;
+    val ml_insts = #1 (prepare_insts pos schematic ctxt1 ctxt insts props);
+    val (ml1, ml2) =
+      if length ths = 1
+      then ("ML_Instantiate.get_thm ML_context", "ML_Instantiate.instantiate_thm " ^ ml_here pos)
+      else ("ML_Instantiate.get_thms ML_context", "ML_Instantiate.instantiate_thms " ^ ml_here pos);
+  in prepare_ml range "lemma" ml1 ml2 (ML_Syntax.print_int i) ml_insts thms_ctxt end;
+
+fun typ_ml (kind, pos: Position.T) = ((kind, pos), "", "ML_Instantiate.instantiate_typ ");
+fun term_ml (kind, pos: Position.T) = ((kind, pos), "", "ML_Instantiate.instantiate_term ");
 fun ctyp_ml (kind, pos) =
-  (kind, "ML_Instantiate.make_ctyp ML_context", "ML_Instantiate.instantiate_ctyp " ^ ml_here pos);
+  ((kind, pos),
+    "ML_Instantiate.make_ctyp ML_context", "ML_Instantiate.instantiate_ctyp " ^ ml_here pos);
 fun cterm_ml (kind, pos) =
-  (kind, "ML_Instantiate.make_cterm ML_context", "ML_Instantiate.instantiate_cterm " ^ ml_here pos);
+  ((kind, pos),
+    "ML_Instantiate.make_cterm ML_context", "ML_Instantiate.instantiate_cterm " ^ ml_here pos);
 
 val command_name = Parse.position o Parse.command_name;
 
+val parse_schematic = Args.mode "schematic" >> (fn b => {schematic = b});
+
 fun parse_body range =
-  (command_name "typ" >> typ_ml || command_name "ctyp" >> ctyp_ml) --
+  (command_name "typ" >> typ_ml || command_name "ctyp" >> ctyp_ml) -- parse_schematic --
     Parse.!!! Parse.typ >> prepare_type range ||
-  (command_name "term" >> term_ml || command_name "cterm" >> cterm_ml)
-    -- Parse.!!! (Parse.term -- Parse.for_fixes) >> prepare_term Syntax.read_term range ||
-  (command_name "prop" >> term_ml || command_name "cprop" >> cterm_ml)
-    -- Parse.!!! (Parse.term -- Parse.for_fixes) >> prepare_term Syntax.read_prop range;
+  (command_name "term" >> term_ml || command_name "cterm" >> cterm_ml) -- parse_schematic --
+    Parse.!!! (Parse.term -- Parse.for_fixes) >> prepare_term Syntax.read_term range ||
+  (command_name "prop" >> term_ml || command_name "cprop" >> cterm_ml) -- parse_schematic --
+    Parse.!!! (Parse.term -- Parse.for_fixes) >> prepare_term Syntax.read_prop range ||
+  (command_name "lemma" >> #2) -- parse_schematic -- ML_Thms.embedded_lemma >> prepare_lemma range;
 
 val _ = Theory.setup
   (ML_Context.add_antiquotation \<^binding>\<open>instantiate\<close> true
--- a/src/Pure/ML/ml_thms.ML	Thu Oct 28 13:20:45 2021 +0200
+++ b/src/Pure/ML/ml_thms.ML	Thu Oct 28 18:37:33 2021 +0200
@@ -8,7 +8,7 @@
 sig
   val the_attributes: Proof.context -> int -> Token.src list
   val the_thmss: Proof.context -> thm list list
-  val embedded_lemma: (Proof.context -> thm list) parser
+  val embedded_lemma: (Proof.context -> thm list * (term list * Proof.context)) parser
   val get_stored_thms: unit -> thm list
   val get_stored_thm: unit -> thm
   val store_thms: string * thm list -> unit
@@ -83,8 +83,9 @@
         let
           val _ = Context_Position.reports ctxt reports;
 
-          val stmt_ctxt = #2 (Proof_Context.add_fixes_cmd fixes ctxt);
-          val stmt = burrow (map (rpair []) o Syntax.read_props stmt_ctxt) raw_stmt;
+          val fixes_ctxt = #2 (Proof_Context.add_fixes_cmd fixes ctxt);
+          val stmt = burrow (map (rpair []) o Syntax.read_props fixes_ctxt) raw_stmt;
+          val stmt_ctxt = (fold o fold) (Proof_Context.augment o #1) stmt fixes_ctxt;
 
           val prep_result = Goal.norm_result ctxt #> not is_open ? Thm.close_derivation \<^here>;
           fun after_qed res goal_ctxt =
@@ -97,12 +98,12 @@
           val thms =
             Proof_Context.get_fact thms_ctxt
               (Facts.named (Proof_Context.full_name thms_ctxt (Binding.name Auto_Bind.thisN)))
-        in thms end);
+        in (thms, (map #1 (flat stmt), stmt_ctxt)) end);
 
 val _ = Theory.setup
   (ML_Antiquotation.declaration \<^binding>\<open>lemma\<close> (Scan.lift embedded_lemma)
     (fn _ => fn make_lemma => fn ctxt =>
-      let val thms = make_lemma ctxt
+      let val thms = #1 (make_lemma ctxt)
       in thm_binding "lemma" (length thms = 1) thms ctxt end));