src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57017 afdf75c0de58
parent 57014 b7999893ffcc
child 57018 142950e9c7e2
equal deleted inserted replaced
57016:c44ce6f4067d 57017:afdf75c0de58
   282 (*** Standard ML version of MaSh ***)
   282 (*** Standard ML version of MaSh ***)
   283 
   283 
   284 structure MaSh_SML =
   284 structure MaSh_SML =
   285 struct
   285 struct
   286 
   286 
   287 fun max a b = if a > b then a else b
       
   288 
       
   289 exception BOTTOM of int
   287 exception BOTTOM of int
   290 
   288 
   291 fun heap cmp bnd a =
   289 fun heap cmp bnd a =
   292   let
   290   let
   293     fun maxson l i =
   291     fun maxson l i =
   294       let
   292       let val i31 = i + i + i + 1 in
   295         val i31 = i + i + i + 1
       
   296       in
       
   297         if i31 + 2 < l then
   293         if i31 + 2 < l then
   298           let
   294           let val x = Unsynchronized.ref i31 in
   299             val x = Unsynchronized.ref i31;
   295             if cmp (Array.sub (a, i31), Array.sub (a, i31 + 1)) = LESS then x := i31 + 1 else ();
   300             val _ = if cmp (Array.sub(a,i31),(Array.sub(a,i31+1))) = LESS then x := i31+1 else ();
   296             if cmp (Array.sub (a, !x), Array.sub (a, i31 + 2)) = LESS then x := i31 + 2 else ();
   301             val _ = if cmp (Array.sub(a,!x),(Array.sub (a,i31+2))) = LESS then x := i31+2 else ()
       
   302           in
       
   303             !x
   297             !x
   304           end
   298           end
   305         else
   299         else
   306           if i31 + 1 < l andalso cmp (Array.sub (a, i31), Array.sub (a, i31 + 1)) = LESS
   300           if i31 + 1 < l andalso cmp (Array.sub (a, i31), Array.sub (a, i31 + 1)) = LESS
   307           then i31 + 1 else if i31 < l then i31 else raise BOTTOM i
   301           then i31 + 1 else if i31 < l then i31 else raise BOTTOM i
   352       end
   346       end
   353 
   347 
   354     val _ = for (((l + 1) div 3) - 1)
   348     val _ = for (((l + 1) div 3) - 1)
   355 
   349 
   356     fun for2 i =
   350     fun for2 i =
   357       if i < max 2 (l - bnd) then () else
   351       if i < Integer.max 2 (l - bnd) then () else
   358       let
   352       let
   359         val e = Array.sub (a, i)
   353         val e = Array.sub (a, i)
   360         val _ = Array.update (a, i, Array.sub (a, 0))
   354         val _ = Array.update (a, i, Array.sub (a, 0))
   361         val _ = trickleup (bubble i 0) e
   355         val _ = trickleup (bubble i 0) e
   362       in
   356       in
   385   syms = symbols of the conjecture
   379   syms = symbols of the conjecture
   386 *)
   380 *)
   387 fun knn avail_num adv_max get_deps get_sym_ths knns advno syms =
   381 fun knn avail_num adv_max get_deps get_sym_ths knns advno syms =
   388   let
   382   let
   389     (* Can be later used for TFIDF *)
   383     (* Can be later used for TFIDF *)
   390     fun sym_wght _ = 1.0;
   384     fun sym_wght _ = 1.0
   391     val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0)));
   385 
       
   386     val overlaps_sqr = Array.tabulate (avail_num, (fn i => (i, 0.0)))
       
   387 
   392     fun inc_overlap j v =
   388     fun inc_overlap j v =
   393       let
   389       let
   394         val ov = snd (Array.sub (overlaps_sqr,j))
   390         val ov = snd (Array.sub (overlaps_sqr, j))
   395       in
   391       in
   396         Array.update (overlaps_sqr, j, (j, v + ov))
   392         Array.update (overlaps_sqr, j, (j, v + ov))
   397       end;
   393       end
       
   394 
   398     fun do_sym (s, con_wght) =
   395     fun do_sym (s, con_wght) =
   399       let
   396       let
   400         val sw = sym_wght s;
   397         val sw = sym_wght s
   401         val w2 = sw * sw * con_wght;
   398         val w2 = sw * sw * con_wght
       
   399 
   402         fun do_th (j, prem_wght) = if j < avail_num then inc_overlap j (w2 * prem_wght) else ()
   400         fun do_th (j, prem_wght) = if j < avail_num then inc_overlap j (w2 * prem_wght) else ()
   403       in
   401       in
   404         ignore (map do_th (get_sym_ths s))
   402         List.app do_th (get_sym_ths s)
   405       end;
   403       end
   406     val () = ignore (map do_sym syms);
   404 
   407     val () = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr;
   405     val _ = List.app do_sym syms
   408     val recommends = Array.tabulate (adv_max, (fn j => (j, 0.0)));
   406     val _ = heap (fn (a, b) => Real.compare (snd a, snd b)) knns overlaps_sqr
       
   407     val recommends = Array.tabulate (adv_max, rpair 0.0)
       
   408 
   409     fun inc_recommend j v =
   409     fun inc_recommend j v =
   410       if j >= adv_max then () else
   410       if j >= adv_max then ()
   411       let
   411       else Array.update (recommends, j, (j, v + snd (Array.sub (recommends, j))))
   412         val ov = snd (Array.sub (recommends,j))
   412 
   413       in
       
   414         Array.update (recommends, j, (j, v + ov))
       
   415       end;
       
   416     fun for k =
   413     fun for k =
   417       if k = knns then () else
   414       if k = knns orelse k >= adv_max then
   418       if k >= adv_max then () else
   415         ()
   419       let
   416       else
   420         val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1);
   417         let
   421         val o1 = Math.sqrt o2;
   418           val (j, o2) = Array.sub (overlaps_sqr, avail_num - k - 1)
   422         val () = inc_recommend j o1;
   419           val o1 = Math.sqrt o2
   423         val ds = get_deps j;
   420           val _ = inc_recommend j o1
   424         val l = Real.fromInt (length ds);
   421           val ds = get_deps j
   425         val _ = map (fn d => inc_recommend d (o1 / l)) ds
   422           val l = Real.fromInt (length ds)
   426       in
   423           val _ = map (fn d => inc_recommend d (o1 / l)) ds
   427         for (k + 1)
   424         in
   428       end;
   425           for (k + 1)
   429     val () = for 0;
   426         end
   430     val () = heap (fn (a, b) => Real.compare (snd a, snd b)) advno recommends;
   427 
       
   428     val _ = for 0
       
   429     val _ = heap (Real.compare o pairself snd) advno recommends
       
   430 
   431     fun ret acc at =
   431     fun ret acc at =
   432       if at = Array.length recommends then acc else ret (Array.sub (recommends,at) :: acc) (at + 1)
   432       if at = Array.length recommends then acc else ret (Array.sub (recommends, at) :: acc) (at + 1)
   433   in
   433   in
   434     ret [] (max 0 (adv_max - advno))
   434     ret [] (Integer.max 0 (adv_max - advno))
   435   end
   435   end
   436 
   436 
   437 val knns = 40 (* FUDGE *)
   437 val knns = 40 (* FUDGE *)
   438 
   438 
   439 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   439 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   440 
   440 
   441 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   441 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   442 
   442 
   443 fun learn_and_query ctxt parents access_G max_suggs hints feats =
   443 fun query ctxt parents access_G max_suggs hints feats =
   444   let
   444   let
   445     val str_of_feat = space_implode "|"
   445     val str_of_feat = space_implode "|"
   446 
   446 
   447     val visible_facts = Graph.all_preds access_G parents
   447     val visible_facts = Graph.all_preds access_G parents
   448     val visible_fact_set = Symtab.make_set visible_facts
   448     val visible_fact_set = Symtab.make_set visible_facts
   467              add_to_xtab fact fact_xtab, feat_xtab')
   467              add_to_xtab fact fact_xtab, feat_xtab')
   468           end)
   468           end)
   469         all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
   469         all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
   470 
   470 
   471     val facts = rev rev_facts
   471     val facts = rev rev_facts
   472     val fact_ary = Array.fromList facts
   472     val fact_vec = Vector.fromList facts
   473 
   473 
   474     val deps_ary = Array.fromList (rev rev_depss)
   474     val deps_vec = Vector.fromList (rev rev_depss)
   475     val facts_ary = Array.array (num_feats, [])
   475     val facts_ary = Array.array (num_feats, [])
   476     val _ =
   476     val _ =
   477       fold (fn feats => fn fact =>
   477       fold (fn feats => fn fact =>
   478           let val fact' = fact - 1 in
   478           let val fact' = fact - 1 in
   479             List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat)
   479             List.app (fn (feat, weight) => map_array_at facts_ary (cons (fact', weight)) feat)
   482           end)
   482           end)
   483         featss (length featss)
   483         featss (length featss)
   484   in
   484   in
   485     trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
   485     trace_msg ctxt (fn () => "MaSh_SML query " ^ encode_features feats ^ " from {" ^
   486       elide_string 1000 (space_implode " " facts) ^ "}");
   486       elide_string 1000 (space_implode " " facts) ^ "}");
   487     knn (Array.length deps_ary) (length visible_facts) (curry Array.sub deps_ary)
   487     knn (Vector.length deps_vec) (length visible_facts) (curry Vector.sub deps_vec)
   488       (curry Array.sub facts_ary) knns max_suggs
   488       (curry Array.sub facts_ary) knns max_suggs
   489       (map_filter (fn (feat, weight) =>
   489       (map_filter (fn (feat, weight) =>
   490          Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
   490          Option.map (rpair weight) (Symtab.lookup feat_tab (str_of_feat feat))) feats)
   491     |> map ((fn i => Array.sub (fact_ary, i)) o fst)
   491     |> map (curry Vector.sub fact_vec o fst)
   492   end
   492   end
   493 
   493 
   494 end;
   494 end;
   495 
   495 
   496 
   496 
   623 
   623 
   624 fun peek_state ctxt overlord f =
   624 fun peek_state ctxt overlord f =
   625   Synchronized.change_result global_state (perhaps (try (load_state ctxt overlord)) #> `snd #>> f)
   625   Synchronized.change_result global_state (perhaps (try (load_state ctxt overlord)) #> `snd #>> f)
   626 
   626 
   627 fun clear_state ctxt overlord =
   627 fun clear_state ctxt overlord =
   628   (* "unlearn" also removes the state file *)
   628   (* "MaSh_Py.unlearn" also removes the state file *)
   629   Synchronized.change global_state (fn _ =>
   629   Synchronized.change global_state (fn _ =>
   630     (if Config.get ctxt sml then wipe_out_mash_state_dir ()
   630     (if Config.get ctxt sml then wipe_out_mash_state_dir ()
   631      else MaSh_Py.unlearn ctxt overlord; (false, empty_state)))
   631      else MaSh_Py.unlearn ctxt overlord; (false, empty_state)))
   632 
   632 
   633 end
   633 end
   969         SOME (pref', _) => pref = pref'
   969         SOME (pref', _) => pref = pref'
   970       | NONE => false)
   970       | NONE => false)
   971     | NONE => false)
   971     | NONE => false)
   972   | is_size_def _ _ = false
   972   | is_size_def _ _ = false
   973 
   973 
   974 fun no_dependencies_for_status status =
       
   975   status = Non_Rec_Def orelse status = Rec_Def
       
   976 
       
   977 fun trim_dependencies deps =
   974 fun trim_dependencies deps =
   978   if length deps > max_dependencies then NONE else SOME deps
   975   if length deps > max_dependencies then NONE else SOME deps
   979 
   976 
   980 fun isar_dependencies_of name_tabs th =
   977 fun isar_dependencies_of name_tabs th =
   981   let val deps = thms_in_proof (SOME name_tabs) th in
   978   let val deps = thms_in_proof (SOME name_tabs) th in
  1020         |> fold (add_isar_dep facts) isar_deps
  1017         |> fold (add_isar_dep facts) isar_deps
  1021         |> map nickify
  1018         |> map nickify
  1022       val num_isar_deps = length isar_deps
  1019       val num_isar_deps = length isar_deps
  1023     in
  1020     in
  1024       if verbose andalso auto_level = 0 then
  1021       if verbose andalso auto_level = 0 then
  1025         "MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^ string_of_int num_isar_deps ^
  1022         Output.urgent_message ("MaSh: " ^ quote prover ^ " on " ^ quote name ^ " with " ^
  1026         " + " ^ string_of_int (length facts - num_isar_deps) ^ " facts."
  1023           string_of_int num_isar_deps ^ " + " ^ string_of_int (length facts - num_isar_deps) ^
  1027         |> Output.urgent_message
  1024           " facts.")
  1028       else
  1025       else
  1029         ();
  1026         ();
  1030       (case run_prover_for_mash ctxt params prover name facts goal of
  1027       (case run_prover_for_mash ctxt params prover name facts goal of
  1031         {outcome = NONE, used_facts, ...} =>
  1028         {outcome = NONE, used_facts, ...} =>
  1032         (if verbose andalso auto_level = 0 then
  1029         (if verbose andalso auto_level = 0 then
  1033            let val num_facts = length used_facts in
  1030            let val num_facts = length used_facts in
  1034              "Found proof with " ^ string_of_int num_facts ^ " fact" ^
  1031              Output.urgent_message ("Found proof with " ^ string_of_int num_facts ^ " fact" ^
  1035              plural_s num_facts ^ "."
  1032                plural_s num_facts ^ ".")
  1036              |> Output.urgent_message
       
  1037            end
  1033            end
  1038          else
  1034          else
  1039            ();
  1035            ();
  1040          (true, map fst used_facts))
  1036          (true, map fst used_facts))
  1041       | _ => (false, isar_deps))
  1037       | _ => (false, isar_deps))
  1185     fun chained_or_extra_features_of factor (((_, stature), th), weight) =
  1181     fun chained_or_extra_features_of factor (((_, stature), th), weight) =
  1186       [prop_of th]
  1182       [prop_of th]
  1187       |> features_of ctxt (theory_of_thm th) num_facts const_tab stature false
  1183       |> features_of ctxt (theory_of_thm th) num_facts const_tab stature false
  1188       |> map (apsnd (fn r => weight * factor * r))
  1184       |> map (apsnd (fn r => weight * factor * r))
  1189 
  1185 
  1190     val (access_G, suggs) =
  1186     fun query_args access_G =
       
  1187       let
       
  1188         val parents = maximal_wrt_access_graph access_G facts
       
  1189         val hints = chained
       
  1190           |> filter (is_fact_in_graph access_G o snd)
       
  1191           |> map (nickname_of_thm o snd)
       
  1192 
       
  1193         val goal_feats =
       
  1194           features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
       
  1195         val chained_feats = chained
       
  1196           |> map (rpair 1.0)
       
  1197           |> map (chained_or_extra_features_of chained_feature_factor)
       
  1198           |> rpair [] |-> fold (union (eq_fst (op =)))
       
  1199         val extra_feats = facts
       
  1200           |> take (Int.max (0, num_extra_feature_facts - length chained))
       
  1201           |> filter fact_has_right_theory
       
  1202           |> weight_facts_steeply
       
  1203           |> map (chained_or_extra_features_of extra_feature_factor)
       
  1204           |> rpair [] |-> fold (union (eq_fst (op =)))
       
  1205         val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
       
  1206           |> debug ? sort (Real.compare o swap o pairself snd)
       
  1207       in
       
  1208         (parents, hints, feats)
       
  1209       end
       
  1210 
       
  1211     val sml = Config.get ctxt sml
       
  1212 
       
  1213     val (access_G, py_suggs) =
  1191       peek_state ctxt overlord (fn {access_G, ...} =>
  1214       peek_state ctxt overlord (fn {access_G, ...} =>
  1192         if Graph.is_empty access_G then
  1215         if Graph.is_empty access_G then
  1193           (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
  1216           (trace_msg ctxt (K "Nothing has been learned yet"); (access_G, []))
  1194         else
  1217         else
  1195           let
  1218           (access_G,
  1196             val parents = maximal_wrt_access_graph access_G facts
  1219            if sml then
  1197             val goal_feats =
  1220              []
  1198               features_of ctxt thy num_facts const_tab (Local, General) true (concl_t :: hyp_ts)
  1221            else
  1199             val chained_feats = chained
  1222              let val (parents, hints, feats) = query_args access_G in
  1200               |> map (rpair 1.0)
  1223                MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats)
  1201               |> map (chained_or_extra_features_of chained_feature_factor)
  1224              end))
  1202               |> rpair [] |-> fold (union (eq_fst (op =)))
  1225 
  1203             val extra_feats = facts
  1226     val sml_suggs =
  1204               |> take (Int.max (0, num_extra_feature_facts - length chained))
  1227       if sml then
  1205               |> filter fact_has_right_theory
  1228         let val (parents, hints, feats) = query_args access_G in
  1206               |> weight_facts_steeply
  1229           MaSh_SML.query ctxt parents access_G max_facts hints feats
  1207               |> map (chained_or_extra_features_of extra_feature_factor)
  1230         end
  1208               |> rpair [] |-> fold (union (eq_fst (op =)))
  1231       else
  1209             val feats = fold (union (eq_fst (op =))) [chained_feats, extra_feats] goal_feats
  1232         []
  1210               |> debug ? sort (Real.compare o swap o pairself snd)
  1233 
  1211             val hints = chained
       
  1212               |> filter (is_fact_in_graph access_G o snd)
       
  1213               |> map (nickname_of_thm o snd)
       
  1214           in
       
  1215             (access_G,
       
  1216              if Config.get ctxt sml then
       
  1217                MaSh_SML.learn_and_query ctxt parents access_G max_facts hints feats
       
  1218              else
       
  1219                MaSh_Py.query ctxt overlord max_facts ([], hints, parents, feats))
       
  1220           end)
       
  1221     val unknown = filter_out (is_fact_in_graph access_G o snd) facts
  1234     val unknown = filter_out (is_fact_in_graph access_G o snd) facts
  1222   in
  1235   in
  1223     find_mash_suggestions ctxt max_facts suggs facts chained unknown
  1236     find_mash_suggestions ctxt max_facts (py_suggs @ sml_suggs) facts chained unknown
  1224     |> pairself (map fact_of_raw_fact)
  1237     |> pairself (map fact_of_raw_fact)
  1225   end
  1238   end
  1226 
  1239 
  1227 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) =
  1240 fun learn_wrt_access_graph ctxt (name, parents, feats, deps) (learns, G) =
  1228   let
  1241   let
  1321     else
  1334     else
  1322       let
  1335       let
  1323         val name_tabs = build_name_tables nickname_of_thm facts
  1336         val name_tabs = build_name_tables nickname_of_thm facts
  1324 
  1337 
  1325         fun deps_of status th =
  1338         fun deps_of status th =
  1326           if no_dependencies_for_status status then
  1339           if status = Non_Rec_Def orelse status = Rec_Def then
  1327             SOME []
  1340             SOME []
  1328           else if run_prover then
  1341           else if run_prover then
  1329             prover_dependencies_of ctxt params prover auto_level facts name_tabs th
  1342             prover_dependencies_of ctxt params prover auto_level facts name_tabs th
  1330             |> (fn (false, _) => NONE | (true, deps) => trim_dependencies deps)
  1343             |> (fn (false, _) => NONE | (true, deps) => trim_dependencies deps)
  1331           else
  1344           else
  1353                  MaSh_Py.relearn ctxt overlord save relearns);
  1366                  MaSh_Py.relearn ctxt overlord save relearns);
  1354               {access_G = access_G, num_known_facts = num_known_facts, dirty = dirty}
  1367               {access_G = access_G, num_known_facts = num_known_facts, dirty = dirty}
  1355             end
  1368             end
  1356 
  1369 
  1357         fun commit last learns relearns flops =
  1370         fun commit last learns relearns flops =
  1358           (if debug andalso auto_level = 0 then
  1371           (if debug andalso auto_level = 0 then Output.urgent_message "Committing..." else ();
  1359              Output.urgent_message "Committing..."
       
  1360            else
       
  1361              ();
       
  1362            map_state ctxt overlord (do_commit (rev learns) relearns flops);
  1372            map_state ctxt overlord (do_commit (rev learns) relearns flops);
  1363            if not last andalso auto_level = 0 then
  1373            if not last andalso auto_level = 0 then
  1364              let val num_proofs = length learns + length relearns in
  1374              let val num_proofs = length learns + length relearns in
  1365                "Learned " ^ string_of_int num_proofs ^ " " ^
  1375                Output.urgent_message ("Learned " ^ string_of_int num_proofs ^ " " ^
  1366                (if run_prover then "automatic" else "Isar") ^ " proof" ^
  1376                  (if run_prover then "automatic" else "Isar") ^ " proof" ^
  1367                plural_s num_proofs ^ " in the last " ^
  1377                  plural_s num_proofs ^ " in the last " ^ string_of_time commit_timeout ^ ".")
  1368                string_of_time commit_timeout ^ "."
       
  1369                |> Output.urgent_message
       
  1370              end
  1378              end
  1371            else
  1379            else
  1372              ())
  1380              ())
  1373 
  1381 
  1374         fun learn_new_fact _ (accum as (_, (_, _, true))) = accum
  1382         fun learn_new_fact _ (accum as (_, (_, _, true))) = accum
  1476     fun learn auto_level run_prover =
  1484     fun learn auto_level run_prover =
  1477       mash_learn_facts ctxt params prover true auto_level run_prover one_year facts
  1485       mash_learn_facts ctxt params prover true auto_level run_prover one_year facts
  1478       |> Output.urgent_message
  1486       |> Output.urgent_message
  1479   in
  1487   in
  1480     if run_prover then
  1488     if run_prover then
  1481       ("MaShing through " ^ string_of_int num_facts ^ " fact" ^ plural_s num_facts ^
  1489       (Output.urgent_message ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
  1482        " for automatic proofs (" ^ quote prover ^ " timeout: " ^ string_of_time timeout ^
  1490          plural_s num_facts ^ " for automatic proofs (" ^ quote prover ^ " timeout: " ^
  1483        ").\n\nCollecting Isar proofs first..."
  1491          string_of_time timeout ^ ").\n\nCollecting Isar proofs first...");
  1484        |> Output.urgent_message;
       
  1485        learn 1 false;
  1492        learn 1 false;
  1486        "Now collecting automatic proofs. This may take several hours. You can safely stop the \
  1493        Output.urgent_message "Now collecting automatic proofs. This may take several hours. You \
  1487        \learning process at any point."
  1494          \can safely stop the learning process at any point.";
  1488        |> Output.urgent_message;
       
  1489        learn 0 true)
  1495        learn 0 true)
  1490     else
  1496     else
  1491       (Output.urgent_message ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
  1497       (Output.urgent_message ("MaShing through " ^ string_of_int num_facts ^ " fact" ^
  1492          plural_s num_facts ^ " for Isar proofs...");
  1498          plural_s num_facts ^ " for Isar proofs...");
  1493        learn 0 false)
  1499        learn 0 false)