read_instantiations: proper type-inference with fixed variables, infer parameter types as well;
authorwenzelm
Mon, 23 Apr 2007 20:44:11 +0200
changeset 22772 e0788ff2e811
parent 22771 ce1fe6ca7dbb
child 22773 9bb135fa5206
read_instantiations: proper type-inference with fixed variables, infer parameter types as well; gen_prep_registration: removed unused read_terms (dead code!); tuned;
src/Pure/Isar/locale.ML
--- a/src/Pure/Isar/locale.ML	Mon Apr 23 20:44:10 2007 +0200
+++ b/src/Pure/Isar/locale.ML	Mon Apr 23 20:44:11 2007 +0200
@@ -2105,52 +2105,46 @@
 
 fun read_instantiations ctxt parms (insts, eqns) =
   let
-    (* user input *)
-    val insts = if length parms < length insts
-         then error "More arguments than parameters in instantiation."
-         else insts @ replicate (length parms - length insts) NONE;
-    val (ps, pTs) = split_list parms;
-    val pvTs = map Logic.varifyT pTs;
-
-    (* context for reading terms *)
-    val tvars = fold Term.add_tvarsT pvTs [];
-
-    (* parameter instantiations given by user *)
-    val given = map_filter (fn (_, (NONE, _)) => NONE
-         | (x, (SOME inst, T)) => SOME (x, (inst, T))) (ps ~~ (insts ~~ pvTs));
+    val thy = ProofContext.theory_of ctxt;
+
+    (* parameter instantiations *)
+    val parms' = map (apsnd Logic.varifyT) parms;
+    val d = length parms - length insts;
+    val insts =
+      if d < 0 then error "More arguments than parameters in instantiation."
+      else insts @ replicate d NONE;
+
+    val given = (parms' ~~ insts) |> map_filter
+      (fn (_, NONE) => NONE
+        | ((x, T), SOME inst) => SOME (x, (inst, T)));
     val (given_ps, given_insts) = split_list given;
 
-    (* equations given by user *)
+    (* equations *)
     val (lefts, rights) = split_list eqns;
     val max_eqs = length eqns;
-    val Ts = map (fn i => TypeInfer.param i ("x", [])) (0 upto max_eqs - 1);
-
-    val (all_vs, vinst) = ProofContext.read_termTs ctxt (K false) (K NONE)
-      (K NONE) [] (given_insts @ (lefts ~~ Ts) @ (rights ~~ Ts));
-    
-    val vars = foldl Term.add_term_tvar_ixns [] all_vs
-      |> subtract (op =) (map fst tvars)
-      |> fold Term.add_varnames all_vs
-    val _ = if null vars then ()
-         else error ("Illegal schematic variable(s) in instantiation: " ^
-           commas_quote (map Term.string_of_vname vars));
-
-    val renameT = Logic.legacy_unvarifyT;
-    val rename = Term.map_types renameT;
-
-    val (vs, ts) = chop (length given_insts) all_vs;
-
-    val instT = Symtab.empty
-      |> fold (fn ((x, 0), T) => Symtab.update (x, renameT T)) vinst;
-    val inst = Symtab.empty
-      |> fold2 (fn x => fn t => Symtab.update (x, rename t)) given_ps vs;
-
-    val (lefts', rights') = chop max_eqs (map rename ts);
-
+    val eqTs = map (fn i => TypeInfer.param i ("'a", [])) (0 upto max_eqs - 1);
+
+    (* read given insts / eqns *)
+    val all_vs = ProofContext.read_termTs ctxt (given_insts @ (lefts ~~ eqTs) @ (rights ~~ eqTs));
+    val ctxt' = ctxt |> fold Variable.declare_term all_vs;
+    val (vs, (lefts', rights')) = all_vs |> chop (length given_insts) ||> chop max_eqs;
+
+    (* infer parameter types *)
+    val tyenv = fold (fn ((_, T), t) => Sign.typ_match thy (T, Term.fastype_of t))
+      (given_insts ~~ vs) Vartab.empty;
+    val looseTs = fold (Term.add_tvarsT o Envir.typ_subst_TVars tyenv o #2) parms' [];
+    val (fixedTs, _) = Variable.invent_types (map #2 looseTs) ctxt';
+    val tyenv' =
+      fold (fn ((xi, S), v) => Vartab.update_new (xi, (S, TFree v))) (looseTs ~~ fixedTs) tyenv;
+
+    (*results*)
+    val instT = Vartab.fold (fn ((a, 0), (S, T)) =>
+      if T = TFree (a, S) then I else Symtab.update (a, T)) tyenv' Symtab.empty;
+    val inst = Symtab.make (given_ps ~~ vs);
   in ((instT, inst), lefts' ~~ rights') end;
 
 
-fun gen_prep_registration mk_ctxt read_terms test_reg activate
+fun gen_prep_registration mk_ctxt test_reg activate
     prep_attr prep_expr prep_insts
     thy_ctxt raw_attn raw_expr raw_insts =
   let
@@ -2220,12 +2214,10 @@
   in (propss, activate attn inst_elemss new_inst_elemss propss) end;
 
 fun gen_prep_global_registration mk_ctxt = gen_prep_registration ProofContext.init
-  (fn thy => fn sorts => fn used => Sign.read_def_terms (thy, K NONE, sorts) used true)
   (fn thy => fn (name, ps) => test_global_registration thy (name, map Logic.varify ps))
   global_activate_facts_elemss mk_ctxt;
 
 fun gen_prep_local_registration mk_ctxt = gen_prep_registration I
-  (fn ctxt => ProofContext.read_termTs ctxt (K false) (K NONE))
   smart_test_registration
   local_activate_facts_elemss mk_ctxt;