src/HOL/Tools/Sledgehammer/sledgehammer_mash.ML
changeset 57103 c9e400a05c9e
parent 57102 3e6af473d666
child 57104 b93e0680a5b3
equal deleted inserted replaced
57102:3e6af473d666 57103:c9e400a05c9e
   501 
   501 
   502 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   502 fun add_to_xtab key (next, tab, keys) = (next + 1, Symtab.update_new (key, next) tab, key :: keys)
   503 
   503 
   504 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   504 fun map_array_at ary f i = Array.update (ary, i, f (Array.sub (ary, i)))
   505 
   505 
   506 fun query ctxt engine parents access_G max_suggs hints feats =
   506 fun query ctxt engine visible_facts max_suggs (learns, hints, parents, feats) =
   507   let
   507   let
   508     val visible_facts = Graph.all_preds access_G parents
       
   509     val visible_fact_set = Symtab.make_set visible_facts
   508     val visible_fact_set = Symtab.make_set visible_facts
   510 
   509 
   511     val all_nodes =
   510     val learns' =
   512       (Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
   511       (learns |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @
   513        |> List.partition (Symtab.defined visible_fact_set o #1) |> op @) @
       
   514       (if null hints then [] else [(".goal", feats, hints)])
   512       (if null hints then [] else [(".goal", feats, hints)])
   515 
   513 
   516     val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) =
   514     val (rev_depss, rev_featss, (num_facts, _, rev_facts), (num_feats, feat_tab, _)) =
   517       fold (fn (fact, feats, deps) =>
   515       fold (fn (fact, feats, deps) =>
   518             fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
   516             fn (rev_depss, rev_featss, fact_xtab as (_, fact_tab, _), feat_xtab) =>
   525             val (feats', feat_xtab') = fold_map add_feat feats feat_xtab
   523             val (feats', feat_xtab') = fold_map add_feat feats feat_xtab
   526           in
   524           in
   527             (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss,
   525             (map_filter (Symtab.lookup fact_tab) deps :: rev_depss, feats' :: rev_featss,
   528              add_to_xtab fact fact_xtab, feat_xtab')
   526              add_to_xtab fact fact_xtab, feat_xtab')
   529           end)
   527           end)
   530         all_nodes ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
   528         learns' ([], [], (0, Symtab.empty, []), (0, Symtab.empty, []))
   531 
   529 
   532     val facts = rev rev_facts
   530     val facts = rev rev_facts
   533     val fact_vec = Vector.fromList facts
   531     val fact_vec = Vector.fromList facts
   534 
   532 
   535     val deps_vec = Vector.fromList (rev rev_depss)
   533     val deps_vec = Vector.fromList (rev rev_depss)
  1267 
  1265 
  1268     val sml_suggs =
  1266     val sml_suggs =
  1269       if engine = MaSh_Py then
  1267       if engine = MaSh_Py then
  1270         []
  1268         []
  1271       else
  1269       else
  1272         let val (parents, hints, feats) = query_args access_G in
  1270         let
  1273           MaSh_SML.query ctxt engine parents access_G max_facts hints feats
  1271           val (parents, hints, feats) = query_args access_G
       
  1272           val visible_facts = Graph.all_preds access_G parents
       
  1273           val learns =
       
  1274             Graph.schedule (fn _ => fn (fact, (_, feats, deps)) => (fact, feats, deps)) access_G
       
  1275         in
       
  1276           MaSh_SML.query ctxt engine visible_facts max_facts (learns, hints, parents, feats)
  1274         end
  1277         end
  1275 
  1278 
  1276     val unknown = filter_out (is_fact_in_graph access_G o snd) facts
  1279     val unknown = filter_out (is_fact_in_graph access_G o snd) facts
  1277   in
  1280   in
  1278     find_mash_suggestions ctxt max_facts (py_suggs @ sml_suggs) facts chained unknown
  1281     find_mash_suggestions ctxt max_facts (py_suggs @ sml_suggs) facts chained unknown