src/HOL/Eisbach/eisbach_rule_insts.ML
author wenzelm
Fri Oct 27 13:50:08 2017 +0200 (22 months ago)
changeset 66924 b4d4027f743b
parent 62162 dca35981c8fb
child 69593 3dda49e08b9d
permissions -rw-r--r--
more permissive;
     1 (*  Title:      HOL/Eisbach/eisbach_rule_insts.ML
     2     Author:     Daniel Matichuk, NICTA/UNSW
     3 
     4 Eisbach-aware variants of the "where" and "of" attributes.
     5 
     6 Alternate syntax for rule_insts.ML participates in token closures by
     7 examining the behaviour of Rule_Insts.where_rule and instantiating token
     8 values accordingly. Instantiations in re-interpretation are done with
     9 infer_instantiate.
    10 *)
    11 
    12 structure Eisbach_Rule_Insts: sig end =
    13 struct
    14 
    15 fun restore_tags thm = Thm.map_tags (K (Thm.get_tags thm));
    16 
    17 val mk_term_type_internal = Logic.protect o Logic.mk_term o Logic.mk_type;
    18 
    19 fun term_type_cases f g t = 
    20   (case (try (Logic.dest_type o Logic.dest_term o Logic.unprotect) t) of
    21     SOME T => f T
    22   | NONE => 
    23     (case (try Logic.dest_term t) of
    24       SOME t => g t
    25     | NONE => raise Fail "Lost encoded instantiation"))
    26 
    27 fun add_thm_insts ctxt thm =
    28   let
    29     val tyvars = Thm.fold_terms Term.add_tvars thm [];
    30     val tyvars' = tyvars |> map (mk_term_type_internal o TVar);
    31 
    32     val tvars = Thm.fold_terms Term.add_vars thm [];
    33     val tvars' = tvars  |> map (Logic.mk_term o Var);
    34 
    35     val conj =
    36       Logic.mk_conjunction_list (tyvars' @ tvars') |> Thm.cterm_of ctxt |> Drule.mk_term;
    37   in
    38     ((tyvars, tvars), Conjunction.intr thm conj)
    39   end;
    40 
    41 fun get_thm_insts thm =
    42   let
    43     val (thm', insts) = Conjunction.elim thm;
    44 
    45     val insts' = insts
    46       |> Drule.dest_term
    47       |> Thm.term_of
    48       |> Logic.dest_conjunction_list
    49       |> (fn f => fold (fn t => fn (tys, ts) =>
    50           term_type_cases (fn T => (T :: tys, ts)) (fn t => (tys, t :: ts)) t) f ([], []))
    51       ||> rev
    52       |>> rev;
    53   in
    54     (thm', insts')
    55   end;
    56 
    57 fun instantiate_xis ctxt insts thm =
    58   let
    59     val tyvars = Thm.fold_terms Term.add_tvars thm [];
    60     val tvars = Thm.fold_terms Term.add_vars thm [];
    61 
    62     fun add_inst (xi, t) (Ts, ts) =
    63       (case AList.lookup (op =) tyvars xi of
    64         SOME S => (((xi, S), Thm.ctyp_of ctxt (Logic.dest_type t)) :: Ts, ts)
    65       | NONE =>
    66           (case AList.lookup (op =) tvars xi of
    67             SOME _ => (Ts, (xi, Thm.cterm_of ctxt t) :: ts)
    68           | NONE => error "indexname not found in thm"));
    69 
    70     val (instT, inst) = fold add_inst insts ([], []);
    71   in
    72     (Thm.instantiate (instT, []) thm
    73     |> infer_instantiate ctxt inst
    74     COMP_INCR asm_rl)
    75     |> Thm.adjust_maxidx_thm ~1
    76     |> restore_tags thm
    77   end;
    78 
    79 
    80 datatype rule_inst =
    81   Named_Insts of ((indexname * string) * (term -> unit)) list * (binding * string option * mixfix) list
    82 | Term_Insts of (indexname * term) list
    83 | Unchecked_Term_Insts of term option list * term option list;
    84 
    85 fun mk_pair (t, t') = Logic.mk_conjunction (Logic.mk_term t, Logic.mk_term t');
    86 
    87 fun dest_pair t = apply2 Logic.dest_term (Logic.dest_conjunction t);
    88 
    89 fun embed_indexname ((xi, s), f) =
    90   let fun wrap_xi xi t = mk_pair (Var (xi, fastype_of t), t);
    91   in ((xi, s), f o wrap_xi xi) end;
    92 
    93 fun unembed_indexname t = dest_pair t |> apfst (Term.dest_Var #> fst);
    94 
    95 fun read_where_insts (insts, fixes) =
    96   let
    97     val insts' =
    98       if forall (fn (_, v) => Parse_Tools.is_real_val v) insts
    99       then Term_Insts (map (unembed_indexname o Parse_Tools.the_real_val o snd) insts)
   100       else
   101         Named_Insts (map (fn (xi, p) => embed_indexname
   102           ((xi, Parse_Tools.the_parse_val p), Parse_Tools.the_parse_fun p)) insts, fixes);
   103   in insts' end;
   104 
   105 fun of_rule thm  (args, concl_args) =
   106   let
   107     fun zip_vars _ [] = []
   108       | zip_vars (_ :: xs) (NONE :: rest) = zip_vars xs rest
   109       | zip_vars ((x, _) :: xs) (SOME t :: rest) = (x, t) :: zip_vars xs rest
   110       | zip_vars [] _ = error "More instantiations than variables in theorem";
   111     val insts =
   112       zip_vars (rev (Term.add_vars (Thm.full_prop_of thm) [])) args @
   113       zip_vars (rev (Term.add_vars (Thm.concl_of thm) [])) concl_args;
   114   in insts end;
   115 
   116 val inst =  Args.maybe Parse_Tools.name_term;
   117 val concl = Args.$$$ "concl" -- Args.colon;
   118 
   119 fun close_unchecked_insts context ((insts, concl_inst), fixes) =
   120   let
   121     val ctxt = Context.proof_of context;
   122     val ctxt1 = ctxt |> Proof_Context.add_fixes_cmd fixes |> #2;
   123 
   124     val insts' = insts @ concl_inst;
   125 
   126     val term_insts =
   127       map (the_list o (Option.map Parse_Tools.the_parse_val)) insts'
   128       |> burrow (Syntax.read_terms ctxt1 #> Variable.export_terms ctxt1 ctxt)
   129       |> map (try the_single);
   130 
   131     val _ =
   132       (insts', term_insts)
   133       |> ListPair.app (fn (SOME p, SOME t) => Parse_Tools.the_parse_fun p t | _ => ());
   134     val (insts'', concl_insts'') = chop (length insts) term_insts;
   135    in Unchecked_Term_Insts (insts'', concl_insts'') end;
   136 
   137 fun read_of_insts checked context ((insts, concl_insts), fixes) =
   138   if forall (fn SOME t => Parse_Tools.is_real_val t | NONE => true) (insts @ concl_insts)
   139   then
   140     if checked
   141     then
   142       (fn _ =>
   143        Term_Insts
   144         (map (unembed_indexname o Parse_Tools.the_real_val) (map_filter I (insts @ concl_insts))))
   145     else
   146       (fn _ =>
   147         Unchecked_Term_Insts
   148           (map (Option.map Parse_Tools.the_real_val) insts,
   149             map (Option.map Parse_Tools.the_real_val) concl_insts))
   150   else
   151     if checked
   152     then
   153       (fn thm =>
   154         Named_Insts
   155           (apply2
   156             (map (Option.map (fn p => (Parse_Tools.the_parse_val p, Parse_Tools.the_parse_fun p))))
   157             (insts, concl_insts)
   158           |> of_rule thm |> map ((fn (xi, (nm, f)) => embed_indexname ((xi, nm), f))), fixes))
   159     else
   160       let val result = close_unchecked_insts context ((insts, concl_insts), fixes);
   161       in fn _ => result end;
   162 
   163 
   164 fun read_instantiate_closed ctxt (Named_Insts (insts, fixes)) thm  =
   165       let
   166         val insts' = map (fn ((v, t), _) => ((v, Position.none), t)) insts;
   167 
   168         val (thm_insts, thm') = add_thm_insts ctxt thm;
   169         val (thm'', thm_insts') =
   170           Rule_Insts.where_rule ctxt insts' fixes thm'
   171           |> get_thm_insts;
   172 
   173         val tyinst =
   174           ListPair.zip (fst thm_insts, fst thm_insts') |> map (fn ((xi, _), typ) => (xi, typ));
   175         val tinst =
   176           ListPair.zip (snd thm_insts, snd thm_insts') |> map (fn ((xi, _), t) => (xi, t));
   177 
   178         val _ =
   179           map (fn ((xi, _), f) =>
   180             (case AList.lookup (op =) tyinst xi of
   181               SOME typ => f (Logic.mk_type typ)
   182             | NONE =>
   183                 (case AList.lookup (op =) tinst xi of
   184                   SOME t => f t
   185                 | NONE => error "Lost indexname in instantiated theorem"))) insts;
   186       in
   187         (thm'' |> restore_tags thm)
   188       end
   189   | read_instantiate_closed ctxt (Unchecked_Term_Insts insts) thm =
   190       let
   191         val (xis, ts) = ListPair.unzip (of_rule thm insts);
   192         val ctxt' = Variable.declare_maxidx (Thm.maxidx_of thm) ctxt;
   193         val (ts', ctxt'') = Variable.import_terms false ts ctxt';
   194         val ts'' = Variable.export_terms ctxt'' ctxt ts';
   195         val insts' = ListPair.zip (xis, ts'');
   196       in instantiate_xis ctxt insts' thm end
   197   | read_instantiate_closed ctxt (Term_Insts insts) thm =
   198       instantiate_xis ctxt insts thm;
   199 
   200 val _ =
   201   Theory.setup
   202     (Attrib.setup @{binding "where"}
   203       (Scan.lift
   204         (Parse.and_list1 (Args.var -- (Args.$$$ "=" |-- Parse_Tools.name_term)) -- Parse.for_fixes)
   205         >> (fn args =>
   206             let val args' = read_where_insts args in
   207               fn (context, thm) =>
   208                 (NONE, SOME (read_instantiate_closed (Context.proof_of context) args' thm))
   209             end))
   210       "named instantiation of theorem");
   211 
   212 val _ =
   213   Theory.setup
   214     (Attrib.setup @{binding "of"}
   215       (Scan.lift
   216         (Args.mode "unchecked" --
   217           (Scan.repeat (Scan.unless concl inst) --
   218             Scan.optional (concl |-- Scan.repeat inst) [] --
   219             Parse.for_fixes)) -- Scan.state >>
   220       (fn ((unchecked, args), context) =>
   221         let
   222           val read_insts = read_of_insts (not unchecked) context args;
   223         in
   224           fn (context, thm) =>
   225             let val thm' =
   226               if Thm.is_free_dummy thm andalso unchecked
   227               then Drule.free_dummy_thm
   228               else read_instantiate_closed (Context.proof_of context) (read_insts thm) thm
   229             in (NONE, SOME thm') end
   230         end))
   231       "positional instantiation of theorem");
   232 
   233 end;