src/Tools/adhoc_overloading.ML
changeset 52622 e0ff1625e96d
parent 50768 2172f82de515
child 52687 72cda5eb5a39
equal deleted inserted replaced
52621:0d0c20a0a34f 52622:e0ff1625e96d
     4 Ad-hoc overloading of constants based on their types.
     4 Ad-hoc overloading of constants based on their types.
     5 *)
     5 *)
     6 
     6 
     7 signature ADHOC_OVERLOADING =
     7 signature ADHOC_OVERLOADING =
     8 sig
     8 sig
     9   val add_overloaded: string -> theory -> theory
     9   val is_overloaded: Proof.context -> string -> bool
    10   val add_variant: string -> string -> theory -> theory
    10   val generic_add_overloaded: string -> Context.generic -> Context.generic
    11 
    11   val generic_remove_overloaded: string -> Context.generic -> Context.generic
       
    12   val generic_add_variant: string -> term -> Context.generic -> Context.generic
       
    13   (*If the list of variants is empty at the end of "generic_remove_variant", then
       
    14   "generic_remove_overloaded" is called implicitly.*)
       
    15   val generic_remove_variant: string -> term -> Context.generic -> Context.generic
    12   val show_variants: bool Config.T
    16   val show_variants: bool Config.T
    13   val setup: theory -> theory
       
    14 end
    17 end
    15 
    18 
    16 structure Adhoc_Overloading: ADHOC_OVERLOADING =
    19 structure Adhoc_Overloading: ADHOC_OVERLOADING =
    17 struct
    20 struct
    18 
    21 
    19 val show_variants = Attrib.setup_config_bool @{binding show_variants} (K false);
    22 val show_variants = Attrib.setup_config_bool @{binding show_variants} (K false);
    20 
    23 
    21 
       
    22 (* errors *)
    24 (* errors *)
    23 
    25 
    24 fun duplicate_variant_err int_name ext_name =
    26 fun duplicate_variant_error oconst =
    25   error ("Constant " ^ quote int_name ^ " is already a variant of " ^ quote ext_name);
    27   error ("Duplicate variant of " ^ quote oconst);
    26 
    28 
    27 fun not_overloaded_err name =
    29 fun not_a_variant_error oconst =
    28   error ("Constant " ^ quote name ^ " is not declared as overloaded");
    30   error ("Not a variant of " ^  quote oconst);
    29 
    31 
    30 fun already_overloaded_err name =
    32 fun not_overloaded_error oconst =
    31   error ("Constant " ^ quote name ^ " is already declared as overloaded");
    33   error ("Constant " ^ quote oconst ^ " is not declared as overloaded");
    32 
    34 
    33 fun unresolved_err ctxt (c, T) t reason =
    35 fun unresolved_overloading_error ctxt (c, T) t reason =
    34   error ("Unresolved overloading of  " ^ quote c ^ " :: " ^
    36   error ("Unresolved overloading of " ^ quote c ^ " :: " ^
    35     quote (Syntax.string_of_typ ctxt T) ^ " in " ^
    37     quote (Syntax.string_of_typ ctxt T) ^ " in " ^
    36     quote (Syntax.string_of_term ctxt t) ^ " (" ^ reason ^ ")");
    38     quote (Syntax.string_of_term ctxt t) ^ " (" ^ reason ^ ")");
    37 
    39 
    38 
    40 (* generic data *)
    39 (* theory data *)
    41 
    40 
    42 fun variants_eq ((v1, T1), (v2, T2)) =
    41 structure Overload_Data = Theory_Data
    43   Term.aconv_untyped (v1, v2) andalso T1 = T2;
       
    44 
       
    45 structure Overload_Data = Generic_Data
    42 (
    46 (
    43   type T =
    47   type T =
    44     { internalize : (string * typ) list Symtab.table,
    48     {variants : (term * typ) list Symtab.table,
    45       externalize : string Symtab.table };
    49      oconsts : string Termtab.table};
    46   val empty = {internalize=Symtab.empty, externalize=Symtab.empty};
    50   val empty = {variants = Symtab.empty, oconsts = Termtab.empty};
    47   val extend = I;
    51   val extend = I;
    48 
    52 
    49   fun merge_ext int_name (ext_name1, ext_name2) =
       
    50     if ext_name1 = ext_name2 then ext_name1
       
    51     else duplicate_variant_err int_name ext_name1;
       
    52 
       
    53   fun merge
    53   fun merge
    54     ({internalize = int1, externalize = ext1},
    54     ({variants = vtab1, oconsts = otab1},
    55       {internalize = int2, externalize = ext2}) : T =
    55      {variants = vtab2, oconsts = otab2}) : T =
    56     {internalize = Symtab.merge_list (op =) (int1, int2),
    56     let
    57       externalize = Symtab.join merge_ext (ext1, ext2)};
    57       fun merge_oconsts _ (oconst1, oconst2) =
       
    58         if oconst1 = oconst2 then oconst1
       
    59         else duplicate_variant_error oconst1;
       
    60     in
       
    61       {variants = Symtab.merge_list variants_eq (vtab1, vtab2),
       
    62        oconsts = Termtab.join merge_oconsts (otab1, otab2)}
       
    63     end;
    58 );
    64 );
    59 
    65 
    60 fun map_tables f g =
    66 fun map_tables f g =
    61   Overload_Data.map (fn {internalize=int, externalize=ext} =>
    67   Overload_Data.map (fn {variants = vtab, oconsts = otab} =>
    62     {internalize=f int, externalize=g ext});
    68     {variants = f vtab, oconsts = g otab});
    63 
    69 
    64 val is_overloaded = Symtab.defined o #internalize o Overload_Data.get;
    70 val is_overloaded = Symtab.defined o #variants o Overload_Data.get o Context.Proof;
    65 val get_variants = Symtab.lookup o #internalize o Overload_Data.get;
    71 val get_variants = Symtab.lookup o #variants o Overload_Data.get o Context.Proof;
    66 val get_external = Symtab.lookup o #externalize o Overload_Data.get;
    72 val get_overloaded = Termtab.lookup o #oconsts o Overload_Data.get o Context.Proof;
    67 
    73 
    68 fun add_overloaded ext_name thy =
    74 fun generic_add_overloaded oconst context =
    69   let val _ = not (is_overloaded thy ext_name) orelse already_overloaded_err ext_name;
    75   if is_overloaded (Context.proof_of context) oconst then context
    70   in map_tables (Symtab.update (ext_name, [])) I thy end;
    76   else map_tables (Symtab.update (oconst, [])) I context;
    71 
    77 
    72 fun add_variant ext_name name thy =
    78 fun generic_remove_overloaded oconst context =
    73   let
    79   let
    74     val _ = is_overloaded thy ext_name orelse not_overloaded_err ext_name;
    80     fun remove_oconst_and_variants context oconst =
    75     val _ =
    81       let
    76       (case get_external thy name of
    82         val remove_variants =
    77         NONE => ()
    83           (case get_variants (Context.proof_of context) oconst of
    78       | SOME gen' => duplicate_variant_err name gen');
    84             NONE => I
    79     val T = Sign.the_const_type thy name;
    85           | SOME vs => fold (Termtab.remove (op =) o rpair oconst o fst) vs);
    80   in
    86       in map_tables (Symtab.delete_safe oconst) remove_variants context end;
    81     map_tables (Symtab.cons_list (ext_name, (name, T)))
    87   in
    82       (Symtab.update (name, ext_name)) thy
    88     if is_overloaded (Context.proof_of context) oconst then remove_oconst_and_variants context oconst
    83   end
    89     else not_overloaded_error oconst
    84 
    90   end;
       
    91 
       
    92 local
       
    93   fun generic_variant add oconst t context =
       
    94     let
       
    95       val ctxt = Context.proof_of context;
       
    96       val _ = if is_overloaded ctxt oconst then () else not_overloaded_error oconst;
       
    97       val T = t |> singleton (Variable.polymorphic ctxt) |> fastype_of;
       
    98       val t' = Term.map_types (K dummyT) t;
       
    99     in
       
   100       if add then
       
   101         let
       
   102           val _ =
       
   103             (case get_overloaded ctxt t' of
       
   104               NONE => ()
       
   105             | SOME oconst' => duplicate_variant_error oconst');
       
   106         in
       
   107           map_tables (Symtab.cons_list (oconst, (t', T))) (Termtab.update (t', oconst)) context
       
   108         end
       
   109       else
       
   110         let
       
   111           val _ =
       
   112             if member variants_eq (the (get_variants ctxt oconst)) (t', T) then ()
       
   113             else not_a_variant_error oconst;
       
   114         in
       
   115           map_tables (Symtab.map_entry oconst (remove1 variants_eq (t', T)))
       
   116             (Termtab.delete_safe t') context
       
   117           |> (fn context =>
       
   118             (case get_variants (Context.proof_of context) oconst of
       
   119               SOME [] => generic_remove_overloaded oconst context
       
   120             | _ => context))
       
   121         end
       
   122     end;
       
   123 in
       
   124   val generic_add_variant = generic_variant true;
       
   125   val generic_remove_variant = generic_variant false;
       
   126 end;
    85 
   127 
    86 (* check / uncheck *)
   128 (* check / uncheck *)
    87 
   129 
    88 fun unifiable_with ctxt T1 (c, T2) =
   130 fun unifiable_with thy T1 (t, T2) =
    89   let
   131   let
    90     val thy = Proof_Context.theory_of ctxt;
       
    91     val maxidx1 = Term.maxidx_of_typ T1;
   132     val maxidx1 = Term.maxidx_of_typ T1;
    92     val T2' = Logic.incr_tvar (maxidx1 + 1) T2;
   133     val T2' = Logic.incr_tvar (maxidx1 + 1) T2;
    93     val maxidx2 = Int.max (maxidx1, Term.maxidx_of_typ T2');
   134     val maxidx2 = Term.maxidx_typ T2' maxidx1;
    94   in
   135   in
    95     (Sign.typ_unify thy (T1, T2') (Vartab.empty, maxidx2); SOME c)
   136     (Sign.typ_unify thy (T1, T2') (Vartab.empty, maxidx2); SOME t)
    96     handle Type.TUNIFY => NONE
   137     handle Type.TUNIFY => NONE
    97   end;
   138   end;
    98 
   139 
    99 fun insert_internal_same ctxt t (Const (c, T)) =
   140 fun insert_variants_same ctxt t (Const (c, T)) =
   100       (case map_filter (unifiable_with ctxt T)
   141       (case map_filter (unifiable_with (Proof_Context.theory_of ctxt) T)
   101          (Same.function (get_variants (Proof_Context.theory_of ctxt)) c) of
   142          (Same.function (get_variants ctxt) c) of
   102         [] => unresolved_err ctxt (c, T) t "no instances"
   143         [] => unresolved_overloading_error ctxt (c, T) t "no instances"
   103       | [c'] => Const (c', dummyT)
   144       | [variant] => variant
   104       | _ => raise Same.SAME)
   145       | _ => raise Same.SAME)
   105   | insert_internal_same _ _ _ = raise Same.SAME;
   146   | insert_variants_same _ _ _ = raise Same.SAME;
   106 
   147 
   107 fun insert_external_same ctxt _ (Const (c, T)) =
   148 fun insert_overloaded_same ctxt variant =
   108       Const (Same.function (get_external (Proof_Context.theory_of ctxt)) c, T)
   149   let
   109   | insert_external_same _ _ _ = raise Same.SAME;
   150     val thy = Proof_Context.theory_of ctxt;
       
   151     val t = Pattern.rewrite_term thy [] [fn t =>
       
   152       Term.map_types (K dummyT) t
       
   153       |> get_overloaded ctxt
       
   154       |> Option.map (Const o rpair (fastype_of variant))] variant;
       
   155   in
       
   156     if Term.aconv_untyped (variant, t) then raise Same.SAME
       
   157     else t
       
   158   end;
   110 
   159 
   111 fun gen_check_uncheck replace ts ctxt =
   160 fun gen_check_uncheck replace ts ctxt =
   112   Same.capture (Same.map (fn t => Term_Subst.map_aterms_same (replace ctxt t) t)) ts
   161   Same.capture (Same.map replace) ts
   113   |> Option.map (rpair ctxt);
   162   |> Option.map (rpair ctxt);
   114 
   163 
   115 val check = gen_check_uncheck insert_internal_same;
   164 fun check ts ctxt = gen_check_uncheck (fn t =>
       
   165   Term_Subst.map_aterms_same (insert_variants_same ctxt t) t) ts ctxt;
   116 
   166 
   117 fun uncheck ts ctxt =
   167 fun uncheck ts ctxt =
   118   if Config.get ctxt show_variants then NONE
   168   if Config.get ctxt show_variants then NONE
   119   else gen_check_uncheck insert_external_same ts ctxt;
   169   else gen_check_uncheck (insert_overloaded_same ctxt) ts ctxt;
   120 
   170 
   121 fun reject_unresolved ts ctxt =
   171 fun reject_unresolved ts ctxt =
   122   let
   172   let
   123     val thy = Proof_Context.theory_of ctxt;
       
   124     fun check_unresolved t =
   173     fun check_unresolved t =
   125       (case filter (is_overloaded thy o fst) (Term.add_consts t []) of
   174       (case filter (is_overloaded ctxt o fst) (Term.add_consts t []) of
   126         [] => ()
   175         [] => ()
   127       | ((c, T) :: _) => unresolved_err ctxt (c, T) t "multiple instances");
   176       | ((c, T) :: _) => unresolved_overloading_error ctxt (c, T) t "multiple instances");
   128     val _ = map check_unresolved ts;
   177     val _ = map check_unresolved ts;
   129   in NONE end;
   178   in NONE end;
   130 
   179 
   131 
       
   132 (* setup *)
   180 (* setup *)
   133 
   181 
   134 val setup = Context.theory_map
   182 val _ = Context.>>
   135   (Syntax_Phases.term_check' 0 "adhoc_overloading" check
   183   (Syntax_Phases.term_check' 0 "adhoc_overloading" check
   136    #> Syntax_Phases.term_check' 1 "adhoc_overloading_unresolved_check" reject_unresolved
   184    #> Syntax_Phases.term_check' 1 "adhoc_overloading_unresolved_check" reject_unresolved
   137    #> Syntax_Phases.term_uncheck' 0 "adhoc_overloading" uncheck);
   185    #> Syntax_Phases.term_uncheck' 0 "adhoc_overloading" uncheck);
   138 
   186 
       
   187 (* commands *)
       
   188 
       
   189 fun generic_adhoc_overloading_cmd add =
       
   190   if add then
       
   191     fold (fn (oconst, ts) =>
       
   192       generic_add_overloaded oconst
       
   193       #> fold (generic_add_variant oconst) ts)
       
   194   else
       
   195     fold (fn (oconst, ts) =>
       
   196       fold (generic_remove_variant oconst) ts);
       
   197 
       
   198 fun adhoc_overloading_cmd' add args phi =
       
   199   let val args' = args
       
   200     |> map (apsnd (map_filter (fn t =>
       
   201          let val t' = Morphism.term phi t;
       
   202          in if Term.aconv_untyped (t, t') then SOME t' else NONE end)));
       
   203   in generic_adhoc_overloading_cmd add args' end;
       
   204 
       
   205 fun adhoc_overloading_cmd add raw_args lthy =
       
   206   let
       
   207     fun const_name ctxt = fst o dest_Const o Proof_Context.read_const ctxt false dummyT;
       
   208     val args =
       
   209       raw_args
       
   210       |> map (apfst (const_name lthy))
       
   211       |> map (apsnd (map (Syntax.read_term lthy)));
       
   212   in
       
   213     Local_Theory.declaration {syntax = true, pervasive = false}
       
   214       (adhoc_overloading_cmd' add args) lthy
       
   215   end;
       
   216 
       
   217 val _ =
       
   218   Outer_Syntax.local_theory @{command_spec "adhoc_overloading"}
       
   219     "add ad-hoc overloading for constants / fixed variables"
       
   220     (Parse.and_list1 (Parse.const -- Scan.repeat Parse.term) >> adhoc_overloading_cmd true);
       
   221 
       
   222 val _ =
       
   223   Outer_Syntax.local_theory @{command_spec "no_adhoc_overloading"}
       
   224     "add ad-hoc overloading for constants / fixed variables"
       
   225     (Parse.and_list1 (Parse.const -- Scan.repeat Parse.term) >> adhoc_overloading_cmd false);
       
   226 
   139 end;
   227 end;
       
   228