diff -r 8b7258c61649 -r 40f5c6b2e8aa src/Pure/ML/ml_instantiate.ML --- a/src/Pure/ML/ml_instantiate.ML Thu Oct 28 13:20:45 2021 +0200 +++ b/src/Pure/ML/ml_instantiate.ML Thu Oct 28 18:37:33 2021 +0200 @@ -14,6 +14,10 @@ val instantiate_ctyp: Position.T -> cinsts -> ctyp -> ctyp val instantiate_term: insts -> term -> term val instantiate_cterm: Position.T -> cinsts -> cterm -> cterm + val instantiate_thm: Position.T -> cinsts -> thm -> thm + val instantiate_thms: Position.T -> cinsts -> thm list -> thm list + val get_thms: Proof.context -> int -> thm list + val get_thm: Proof.context -> int -> thm end; structure ML_Instantiate: ML_INSTANTIATE = @@ -41,14 +45,41 @@ val inst = Vars.make ((map o apfst o apsnd) instantiateT (#2 insts)); in Term_Subst.instantiate_beta (instT, inst) end; -fun instantiate_cterm pos (cinsts: cinsts) ct = +fun make_cinsts (cinsts: cinsts) = let val cinstT = TVars.make (#1 cinsts); val instantiateT = Term_Subst.instantiateT (TVars.map (K Thm.typ_of) cinstT); val cinst = Vars.make ((map o apfst o apsnd) instantiateT (#2 cinsts)); - in Thm.instantiate_beta_cterm (cinstT, cinst) ct end + in (cinstT, cinst) end; + +fun instantiate_cterm pos cinsts ct = + Thm.instantiate_beta_cterm (make_cinsts cinsts) ct handle CTERM (msg, args) => Exn.reraise (CTERM (msg ^ Position.here pos, args)); +fun instantiate_thm pos cinsts th = + Thm.instantiate_beta (make_cinsts cinsts) th + handle THM (msg, i, args) => Exn.reraise (THM (msg ^ Position.here pos, i, args)); + +fun instantiate_thms pos cinsts = map (instantiate_thm pos cinsts); + + +(* context data *) + +structure Data = Proof_Data +( + type T = int * thm list Inttab.table; + fun init _ = (0, Inttab.empty); +); + +fun put_thms ths ctxt = + let + val (i, thms) = Data.get ctxt; + val ctxt' = ctxt |> Data.put (i + 1, Inttab.update (i, ths) thms); + in (i, ctxt') end; + +fun get_thms ctxt i = the (Inttab.lookup (#2 (Data.get ctxt)) i); +fun get_thm ctxt i = the_single (get_thms ctxt i); + (* ML antiquotation *) @@ -57,7 +88,7 @@ val make_keywords = Thy_Header.get_keywords' #> Keyword.no_major_keywords - #> Keyword.add_major_keywords ["typ", "term", "prop", "ctyp", "cterm", "cprop"]; + #> Keyword.add_major_keywords ["typ", "term", "prop", "ctyp", "cterm", "cprop", "lemma"]; val parse_inst_name = Parse.position (Parse.type_ident >> pair true || Parse.name >> pair false); @@ -77,6 +108,7 @@ val ml_list = ml_bracks o ml_commas; fun ml_pair (x, y) = ml_parens (ml_commas [x, y]); val ml_list_pair = ml_list o ListPair.map ml_pair; +val ml_here = ML_Syntax.atomic o ML_Syntax.print_position; fun get_tfree envT (a, pos) = (case AList.lookup (op =) envT a of @@ -89,15 +121,19 @@ (Context_Position.reports ctxt (map (pair pos) (Syntax_Phases.markup_free ctxt x)); (x, T)) | NONE => error ("No occurrence of variable " ^ quote x ^ Position.here pos)); -fun missing_instT envT instT = +fun missing_instT pos envT instT = (case filter_out (fn (a, _) => exists (fn (b, _) => a = b) instT) envT of [] => () - | bad => error ("No instantiation for free type variable(s) " ^ commas_quote (map #1 bad))); + | bad => + error ("No instantiation for free type variable(s) " ^ commas_quote (map #1 bad) ^ + Position.here pos)); -fun missing_inst env inst = +fun missing_inst pos env inst = (case filter_out (fn (a, _) => exists (fn (b, _) => a = b) inst) env of [] => () - | bad => error ("No instantiation for free variable(s) " ^ commas_quote (map #1 bad))); + | bad => + error ("No instantiation for free variable(s) " ^ commas_quote (map #1 bad) ^ + Position.here pos)); fun make_instT (a, pos) T = (case try (Term.dest_TVar o Logic.dest_type) T of @@ -109,25 +145,32 @@ NONE => error ("Not a free variable " ^ quote a ^ Position.here pos) | SOME v => ml (ML_Syntax.print_pair ML_Syntax.print_indexname ML_Syntax.print_typ v)); -fun make_env t = (Term.add_tfrees t [], Term.add_frees t []); +fun make_env ts = + let + val envT = fold Term.add_tfrees ts []; + val env = fold Term.add_frees ts []; + in (envT, env) end; -fun prepare_insts ctxt1 ctxt0 (instT, inst) t = +fun prepare_insts pos {schematic} ctxt1 ctxt0 (instT, inst) ts = let - val (envT, env) = make_env t; + val (envT, env) = make_env ts; val freesT = map (Logic.mk_type o TFree o get_tfree envT) instT; val frees = map (Free o check_free ctxt1 env) inst; - val (t' :: varsT, vars) = - Variable.export_terms ctxt1 ctxt0 (t :: freesT @ frees) - |> chop (1 + length freesT); + val (ts', (varsT, vars)) = + Variable.export_terms ctxt1 ctxt0 (ts @ freesT @ frees) + |> chop (length ts) ||> chop (length freesT); + val ml_insts = (map2 make_instT instT varsT, map2 make_inst inst vars); + in + if schematic then () + else + let val (envT', env') = make_env ts' in + missing_instT pos (subtract (eq_fst op =) envT' envT) instT; + missing_inst pos (subtract (eq_fst op =) env' env) inst + end; + (ml_insts, ts') + end; - val (envT', env') = make_env t'; - val _ = missing_instT (subtract (eq_fst op =) envT' envT) instT; - val _ = missing_inst (subtract (eq_fst op =) env' env) inst; - - val ml_insts = (map2 make_instT instT varsT, map2 make_inst inst vars); - in (ml_insts, t') end; - -fun prepare_ml range (kind, ml1, ml2) ml_val (ml_instT, ml_inst) ctxt = +fun prepare_ml range kind ml1 ml2 ml_val (ml_instT, ml_inst) ctxt = let val (ml_name, ctxt') = ML_Context.variant kind ctxt; val ml_env = ml ("val " ^ ml_name ^ " = ") @ ml ml1 @ ml_parens (ml ml_val) @ ml ";\n"; @@ -137,40 +180,55 @@ ml_range range (ML_Context.struct_name ctxt ^ "." ^ ml_name)); in ((ml_env, ml_body), ctxt') end; -fun prepare_type range (arg, s) insts ctxt = +fun prepare_type range ((((kind, pos), ml1, ml2), schematic), s) insts ctxt = let val T = Syntax.read_typ ctxt s; val t = Logic.mk_type T; val ctxt1 = Proof_Context.augment t ctxt; - val (ml_insts, T') = prepare_insts ctxt1 ctxt insts t ||> Logic.dest_type; - in prepare_ml range arg (ML_Syntax.print_typ T') ml_insts ctxt end; + val (ml_insts, T') = + prepare_insts pos schematic ctxt1 ctxt insts [t] ||> (the_single #> Logic.dest_type); + in prepare_ml range kind ml1 ml2 (ML_Syntax.print_typ T') ml_insts ctxt end; -fun prepare_term read range (arg, (s, fixes)) insts ctxt = +fun prepare_term read range ((((kind, pos), ml1, ml2), schematic), (s, fixes)) insts ctxt = let val ctxt' = #2 (Proof_Context.add_fixes_cmd fixes ctxt); val t = read ctxt' s; val ctxt1 = Proof_Context.augment t ctxt'; - val (ml_insts, t') = prepare_insts ctxt1 ctxt insts t; - in prepare_ml range arg (ML_Syntax.print_term t') ml_insts ctxt end; - -val ml_here = ML_Syntax.atomic o ML_Syntax.print_position; + val (ml_insts, t') = prepare_insts pos schematic ctxt1 ctxt insts [t] ||> the_single; + in prepare_ml range kind ml1 ml2 (ML_Syntax.print_term t') ml_insts ctxt end; -fun typ_ml (kind, _: Position.T) = (kind, "", "ML_Instantiate.instantiate_typ "); -fun term_ml (kind, _: Position.T) = (kind, "", "ML_Instantiate.instantiate_term "); +fun prepare_lemma range ((pos, schematic), make_lemma) insts ctxt = + let + val (ths, (props, ctxt1)) = make_lemma ctxt + val (i, thms_ctxt) = put_thms ths ctxt; + val ml_insts = #1 (prepare_insts pos schematic ctxt1 ctxt insts props); + val (ml1, ml2) = + if length ths = 1 + then ("ML_Instantiate.get_thm ML_context", "ML_Instantiate.instantiate_thm " ^ ml_here pos) + else ("ML_Instantiate.get_thms ML_context", "ML_Instantiate.instantiate_thms " ^ ml_here pos); + in prepare_ml range "lemma" ml1 ml2 (ML_Syntax.print_int i) ml_insts thms_ctxt end; + +fun typ_ml (kind, pos: Position.T) = ((kind, pos), "", "ML_Instantiate.instantiate_typ "); +fun term_ml (kind, pos: Position.T) = ((kind, pos), "", "ML_Instantiate.instantiate_term "); fun ctyp_ml (kind, pos) = - (kind, "ML_Instantiate.make_ctyp ML_context", "ML_Instantiate.instantiate_ctyp " ^ ml_here pos); + ((kind, pos), + "ML_Instantiate.make_ctyp ML_context", "ML_Instantiate.instantiate_ctyp " ^ ml_here pos); fun cterm_ml (kind, pos) = - (kind, "ML_Instantiate.make_cterm ML_context", "ML_Instantiate.instantiate_cterm " ^ ml_here pos); + ((kind, pos), + "ML_Instantiate.make_cterm ML_context", "ML_Instantiate.instantiate_cterm " ^ ml_here pos); val command_name = Parse.position o Parse.command_name; +val parse_schematic = Args.mode "schematic" >> (fn b => {schematic = b}); + fun parse_body range = - (command_name "typ" >> typ_ml || command_name "ctyp" >> ctyp_ml) -- + (command_name "typ" >> typ_ml || command_name "ctyp" >> ctyp_ml) -- parse_schematic -- Parse.!!! Parse.typ >> prepare_type range || - (command_name "term" >> term_ml || command_name "cterm" >> cterm_ml) - -- Parse.!!! (Parse.term -- Parse.for_fixes) >> prepare_term Syntax.read_term range || - (command_name "prop" >> term_ml || command_name "cprop" >> cterm_ml) - -- Parse.!!! (Parse.term -- Parse.for_fixes) >> prepare_term Syntax.read_prop range; + (command_name "term" >> term_ml || command_name "cterm" >> cterm_ml) -- parse_schematic -- + Parse.!!! (Parse.term -- Parse.for_fixes) >> prepare_term Syntax.read_term range || + (command_name "prop" >> term_ml || command_name "cprop" >> cterm_ml) -- parse_schematic -- + Parse.!!! (Parse.term -- Parse.for_fixes) >> prepare_term Syntax.read_prop range || + (command_name "lemma" >> #2) -- parse_schematic -- ML_Thms.embedded_lemma >> prepare_lemma range; val _ = Theory.setup (ML_Context.add_antiquotation \<^binding>\instantiate\ true