src/Pure/Isar/locale.ML
changeset 19783 82f365a14960
parent 19780 dce2168b0ea4
child 19810 dae765e552ce
--- a/src/Pure/Isar/locale.ML	Tue Jun 06 09:28:24 2006 +0200
+++ b/src/Pure/Isar/locale.ML	Tue Jun 06 10:05:57 2006 +0200
@@ -151,7 +151,8 @@
        (cf. [1], normalisation of locale expressions.)
     *)
   import: expr,                                                     (*dynamic import*)
-  elems: (Element.context_i * stamp) list,                          (*static content*)
+  elems: (Element.context_i * stamp) list,
+    (* Static content, neither Fixes nor Constrains elements *)
   params: ((string * typ) * mixfix) list,                           (*all params*)
   lparams: string list,                                             (*local parmas*)
   term_syntax: ((Proof.context -> Proof.context) * stamp) list, (* FIXME depend on morphism *)
@@ -489,6 +490,23 @@
 fun err_in_locale' ctxt msg ids' = err_in_locale ctxt msg (map fst ids');
 
 
+fun pretty_ren NONE = Pretty.str "_"
+  | pretty_ren (SOME (x, NONE)) = Pretty.str x
+  | pretty_ren (SOME (x, SOME syn)) =
+      Pretty.block [Pretty.str x, Pretty.brk 1, Syntax.pretty_mixfix syn];
+
+fun pretty_expr thy (Locale name) = Pretty.str (extern thy name)
+  | pretty_expr thy (Rename (expr, xs)) =
+      Pretty.block [pretty_expr thy expr, Pretty.brk 1, Pretty.block (map pretty_ren xs |> Pretty.breaks)]
+  | pretty_expr thy (Merge es) =
+      Pretty.separate "+" (map (pretty_expr thy) es) |> Pretty.block;
+
+fun err_in_expr _ msg (Merge []) = error msg
+  | err_in_expr ctxt msg expr =
+    error (msg ^ "\n" ^ Pretty.string_of (Pretty.block
+      [Pretty.str "The error(s) above occured in locale expression:", Pretty.brk 1,
+       pretty_expr (ProofContext.theory_of ctxt) expr]));
+
 
 (** structured contexts: rename + merge + implicit type instantiation **)
 
@@ -552,18 +570,6 @@
 
 local
 
-fun unique_parms ctxt elemss =
-  let
-    val param_decls =
-      maps (fn (((name, (ps, qs)), _), _) => map (rpair (name, ps)) qs) elemss
-      |> Symtab.make_list |> Symtab.dest;
-  in
-    (case find_first (fn (_, ids) => length ids > 1) param_decls of
-      SOME (q, ids) => err_in_locale ctxt ("Multiple declaration of parameter " ^ quote q)
-          (map (apsnd (map fst)) ids)
-    | NONE => map (apfst (apfst (apsnd #1))) elemss)
-  end;
-
 fun unify_parms ctxt fixed_parms raw_parmss =
   let
     val thy = ProofContext.theory_of ctxt;
@@ -612,7 +618,7 @@
 
 (* like unify_elemss, but does not touch mode, additional
    parameter c_parms for enforcing further constraints (eg. syntax) *)
-(* FIXME avoid code duplication *)
+(* FIXME avoid code duplication *) (* FIXME: avoid stipulating comments *)
 
 fun unify_elemss' _ _ [] [] = []
   | unify_elemss' _ [] [elems] [] = [elems]
@@ -627,6 +633,84 @@
       in map inst (elemss ~~ Library.take (length elemss, envs)) end;
 
 
+(* params_of_expr:
+   Compute parameters (with types and syntax) of locale expression.
+*)
+
+fun params_of_expr ctxt fixed_params expr (prev_parms, prev_types, prev_syn) =
+  let
+    val thy = ProofContext.theory_of ctxt;
+
+    fun renaming (SOME x :: xs) (y :: ys) = (y, x) :: renaming xs ys
+      | renaming (NONE :: xs) (y :: ys) = renaming xs ys
+      | renaming [] _ = []
+      | renaming xs [] = error ("Too many arguments in renaming: " ^
+          commas (map (fn NONE => "_" | SOME x => quote (fst x)) xs));
+
+    fun merge_tenvs fixed tenv1 tenv2 =
+        let
+          val [env1, env2] = unify_parms ctxt fixed
+                [tenv1 |> Symtab.dest |> map (apsnd SOME),
+                 tenv2 |> Symtab.dest |> map (apsnd SOME)]
+        in
+          Symtab.merge (op =) (Symtab.map (Element.instT_type env1) tenv1,
+            Symtab.map (Element.instT_type env2) tenv2)
+        end;
+
+    fun merge_syn expr syn1 syn2 =
+        Symtab.merge (op =) (syn1, syn2)
+        handle Symtab.DUPS xs => err_in_expr ctxt
+          ("Conflicting syntax for parameter(s): " ^ commas_quote xs) expr;
+            
+    fun params_of (expr as Locale name) =
+          let
+            val {import, params, ...} = the_locale thy name;
+            val parms = map (fst o fst) params;
+            val (parms', types', syn') = params_of import;
+            val all_parms = merge_lists parms' parms;
+            val all_types = merge_tenvs [] types' (params |> map fst |> Symtab.make);
+            val all_syn = merge_syn expr syn' (params |> map (apfst fst) |> Symtab.make);
+          in (all_parms, all_types, all_syn) end
+      | params_of (expr as Rename (e, xs)) =
+          let
+            val (parms', types', syn') = params_of e;
+            val ren = renaming xs parms';
+            (* renaming may reduce number of parameters *)
+            val new_parms = map (Element.rename ren) parms' |> distinct (op =);
+            val ren_syn = syn' |> Symtab.dest |> map (Element.rename_var ren);
+            val new_syn = fold (Symtab.insert (op =)) ren_syn Symtab.empty
+                handle Symtab.DUP x =>
+                  err_in_expr ctxt ("Conflicting syntax for parameter: " ^ quote x) expr;
+            val syn_types = map (apsnd (fn mx => SOME (Type.freeze_type (#1 (TypeInfer.paramify_dummies (TypeInfer.mixfixT mx) 0))))) (Symtab.dest new_syn);
+            val ren_types = types' |> Symtab.dest |> map (apfst (Element.rename ren));
+            val (env :: _) = unify_parms ctxt [] 
+                ((ren_types |> map (apsnd SOME)) :: map single syn_types);
+            val new_types = fold (Symtab.insert (op =))
+                (map (apsnd (Element.instT_type env)) ren_types) Symtab.empty;
+          in (new_parms, new_types, new_syn) end
+      | params_of (Merge es) =
+          fold (fn e => fn (parms, types, syn) =>
+                   let
+                     val (parms', types', syn') = params_of e
+                   in
+                     (merge_lists parms parms', merge_tenvs [] types types',
+                      merge_syn e syn syn')
+                   end)
+            es ([], Symtab.empty, Symtab.empty)
+
+      val (parms, types, syn) = params_of expr;
+    in
+      (merge_lists prev_parms parms, merge_tenvs fixed_params prev_types types,
+       merge_syn expr prev_syn syn)
+    end;
+
+fun make_params_ids params = [(("", params), ([], Assumed []))];
+fun make_raw_params_elemss (params, tenv, syn) =
+    [((("", map (fn p => (p, Symtab.lookup tenv p)) params), Assumed []),
+      Int [Fixes (map (fn p =>
+        (p, Symtab.lookup tenv p, Symtab.lookup syn p |> the)) params)])];
+
+
 (* flatten_expr:
    Extend list of identifiers by those new in locale expression expr.
    Compute corresponding list of lists of locale elements (one entry per
@@ -659,14 +743,11 @@
           commas (map (fn NONE => "_" | SOME x => quote (fst x)) xs));
 
     fun rename_parms top ren ((name, ps), (parms, mode)) =
-      let val ps' = map (Element.rename ren) ps in
-        (case duplicates (op =) ps' of
-          [] => ((name, ps'),
-                 if top then (map (Element.rename ren) parms,
-                   map_mode (map (Element.rename_witness ren)) mode)
-                 else (parms, mode))
-        | dups => err_in_locale ctxt ("Duplicate parameters: " ^ commas_quote dups) [(name, ps')])
-      end;
+        ((name, map (Element.rename ren) ps),
+         if top
+         then (map (Element.rename ren) parms,
+               map_mode (map (Element.rename_witness ren)) mode)
+         else (parms, mode));
 
     (* add registrations of (name, ps), recursively; adjust hyps of witnesses *)
 
@@ -739,8 +820,8 @@
 
             val ids'' = distinct (eq_fst (op =)) (map (rename_parms top ren) ids');
             val parms'' = distinct (op =) (maps (#2 o #1) ids'');
-            val syn'' = syn' |> Symtab.dest |> map (Element.rename_var ren) |> Symtab.make;
-            (* check for conflicting syntax? *)
+            val syn'' = fold (Symtab.insert (op =))
+                (map (Element.rename_var ren) (Symtab.dest syn')) Symtab.empty;
           in (ids'', parms'', syn'') end
       | identify top (Merge es) =
           fold (fn e => fn (ids, parms, syn) =>
@@ -765,8 +846,8 @@
         fun lookup_syn x = (case Symtab.lookup syn x of SOME Structure => NONE | opt => opt);
         val ren = map #1 ps' ~~ map (fn x => (x, lookup_syn x)) xs;
         val (params', elems') =
-          if null ren then ((ps', qs), map #1 elems)
-          else ((map (apfst (Element.rename ren)) ps', map (Element.rename ren) qs),
+          if null ren then ((ps'(*, qs*)), map #1 elems)
+          else ((map (apfst (Element.rename ren)) ps'(*, map (Element.rename ren) qs*)),
             map (Element.rename_ctxt ren o #1) elems);
         val elems'' = elems' |> map (Element.map_ctxt
           {var = I, typ = I, term = I, fact = I, attrib = I,
@@ -781,8 +862,8 @@
     val (ids, _, syn) = identify true expr;
     val idents = gen_rems (eq_fst (op =)) (ids, prev_idents);
     val syntax = merge_syntax ctxt ids (syn, prev_syntax);
-    (* add types to params, check for unique params and unify them *)
-    val raw_elemss = unique_parms ctxt (map (eval syntax) idents);
+    (* add types to params and unify them *)
+    val raw_elemss = (*unique_parms ctxt*) (map (eval syntax) idents);
     val elemss = unify_elemss' ctxt [] raw_elemss (map (apsnd mixfix_type) (Symtab.dest syntax));
     (* replace params in ids by params from axioms,
        adjust types in mode *)
@@ -909,52 +990,13 @@
   | intern_expr thy (Rename (expr, xs)) = Rename (intern_expr thy expr, xs);
 
 
-(* experimental code for type inference *)
-
-local
-
-fun declare_int_elem (ctxt, Fixes fixes) =
-      (ctxt |> ProofContext.add_fixes_i (map (fn (x, T, mx) =>
-        (x, Option.map (Term.map_type_tfree (TypeInfer.param 0)) T, mx)) fixes) |> snd, [])
-  | declare_int_elem (ctxt, _) = (ctxt, []);
-
-fun declare_ext_elem prep_vars (ctxt, Fixes fixes) =
-      let val (vars, _) = prep_vars fixes ctxt
-      in (ctxt |> ProofContext.add_fixes_i vars |> snd, []) end
-  | declare_ext_elem prep_vars (ctxt, Constrains csts) =
-      let val (_, ctxt') = prep_vars (map (fn (x, T) => (x, SOME T, NoSyn)) csts) ctxt
-      in (ctxt', []) end
-  | declare_ext_elem _ (ctxt, Assumes asms) = (ctxt, map #2 asms)
-  | declare_ext_elem _ (ctxt, Defines defs) = (ctxt, map (fn (_, (t, ps)) => [(t, ps)]) defs)
-  | declare_ext_elem _ (ctxt, Notes facts) = (ctxt, []);
-
-fun declare_elems prep_vars (ctxt, (((name, ps), Assumed _), elems)) =
-    let val (ctxt', propps) =
-      (case elems of
-        Int es => foldl_map declare_int_elem (ctxt, es)
-      | Ext e => foldl_map (declare_ext_elem prep_vars) (ctxt, [e]))
-      handle ERROR msg => err_in_locale ctxt msg [(name, map fst ps)]
-    in (ctxt', propps) end
-  | declare_elems _ (ctxt, ((_, Derived _), elems)) = (ctxt, []);
-
-in
-
-(* The Plan:
-- tell context about parameters and their syntax (possibly also types)
-- add declarations to context
-- retrieve parameter types
-*)
-
-end; (* local *)
-
-
 (* propositions and bindings *)
 
 (* flatten (ctxt, prep_expr) ((ids, syn), expr)
    normalises expr (which is either a locale
    expression or a single context element) wrt.
    to the list ids of already accumulated identifiers.
-   It returns (ids', syn', elemss) where ids' is an extension of ids
+   It returns ((ids', syn'), elemss) where ids' is an extension of ids
    with identifiers generated for expr, and elemss is the list of
    context elements generated from expr.
    syn and syn' are symtabs mapping parameter names to their syntax.  syn'
@@ -977,7 +1019,7 @@
          merge_syntax ctxt ids'
            (syn, Symtab.make (map (fn fx => (#1 fx, #3 fx)) fixes))
            handle Symtab.DUPS xs => err_in_locale ctxt
-             ("Conflicting syntax for parameters: " ^ commas_quote xs)
+             ("Conflicting syntax (3) for parameters: " ^ commas_quote xs)
              (map #1 ids')),
          [((("", map (rpair NONE o #1) fixes), Assumed []), Ext (Fixes fixes))])
       end
@@ -1297,9 +1339,13 @@
 fun parameters_of_expr thy expr =
   let
     val ctxt = ProofContext.init thy;
+    val pts = params_of_expr ctxt [] (intern_expr thy expr)
+        ([], Symtab.empty, Symtab.empty);
+    val raw_params_elemss = make_raw_params_elemss pts;
     val ((_, syn), raw_elemss) = flatten (ctxt, intern_expr thy)
         (([], Symtab.empty), Expr expr);
-    val ((parms, _, _), _) = read_elemss false ctxt [] raw_elemss [];
+    val ((parms, _, _), _) =
+        read_elemss false ctxt [] (raw_params_elemss @ raw_elemss) [];
   in map (fn p as (n, _) => (p, Symtab.lookup syn n |> the)) parms end;
 
 fun local_asms_of thy name =
@@ -1320,24 +1366,43 @@
   let
     val thy = ProofContext.theory_of context;
 
-    val ((import_ids, import_syn), raw_import_elemss) =
+    val (import_params, import_tenv, import_syn) =
+      params_of_expr context fixed_params (prep_expr thy import)
+        ([], Symtab.empty, Symtab.empty);
+    val includes = map_filter (fn Expr e => SOME e | Elem _ => NONE) elements;
+    val (incl_params, incl_tenv, incl_syn) = fold (params_of_expr context fixed_params)
+      (map (prep_expr thy) includes) (import_params, import_tenv, import_syn);
+
+    val ((import_ids, _), raw_import_elemss) =
       flatten (context, prep_expr thy) (([], Symtab.empty), Expr import);
     (* CB: normalise "includes" among elements *)
     val ((ids, syn), raw_elemsss) = foldl_map (flatten (context, prep_expr thy))
-      ((import_ids, import_syn), elements);
+      ((import_ids, incl_syn), elements);
 
     val raw_elemss = flat raw_elemsss;
     (* CB: raw_import_elemss @ raw_elemss is the normalised list of
        context elements obtained from import and elements. *)
+    (* Now additional elements for parameters are inserted. *)
+    val import_params_ids = make_params_ids import_params;
+    val incl_params_ids =
+        make_params_ids (incl_params \\ import_params);
+    val raw_import_params_elemss =
+        make_raw_params_elemss (import_params, incl_tenv, incl_syn);
+    val raw_incl_params_elemss =
+        make_raw_params_elemss (incl_params \\ import_params, incl_tenv, incl_syn);
     val ((parms, all_elemss, concl), (spec, (_, _, defs))) = prep_elemss do_close
-      context fixed_params (raw_import_elemss @ raw_elemss) raw_concl;
+      context fixed_params
+      (raw_import_params_elemss @ raw_import_elemss @ raw_incl_params_elemss @ raw_elemss) raw_concl;
+
     (* replace extended ids (for axioms) by ids *)
+    val (import_ids', incl_ids) = chop (length import_ids) ids;
+    val add_ids = import_params_ids @ import_ids' @ incl_params_ids @ incl_ids;
     val all_elemss' = map (fn (((_, ps), _), (((n, ps'), mode), elems)) =>
         (((n, map (fn p => (p, (the o AList.lookup (op =) ps') p)) ps), mode), elems))
-      (ids ~~ all_elemss);
+      (add_ids ~~ all_elemss);
+    (* CB: all_elemss and parms contain the correct parameter types *)
 
-    (* CB: all_elemss and parms contain the correct parameter types *)
-    val (ps, qs) = chop (length raw_import_elemss) all_elemss';
+    val (ps, qs) = chop (length raw_import_params_elemss + length raw_import_elemss) all_elemss';
     val (import_ctxt, (import_elemss, _)) =
       activate_facts prep_facts (context, ps);
 
@@ -1348,7 +1413,8 @@
                            | ((_, Derived _), _) => []) qs);
     val cstmt = map (cterm_of thy) stmt;
   in
-    ((((import_ctxt, import_elemss), (ctxt, elemss, syn)), (parms, spec, defs)), (cstmt, concl))
+    ((((import_ctxt, import_elemss), (ctxt, elemss, syn)),
+      (parms, spec, defs)), (cstmt, concl))
   end;
 
 fun prep_statement prep_locale prep_ctxt raw_locale elems concl ctxt =
@@ -1716,14 +1782,14 @@
     val export = ProofContext.export_view predicate_statement ctxt thy_ctxt;
     val facts' = facts |> map (fn (a, ths) => ((a, []), [(map export ths, [])]));
     val elems' = maps #2 (filter (equal "" o #1 o #1) elemss');
-
+    val elems'' = map_filter (fn (Fixes _) => NONE | e => SOME e) elems';
     val thy' = pred_thy
       |> PureThy.note_thmss_qualified "" bname facts' |> snd
       |> declare_locale name
       |> put_locale name
        {predicate = predicate,
         import = import,
-        elems = map (fn e => (e, stamp ())) elems',
+        elems = map (fn e => (e, stamp ())) elems'',
         params = params_of elemss' |> map (fn (x, SOME T) => ((x, T), the (Symtab.lookup syn x))),
         lparams = map #1 (params_of body_elemss),
         term_syntax = [],
@@ -1918,10 +1984,14 @@
     val thy = ProofContext.theory_of ctxt;
 
     val ctxt' = ctxt |> ProofContext.theory_of |> ProofContext.init;
+    val pts = params_of_expr ctxt' [] (intern_expr thy expr)
+          ([], Symtab.empty, Symtab.empty);
+    val params_ids = make_params_ids (#1 pts);
+    val raw_params_elemss = make_raw_params_elemss pts;
     val ((ids, _), raw_elemss) = flatten (ctxt', intern_expr thy)
           (([], Symtab.empty), Expr expr);
     val ((parms, all_elemss, _), (_, (_, defs, _))) =
-          read_elemss false ctxt' [] raw_elemss [];
+          read_elemss false ctxt' [] (raw_params_elemss @ raw_elemss) [];
 
     (** compute instantiation **)
 
@@ -1975,11 +2045,13 @@
     val insts = (tinst, inst);
     (* Note: insts contain no vars. *)
 
+
     (** compute proof obligations **)
 
     (* restore "small" ids *)
     val ids' = map (fn ((n, ps), (_, mode)) =>
-          ((n, map (fn p => Free (p, (the o AList.lookup (op =) parms) p)) ps), mode)) ids;
+          ((n, map (fn p => Free (p, (the o AList.lookup (op =) parms) p)) ps), mode))
+        (params_ids @ ids);
     (* instantiate ids and elements *)
     val inst_elemss = (ids' ~~ all_elemss) |> map (fn (((n, ps), _), ((_, mode), elems)) =>
       ((n, map (Element.inst_term insts) ps),