beginning support for type instantiation;
authorwenzelm
Fri, 14 Dec 2001 11:55:34 +0100
changeset 12502 9e7f72e25022
parent 12501 36b2ac65e18d
child 12503 52994bfef01b
beginning support for type instantiation; tuned internal arrangements;
src/Pure/Isar/locale.ML
--- a/src/Pure/Isar/locale.ML	Fri Dec 14 11:54:47 2001 +0100
+++ b/src/Pure/Isar/locale.ML	Fri Dec 14 11:55:34 2001 +0100
@@ -29,12 +29,13 @@
   type locale
   val intern: Sign.sg -> xstring -> string
   val cond_extern: Sign.sg -> string -> xstring
+  val the_locale: theory -> string -> locale
   val attribute: ('att -> context attribute) -> ('typ, 'term, 'thm, 'att) elem_expr
     -> ('typ, 'term, 'thm, context attribute) elem_expr
-  val activate_elements: context attribute element list -> context -> context
-  val activate_elements_i: context attribute element_i list -> context -> context
-  val activate_locale: xstring -> context -> context
-  val activate_locale_i: string -> context -> context
+  val activate_context: context -> expr * context attribute element list ->
+   (context -> context) * (context -> context)
+  val activate_context_i: context -> expr * context attribute element_i list ->
+   (context -> context) * (context -> context)
   val add_locale: bstring -> expr -> context attribute element list -> theory -> theory
   val add_locale_i: bstring -> expr -> context attribute element_i list -> theory -> theory
   val print_locales: theory -> unit
@@ -47,6 +48,7 @@
 struct
 
 
+
 (** locale elements and expressions **)
 
 type context = ProofContext.context;
@@ -73,7 +75,7 @@
 type locale =
  {import: expr,                                                         (*dynamic import*)
   elems: ((typ, term, thm list, context attribute) elem * stamp) list,  (*static content*)
-  params: (string * typ option) list * (string * typ option) list,      (*all vs. local params*)
+  params: (string * typ option) list * string list,                     (*all vs. local params*)
   text: (string * typ) list * term list}                                (*logical representation*)
 
 fun make_locale import elems params text =
@@ -130,11 +132,11 @@
   let
     fun prt_id (name, parms) = [Pretty.block (Pretty.breaks (map Pretty.str (name :: parms)))];
     val prt_ids = flat (separate [Pretty.str " +", Pretty.brk 1] (map prt_id ids));
-  in
-    if null ids then raise ProofContext.CONTEXT (msg, ctxt)
-    else raise ProofContext.CONTEXT (msg ^ "\n" ^ Pretty.string_of (Pretty.block
-      (Pretty.str "The error(s) above occurred in locale:" :: Pretty.brk 1 :: prt_ids)), ctxt)
-  end;
+    val err_msg =
+      if null ids then msg
+      else msg ^ "\n" ^ Pretty.string_of (Pretty.block
+        (Pretty.str "The error(s) above occurred in locale:" :: Pretty.brk 1 :: prt_ids));
+  in raise ProofContext.CONTEXT (err_msg, ctxt) end;
 
 
 
@@ -162,13 +164,11 @@
   | Notes facts =>
       Notes (map (fn (a, bs) => (prep_name ctxt a, map (apfst (prep_thms ctxt)) bs)) facts);
 
+in
+
 fun read_elem x = prep_elem ProofContext.read_vars ProofContext.read_propp ProofContext.get_thms x;
 fun cert_elem x = prep_elem ProofContext.cert_vars ProofContext.cert_propp (K I) x;
 
-fun read_att attrib (x, srcs) = (x, map attrib srcs);
-
-in
-
 fun read_expr ctxt (Locale xname) = Locale (intern (ProofContext.sign_of ctxt) xname)
   | read_expr ctxt (Merge exprs) = Merge (map (read_expr ctxt) exprs)
   | read_expr ctxt (Rename (expr, xs)) = Rename (read_expr ctxt expr, xs);
@@ -179,6 +179,13 @@
 fun cert_element ctxt (Elem e) = Elem (cert_elem ctxt e)
   | cert_element ctxt (Expr e) = Expr e;
 
+end;
+
+
+(* internalize attributes *)
+
+local fun read_att attrib (x, srcs) = (x, map attrib srcs) in
+
 fun attribute _ (Elem (Fixes fixes)) = Elem (Fixes fixes)
   | attribute attrib (Elem (Assumes asms)) = Elem (Assumes (map (apfst (read_att attrib)) asms))
   | attribute attrib (Elem (Defines defs)) = Elem (Defines (map (apfst (read_att attrib)) defs))
@@ -202,7 +209,7 @@
   let
     val {sign, hyps, prop, maxidx, ...} = Thm.rep_thm th;
     val cert = Thm.cterm_of sign;
-    val (xs, Ts) = Library.split_list (foldl Drule.add_frees ([], prop :: hyps));
+    val (xs, Ts) = Library.split_list (foldl Term.add_frees ([], prop :: hyps));
     val xs' = map (rename ren) xs;
     fun cert_frees names = map (cert o Free) (names ~~ Ts);
     fun cert_vars names = map (cert o Var o apfst (rpair (maxidx + 1))) (names ~~ Ts);
@@ -239,8 +246,99 @@
   end;
 
 
+(* type instantiation *)
+
+fun inst_type [] T = T
+  | inst_type env T = Term.map_type_tfree (fn v => if_none (assoc (env, v)) (TFree v)) T;
+
+fun inst_term [] t = t
+  | inst_term env t = Term.map_term_types (inst_type env) t;
+
+fun inst_thm [] th = th
+  | inst_thm env th =
+      let
+        val {sign, hyps, prop, maxidx, ...} = Thm.rep_thm th;
+        val cert = Thm.cterm_of sign and certT = Thm.ctyp_of sign;
+        val names = foldr Term.add_term_tfree_names (prop :: hyps, []);
+        val env' = filter (fn ((a, _), _) => a mem_string names) env;
+      in
+        if null env' then th
+        else
+          th
+          |> Drule.implies_intr_list (map cert hyps)
+          |> Drule.tvars_intr_list names
+          |> (fn (th', al) => th' |>
+            Thm.instantiate ((map (fn ((a, _), T) => (the (assoc (al, a)), certT T)) env'), []))
+          |> (fn th'' => Drule.implies_elim_list th''
+              (map (Thm.assume o cert o inst_term env') hyps))
+      end;
+
+fun inst_elem env (Fixes fixes) =
+      Fixes (map (fn (x, T, mx) => (x, apsome (inst_type env) T, mx)) fixes)
+  | inst_elem env (Assumes asms) = Assumes (map (apsnd (map (fn (t, (ps, qs)) =>
+      (inst_term env t, (map (inst_term env) ps, map (inst_term env) qs))))) asms)
+  | inst_elem env (Defines defs) = Defines (map (apsnd (fn (t, ps) =>
+      (inst_term env t, map (inst_term env) ps))) defs)
+  | inst_elem env (Notes facts) = Notes (map (apsnd (map (apfst (map (inst_thm env))))) facts);
+
+
 (* evaluation *)
 
+fun frozen_tvars ctxt Ts =
+  let
+    val tvars = rev (foldl Term.add_tvarsT ([], Ts));
+    val tfrees = map TFree
+      (Term.invent_type_names (ProofContext.used_types ctxt) (length tvars) ~~ map #2 tvars);
+  in map #1 tvars ~~ tfrees end;
+
+fun unify_parms ctxt raw_parmss =
+  let
+    val tsig = Sign.tsig_of (ProofContext.sign_of ctxt);
+    val maxidx = length raw_parmss;
+    val idx_parmss = (0 upto maxidx - 1) ~~ raw_parmss;
+
+    fun varify i = Term.map_type_tfree (fn (a, S) => TVar ((a, i), S));
+    fun varify_parms (i, ps) =
+      mapfilter (fn (_, None) => None | (x, Some T) => Some (x, varify i T)) ps;
+    val parms = flat (map varify_parms idx_parmss);
+
+    fun unify T ((env, maxidx), U) = Type.unify tsig maxidx env (U, T);  (*should never fail*)
+    fun unify_list (envir, T :: Us) = foldl (unify T) (envir, Us)
+      | unify_list (envir, []) = envir;
+    val (unifier, _) = foldl unify_list
+      ((Vartab.empty, maxidx), map #2 (Symtab.dest (Symtab.make_multi parms)));
+
+    val parms' = map (apsnd (Envir.norm_type unifier)) (gen_distinct eq_fst parms);
+    val unifier' = Vartab.extend (unifier, frozen_tvars ctxt (map #2 parms'));
+
+    fun inst_parms (i, ps) =
+      foldr Term.add_typ_tfrees (mapfilter snd ps, [])
+      |> mapfilter (fn (a, S) =>
+          let val T = Envir.norm_type unifier' (TVar ((a, i), S))
+          in if T = TFree (a, S) then None else Some ((a, S), T) end);
+  in map inst_parms idx_parmss end;
+
+fun unique_parms ctxt elemss =
+  let
+    val param_decls =
+      flat (map (fn ((name, (ps, qs)), _) => map (rpair (name, ps)) qs) elemss)
+      |> Symtab.make_multi |> Symtab.dest;
+  in
+    (case find_first (fn (_, ids) => length ids > 1) param_decls of
+      Some (q, ids) => err_in_locale ctxt ("Multiple declaration of parameter " ^ quote q)
+          (map (apsnd (map fst)) ids)
+    | None => map (apfst (apsnd #1)) elemss)
+  end;
+
+fun inst_types _ [elems] = [elems]
+  | inst_types ctxt elemss =
+      let
+        val envs = unify_parms ctxt (map (#2 o #1) elemss);
+        fun inst (((name, ps), elems), env) =
+          ((name, map (apsnd (apsome (inst_type env))) ps), map (inst_elem env) elems);
+      in map inst (elemss ~~ envs) end;
+
+
 fun eval_expr ctxt expr =
   let
     val thy = ProofContext.theory_of ctxt;
@@ -278,29 +376,42 @@
 
     fun eval (name, xs) =
       let
-        val {params = (ps, _), elems, ...} = the_locale thy name;
+        val {params = (ps, qs), elems, ...} = the_locale thy name;
         val ren = filter_out (op =) (map #1 ps ~~ xs);
-        val (ps', elems') =
-          if null ren then (ps, map #1 elems)
-          else (map (apfst (rename ren)) ps, map (rename_elem ren o #1) elems);
-      in ((name, ps'), map (qualify_elem [NameSpace.base name, space_implode "_" xs]) elems') end;
+        val (params', elems') =
+          if null ren then ((ps, qs), map #1 elems)
+          else ((map (apfst (rename ren)) ps, map (rename ren) qs),
+            map (rename_elem ren o #1) elems);
+        val elems'' = map (qualify_elem [NameSpace.base name, space_implode "_" xs]) elems';
+      in ((name, params'), elems'') end;
 
-    (* FIXME unify types *)
-
-    val (idents, parms) = identify (([], []), expr);
-  in (map eval idents, parms) end;
-
-fun eval_element _ (Elem e) = [(("", []), [e])]
-  | eval_element ctxt (Expr e) = #1 (eval_expr ctxt e);
+    val raw_elemss = unique_parms ctxt (map eval (#1 (identify (([], []), expr))));
+    val elemss = inst_types ctxt raw_elemss;
+  in elemss end;
 
 
 
 (** activation **)
 
-(* internal elems *)
+(* internalize elems *)
+
+fun declare_elem gen =
+  let
+    val gen_typ = if gen then Term.map_type_tfree (Type.param []) else I;
+    val gen_term = if gen then Term.map_term_types gen_typ else I;
+
+    fun declare (Fixes fixes) = ProofContext.add_syntax fixes o
+          ProofContext.fix_direct (map (fn (x, T, _) => ([x], apsome gen_typ T)) fixes)
+      | declare (Assumes asms) = (fn ctxt => #1 (ProofContext.bind_propp_i
+          (ctxt, map (map (fn (t, (ps, ps')) =>
+            (gen_term t, (map gen_term ps, map gen_term ps'))) o #2) asms)))
+      | declare (Defines defs) = (fn ctxt => #1 (ProofContext.bind_propp_i
+          (ctxt, map (fn (_, (t, ps)) => [(gen_term t, (map gen_term ps, []))]) defs)))
+      | declare (Notes _) = I;
+  in declare end;
 
 fun activate_elem (Fixes fixes) = ProofContext.add_syntax fixes o
-      ProofContext.fix_direct (map (fn (x, T, mx) => ([x], T)) fixes)
+      ProofContext.fix_direct (map (fn (x, T, _) => ([x], T)) fixes)
   | activate_elem (Assumes asms) =
       #1 o ProofContext.assume_i ProofContext.export_assume asms o
       ProofContext.fix_frees (flat (map (map #1 o #2) asms))
@@ -310,24 +421,67 @@
         in ((if name = "" then Thm.def_name c else name, atts), [(t', (ps, []))]) end) defs) ctxt))
   | activate_elem (Notes facts) = #1 o ProofContext.have_thmss facts;
 
-fun activate_elems es ctxt = foldl (fn (c, e) => activate_elem e c) (ctxt, es);
 
-fun activate_locale_elems named_elems = ProofContext.qualified (fn context =>
-  foldl (fn (ctxt, ((name, ps), es)) =>    (* FIXME type inst *)
-    activate_elems es ctxt handle ProofContext.CONTEXT (msg, ctxt) =>
+fun perform_elems f named_elems = ProofContext.qualified (fn context =>
+  foldl (fn (ctxt, ((name, ps), es)) =>
+    foldl (fn (c, e) => f e c) (ctxt, es) handle ProofContext.CONTEXT (msg, ctxt) =>
       err_in_locale ctxt msg [(name, map fst ps)]) (context, named_elems));
 
+fun declare_elemss gen = perform_elems (declare_elem gen);
+fun activate_elemss x = perform_elems activate_elem x;
+
 
-(* external elements and locales *)
+(* context specifications: import expression + external elements *)
+
+local
+
+fun close_frees ctxt t =
+  let val frees = rev (filter_out (ProofContext.is_fixed ctxt o #1) (Term.add_frees ([], t)))
+  in Term.list_all_free (frees, t) end;
+
+(*quantify dangling frees, strip term bindings*)
+fun closeup ctxt (Assumes asms) = Assumes (asms |> map (fn (a, propps) =>
+      (a, map (fn (t, _) => (close_frees ctxt t, ([], []))) propps)))
+  | closeup ctxt (Defines defs) = Defines (defs |> map (fn (a, (t, _)) =>
+      (a, (close_frees ctxt (#2 (ProofContext.cert_def ctxt t)), []))))
+  | closeup ctxt elem = elem;
+
+fun prepare_context prep_elem prep_expr close context (import, elements) =
+  let
+    fun prep_element (ctxt, Elem raw_elem) =
+          let val elem = (if close then closeup ctxt else I) (prep_elem ctxt raw_elem)
+          in (ctxt |> declare_elem false elem, [(("", []), [elem])]) end
+      | prep_element (ctxt, Expr raw_expr) =
+          let
+            val expr = prep_expr ctxt raw_expr;
+            val named_elemss = eval_expr ctxt expr;
+          in (ctxt |> declare_elemss true named_elemss, named_elemss) end;
 
-fun gen_activate_elements prep_element raw_elements context =
-  foldl (fn (ctxt, e) => activate_locale_elems (eval_element ctxt (prep_element ctxt e)) ctxt)
-    (context, raw_elements);
+    val (import_ctxt, import_elemss) = prep_element (context, Expr import);
+    val (elements_ctxt, elements_elemss) =
+      apsnd flat (foldl_map prep_element (import_ctxt, elements));
+
+    val xs = flat (map (map #1 o (#2 o #1)) (import_elemss @ elements_elemss));
+    val env = frozen_tvars elements_ctxt (mapfilter (ProofContext.default_type elements_ctxt) xs);
+
+    fun inst_elems ((name, ps), elems) = ((name, ps), elems);  (* FIXME *)
+
+  in (map inst_elems import_elemss, map inst_elems elements_elemss) end;
 
-val activate_elements = gen_activate_elements read_element;
-val activate_elements_i = gen_activate_elements cert_element;
-val activate_locale_i = activate_elements_i o single o Expr o Locale;
-val activate_locale = activate_elements o single o Expr o Locale;
+fun gen_activate_context prep_elem prep_expr ctxt args =
+  pairself activate_elemss (prepare_context prep_elem prep_expr false ctxt args);
+
+in
+
+val read_context = prepare_context read_elem read_expr true;
+val cert_context = prepare_context cert_elem (K I) true;
+val activate_context = gen_activate_context read_elem read_expr;
+val activate_context_i = gen_activate_context cert_elem (K I);
+
+fun activate_locale name ctxt =
+  #1 (activate_context_i ctxt (Locale name, [])) ctxt;
+
+end;
 
 
 
@@ -338,9 +492,8 @@
     val sg = Theory.sign_of thy;
     val thy_ctxt = ProofContext.init thy;
 
-    val expr = read_expr thy_ctxt raw_expr;
-    val elems = #1 (eval_expr thy_ctxt expr);
-    val ctxt = activate_locale_elems elems thy_ctxt;
+    val elemss = #1 (read_context thy_ctxt (raw_expr, []));
+    val ctxt = activate_elemss elemss thy_ctxt;
 
     val prt_typ = Pretty.quote o ProofContext.pretty_typ ctxt;
     val prt_term = Pretty.quote o ProofContext.pretty_term ctxt;
@@ -367,7 +520,7 @@
       | prt_elem (Defines defs) = items "defines" (map prt_def defs)
       | prt_elem (Notes facts) = items "notes" (map prt_fact facts);
   in
-    Pretty.big_list "locale elements:" (map (Pretty.chunks o prt_elem) (flat (map #2 elems)))
+    Pretty.big_list "locale elements:" (map (Pretty.chunks o prt_elem) (flat (map #2 elemss)))
     |> Pretty.writeln
   end;
 
@@ -375,59 +528,42 @@
 
 (** define locales **)
 
-(* closeup -- quantify dangling frees *)
-
-fun close_frees_wrt ctxt t =
-  let val frees = rev (filter_out (ProofContext.is_fixed ctxt o #1) (Drule.add_frees ([], t)))
-  in curry Term.list_all_free frees end;
-
-fun closeup ctxt (Assumes asms) = Assumes (asms |> map (fn (a, propps) =>
-      (a, propps |> map (fn (t, (ps1, ps2)) =>
-        let val close = close_frees_wrt ctxt t in (close t, (map close ps1, map close ps2)) end))))
-  | closeup ctxt (Defines defs) = Defines (defs |> map (fn (a, (t, ps)) =>
-      let
-        val (_, t') = ProofContext.cert_def ctxt t;
-        val close = close_frees_wrt ctxt t';
-      in (a, (close t', map close ps)) end))
-  | closeup ctxt elem = elem;
-
-
 (* add_locale(_i) *)
 
-fun gen_add_locale prep_expr prep_element bname raw_import raw_elements thy =
+local
+
+fun gen_add_locale prep_context prep_expr bname raw_import raw_body thy =
   let
     val sign = Theory.sign_of thy;
     val name = Sign.full_name sign bname;
-    val _ =
-      if is_none (get_locale thy name) then () else
-      error ("Duplicate definition of locale " ^ quote name);
+    val _ = conditional (is_some (get_locale thy name)) (fn () =>
+      error ("Duplicate definition of locale " ^ quote name));
 
     val thy_ctxt = ProofContext.init thy;
 
+    val (import_elemss, body_elemss) = prep_context thy_ctxt (raw_import, raw_body);
     val import = prep_expr thy_ctxt raw_import;
-    val (import_elems, import_params) = eval_expr thy_ctxt import;
-    val import_ctxt = activate_locale_elems import_elems thy_ctxt;
+    val import_elemss = eval_expr thy_ctxt import;
 
-    fun prep (ctxt, raw_element) =
-      let val elems = map (apsnd (map (closeup ctxt)))
-        (eval_element ctxt (prep_element ctxt raw_element))
-      in (activate_locale_elems elems ctxt, flat (map #2 elems)) end;
-    val (locale_ctxt, elemss) = foldl_map prep (import_ctxt, raw_elements);
+    val import_ctxt = thy_ctxt |> activate_elemss import_elemss;
+    val body_ctxt = import_ctxt |> activate_elemss body_elemss;
 
-    val elems = flat elemss;
-    val local_params =  (* FIXME lookup final types *)
-      flat (map (fn Fixes fixes => map (fn (x, T, _) => (x, T)) fixes | _ => []) elems);
-    val params = map (rpair None) import_params @ local_params;  (* FIXME *)
+    val elems = flat (map #2 body_elemss);
+    val (import_parms, body_parms) = pairself (flat o map (#2 o #1)) (import_elemss, body_elemss);
     val text = ([], []);  (* FIXME *)
   in
     thy
     |> declare_locale name
     |> put_locale name (make_locale import (map (fn e => (e, stamp ())) elems)
-      (params, local_params) text)
+      (import_parms @ body_parms, map #1 body_parms) text)
   end;
 
-val add_locale = gen_add_locale read_expr read_element;
-val add_locale_i = gen_add_locale (K I) (K I);
+in
+
+val add_locale = gen_add_locale read_context read_expr;
+val add_locale_i = gen_add_locale cert_context (K I);
+
+end;
 
 
 
@@ -439,7 +575,7 @@
     val note = Notes (map (fn ((a, ths), atts) =>
       ((a, atts), [(map (curry Thm.name_thm a) ths, [])])) args);
   in
-    thy |> ProofContext.init |> activate_locale_i name |> activate_elem note;  (*test attributes!*)
+    thy |> ProofContext.init |> activate_locale name |> activate_elem note;  (*test attributes!*)
     thy |> put_locale name (make_locale import (elems @ [(note, stamp ())]) params text)
   end;