localized and modernized adhoc-overloading (patch by Christian Sternagel);
authorwenzelm
Fri Jul 12 16:19:05 2013 +0200 (2013-07-12)
changeset 52622e0ff1625e96d
parent 52621 0d0c20a0a34f
child 52623 fee0db8cf60d
localized and modernized adhoc-overloading (patch by Christian Sternagel);
src/HOL/Imperative_HOL/Heap_Monad.thy
src/HOL/Library/Monad_Syntax.thy
src/Tools/Adhoc_Overloading.thy
src/Tools/adhoc_overloading.ML
     1.1 --- a/src/HOL/Imperative_HOL/Heap_Monad.thy	Fri Jul 12 15:51:25 2013 +0200
     1.2 +++ b/src/HOL/Imperative_HOL/Heap_Monad.thy	Fri Jul 12 16:19:05 2013 +0200
     1.3 @@ -274,10 +274,8 @@
     1.4                    Some (x, h') \<Rightarrow> execute (g x) h'
     1.5                  | None \<Rightarrow> None)"
     1.6  
     1.7 -setup {*
     1.8 -  Adhoc_Overloading.add_variant 
     1.9 -    @{const_name Monad_Syntax.bind} @{const_name Heap_Monad.bind}
    1.10 -*}
    1.11 +adhoc_overloading
    1.12 +  Monad_Syntax.bind Heap_Monad.bind
    1.13  
    1.14  lemma execute_bind [execute_simps]:
    1.15    "execute f h = Some (x, h') \<Longrightarrow> execute (f \<guillemotright>= g) h = execute (g x) h'"
     2.1 --- a/src/HOL/Library/Monad_Syntax.thy	Fri Jul 12 15:51:25 2013 +0200
     2.2 +++ b/src/HOL/Library/Monad_Syntax.thy	Fri Jul 12 16:19:05 2013 +0200
     2.3 @@ -69,12 +69,7 @@
     2.4    "_do_block (_do_final e)" => "e"
     2.5    "(m >> n)" => "(m >>= (\<lambda>_. n))"
     2.6  
     2.7 -setup {*
     2.8 -  Adhoc_Overloading.add_overloaded @{const_name bind}
     2.9 -  #> Adhoc_Overloading.add_variant @{const_name bind} @{const_name Set.bind}
    2.10 -  #> Adhoc_Overloading.add_variant @{const_name bind} @{const_name Predicate.bind}
    2.11 -  #> Adhoc_Overloading.add_variant @{const_name bind} @{const_name Option.bind}
    2.12 -  #> Adhoc_Overloading.add_variant @{const_name bind} @{const_name List.bind}
    2.13 -*}
    2.14 +adhoc_overloading
    2.15 +  bind Set.bind Predicate.bind Option.bind List.bind
    2.16  
    2.17  end
     3.1 --- a/src/Tools/Adhoc_Overloading.thy	Fri Jul 12 15:51:25 2013 +0200
     3.2 +++ b/src/Tools/Adhoc_Overloading.thy	Fri Jul 12 16:19:05 2013 +0200
     3.3 @@ -6,10 +6,10 @@
     3.4  
     3.5  theory Adhoc_Overloading
     3.6  imports Pure
     3.7 +keywords "adhoc_overloading" :: thy_decl and  "no_adhoc_overloading" :: thy_decl
     3.8  begin
     3.9  
    3.10  ML_file "adhoc_overloading.ML"
    3.11 -setup Adhoc_Overloading.setup
    3.12  
    3.13  end
    3.14  
     4.1 --- a/src/Tools/adhoc_overloading.ML	Fri Jul 12 15:51:25 2013 +0200
     4.2 +++ b/src/Tools/adhoc_overloading.ML	Fri Jul 12 16:19:05 2013 +0200
     4.3 @@ -6,11 +6,14 @@
     4.4  
     4.5  signature ADHOC_OVERLOADING =
     4.6  sig
     4.7 -  val add_overloaded: string -> theory -> theory
     4.8 -  val add_variant: string -> string -> theory -> theory
     4.9 -
    4.10 +  val is_overloaded: Proof.context -> string -> bool
    4.11 +  val generic_add_overloaded: string -> Context.generic -> Context.generic
    4.12 +  val generic_remove_overloaded: string -> Context.generic -> Context.generic
    4.13 +  val generic_add_variant: string -> term -> Context.generic -> Context.generic
    4.14 +  (*If the list of variants is empty at the end of "generic_remove_variant", then
    4.15 +  "generic_remove_overloaded" is called implicitly.*)
    4.16 +  val generic_remove_variant: string -> term -> Context.generic -> Context.generic
    4.17    val show_variants: bool Config.T
    4.18 -  val setup: theory -> theory
    4.19  end
    4.20  
    4.21  structure Adhoc_Overloading: ADHOC_OVERLOADING =
    4.22 @@ -18,122 +21,208 @@
    4.23  
    4.24  val show_variants = Attrib.setup_config_bool @{binding show_variants} (K false);
    4.25  
    4.26 -
    4.27  (* errors *)
    4.28  
    4.29 -fun duplicate_variant_err int_name ext_name =
    4.30 -  error ("Constant " ^ quote int_name ^ " is already a variant of " ^ quote ext_name);
    4.31 +fun duplicate_variant_error oconst =
    4.32 +  error ("Duplicate variant of " ^ quote oconst);
    4.33  
    4.34 -fun not_overloaded_err name =
    4.35 -  error ("Constant " ^ quote name ^ " is not declared as overloaded");
    4.36 +fun not_a_variant_error oconst =
    4.37 +  error ("Not a variant of " ^  quote oconst);
    4.38  
    4.39 -fun already_overloaded_err name =
    4.40 -  error ("Constant " ^ quote name ^ " is already declared as overloaded");
    4.41 +fun not_overloaded_error oconst =
    4.42 +  error ("Constant " ^ quote oconst ^ " is not declared as overloaded");
    4.43  
    4.44 -fun unresolved_err ctxt (c, T) t reason =
    4.45 -  error ("Unresolved overloading of  " ^ quote c ^ " :: " ^
    4.46 +fun unresolved_overloading_error ctxt (c, T) t reason =
    4.47 +  error ("Unresolved overloading of " ^ quote c ^ " :: " ^
    4.48      quote (Syntax.string_of_typ ctxt T) ^ " in " ^
    4.49      quote (Syntax.string_of_term ctxt t) ^ " (" ^ reason ^ ")");
    4.50  
    4.51 +(* generic data *)
    4.52  
    4.53 -(* theory data *)
    4.54 +fun variants_eq ((v1, T1), (v2, T2)) =
    4.55 +  Term.aconv_untyped (v1, v2) andalso T1 = T2;
    4.56  
    4.57 -structure Overload_Data = Theory_Data
    4.58 +structure Overload_Data = Generic_Data
    4.59  (
    4.60    type T =
    4.61 -    { internalize : (string * typ) list Symtab.table,
    4.62 -      externalize : string Symtab.table };
    4.63 -  val empty = {internalize=Symtab.empty, externalize=Symtab.empty};
    4.64 +    {variants : (term * typ) list Symtab.table,
    4.65 +     oconsts : string Termtab.table};
    4.66 +  val empty = {variants = Symtab.empty, oconsts = Termtab.empty};
    4.67    val extend = I;
    4.68  
    4.69 -  fun merge_ext int_name (ext_name1, ext_name2) =
    4.70 -    if ext_name1 = ext_name2 then ext_name1
    4.71 -    else duplicate_variant_err int_name ext_name1;
    4.72 -
    4.73    fun merge
    4.74 -    ({internalize = int1, externalize = ext1},
    4.75 -      {internalize = int2, externalize = ext2}) : T =
    4.76 -    {internalize = Symtab.merge_list (op =) (int1, int2),
    4.77 -      externalize = Symtab.join merge_ext (ext1, ext2)};
    4.78 +    ({variants = vtab1, oconsts = otab1},
    4.79 +     {variants = vtab2, oconsts = otab2}) : T =
    4.80 +    let
    4.81 +      fun merge_oconsts _ (oconst1, oconst2) =
    4.82 +        if oconst1 = oconst2 then oconst1
    4.83 +        else duplicate_variant_error oconst1;
    4.84 +    in
    4.85 +      {variants = Symtab.merge_list variants_eq (vtab1, vtab2),
    4.86 +       oconsts = Termtab.join merge_oconsts (otab1, otab2)}
    4.87 +    end;
    4.88  );
    4.89  
    4.90  fun map_tables f g =
    4.91 -  Overload_Data.map (fn {internalize=int, externalize=ext} =>
    4.92 -    {internalize=f int, externalize=g ext});
    4.93 +  Overload_Data.map (fn {variants = vtab, oconsts = otab} =>
    4.94 +    {variants = f vtab, oconsts = g otab});
    4.95 +
    4.96 +val is_overloaded = Symtab.defined o #variants o Overload_Data.get o Context.Proof;
    4.97 +val get_variants = Symtab.lookup o #variants o Overload_Data.get o Context.Proof;
    4.98 +val get_overloaded = Termtab.lookup o #oconsts o Overload_Data.get o Context.Proof;
    4.99 +
   4.100 +fun generic_add_overloaded oconst context =
   4.101 +  if is_overloaded (Context.proof_of context) oconst then context
   4.102 +  else map_tables (Symtab.update (oconst, [])) I context;
   4.103  
   4.104 -val is_overloaded = Symtab.defined o #internalize o Overload_Data.get;
   4.105 -val get_variants = Symtab.lookup o #internalize o Overload_Data.get;
   4.106 -val get_external = Symtab.lookup o #externalize o Overload_Data.get;
   4.107 -
   4.108 -fun add_overloaded ext_name thy =
   4.109 -  let val _ = not (is_overloaded thy ext_name) orelse already_overloaded_err ext_name;
   4.110 -  in map_tables (Symtab.update (ext_name, [])) I thy end;
   4.111 +fun generic_remove_overloaded oconst context =
   4.112 +  let
   4.113 +    fun remove_oconst_and_variants context oconst =
   4.114 +      let
   4.115 +        val remove_variants =
   4.116 +          (case get_variants (Context.proof_of context) oconst of
   4.117 +            NONE => I
   4.118 +          | SOME vs => fold (Termtab.remove (op =) o rpair oconst o fst) vs);
   4.119 +      in map_tables (Symtab.delete_safe oconst) remove_variants context end;
   4.120 +  in
   4.121 +    if is_overloaded (Context.proof_of context) oconst then remove_oconst_and_variants context oconst
   4.122 +    else not_overloaded_error oconst
   4.123 +  end;
   4.124  
   4.125 -fun add_variant ext_name name thy =
   4.126 -  let
   4.127 -    val _ = is_overloaded thy ext_name orelse not_overloaded_err ext_name;
   4.128 -    val _ =
   4.129 -      (case get_external thy name of
   4.130 -        NONE => ()
   4.131 -      | SOME gen' => duplicate_variant_err name gen');
   4.132 -    val T = Sign.the_const_type thy name;
   4.133 -  in
   4.134 -    map_tables (Symtab.cons_list (ext_name, (name, T)))
   4.135 -      (Symtab.update (name, ext_name)) thy
   4.136 -  end
   4.137 -
   4.138 +local
   4.139 +  fun generic_variant add oconst t context =
   4.140 +    let
   4.141 +      val ctxt = Context.proof_of context;
   4.142 +      val _ = if is_overloaded ctxt oconst then () else not_overloaded_error oconst;
   4.143 +      val T = t |> singleton (Variable.polymorphic ctxt) |> fastype_of;
   4.144 +      val t' = Term.map_types (K dummyT) t;
   4.145 +    in
   4.146 +      if add then
   4.147 +        let
   4.148 +          val _ =
   4.149 +            (case get_overloaded ctxt t' of
   4.150 +              NONE => ()
   4.151 +            | SOME oconst' => duplicate_variant_error oconst');
   4.152 +        in
   4.153 +          map_tables (Symtab.cons_list (oconst, (t', T))) (Termtab.update (t', oconst)) context
   4.154 +        end
   4.155 +      else
   4.156 +        let
   4.157 +          val _ =
   4.158 +            if member variants_eq (the (get_variants ctxt oconst)) (t', T) then ()
   4.159 +            else not_a_variant_error oconst;
   4.160 +        in
   4.161 +          map_tables (Symtab.map_entry oconst (remove1 variants_eq (t', T)))
   4.162 +            (Termtab.delete_safe t') context
   4.163 +          |> (fn context =>
   4.164 +            (case get_variants (Context.proof_of context) oconst of
   4.165 +              SOME [] => generic_remove_overloaded oconst context
   4.166 +            | _ => context))
   4.167 +        end
   4.168 +    end;
   4.169 +in
   4.170 +  val generic_add_variant = generic_variant true;
   4.171 +  val generic_remove_variant = generic_variant false;
   4.172 +end;
   4.173  
   4.174  (* check / uncheck *)
   4.175  
   4.176 -fun unifiable_with ctxt T1 (c, T2) =
   4.177 +fun unifiable_with thy T1 (t, T2) =
   4.178    let
   4.179 -    val thy = Proof_Context.theory_of ctxt;
   4.180      val maxidx1 = Term.maxidx_of_typ T1;
   4.181      val T2' = Logic.incr_tvar (maxidx1 + 1) T2;
   4.182 -    val maxidx2 = Int.max (maxidx1, Term.maxidx_of_typ T2');
   4.183 +    val maxidx2 = Term.maxidx_typ T2' maxidx1;
   4.184    in
   4.185 -    (Sign.typ_unify thy (T1, T2') (Vartab.empty, maxidx2); SOME c)
   4.186 +    (Sign.typ_unify thy (T1, T2') (Vartab.empty, maxidx2); SOME t)
   4.187      handle Type.TUNIFY => NONE
   4.188    end;
   4.189  
   4.190 -fun insert_internal_same ctxt t (Const (c, T)) =
   4.191 -      (case map_filter (unifiable_with ctxt T)
   4.192 -         (Same.function (get_variants (Proof_Context.theory_of ctxt)) c) of
   4.193 -        [] => unresolved_err ctxt (c, T) t "no instances"
   4.194 -      | [c'] => Const (c', dummyT)
   4.195 +fun insert_variants_same ctxt t (Const (c, T)) =
   4.196 +      (case map_filter (unifiable_with (Proof_Context.theory_of ctxt) T)
   4.197 +         (Same.function (get_variants ctxt) c) of
   4.198 +        [] => unresolved_overloading_error ctxt (c, T) t "no instances"
   4.199 +      | [variant] => variant
   4.200        | _ => raise Same.SAME)
   4.201 -  | insert_internal_same _ _ _ = raise Same.SAME;
   4.202 +  | insert_variants_same _ _ _ = raise Same.SAME;
   4.203  
   4.204 -fun insert_external_same ctxt _ (Const (c, T)) =
   4.205 -      Const (Same.function (get_external (Proof_Context.theory_of ctxt)) c, T)
   4.206 -  | insert_external_same _ _ _ = raise Same.SAME;
   4.207 +fun insert_overloaded_same ctxt variant =
   4.208 +  let
   4.209 +    val thy = Proof_Context.theory_of ctxt;
   4.210 +    val t = Pattern.rewrite_term thy [] [fn t =>
   4.211 +      Term.map_types (K dummyT) t
   4.212 +      |> get_overloaded ctxt
   4.213 +      |> Option.map (Const o rpair (fastype_of variant))] variant;
   4.214 +  in
   4.215 +    if Term.aconv_untyped (variant, t) then raise Same.SAME
   4.216 +    else t
   4.217 +  end;
   4.218  
   4.219  fun gen_check_uncheck replace ts ctxt =
   4.220 -  Same.capture (Same.map (fn t => Term_Subst.map_aterms_same (replace ctxt t) t)) ts
   4.221 +  Same.capture (Same.map replace) ts
   4.222    |> Option.map (rpair ctxt);
   4.223  
   4.224 -val check = gen_check_uncheck insert_internal_same;
   4.225 +fun check ts ctxt = gen_check_uncheck (fn t =>
   4.226 +  Term_Subst.map_aterms_same (insert_variants_same ctxt t) t) ts ctxt;
   4.227  
   4.228  fun uncheck ts ctxt =
   4.229    if Config.get ctxt show_variants then NONE
   4.230 -  else gen_check_uncheck insert_external_same ts ctxt;
   4.231 +  else gen_check_uncheck (insert_overloaded_same ctxt) ts ctxt;
   4.232  
   4.233  fun reject_unresolved ts ctxt =
   4.234    let
   4.235 -    val thy = Proof_Context.theory_of ctxt;
   4.236      fun check_unresolved t =
   4.237 -      (case filter (is_overloaded thy o fst) (Term.add_consts t []) of
   4.238 +      (case filter (is_overloaded ctxt o fst) (Term.add_consts t []) of
   4.239          [] => ()
   4.240 -      | ((c, T) :: _) => unresolved_err ctxt (c, T) t "multiple instances");
   4.241 +      | ((c, T) :: _) => unresolved_overloading_error ctxt (c, T) t "multiple instances");
   4.242      val _ = map check_unresolved ts;
   4.243    in NONE end;
   4.244  
   4.245 -
   4.246  (* setup *)
   4.247  
   4.248 -val setup = Context.theory_map
   4.249 +val _ = Context.>>
   4.250    (Syntax_Phases.term_check' 0 "adhoc_overloading" check
   4.251     #> Syntax_Phases.term_check' 1 "adhoc_overloading_unresolved_check" reject_unresolved
   4.252     #> Syntax_Phases.term_uncheck' 0 "adhoc_overloading" uncheck);
   4.253  
   4.254 +(* commands *)
   4.255 +
   4.256 +fun generic_adhoc_overloading_cmd add =
   4.257 +  if add then
   4.258 +    fold (fn (oconst, ts) =>
   4.259 +      generic_add_overloaded oconst
   4.260 +      #> fold (generic_add_variant oconst) ts)
   4.261 +  else
   4.262 +    fold (fn (oconst, ts) =>
   4.263 +      fold (generic_remove_variant oconst) ts);
   4.264 +
   4.265 +fun adhoc_overloading_cmd' add args phi =
   4.266 +  let val args' = args
   4.267 +    |> map (apsnd (map_filter (fn t =>
   4.268 +         let val t' = Morphism.term phi t;
   4.269 +         in if Term.aconv_untyped (t, t') then SOME t' else NONE end)));
   4.270 +  in generic_adhoc_overloading_cmd add args' end;
   4.271 +
   4.272 +fun adhoc_overloading_cmd add raw_args lthy =
   4.273 +  let
   4.274 +    fun const_name ctxt = fst o dest_Const o Proof_Context.read_const ctxt false dummyT;
   4.275 +    val args =
   4.276 +      raw_args
   4.277 +      |> map (apfst (const_name lthy))
   4.278 +      |> map (apsnd (map (Syntax.read_term lthy)));
   4.279 +  in
   4.280 +    Local_Theory.declaration {syntax = true, pervasive = false}
   4.281 +      (adhoc_overloading_cmd' add args) lthy
   4.282 +  end;
   4.283 +
   4.284 +val _ =
   4.285 +  Outer_Syntax.local_theory @{command_spec "adhoc_overloading"}
   4.286 +    "add ad-hoc overloading for constants / fixed variables"
   4.287 +    (Parse.and_list1 (Parse.const -- Scan.repeat Parse.term) >> adhoc_overloading_cmd true);
   4.288 +
   4.289 +val _ =
   4.290 +  Outer_Syntax.local_theory @{command_spec "no_adhoc_overloading"}
   4.291 +    "add ad-hoc overloading for constants / fixed variables"
   4.292 +    (Parse.and_list1 (Parse.const -- Scan.repeat Parse.term) >> adhoc_overloading_cmd false);
   4.293 +
   4.294  end;
   4.295 +