src/HOL/Mirabelle/Actions/mirabelle_sledgehammer_filter.ML
changeset 47477 3fabf352243e
parent 45706 418846ea4f99
equal deleted inserted replaced
47476:92d1c566ebbf 47477:3fabf352243e
       
     1 (*  Title:      HOL/Mirabelle/Actions/mirabelle_sledgehammer_filter.ML
       
     2     Author:     Jasmin Blanchette, TU Munich
       
     3 *)
       
     4 
       
     5 structure Mirabelle_Sledgehammer_Filter : MIRABELLE_ACTION =
       
     6 struct
       
     7 
       
     8 fun get args name default_value =
       
     9   case AList.lookup (op =) args name of
       
    10     SOME value => the (Real.fromString value)
       
    11   | NONE => default_value
       
    12 
       
    13 fun extract_relevance_fudge args
       
    14       {local_const_multiplier, worse_irrel_freq, higher_order_irrel_weight,
       
    15        abs_rel_weight, abs_irrel_weight, skolem_irrel_weight,
       
    16        theory_const_rel_weight, theory_const_irrel_weight,
       
    17        chained_const_irrel_weight, intro_bonus, elim_bonus, simp_bonus,
       
    18        local_bonus, assum_bonus, chained_bonus, max_imperfect, max_imperfect_exp,
       
    19        threshold_divisor, ridiculous_threshold} =
       
    20   {local_const_multiplier =
       
    21        get args "local_const_multiplier" local_const_multiplier,
       
    22    worse_irrel_freq = get args "worse_irrel_freq" worse_irrel_freq,
       
    23    higher_order_irrel_weight =
       
    24        get args "higher_order_irrel_weight" higher_order_irrel_weight,
       
    25    abs_rel_weight = get args "abs_rel_weight" abs_rel_weight,
       
    26    abs_irrel_weight = get args "abs_irrel_weight" abs_irrel_weight,
       
    27    skolem_irrel_weight = get args "skolem_irrel_weight" skolem_irrel_weight,
       
    28    theory_const_rel_weight =
       
    29        get args "theory_const_rel_weight" theory_const_rel_weight,
       
    30    theory_const_irrel_weight =
       
    31        get args "theory_const_irrel_weight" theory_const_irrel_weight,
       
    32    chained_const_irrel_weight =
       
    33        get args "chained_const_irrel_weight" chained_const_irrel_weight,
       
    34    intro_bonus = get args "intro_bonus" intro_bonus,
       
    35    elim_bonus = get args "elim_bonus" elim_bonus,
       
    36    simp_bonus = get args "simp_bonus" simp_bonus,
       
    37    local_bonus = get args "local_bonus" local_bonus,
       
    38    assum_bonus = get args "assum_bonus" assum_bonus,
       
    39    chained_bonus = get args "chained_bonus" chained_bonus,
       
    40    max_imperfect = get args "max_imperfect" max_imperfect,
       
    41    max_imperfect_exp = get args "max_imperfect_exp" max_imperfect_exp,
       
    42    threshold_divisor = get args "threshold_divisor" threshold_divisor,
       
    43    ridiculous_threshold = get args "ridiculous_threshold" ridiculous_threshold}
       
    44 
       
    45 structure Prooftab =
       
    46   Table(type key = int * int val ord = prod_ord int_ord int_ord)
       
    47 
       
    48 val proof_table = Unsynchronized.ref (Prooftab.empty: string list list Prooftab.table)
       
    49 
       
    50 val num_successes = Unsynchronized.ref ([] : (int * int) list)
       
    51 val num_failures = Unsynchronized.ref ([] : (int * int) list)
       
    52 val num_found_proofs = Unsynchronized.ref ([] : (int * int) list)
       
    53 val num_lost_proofs = Unsynchronized.ref ([] : (int * int) list)
       
    54 val num_found_facts = Unsynchronized.ref ([] : (int * int) list)
       
    55 val num_lost_facts = Unsynchronized.ref ([] : (int * int) list)
       
    56 
       
    57 fun get id c = the_default 0 (AList.lookup (op =) (!c) id)
       
    58 fun add id c n =
       
    59   c := (case AList.lookup (op =) (!c) id of
       
    60           SOME m => AList.update (op =) (id, m + n) (!c)
       
    61         | NONE => (id, n) :: !c)
       
    62 
       
    63 fun init proof_file _ thy =
       
    64   let
       
    65     fun do_line line =
       
    66       case line |> space_explode ":" of
       
    67         [line_num, offset, proof] =>
       
    68         SOME (pairself (the o Int.fromString) (line_num, offset),
       
    69               proof |> space_explode " " |> filter_out (curry (op =) ""))
       
    70        | _ => NONE
       
    71     val proofs = File.read (Path.explode proof_file)
       
    72     val proof_tab =
       
    73       proofs |> space_explode "\n"
       
    74              |> map_filter do_line
       
    75              |> AList.coalesce (op =)
       
    76              |> Prooftab.make
       
    77   in proof_table := proof_tab; thy end
       
    78 
       
    79 fun percentage a b = if b = 0 then "N/A" else string_of_int (a * 100 div b)
       
    80 fun percentage_alt a b = percentage a (a + b)
       
    81 
       
    82 fun done id ({log, ...} : Mirabelle.done_args) =
       
    83   if get id num_successes + get id num_failures > 0 then
       
    84     (log "";
       
    85      log ("Number of overall successes: " ^
       
    86           string_of_int (get id num_successes));
       
    87      log ("Number of overall failures: " ^ string_of_int (get id num_failures));
       
    88      log ("Overall success rate: " ^
       
    89           percentage_alt (get id num_successes) (get id num_failures) ^ "%");
       
    90      log ("Number of found proofs: " ^ string_of_int (get id num_found_proofs));
       
    91      log ("Number of lost proofs: " ^ string_of_int (get id num_lost_proofs));
       
    92      log ("Proof found rate: " ^
       
    93           percentage_alt (get id num_found_proofs) (get id num_lost_proofs) ^
       
    94           "%");
       
    95      log ("Number of found facts: " ^ string_of_int (get id num_found_facts));
       
    96      log ("Number of lost facts: " ^ string_of_int (get id num_lost_facts));
       
    97      log ("Fact found rate: " ^
       
    98           percentage_alt (get id num_found_facts) (get id num_lost_facts) ^
       
    99           "%"))
       
   100   else
       
   101     ()
       
   102 
       
   103 val default_prover = ATP_Systems.eN (* arbitrary ATP *)
       
   104 
       
   105 fun with_index (i, s) = s ^ "@" ^ string_of_int i
       
   106 
       
   107 fun action args id ({pre, pos, log, ...} : Mirabelle.run_args) =
       
   108   case (Position.line_of pos, Position.offset_of pos) of
       
   109     (SOME line_num, SOME offset) =>
       
   110     (case Prooftab.lookup (!proof_table) (line_num, offset) of
       
   111        SOME proofs =>
       
   112        let
       
   113          val {context = ctxt, facts = chained_ths, goal} = Proof.goal pre
       
   114          val prover = AList.lookup (op =) args "prover"
       
   115                       |> the_default default_prover
       
   116          val {relevance_thresholds, max_relevant, slice, ...} =
       
   117            Sledgehammer_Isar.default_params ctxt args
       
   118          val default_max_relevant =
       
   119            Sledgehammer_Provers.default_max_relevant_for_prover ctxt slice
       
   120                                                                 prover
       
   121          val is_appropriate_prop =
       
   122            Sledgehammer_Provers.is_appropriate_prop_for_prover ctxt
       
   123                default_prover
       
   124          val is_built_in_const =
       
   125            Sledgehammer_Provers.is_built_in_const_for_prover ctxt default_prover
       
   126          val relevance_fudge =
       
   127            extract_relevance_fudge args
       
   128                (Sledgehammer_Provers.relevance_fudge_for_prover ctxt prover)
       
   129          val relevance_override = {add = [], del = [], only = false}
       
   130          val subgoal = 1
       
   131          val (_, hyp_ts, concl_t) = ATP_Util.strip_subgoal ctxt goal subgoal
       
   132          val ho_atp = Sledgehammer_Provers.is_ho_atp ctxt prover
       
   133          val facts =
       
   134           Sledgehammer_Filter.nearly_all_facts ctxt ho_atp relevance_override
       
   135                                                chained_ths hyp_ts concl_t
       
   136           |> filter (is_appropriate_prop o prop_of o snd)
       
   137           |> Sledgehammer_Filter.relevant_facts ctxt relevance_thresholds
       
   138                  (the_default default_max_relevant max_relevant)
       
   139                  is_built_in_const relevance_fudge relevance_override
       
   140                  chained_ths hyp_ts concl_t
       
   141            |> map (fst o fst)
       
   142          val (found_facts, lost_facts) =
       
   143            flat proofs |> sort_distinct string_ord
       
   144            |> map (fn fact => (find_index (curry (op =) fact) facts, fact))
       
   145            |> List.partition (curry (op <=) 0 o fst)
       
   146            |>> sort (prod_ord int_ord string_ord) ||> map snd
       
   147          val found_proofs = filter (forall (member (op =) facts)) proofs
       
   148          val n = length found_proofs
       
   149          val _ =
       
   150            if n = 0 then
       
   151              (add id num_failures 1; log "Failure")
       
   152            else
       
   153              (add id num_successes 1;
       
   154               add id num_found_proofs n;
       
   155               log ("Success (" ^ string_of_int n ^ " of " ^
       
   156                    string_of_int (length proofs) ^ " proofs)"))
       
   157          val _ = add id num_lost_proofs (length proofs - n)
       
   158          val _ = add id num_found_facts (length found_facts)
       
   159          val _ = add id num_lost_facts (length lost_facts)
       
   160          val _ =
       
   161            if null found_facts then
       
   162              ()
       
   163            else
       
   164              let
       
   165                val found_weight =
       
   166                  Real.fromInt (fold (fn (n, _) =>
       
   167                                         Integer.add (n * n)) found_facts 0)
       
   168                    / Real.fromInt (length found_facts)
       
   169                  |> Math.sqrt |> Real.ceil
       
   170              in
       
   171                log ("Found facts (among " ^ string_of_int (length facts) ^
       
   172                     ", weight " ^ string_of_int found_weight ^ "): " ^
       
   173                     commas (map with_index found_facts))
       
   174              end
       
   175          val _ = if null lost_facts then
       
   176                    ()
       
   177                  else
       
   178                    log ("Lost facts (among " ^ string_of_int (length facts) ^
       
   179                         "): " ^ commas lost_facts)
       
   180        in () end
       
   181      | NONE => log "No known proof")
       
   182   | _ => ()
       
   183 
       
   184 val proof_fileK = "proof_file"
       
   185 
       
   186 fun invoke args =
       
   187   let
       
   188     val (pf_args, other_args) =
       
   189       args |> List.partition (curry (op =) proof_fileK o fst)
       
   190     val proof_file = case pf_args of
       
   191                        [] => error "No \"proof_file\" specified"
       
   192                      | (_, s) :: _ => s
       
   193   in Mirabelle.register (init proof_file, action other_args, done) end
       
   194 
       
   195 end;
       
   196 
       
   197 (* Workaround to keep the "mirabelle.pl" script happy *)
       
   198 structure Mirabelle_Sledgehammer_filter = Mirabelle_Sledgehammer_Filter;