src/HOL/Tools/Nitpick/nitpick_mono.ML
changeset 40998 bcd23ddeecef
parent 40997 67e11a73532a
child 40999 69d0d445c46a
--- a/src/HOL/Tools/Nitpick/nitpick_mono.ML	Mon Dec 06 13:30:36 2010 +0100
+++ b/src/HOL/Tools/Nitpick/nitpick_mono.ML	Mon Dec 06 13:30:38 2010 +0100
@@ -244,6 +244,8 @@
                 $ (Const (@{const_name unknown}, ran_T)) $ (t0 $ t1 $ t2 $ t3)))
   | fin_fun_body _ _ _ = NONE
 
+(* ### FIXME: make sure wellformed! *)
+
 fun fresh_mfun_for_fun_type (mdata as {max_fresh, ...} : mdata) all_minus
                             T1 T2 =
   let
@@ -404,9 +406,7 @@
     SOME (insert (op =) (aa1, aa2, cmp, xs) comps, clauses)
 
 fun add_annotation_atom_comp cmp xs aa1 aa2 (comps, clauses) =
-  (trace_msg (fn () => "*** Add " ^ string_for_annotation_atom aa1 ^ " " ^
-                       string_for_comp_op cmp ^ " " ^
-                       string_for_annotation_atom aa2);
+  (trace_msg (fn () => "*** Add " ^ string_for_comp (aa1, aa2, cmp, xs));
    case do_annotation_atom_comp cmp xs aa1 aa2 (comps, clauses) of
      NONE => (trace_msg (K "**** Unsolvable"); raise UNSOLVABLE ())
    | SOME cset => cset)
@@ -602,45 +602,51 @@
       end
   end
 
-type mtype_schema = mtyp * constraint_set
-type mtype_context =
+type mcontext =
   {bound_Ts: typ list,
    bound_Ms: mtyp list,
-   bound_frame: (int * annotation_atom) list,
+   frame: (int * annotation_atom) list,
    frees: (styp * mtyp) list,
    consts: (styp * mtyp) list}
 
-type accumulator = mtype_context * constraint_set
+fun string_for_bound ctxt Ms (j, aa) =
+  Syntax.string_of_term ctxt (Bound (length Ms - j - 1)) ^ " :\<^bsup>" ^
+  string_for_annotation_atom aa ^ "\<^esup> " ^
+  string_for_mtype (nth Ms (length Ms - j - 1))
+fun string_for_free relevant_frees ((s, _), M) =
+  if member (op =) relevant_frees s then SOME (s ^ " : " ^ string_for_mtype M)
+  else NONE
+fun string_for_mcontext ctxt t {bound_Ms, frame, frees, ...} =
+  (map_filter (string_for_free (Term.add_free_names t [])) frees @
+   map (string_for_bound ctxt bound_Ms) frame)
+  |> commas |> enclose "[" "]"
 
 val initial_gamma =
-  {bound_Ts = [], bound_Ms = [], bound_frame = [], frees = [], consts = []}
+  {bound_Ts = [], bound_Ms = [], frame = [], frees = [], consts = []}
 
-fun push_bound aa T M {bound_Ts, bound_Ms, bound_frame, frees, consts} =
+fun push_bound aa T M {bound_Ts, bound_Ms, frame, frees, consts} =
   {bound_Ts = T :: bound_Ts, bound_Ms = M :: bound_Ms,
-   bound_frame = (length bound_Ts, aa) :: bound_frame, frees = frees,
-   consts = consts}
-fun pop_bound {bound_Ts, bound_Ms, bound_frame, frees, consts} =
+   frame = frame @ [(length bound_Ts, aa)], frees = frees, consts = consts}
+fun pop_bound {bound_Ts, bound_Ms, frame, frees, consts} =
   {bound_Ts = tl bound_Ts, bound_Ms = tl bound_Ms,
-   bound_frame = bound_frame
-                 |> filter_out (fn (j, _) => j = length bound_Ts - 1),
+   frame = frame |> filter_out (fn (j, _) => j = length bound_Ts - 1),
    frees = frees, consts = consts}
   handle List.Empty => initial_gamma (* FIXME: needed? *)
 
-fun set_frame bound_frame ({bound_Ts, bound_Ms, frees, consts, ...}
-                           : mtype_context) =
-  {bound_Ts = bound_Ts, bound_Ms = bound_Ms, bound_frame = bound_frame,
-   frees = frees, consts = consts}
+fun set_frame frame ({bound_Ts, bound_Ms, frees, consts, ...} : mcontext) =
+  {bound_Ts = bound_Ts, bound_Ms = bound_Ms, frame = frame, frees = frees,
+   consts = consts}
 
 (* FIXME: make sure tracing messages are complete *)
 
-fun add_comp_frame a cmp = fold (add_annotation_atom_comp cmp [] (A a) o snd)
+fun add_comp_frame aa cmp = fold (add_annotation_atom_comp cmp [] aa o snd)
 
 fun add_bound_frame j frame =
   let
     val (new_frame, gen_frame) = List.partition (curry (op =) j o fst) frame
   in
-    add_comp_frame New Leq new_frame
-    #> add_comp_frame Gen Eq gen_frame
+    add_comp_frame (A New) Leq new_frame
+    #> add_comp_frame (A Gen) Eq gen_frame
   end
 
 fun fresh_frame ({max_fresh, ...} : mdata) fls tru =
@@ -684,10 +690,9 @@
             else (trace_msg (K "**** Unsolvable"); raise UNSOLVABLE ())
   | V x => SOME (x, (sign_for_comp_op cmp, a))
 
-val annotation_clause_from_quasi_clause =
+val assign_clause_from_quasi_clause =
   map_filter annotation_literal_from_quasi_literal
-
-val add_quasi_clause = annotation_clause_from_quasi_clause #> add_assign_clause
+val add_quasi_clause = assign_clause_from_quasi_clause #> add_assign_clause
 
 fun add_connective_var conn mk_quasi_clauses res_aa aa1 aa2 =
   (trace_msg (fn () => "*** Add " ^ string_for_annotation_atom res_aa ^ " = " ^
@@ -700,6 +705,41 @@
                    add_connective_var conn mk_quasi_clauses res_aa aa1 aa2)
                res_frame frame1 frame2)
 
+fun kill_unused_in_frame is_in (accum as ({frame, ...}, _)) =
+  let val (used_frame, unused_frame) = List.partition is_in frame in
+    accum |>> set_frame used_frame
+          ||> add_comp_frame (A Gen) Eq unused_frame
+  end
+
+fun split_frame is_in_fun (gamma as {frame, ...}, cset) =
+  let
+    fun bubble fun_frame arg_frame [] cset =
+        ((rev fun_frame, rev arg_frame), cset)
+      | bubble fun_frame arg_frame ((bound as (_, aa)) :: rest) cset =
+        if is_in_fun bound then
+          bubble (bound :: fun_frame) arg_frame rest
+                 (cset |> add_comp_frame aa Leq arg_frame)
+        else
+          bubble fun_frame (bound :: arg_frame) rest cset
+  in cset |> bubble [] [] frame ||> pair gamma end
+
+fun add_annotation_atom_comp_alt _ (A Gen) _ _ = I
+  | add_annotation_atom_comp_alt _ (A _) _ _ =
+    (trace_msg (K "*** Expected G"); raise UNSOLVABLE ())
+  | add_annotation_atom_comp_alt cmp (V x) aa1 aa2 =
+    add_annotation_atom_comp cmp [x] aa1 aa2
+
+fun add_arg_order1 ((_, aa), (_, prev_aa)) =
+  add_annotation_atom_comp_alt Neq prev_aa (A Gen) aa
+fun add_app1 fun_aa ((_, res_aa), (_, arg_aa)) =
+  add_annotation_atom_comp_alt Leq arg_aa fun_aa res_aa
+  ##> add_quasi_clause [(arg_aa, (Neq, Gen)), (res_aa, (Eq, Gen))]
+fun add_app _ [] [] = I
+  | add_app fun_aa res_frame arg_frame =
+    add_comp_frame (A New) Leq arg_frame
+    #> fold add_arg_order1 (tl arg_frame ~~ (fst (split_last arg_frame)))
+    #> fold (add_app1 fun_aa) (res_frame ~~ arg_frame)
+
 fun consider_term (mdata as {hol_ctxt = {thy, ctxt, stds, ...}, alpha_T,
                              max_fresh, ...}) =
   let
@@ -779,28 +819,29 @@
                                  MApp (bound_m, MRaw (Bound 0, M1))),
                            body_m))), accum)
       end
-    and do_connect conn mk_quasi_clauses t0 t1 t2
-                   (accum as ({bound_frame, ...}, _)) =
+    and do_connect conn mk_quasi_clauses t0 t1 t2 (accum as ({frame, ...}, _)) =
       let
-        val frame1 = fresh_frame mdata (SOME Tru) NONE bound_frame
-        val frame2 = fresh_frame mdata (SOME Fls) NONE bound_frame
+        val frame1 = fresh_frame mdata (SOME Tru) NONE frame
+        val frame2 = fresh_frame mdata (SOME Fls) NONE frame
         val (m1, accum) = accum |>> set_frame frame1 |> do_term t1
         val (m2, accum) = accum |>> set_frame frame2 |> do_term t2
       in
         (MApp (MApp (MRaw (t0, mtype_for (fastype_of t0)), m1), m2),
-         accum |>> set_frame bound_frame
+         accum |>> set_frame frame
                ||> apsnd (add_connective_frames conn mk_quasi_clauses
-                                                bound_frame frame1 frame2))
+                                                frame frame1 frame2))
       end
-    and do_term t (accum as ({bound_Ts, bound_Ms, bound_frame, frees, consts},
-                             cset)) =
-      (trace_msg (fn () => "  \<Gamma> \<turnstile> " ^
-                           Syntax.string_of_term ctxt t ^ " : _?");
+    and do_term t
+            (accum as (gamma as {bound_Ts, bound_Ms, frame, frees, consts},
+                       cset)) =
+      (trace_msg (fn () => "  " ^ string_for_mcontext ctxt t gamma ^
+                           " \<turnstile> " ^ Syntax.string_of_term ctxt t ^
+                           " : _?");
        case t of
          @{const False} =>
-         (MRaw (t, bool_M), accum ||> add_comp_frame Fls Leq bound_frame)
+         (MRaw (t, bool_M), accum ||> add_comp_frame (A Fls) Leq frame)
        | @{const True} =>
-         (MRaw (t, bool_M), accum ||> add_comp_frame Tru Leq bound_frame)
+         (MRaw (t, bool_M), accum ||> add_comp_frame (A Tru) Leq frame)
        | Const (x as (s, T)) =>
          (case AList.lookup (op =) consts x of
             SOME M => (M, accum)
@@ -900,26 +941,24 @@
                 (fresh_mtype_for_type mdata true T, accum)
               else
                 let val M = mtype_for T in
-                  (M, ({bound_Ts = bound_Ts, bound_Ms = bound_Ms,
-                        bound_frame = bound_frame, frees = frees,
-                        consts = (x, M) :: consts}, cset))
+                  (M, ({bound_Ts = bound_Ts, bound_Ms = bound_Ms, frame = frame,
+                        frees = frees, consts = (x, M) :: consts}, cset))
                 end)
            |>> curry MRaw t
-           ||> apsnd (add_comp_frame Gen Eq bound_frame)
+           ||> apsnd (add_comp_frame (A Gen) Eq frame)
          | Free (x as (_, T)) =>
            (case AList.lookup (op =) frees x of
               SOME M => (M, accum)
             | NONE =>
               let val M = mtype_for T in
-                (M, ({bound_Ts = bound_Ts, bound_Ms = bound_Ms,
-                      bound_frame = bound_frame, frees = (x, M) :: frees,
-                      consts = consts}, cset))
+                (M, ({bound_Ts = bound_Ts, bound_Ms = bound_Ms, frame = frame,
+                      frees = (x, M) :: frees, consts = consts}, cset))
               end)
-           |>> curry MRaw t ||> apsnd (add_comp_frame Gen Eq bound_frame)
+           |>> curry MRaw t ||> apsnd (add_comp_frame (A Gen) Eq frame)
          | Var _ => (trace_msg (K "*** Var"); raise UNSOLVABLE ())
          | Bound j =>
            (MRaw (t, nth bound_Ms j),
-            accum ||> add_bound_frame (length bound_Ts - j - 1) bound_frame)
+            accum ||> add_bound_frame (length bound_Ts - j - 1) frame)
          | Abs (s, T, t') =>
            (case fin_fun_body T (fastype_of1 (T :: bound_Ts, t')) t' of
               SOME t' =>
@@ -967,16 +1006,31 @@
            do_term (betapply (t2, t1)) accum
          | t1 $ t2 =>
            let
-             val (m1, accum) = do_term t1 accum
-             val (m2, accum) = do_term t2 accum
+             fun is_in t (j, _) = loose_bvar1 (t, length bound_Ts - j - 1)
+             val accum as ({frame, ...}, _) =
+               accum |> kill_unused_in_frame (is_in t)
+             val ((frame1a, frame1b), accum) = accum |> split_frame (is_in t1)
+             val frame2a = frame1a |> map (apsnd (K (A Gen)))
+             val frame2b =
+               frame1b |> map (apsnd (fn _ => V (Unsynchronized.inc max_fresh)))
+             val frame2 = frame2a @ frame2b
+             val (m1, accum) = accum |>> set_frame frame1a |> do_term t1
+             val (m2, accum) = accum |>> set_frame frame2 |> do_term t2
            in
              let
-               val M11 = mtype_of_mterm m1 |> dest_MFun |> #1
+               val (M11, aa, _) = mtype_of_mterm m1 |> dest_MFun
                val M2 = mtype_of_mterm m2
-             in (MApp (m1, m2), accum ||> add_is_sub_mtype M2 M11) end
+             in
+               (MApp (m1, m2),
+                accum |>> set_frame frame
+                      ||> add_is_sub_mtype M2 M11
+                      ||> add_app aa frame1b frame2b)
+             end
            end)
-        |> tap (fn (m, _) => trace_msg (fn () => "  \<Gamma> \<turnstile> " ^
-                                                 string_for_mterm ctxt m))
+        |> tap (fn (m, (gamma, _)) =>
+                   trace_msg (fn () => "  " ^ string_for_mcontext ctxt t gamma ^
+                                       " \<turnstile> " ^
+                                       string_for_mterm ctxt m))
   in do_term end
 
 fun force_minus_funs 0 _ = I
@@ -1007,7 +1061,7 @@
   let
     val mtype_for = fresh_mtype_for_type mdata false
     val do_term = consider_term mdata
-    fun do_formula sn t accum =
+    fun do_formula sn t (accum as (gamma, _)) =
         let
           fun do_quantifier (quant_x as (quant_s, _)) abs_s abs_T body_t =
             let
@@ -1029,9 +1083,9 @@
               Plus => do_term t accum
             | Minus => consider_general_equals mdata false x t1 t2 accum
         in
-          (trace_msg (fn () => "  \<Gamma> \<turnstile> " ^
-                               Syntax.string_of_term ctxt t ^ " : o\<^sup>" ^
-                               string_for_sign sn ^ "?");
+          (trace_msg (fn () => "  " ^ string_for_mcontext ctxt t gamma ^
+                               " \<turnstile> " ^ Syntax.string_of_term ctxt t ^
+                               " : o\<^sup>" ^ string_for_sign sn ^ "?");
            case t of
              Const (x as (@{const_name all}, _)) $ Abs (s1, T1, t1) =>
              do_quantifier x s1 T1 t1
@@ -1081,8 +1135,9 @@
                do_term t accum
            | _ => do_term t accum)
         end
-        |> tap (fn (m, _) =>
-                   trace_msg (fn () => "\<Gamma> \<turnstile> " ^
+        |> tap (fn (m, (gamma, _)) =>
+                   trace_msg (fn () => string_for_mcontext ctxt t gamma ^
+                                       " \<turnstile> " ^
                                        string_for_mterm ctxt m ^ " : o\<^sup>" ^
                                        string_for_sign sn))
   in do_formula end
@@ -1166,7 +1221,7 @@
 fun string_for_mtype_of_term ctxt asgs t M =
   Syntax.string_of_term ctxt t ^ " : " ^ string_for_mtype (resolve_mtype asgs M)
 
-fun print_mtype_context ctxt asgs ({frees, consts, ...} : mtype_context) =
+fun print_mcontext ctxt asgs ({frees, consts, ...} : mcontext) =
   trace_msg (fn () =>
       map (fn (x, M) => string_for_mtype_of_term ctxt asgs (Free x) M) frees @
       map (fn (x, M) => string_for_mtype_of_term ctxt asgs (Const x) M) consts
@@ -1192,7 +1247,7 @@
       ([], accum) |> fold (amass (consider_definitional_axiom mdata)) def_ts
   in
     case solve calculus (!max_fresh) cset of
-      SOME asgs => (print_mtype_context ctxt asgs gamma;
+      SOME asgs => (print_mcontext ctxt asgs gamma;
                     SOME (asgs, (nondef_ms, def_ms), !constr_mcache))
     | _ => NONE
   end