Expression types cleaned up, proper treatment of term patterns.
authorballarin
Tue, 25 Nov 2008 18:06:21 +0100
changeset 28885 6f6bf52e75bb
parent 28884 7cef91288634
child 28886 9cb1297b6f13
Expression types cleaned up, proper treatment of term patterns.
src/Pure/Isar/expression.ML
--- a/src/Pure/Isar/expression.ML	Mon Nov 24 21:09:31 2008 +0100
+++ b/src/Pure/Isar/expression.ML	Tue Nov 25 18:06:21 2008 +0100
@@ -7,13 +7,10 @@
 
 signature EXPRESSION =
 sig
-  type 'term map
-  type 'morph expr
-
-  val empty_expr: 'morph expr
-
-  type expression = (string * string map) expr * (Name.binding * string option * mixfix) list
-(*  type expression_i = Morphism.morphism expr * (Name.binding * typ option * mixfix) list *)
+  datatype 'term map = Positional of 'term option list | Named of (string * 'term) list;
+  type 'term expr = (string * (string * 'term map)) list;
+  type expression = string expr * (Name.binding * string option * mixfix) list;
+  type expression_i = term expr * (Name.binding * typ option * mixfix) list;
 
   (* Processing of locale statements *)
   val read_statement: Element.context list -> (string * string list) list list ->
@@ -24,16 +21,15 @@
   (* Declaring locales *)
   val add_locale: string -> bstring -> expression -> Element.context list -> theory ->
     string * Proof.context
-(*
   val add_locale_i: string -> bstring -> expression_i -> Element.context_i list -> theory ->
     string * Proof.context
-*)
+
   (* Debugging and development *)
   val parse_expression: OuterParse.token list -> expression * OuterParse.token list
 end;
 
 
-structure Expression (*: EXPRESSION *) =
+structure Expression : EXPRESSION =
 struct
 
 datatype ctxt = datatype Element.ctxt;
@@ -45,11 +41,10 @@
   Positional of 'term option list |
   Named of (string * 'term) list;
 
-datatype 'morph expr = Expr of (string * 'morph) list;
+type 'term expr = (string * (string * 'term map)) list;
 
-type expression = (string * string map) expr * (Name.binding * string option * mixfix) list;
-
-val empty_expr = Expr [];
+type expression = string expr * (Name.binding * string option * mixfix) list;
+type expression_i = term expr * (Name.binding * typ option * mixfix) list;
 
 
 (** Parsing and printing **)
@@ -76,12 +71,12 @@
     fun expr2 x = P.xname x;
     fun expr1 x = (Scan.optional prefix "" -- expr2 --
       Scan.optional instance (Named []) >> (fn ((p, l), i) => (l, (p, i)))) x;
-    fun expr0 x = (plus1_unless loc_keyword expr1 >> Expr) x;
+    fun expr0 x = (plus1_unless loc_keyword expr1) x;
   in expr0 -- P.for_fixes end;
 
 end;
 
-fun pretty_expr thy (Expr expr) =
+fun pretty_expr thy expr =
   let
     fun pretty_pos NONE = Pretty.str "_"
       | pretty_pos (SOME x) = Pretty.str x;
@@ -99,19 +94,19 @@
             Pretty.brk 1 :: pretty_ren ren);
   in Pretty.separate "+" (map pretty_rename expr) |> Pretty.block end;
 
-fun err_in_expr thy msg (Expr expr) =
+fun err_in_expr thy msg expr =
   let
     val err_msg =
       if null expr then msg
       else msg ^ "\n" ^ Pretty.string_of (Pretty.block
         [Pretty.str "The above error(s) occurred in expression:", Pretty.brk 1,
-          pretty_expr thy (Expr expr)])
+          pretty_expr thy expr])
   in error err_msg end;
 
 
 (** Internalise locale names in expr **)
 
-fun intern thy (Expr instances) = Expr (map (apfst (NewLocale.intern thy)) instances);
+fun intern thy instances =  map (apfst (NewLocale.intern thy)) instances;
 
 
 (** Parameters of expression.
@@ -142,9 +137,7 @@
 	      else insts @ replicate d NONE;
             val ps' = (ps ~~ insts') |>
               map_filter (fn (p, NONE) => SOME p | (_, SOME _) => NONE);
-          in
-            (ps', (loc', (prfx, Positional insts')))
-          end
+          in (ps', (loc', (prfx, Positional insts'))) end
       | params_inst (expr as (loc, (prfx, Named insts))) =
           let
             val _ = reject_dups "Duplicate instantiation of the following parameter(s): "
@@ -154,10 +147,8 @@
             val ps' = fold (fn (p, _) => fn ps =>
               if AList.defined match_bind ps p then AList.delete match_bind p ps
               else error (quote p ^" not a parameter of instantiated expression.")) insts ps;
-          in
-            (ps', (loc', (prfx, Named insts)))
-          end;
-    fun params_expr (Expr is) =
+          in (ps', (loc', (prfx, Named insts))) end;
+    fun params_expr is =
           let
             val (is', ps') = fold_map (fn i => fn ps =>
               let
@@ -169,9 +160,7 @@
                   else error ("Conflicting syntax for parameter" ^ quote (Name.display p) ^
                     " in expression.")) (ps, ps')
               in (i', ps'') end) is []
-          in
-            (ps', Expr is')
-          end;
+          in (ps', is') end;
 
     val (parms, expr') = params_expr expr;
 
@@ -205,12 +194,6 @@
 
 end;
 
-(* Prepare type inference problem for Syntax.check_terms *)
-
-fun varify_indexT i ty = ty |> Term.map_atyps
-  (fn TFree (a, S) => TVar ((a, i), S)
-    | TVar (ai, _) => raise TYPE ("Illegal schematic variable: " ^
-        quote (Term.string_of_vname ai), [ty], []));
 
 (* Instantiation morphism *)
 
@@ -252,37 +235,71 @@
     (prep_term ctxt, map (prep_term ctxt) ps)) concl;
 
 
-(** Type checking **)
+(** Simultaneous type inference: instantiations + elements + conclusion **)
+
+local
+
+fun mk_type T = (Logic.mk_type T, []);
+fun mk_term t = (t, []);
+fun mk_propp (p, pats) = (Syntax.type_constraint propT p, pats);
 
-fun extract_elem (Fixes fixes) = map (#2 #> the_list #> map (Logic.mk_type #> rpair [])) fixes
-  | extract_elem (Constrains csts) = map (#2 #> single #> map (Logic.mk_type #> rpair [])) csts
-  | extract_elem (Assumes asms) = map #2 asms
-  | extract_elem (Defines defs) = map (fn (_, (t, ps)) => [(t, ps)]) defs
+fun dest_type (T, []) = Logic.dest_type T;
+fun dest_term (t, []) = t;
+fun dest_propp (p, pats) = (p, pats);
+
+fun extract_inst (_, (_, ts)) = map mk_term ts;
+fun restore_inst ((l, (p, _)), cs) = (l, (p, map dest_term cs));
+
+fun extract_elem (Fixes fixes) = map (#2 #> the_list #> map mk_type) fixes
+  | extract_elem (Constrains csts) = map (#2 #> single #> map mk_type) csts
+  | extract_elem (Assumes asms) = map (#2 #> map mk_propp) asms
+  | extract_elem (Defines defs) = map (fn (_, (t, ps)) => [mk_propp (t, ps)]) defs
   | extract_elem (Notes _) = [];
 
-fun restore_elem (Fixes fixes, propps) =
-      (fixes ~~ propps) |> map (fn ((x, _, mx), propp) =>
-        (x, propp |> map (fst #> Logic.dest_type) |> try hd, mx)) |> Fixes
-  | restore_elem (Constrains csts, propps) =
-      (csts ~~ propps) |> map (fn ((x, _), propp) =>
-        (x, propp |> map (fst #> Logic.dest_type) |> hd)) |> Constrains
-  | restore_elem (Assumes asms, propps) =
-      (asms ~~ propps) |> map (fn ((b, _), propp) => (b, propp)) |> Assumes
-  | restore_elem (Defines defs, propps) =
-      (defs ~~ propps) |> map (fn ((b, _), [propp]) => (b, propp)) |> Defines
+fun restore_elem (Fixes fixes, css) =
+      (fixes ~~ css) |> map (fn ((x, _, mx), cs) =>
+        (x, cs |> map dest_type |> try hd, mx)) |> Fixes
+  | restore_elem (Constrains csts, css) =
+      (csts ~~ css) |> map (fn ((x, _), cs) =>
+        (x, cs |> map dest_type |> hd)) |> Constrains
+  | restore_elem (Assumes asms, css) =
+      (asms ~~ css) |> map (fn ((b, _), cs) => (b, map dest_propp cs)) |> Assumes
+  | restore_elem (Defines defs, css) =
+      (defs ~~ css) |> map (fn ((b, _), [c]) => (b, dest_propp c)) |> Defines
   | restore_elem (Notes notes, _) = Notes notes;
 
+fun check cs context =
+  let
+    fun prep (_, pats) (ctxt, t :: ts) =
+      let val ctxt' = Variable.auto_fixes t ctxt
+      in
+        ((t, Syntax.check_props (ProofContext.set_mode ProofContext.mode_pattern ctxt') pats),
+          (ctxt', ts))
+      end
+    val (cs', (context', _)) = fold_map prep cs
+      (context, Syntax.check_terms
+        (ProofContext.set_mode ProofContext.mode_schematic context) (map fst cs));
+  in (cs', context') end;
+
+in
+
 fun check_autofix insts elems concl ctxt =
   let
-    val instss = map (snd o snd) insts |> (map o map) (fn t => (t, []));
-    val elemss = elems |> map extract_elem;
-    val all_terms' = (burrow o burrow_fst) (Syntax.check_terms ctxt) (concl @ instss @ flat elemss); 
-(*    val (ctxt', all_props') = ProofContext.check_propp_schematic (ctxt, concl @ flat propss); *)
-    val ctxt'' = (fold o fold) (fn (t, _) => Variable.auto_fixes t) all_terms' ctxt;
-    val (concl', mores') = chop (length concl) all_terms';
-    val (insts', elems') = chop (length instss) mores';
-  in (insts' |> (map o map) fst |> curry (op ~~) insts |> map (fn ((l, (p, _)), is) => (l, (p, is))),
-    elems' |> unflat elemss |> curry (op ~~) elems |> map restore_elem, concl', ctxt'') end;
+    val inst_cs = map extract_inst insts;
+    val elem_css = map extract_elem elems;
+    val concl_cs = (map o map) mk_propp concl;
+    (* Type inference *)
+    val (inst_cs' :: css', ctxt') =
+      (fold_burrow o fold_burrow) check (inst_cs :: elem_css @ [concl_cs]) ctxt;
+    (* Re-check to resolve bindings, elements and conclusion only *)
+    val (css'', _) = (fold_burrow o fold_burrow) check css' ctxt';
+    val (elem_css'', [concl_cs'']) = chop (length elem_css) css'';
+  in
+    (map restore_inst (insts ~~ inst_cs'), map restore_elem (elems ~~ elem_css''),
+      concl_cs'', ctxt')
+  end;
+
+end;
 
 
 (** Prepare locale elements **)
@@ -412,7 +429,8 @@
         val (parm_names, parm_types) = NewLocale.params_of thy loc |>
           map (fn (b, SOME T, _) => (Name.name_of b, T)) |> split_list;
         val inst' = parse_inst parm_names inst ctxt;
-        val parm_types' = map (TypeInfer.paramify_vars o varify_indexT i) parm_types;
+        val parm_types' = map (TypeInfer.paramify_vars o
+          Term.map_type_tvar (fn ((x, _), S) => TVar ((x, i), S)) o Logic.varifyT) parm_types;
         val inst'' = map2 TypeInfer.constrain parm_types' inst';
         val insts' = insts @ [(loc, (prfx, inst''))];
         val (insts'', _, _, ctxt') = check_autofix insts' [] [] ctxt;
@@ -491,7 +509,7 @@
   let
     val thy = ProofContext.theory_of context;
 
-    val (Expr expr, fixed) = parameters_of thy (apfst (prep_expr thy) imprt);
+    val (expr, fixed) = parameters_of thy (apfst (prep_expr thy) imprt);
     val ((parms, fors, deps, elems, concl), (spec, (_, _, defs))) =
       prep_elems do_close context fixed expr elements raw_concl;
 
@@ -504,7 +522,7 @@
 
 fun prep_statement prep_ctxt elems concl ctxt =
   let
-    val (((_, (ctxt', _), _)), concl) = prep_ctxt false (Expr [], []) elems concl ctxt
+    val (((_, (ctxt', _), _)), concl) = prep_ctxt false ([], []) elems concl ctxt
   in (concl, ctxt') end;
 
 in
@@ -708,7 +726,7 @@
 in
 
 val add_locale = gen_add_locale read_context;
-(* val add_locale_i = gen_add_locale cert_context; *)
+val add_locale_i = gen_add_locale cert_context;
 
 end;