src/HOL/Eisbach/match_method.ML
changeset 60285 b4f1a0a701ae
parent 60248 f7e4294216d2
child 60287 adde5ce1e0a7
equal deleted inserted replaced
60284:014b86186c49 60285:b4f1a0a701ae
    38 
    38 
    39 datatype match_kind =
    39 datatype match_kind =
    40     Match_Term of term Item_Net.T
    40     Match_Term of term Item_Net.T
    41   | Match_Fact of thm Item_Net.T
    41   | Match_Fact of thm Item_Net.T
    42   | Match_Concl
    42   | Match_Concl
    43   | Match_Prems;
    43   | Match_Prems of bool;
    44 
    44 
    45 
    45 
    46 val aconv_net = Item_Net.init (op aconv) single;
    46 val aconv_net = Item_Net.init (op aconv) single;
    47 
    47 
    48 val parse_match_kind =
    48 val parse_match_kind =
    49   Scan.lift @{keyword "conclusion"} >> K Match_Concl ||
    49   Scan.lift @{keyword "conclusion"} >> K Match_Concl ||
    50   Scan.lift @{keyword "premises"} >> K Match_Prems ||
    50   Scan.lift (@{keyword "premises"} |-- Args.mode "local") >> Match_Prems ||
    51   Scan.lift (@{keyword "("}) |-- Args.term --| Scan.lift (@{keyword ")"}) >>
    51   Scan.lift (@{keyword "("}) |-- Args.term --| Scan.lift (@{keyword ")"}) >>
    52     (fn t => Match_Term (Item_Net.update t aconv_net)) ||
    52     (fn t => Match_Term (Item_Net.update t aconv_net)) ||
    53   Attrib.thms >> (fn thms => Match_Fact (fold Item_Net.update thms Thm.full_rules));
    53   Attrib.thms >> (fn thms => Match_Fact (fold Item_Net.update thms Thm.full_rules));
    54 
    54 
    55 
    55 
    56 fun nameable_match m = (case m of Match_Fact _ => true | Match_Prems => true | _ => false);
    56 fun nameable_match m = (case m of Match_Fact _ => true | Match_Prems _ => true | _ => false);
    57 fun prop_match m = (case m of Match_Term _ => false | _ => true);
    57 fun prop_match m = (case m of Match_Term _ => false | _ => true);
    58 
    58 
    59 val bound_term : (term, binding) Parse_Tools.parse_val parser =
    59 val bound_term : (term, binding) Parse_Tools.parse_val parser =
    60   Parse_Tools.parse_term_val Parse.binding;
    60   Parse_Tools.parse_term_val Parse.binding;
    61 
    61 
    64     Scan.option (@{keyword "::"} |-- Parse.!!! Parse.typ)
    64     Scan.option (@{keyword "::"} |-- Parse.!!! Parse.typ)
    65     >> (fn (xs, T) => map (fn (nm, pos) => ((nm, T), pos)) xs)) >> flat;
    65     >> (fn (xs, T) => map (fn (nm, pos) => ((nm, T), pos)) xs)) >> flat;
    66 
    66 
    67 val for_fixes = Scan.optional (@{keyword "for"} |-- fixes) [];
    67 val for_fixes = Scan.optional (@{keyword "for"} |-- fixes) [];
    68 
    68 
    69 fun pos_of dyn =
    69 fun pos_of dyn = Parse_Tools.the_parse_val dyn |> Binding.pos_of;
    70   (case dyn of
       
    71     Parse_Tools.Parse_Val (b, _) => Binding.pos_of b
       
    72   | _ => raise Fail "Not a parse value");
       
    73 
       
    74 
    70 
    75 (*FIXME: Dynamic facts modify the background theory, so we have to resort
    71 (*FIXME: Dynamic facts modify the background theory, so we have to resort
    76   to token replacement for matched facts. *)
    72   to token replacement for matched facts. *)
    77 fun dynamic_fact ctxt =
    73 fun dynamic_fact ctxt =
    78   bound_term -- Args.opt_attribs (Attrib.check_name ctxt);
    74   bound_term -- Args.opt_attribs (Attrib.check_name ctxt);
    79 
    75 
    80 type match_args = {multi : bool, cut : bool};
    76 type match_args = {multi : bool, cut : int};
    81 
    77 
    82 val parse_match_args =
    78 val parse_match_args =
    83   Scan.optional (Args.parens (Parse.enum1 ","
    79   Scan.optional (Args.parens (Parse.enum1 ","
    84     (Args.$$$ "multi" || Args.$$$ "cut"))) [] >>
    80     (Args.$$$ "multi" -- Scan.succeed ~1 || Args.$$$ "cut" -- Scan.optional Parse.int 1))) [] >>
    85     (fn ss =>
    81     (fn ss =>
    86       fold (fn s => fn {multi, cut} =>
    82       fold (fn s => fn {multi, cut} =>
    87         (case s of
    83         (case s of
    88          "multi" => {multi = true, cut = cut}
    84           ("multi", _) => {multi = true, cut = cut}
    89         | "cut" => {multi = multi, cut = true}))
    85         | ("cut", n) => {multi = multi, cut = n}))
    90       ss {multi = false, cut = false});
    86       ss {multi = false, cut = ~1});
    91 
    87 
    92 fun parse_named_pats match_kind =
    88 fun parse_named_pats match_kind =
    93   Args.context :|-- (fn ctxt =>
    89   Args.context :|-- (fn ctxt =>
    94     Scan.lift (Parse.and_list1 (Scan.option (dynamic_fact ctxt --| Args.colon) :--
    90     Scan.lift (Parse.and_list1 (Scan.option (dynamic_fact ctxt --| Args.colon) :--
    95       (fn opt_dyn =>
    91       (fn opt_dyn =>
   124           fun parse_term term =
   120           fun parse_term term =
   125             if prop_match match_kind
   121             if prop_match match_kind
   126             then Syntax.parse_prop ctxt3 term
   122             then Syntax.parse_prop ctxt3 term
   127             else Syntax.parse_term ctxt3 term;
   123             else Syntax.parse_term ctxt3 term;
   128 
   124 
       
   125           fun drop_Trueprop_dummy t =
       
   126             (case t of
       
   127               Const (@{const_name Trueprop}, _) $
       
   128                 (Const (@{syntax_const "_type_constraint_"}, T) $
       
   129                   Const (@{const_name Pure.dummy_pattern}, _)) =>
       
   130                     Const (@{syntax_const "_type_constraint_"}, T) $
       
   131                       Const (@{const_name Pure.dummy_pattern}, propT)
       
   132             | t1 $ t2 => drop_Trueprop_dummy t1 $ drop_Trueprop_dummy t2
       
   133             | Abs (a, T, b) => Abs (a, T, drop_Trueprop_dummy b)
       
   134             | _ => t);
       
   135 
   129           val pats =
   136           val pats =
   130             map (fn (_, (term, _)) => parse_term (Parse_Tools.the_parse_val term)) ts
   137             map (fn (_, (term, _)) => parse_term (Parse_Tools.the_parse_val term)) ts
       
   138             |> map drop_Trueprop_dummy
       
   139             |> (fn ts => fold_map Term.replace_dummy_patterns ts (Variable.maxidx_of ctxt3 + 1))
       
   140             |> fst
   131             |> Syntax.check_terms ctxt3;
   141             |> Syntax.check_terms ctxt3;
   132 
   142 
   133           val pat_fixes = fold (Term.add_frees) pats [] |> map fst;
   143           val pat_fixes = fold (Term.add_frees) pats [] |> map fst;
   134 
   144 
   135           val _ =
   145           val _ =
   136             map2 (fn nm => fn (_, pos) =>
   146             map2 (fn nm => fn (_, pos) =>
   137                 member (op =) pat_fixes nm orelse
   147                 member (op =) pat_fixes nm orelse
   138                 error ("For-fixed variable must be bound in some pattern" ^ Position.here pos))
   148                 error ("For-fixed variable must be bound in some pattern" ^ Position.here pos))
   139               fix_nms fixes;
   149               fix_nms fixes;
   140 
   150 
   141           val _ = map (Term.map_types Type.no_tvars) pats
   151           val _ = map (Term.map_types Type.no_tvars) pats;
   142 
   152 
   143           val ctxt4 = fold Variable.declare_term pats ctxt3;
   153           val ctxt4 = fold Variable.declare_term pats ctxt3;
   144 
   154 
   145           val (Ts, ctxt5) = ctxt4 |> fold_map Proof_Context.inferred_param fix_nms;
   155           val (Ts, ctxt5) = ctxt4 |> fold_map Proof_Context.inferred_param fix_nms;
   146 
   156 
   198               (ctxt
   208               (ctxt
   199                 |> Token.declare_maxidx_src src
   209                 |> Token.declare_maxidx_src src
   200                 |> Variable.declare_maxidx (Variable.maxidx_of ctxt6));
   210                 |> Variable.declare_maxidx (Variable.maxidx_of ctxt6));
   201 
   211 
   202           val pats' = map (Term.map_types Type_Infer.paramify_vars #> Morphism.term morphism) pats;
   212           val pats' = map (Term.map_types Type_Infer.paramify_vars #> Morphism.term morphism) pats;
   203           val _ = ListPair.app (fn ((_, (Parse_Tools.Parse_Val (_, f), _)), t) => f t) (ts, pats');
   213           val _ = ListPair.app (fn ((_, (v, _)), t) => Parse_Tools.the_parse_fun v t) (ts, pats');
   204 
   214 
   205           fun close_src src =
   215           fun close_src src =
   206             let
   216             let
   207               val src' = Token.closure_src src |> Token.transform_src morphism;
   217               val src' = Token.closure_src src |> Token.transform_src morphism;
   208               val _ =
   218               val _ =
   209                 map2 (fn tok1 => fn tok2 =>
   219                 map2 (fn tok1 => fn tok2 =>
   210                   (case (Token.get_value tok2) of
   220                   (case Token.get_value tok2 of
   211                     SOME value => Token.assign (SOME value) tok1
   221                     SOME value => Token.assign (SOME value) tok1
   212                   | NONE => ()))
   222                   | NONE => ()))
   213                   (Token.args_of_src src)
   223                   (Token.args_of_src src)
   214                   (Token.args_of_src src');
   224                   (Token.args_of_src src');
   215             in src' end;
   225             in src' end;
   217           val binds' =
   227           val binds' =
   218             map (Option.map (fn (t, atts) => (Morphism.term morphism t, map close_src atts))) binds;
   228             map (Option.map (fn (t, atts) => (Morphism.term morphism t, map close_src atts))) binds;
   219 
   229 
   220           val _ =
   230           val _ =
   221             ListPair.app
   231             ListPair.app
   222               (fn ((SOME ((Parse_Tools.Parse_Val (_, f), _)), _), SOME (t, _)) => f t
   232               (fn ((SOME ((v, _)), _), SOME (t, _)) => Parse_Tools.the_parse_fun v t
   223                 | ((NONE, _), NONE) => ()
   233                 | ((NONE, _), NONE) => ()
   224                 | _ => error "Mismatch between real and parsed bound variables")
   234                 | _ => error "Mismatch between real and parsed bound variables")
   225               (ts, binds');
   235               (ts, binds');
   226 
   236 
   227           val real_fixes' = map (Morphism.term morphism) real_fixes;
   237           val real_fixes' = map (Morphism.term morphism) real_fixes;
   228           val _ =
   238           val _ =
   229             ListPair.app (fn (( (Parse_Tools.Parse_Val (_, f), _) , _), t) => f t)
   239             ListPair.app (fn (((v, _) , _), t) => Parse_Tools.the_parse_fun v t)
   230               (fixes, real_fixes');
   240               (fixes, real_fixes');
   231 
   241 
   232           val match_args = map (fn (_, (_, match_args)) => match_args) ts;
   242           val match_args = map (fn (_, (_, match_args)) => match_args) ts;
   233           val binds'' = (binds' ~~ match_args) ~~ pats';
   243           val binds'' = (binds' ~~ match_args) ~~ pats';
   234 
   244 
   253 
   263 
   254 fun inst_thm ctxt env ts params thm =
   264 fun inst_thm ctxt env ts params thm =
   255   let
   265   let
   256     val ts' = map (Envir.norm_term env) ts;
   266     val ts' = map (Envir.norm_term env) ts;
   257     val insts = map (Thm.cterm_of ctxt) ts' ~~ map (Thm.cterm_of ctxt) params;
   267     val insts = map (Thm.cterm_of ctxt) ts' ~~ map (Thm.cterm_of ctxt) params;
   258 
       
   259   in
   268   in
   260     Drule.cterm_instantiate insts thm
   269     Drule.cterm_instantiate insts thm
   261   end;
   270   end;
   262 
   271 
   263 fun do_inst fact_insts' env text ctxt =
   272 fun do_inst fact_insts' env text ctxt =
   307       Variable.export_morphism ctxt'
   316       Variable.export_morphism ctxt'
   308         (ctxt |> Variable.declare_maxidx (Variable.maxidx_of ctxt'));
   317         (ctxt |> Variable.declare_maxidx (Variable.maxidx_of ctxt'));
   309     val pat'' :: params'' = map (Morphism.term morphism) (pat' :: params');
   318     val pat'' :: params'' = map (Morphism.term morphism) (pat' :: params');
   310 
   319 
   311     fun prep_head (t, att) = (dest_internal_fact t, att);
   320     fun prep_head (t, att) = (dest_internal_fact t, att);
   312 
       
   313   in
   321   in
   314     ((((Option.map prep_head x, args), params''), pat''), ctxt')
   322     ((((Option.map prep_head x, args), params''), pat''), ctxt')
   315   end;
   323   end;
   316 
   324 
   317 fun recalculate_maxidx env =
   325 fun recalculate_maxidx env =
   324     Envir.Envir
   332     Envir.Envir
   325       {maxidx = Int.max (Int.max (max_tidx, max_Tidx), Envir.maxidx_of env),
   333       {maxidx = Int.max (Int.max (max_tidx, max_Tidx), Envir.maxidx_of env),
   326         tenv = tenv, tyenv = tyenv}
   334         tenv = tenv, tyenv = tyenv}
   327   end
   335   end
   328 
   336 
   329 fun match_filter_env inner_ctxt morphism pat_vars fixes (ts, params) thm inner_env =
   337 fun morphism_env morphism env =
   330   let
   338   let
   331     val tenv = Envir.term_env inner_env
   339     val tenv = Envir.term_env env
   332       |> Vartab.map (K (fn (T, t) => (Morphism.typ morphism T, Morphism.term morphism t)));
   340       |> Vartab.map (K (fn (T, t) => (Morphism.typ morphism T, Morphism.term morphism t)));
   333 
   341     val tyenv = Envir.type_env env
   334     val tyenv = Envir.type_env inner_env
       
   335       |> Vartab.map (K (fn (S, T) => (S, Morphism.typ morphism T)));
   342       |> Vartab.map (K (fn (S, T) => (S, Morphism.typ morphism T)));
   336 
   343    in Envir.Envir {maxidx = Envir.maxidx_of env, tenv = tenv, tyenv = tyenv} end;
   337     val outer_env = Envir.Envir {maxidx = Envir.maxidx_of inner_env, tenv = tenv, tyenv = tyenv};
   344 
   338 
   345 fun export_with_params ctxt morphism (SOME ts, params) thm env =
       
   346       let
       
   347         val outer_env = morphism_env morphism env;
       
   348         val thm' = Morphism.thm morphism thm;
       
   349       in inst_thm ctxt outer_env params ts thm' end
       
   350   | export_with_params _ morphism (NONE,_) thm _ = Morphism.thm morphism thm;
       
   351 
       
   352 fun match_filter_env is_newly_fixed pat_vars fixes params env =
       
   353   let
   339     val param_vars = map Term.dest_Var params;
   354     val param_vars = map Term.dest_Var params;
   340 
   355 
   341     val params' = map (fn (xi, _) => Vartab.lookup (Envir.term_env outer_env) xi) param_vars;
   356     val tenv = Envir.term_env env;
       
   357 
       
   358     val params' = map (fn (xi, _) => Vartab.lookup tenv xi) param_vars;
   342 
   359 
   343     val fixes_vars = map Term.dest_Var fixes;
   360     val fixes_vars = map Term.dest_Var fixes;
   344 
   361 
   345     val all_vars = Vartab.keys tenv;
   362     val all_vars = Vartab.keys tenv;
   346 
   363 
   347     val extra_vars = subtract (fn ((xi, _), xi') => xi = xi') fixes_vars all_vars;
   364     val extra_vars = subtract (fn ((xi, _), xi') => xi = xi') fixes_vars all_vars;
   348 
   365 
   349     val tenv' = tenv
   366     val tenv' = tenv |> fold (Vartab.delete_safe) extra_vars;
   350       |> fold (Vartab.delete_safe) extra_vars;
       
   351 
   367 
   352     val env' =
   368     val env' =
   353       Envir.Envir {maxidx = Envir.maxidx_of outer_env, tenv = tenv', tyenv = tyenv}
   369       Envir.Envir {maxidx = Envir.maxidx_of env, tenv = tenv', tyenv = Envir.type_env env}
   354       |> recalculate_maxidx;
   370 
   355 
   371     val all_params_bound = forall (fn SOME (_, Free (x,_)) => is_newly_fixed x | _ => false) params';
   356     val all_params_bound = forall (fn SOME (_, Var _) => true | _ => false) params';
   372 
       
   373     val all_params_distinct = not (has_duplicates (op =) params');
   357 
   374 
   358     val pat_fixes = inter (eq_fst (op =)) fixes_vars pat_vars;
   375     val pat_fixes = inter (eq_fst (op =)) fixes_vars pat_vars;
   359 
   376 
   360     val all_pat_fixes_bound = forall (fn (xi, _) => is_some (Vartab.lookup tenv' xi)) pat_fixes;
   377     val all_pat_fixes_bound = forall (fn (xi, _) => is_some (Vartab.lookup tenv' xi)) pat_fixes;
   361 
   378   in
   362     val thm' = Morphism.thm morphism thm;
   379     if all_params_bound andalso all_pat_fixes_bound andalso all_params_distinct
   363 
   380     then SOME env'
   364   in
       
   365     if all_params_bound andalso all_pat_fixes_bound then
       
   366       SOME (case ts of SOME ts => inst_thm inner_ctxt outer_env params ts thm' | _ => thm', env')
       
   367     else NONE
   381     else NONE
   368   end;
   382   end;
   369 
   383 
   370 
   384 
   371 (* Slightly hacky way of uniquely identifying focus premises *)
   385 (* Slightly hacky way of uniquely identifying focus premises *)
   372 val prem_idN = "premise_id";
   386 val prem_idN = "premise_id";
   373 
   387 
   374 fun prem_id_eq ((id, _ : thm), (id', _ : thm)) = id = id';
   388 fun prem_id_eq ((id, _ : thm), (id', _ : thm)) = id = id';
   375 
   389 
   376 val prem_rules : (int * thm) Item_Net.T =
   390 val prem_rules : (int * thm) Item_Net.T =
   377    Item_Net.init prem_id_eq (single o Thm.full_prop_of o snd);
   391   Item_Net.init prem_id_eq (single o Thm.full_prop_of o snd);
   378 
   392 
   379 fun raw_thm_to_id thm =
   393 fun raw_thm_to_id thm =
   380   (case Properties.get (Thm.get_tags thm) prem_idN of NONE => NONE | SOME id => Int.fromString id)
   394   (case Properties.get (Thm.get_tags thm) prem_idN of NONE => NONE | SOME id => Int.fromString id)
   381   |> the_default ~1;
   395   |> the_default ~1;
   382 
   396 
   392 
   406 
   393 (* focus prems *)
   407 (* focus prems *)
   394 
   408 
   395 val focus_prems = #1 o Focus_Data.get;
   409 val focus_prems = #1 o Focus_Data.get;
   396 
   410 
       
   411 fun hyp_from_premid ctxt (ident, prem) =
       
   412   let
       
   413     val ident = Thm.cterm_of ctxt (HOLogic.mk_number @{typ nat} ident |> Logic.mk_term);
       
   414     val hyp =
       
   415       (case #hyps (Thm.crep_thm prem) of
       
   416         [hyp] => hyp
       
   417       | _ => error "Prem should have exactly one hyp");  (* FIXME error vs. raise Fail !? *)
       
   418     val ct = Drule.mk_term (hyp) |> Thm.cprop_of;
       
   419   in Drule.protect (Conjunction.mk_conjunction (ident, ct)) end;
       
   420 
       
   421 fun hyp_from_ctermid ctxt (ident,cterm) =
       
   422   let
       
   423     val ident = Thm.cterm_of ctxt (HOLogic.mk_number @{typ nat} ident |> Logic.mk_term);
       
   424   in Drule.protect (Conjunction.mk_conjunction (ident, cterm)) end;
       
   425 
       
   426 fun add_premid_hyp premid ctxt =
       
   427   Thm.declare_hyps (hyp_from_premid ctxt premid) ctxt;
       
   428 
   397 fun add_focus_prem prem =
   429 fun add_focus_prem prem =
       
   430   `(Focus_Data.get #> #1 #> #1) ##>
   398   (Focus_Data.map o @{apply 3(1)}) (fn (next, net) =>
   431   (Focus_Data.map o @{apply 3(1)}) (fn (next, net) =>
   399     (next + 1, Item_Net.update (next, Thm.tag_rule (prem_idN, string_of_int next) prem) net));
   432     (next + 1, Item_Net.update (next, Thm.tag_rule (prem_idN, string_of_int next) prem) net));
   400 
   433 
   401 fun remove_focus_prem thm =
   434 fun remove_focus_prem' (ident, thm) =
   402   (Focus_Data.map o @{apply 3(1)} o apsnd)
   435   (Focus_Data.map o @{apply 3(1)} o apsnd)
   403     (Item_Net.remove (raw_thm_to_id thm, thm));
   436     (Item_Net.remove (ident, thm));
       
   437 
       
   438 fun remove_focus_prem thm = remove_focus_prem' (raw_thm_to_id thm, thm);
   404 
   439 
   405 (*TODO: Preliminary analysis to see if we're trying to clear in a non-focus match?*)
   440 (*TODO: Preliminary analysis to see if we're trying to clear in a non-focus match?*)
   406 val _ =
   441 val _ =
   407   Theory.setup
   442   Theory.setup
   408     (Attrib.setup @{binding "thin"}
   443     (Attrib.setup @{binding "thin"}
   427 
   462 
   428 fun add_focus_params params =
   463 fun add_focus_params params =
   429   (Focus_Data.map o @{apply 3(3)})
   464   (Focus_Data.map o @{apply 3(3)})
   430     (append (map (fn (_, ct) => Thm.term_of ct) params));
   465     (append (map (fn (_, ct) => Thm.term_of ct) params));
   431 
   466 
       
   467 fun solve_term ct = Thm.trivial ct OF [Drule.termI];
       
   468 
       
   469 fun get_thinned_prems goal =
       
   470   let
       
   471     val chyps = Thm.crep_thm goal |> #hyps;
       
   472 
       
   473     fun prem_from_hyp hyp goal =
       
   474     let
       
   475       val asm = Thm.assume hyp;
       
   476       val (identt,ct) = asm |> Goal.conclude |> Thm.cprop_of |> Conjunction.dest_conjunction;
       
   477       val ident = HOLogic.dest_number (Thm.term_of identt |> Logic.dest_term) |> snd;
       
   478       val thm = Conjunction.intr (solve_term identt) (solve_term ct) |> Goal.protect 0
       
   479       val goal' = Thm.implies_elim (Thm.implies_intr hyp goal) thm;
       
   480     in
       
   481       (SOME (ident,ct),goal')
       
   482     end handle TERM _ => (NONE,goal) | THM _ => (NONE,goal);
       
   483   in
       
   484     fold_map prem_from_hyp chyps goal
       
   485     |>> map_filter I
       
   486   end;
       
   487 
   432 
   488 
   433 (* Add focus elements as proof data *)
   489 (* Add focus elements as proof data *)
   434 fun augment_focus
   490 fun augment_focus (focus: Subgoal.focus) : (int list * Subgoal.focus) =
   435     ({context, params, prems, asms, concl, schematics} : Subgoal.focus) : Subgoal.focus =
   491   let
   436   let
   492     val {context, params, prems, asms, concl, schematics} = focus;
   437     val context' = context
   493 
       
   494     val (prem_ids,ctxt') = context
   438       |> add_focus_params params
   495       |> add_focus_params params
   439       |> add_focus_schematics (snd schematics)
   496       |> add_focus_schematics (snd schematics)
   440       |> fold add_focus_prem (rev prems);
   497       |> fold_map add_focus_prem (rev prems)
   441   in
   498 
   442     {context = context',
   499     val local_prems = map2 pair prem_ids (rev prems);
       
   500 
       
   501     val ctxt'' = fold add_premid_hyp local_prems ctxt';
       
   502   in
       
   503     (prem_ids,{context = ctxt'',
   443      params = params,
   504      params = params,
   444      prems = prems,
   505      prems = prems,
   445      concl = concl,
   506      concl = concl,
   446      schematics = schematics,
   507      schematics = schematics,
   447      asms = asms}
   508      asms = asms})
   448   end;
   509   end;
   449 
   510 
   450 
   511 
   451 (* Fix schematics in the goal *)
   512 (* Fix schematics in the goal *)
   452 fun focus_concl ctxt i goal =
   513 fun focus_concl ctxt i goal =
   465   in
   526   in
   466     ({context = context', concl = concl', params = params, prems = prems,
   527     ({context = context', concl = concl', params = params, prems = prems,
   467       schematics = schematics', asms = asms} : Subgoal.focus, goal'')
   528       schematics = schematics', asms = asms} : Subgoal.focus, goal'')
   468   end;
   529   end;
   469 
   530 
   470 exception MATCH_CUT;
       
   471 
       
   472 val raise_match : (thm * Envir.env) Seq.seq = Seq.make (fn () => raise MATCH_CUT);
       
   473 
       
   474 fun map_handle seq =
       
   475   Seq.make (fn () =>
       
   476     (case (Seq.pull seq handle MATCH_CUT => NONE) of
       
   477       SOME (x, seq') => SOME (x, map_handle seq')
       
   478     | NONE => NONE));
       
   479 
   531 
   480 fun deduplicate eq prev seq =
   532 fun deduplicate eq prev seq =
   481   Seq.make (fn () =>
   533   Seq.make (fn () =>
   482     (case (Seq.pull seq) of
   534     (case Seq.pull seq of
   483       SOME (x, seq') =>
   535       SOME (x, seq') =>
   484         if member eq prev x
   536         if member eq prev x
   485         then Seq.pull (deduplicate eq prev seq')
   537         then Seq.pull (deduplicate eq prev seq')
   486         else SOME (x, deduplicate eq (x :: prev) seq')
   538         else SOME (x, deduplicate eq (x :: prev) seq')
   487     | NONE => NONE));
   539     | NONE => NONE));
   488 
   540 
       
   541 
   489 fun consistent_env env =
   542 fun consistent_env env =
   490   let
   543   let
   491     val tenv = Envir.term_env env;
   544     val tenv = Envir.term_env env;
   492     val tyenv = Envir.type_env env;
   545     val tyenv = Envir.type_env env;
   493   in
   546   in
   494     forall (fn (_, (T, t)) => Envir.norm_type tyenv T = fastype_of t) (Vartab.dest tenv)
   547     forall (fn (_, (T, t)) => Envir.norm_type tyenv T = fastype_of t) (Vartab.dest tenv)
   495   end;
   548   end;
   496 
   549 
       
   550 fun term_eq_wrt (env1,env2) (t1,t2) =
       
   551   Envir.eta_contract (Envir.norm_term env1 t1) aconv
       
   552   Envir.eta_contract (Envir.norm_term env2 t2);
       
   553 
       
   554 fun type_eq_wrt (env1,env2) (T1,T2) =
       
   555   Envir.norm_type (Envir.type_env env1) T1 = Envir.norm_type (Envir.type_env env2) T2
       
   556 
       
   557 
   497 fun eq_env (env1, env2) =
   558 fun eq_env (env1, env2) =
   498   let
       
   499     val tyenv1 = Envir.type_env env1;
       
   500     val tyenv2 = Envir.type_env env2;
       
   501   in
       
   502     Envir.maxidx_of env1 = Envir.maxidx_of env1 andalso
   559     Envir.maxidx_of env1 = Envir.maxidx_of env1 andalso
   503     ListPair.allEq (fn ((var, (_, t)), (var', (_, t'))) =>
   560     ListPair.allEq (fn ((var, (_, t)), (var', (_, t'))) =>
   504         (var = var' andalso
   561         (var = var' andalso term_eq_wrt (env1,env2) (t,t')))
   505           Envir.eta_contract (Envir.norm_term env1 t) aconv
       
   506           Envir.eta_contract (Envir.norm_term env2 t')))
       
   507       (apply2 Vartab.dest (Envir.term_env env1, Envir.term_env env2))
   562       (apply2 Vartab.dest (Envir.term_env env1, Envir.term_env env2))
   508     andalso
   563     andalso
   509     ListPair.allEq (fn ((var, (_, T)), (var', (_, T'))) =>
   564     ListPair.allEq (fn ((var, (_, T)), (var', (_, T'))) =>
   510         var = var' andalso Envir.norm_type tyenv1 T = Envir.norm_type tyenv2 T')
   565         var = var' andalso type_eq_wrt (env1,env2) (T,T'))
   511       (apply2 Vartab.dest (Envir.type_env env1, Envir.type_env env2))
   566       (apply2 Vartab.dest (Envir.type_env env1, Envir.type_env env2));
   512   end;
   567 
   513 
   568 
       
   569 fun merge_env (env1,env2) =
       
   570   let
       
   571     val tenv =
       
   572       Vartab.merge (eq_snd (term_eq_wrt (env1, env2))) (Envir.term_env env1, Envir.term_env env2);
       
   573     val tyenv =
       
   574       Vartab.merge (eq_snd (type_eq_wrt (env1, env2)) andf eq_fst (op =))
       
   575         (Envir.type_env env1,Envir.type_env env2);
       
   576     val maxidx = Int.max (Envir.maxidx_of env1, Envir.maxidx_of env2);
       
   577   in Envir.Envir {maxidx = maxidx, tenv = tenv, tyenv = tyenv} end;
       
   578 
       
   579 
       
   580 fun import_with_tags thms ctxt =
       
   581   let
       
   582     val ((_, thms'), ctxt') = Variable.import false thms ctxt;
       
   583     val thms'' = map2 (fn thm => Thm.map_tags (K (Thm.get_tags thm))) thms thms';
       
   584   in (thms'', ctxt') end;
       
   585 
       
   586 
       
   587 fun try_merge (env, env') = SOME (merge_env (env, env')) handle Vartab.DUP _ => NONE
       
   588 
       
   589 
       
   590 fun Seq_retrieve seq f =
       
   591   let
       
   592     fun retrieve' (list, seq) f =
       
   593       (case Seq.pull seq of
       
   594         SOME (x, seq') =>
       
   595           if f x then (SOME x, (list, seq'))
       
   596           else retrieve' (list @ [x], seq') f
       
   597       | NONE => (NONE, (list, seq)));
       
   598 
       
   599     val (result, (list, seq)) = retrieve' ([], seq) f;
       
   600   in (result, Seq.append (Seq.of_list list) seq) end;
   514 
   601 
   515 fun match_facts ctxt fixes prop_pats get =
   602 fun match_facts ctxt fixes prop_pats get =
   516   let
   603   let
   517     fun is_multi (((_, x : match_args), _), _) = #multi x;
   604     fun is_multi (((_, x : match_args), _), _) = #multi x;
   518     fun is_cut (_, x : match_args) = #cut x;
   605     fun get_cut (((_, x : match_args), _), _) = #cut x;
   519 
   606     fun do_cut n = if n = ~1 then I else Seq.take n;
   520     fun match_thm (((x, params), pat), thm) env  =
   607 
       
   608     val raw_thmss = map (get o snd) prop_pats;
       
   609     val (thmss,ctxt') = fold_burrow import_with_tags raw_thmss ctxt;
       
   610 
       
   611     val newly_fixed = Variable.is_newly_fixed ctxt' ctxt;
       
   612 
       
   613     val morphism = Variable.export_morphism ctxt' ctxt;
       
   614 
       
   615     fun match_thm (((x, params), pat), thm)  =
   521       let
   616       let
   522         val pat_vars = Term.add_vars pat [];
   617         val pat_vars = Term.add_vars pat [];
   523 
   618 
   524         val pat' = pat |> Envir.norm_term env;
       
   525 
       
   526         val (((Tinsts', insts), [thm']), inner_ctxt) = Variable.import false [thm] ctxt;
       
   527 
       
   528         val item' = Thm.prop_of thm';
       
   529 
       
   530         val ts = Option.map (fst o fst) (fst x);
   619         val ts = Option.map (fst o fst) (fst x);
   531 
   620 
   532         val outer_ctxt = ctxt |> Variable.declare_maxidx (Envir.maxidx_of env);
   621         val item' = Thm.prop_of thm;
   533 
       
   534         val morphism = Variable.export_morphism inner_ctxt outer_ctxt;
       
   535 
   622 
   536         val matches =
   623         val matches =
   537           (Unify.matchers (Context.Proof ctxt) [(pat', item')])
   624           (Unify.matchers (Context.Proof ctxt) [(pat, item')])
   538           |> Seq.filter consistent_env
   625           |> Seq.filter consistent_env
   539           |> Seq.map_filter (fn env' =>
   626           |> Seq.map_filter (fn env' =>
   540               match_filter_env inner_ctxt morphism pat_vars fixes
   627               (case match_filter_env newly_fixed pat_vars fixes params env' of
   541                 (ts, params) thm' (Envir.merge (env, env')))
   628                 SOME env'' => SOME (export_with_params ctxt morphism (ts,params) thm env',env'')
       
   629               | NONE => NONE))
   542           |> Seq.map (apfst (Thm.map_tags (K (Thm.get_tags thm))))
   630           |> Seq.map (apfst (Thm.map_tags (K (Thm.get_tags thm))))
   543           |> deduplicate (eq_snd eq_env) []
   631           |> deduplicate (eq_pair Thm.eq_thm_prop eq_env) []
   544           |> is_cut x ? (fn t => Seq.make (fn () =>
   632       in matches end;
   545             Option.map (fn (x, _) => (x, raise_match)) (Seq.pull t)));
       
   546       in
       
   547         matches
       
   548       end;
       
   549 
   633 
   550     val all_matches =
   634     val all_matches =
   551       map (fn pat => (pat, get (snd pat))) prop_pats
   635       map2 pair prop_pats thmss
   552       |> map (fn (pat, matches) => (pat, map (fn thm => match_thm (pat, thm)) matches));
   636       |> map (fn (pat, matches) => (pat, map (fn thm => match_thm (pat, thm)) matches));
   553 
   637 
   554     fun proc_multi_match (pat, thmenvs) (pats, env) =
   638     fun proc_multi_match (pat, thmenvs) (pats, env) =
   555       if is_multi pat then
   639       do_cut (get_cut pat)
   556         let
   640         (if is_multi pat then
   557           val empty = ([], Envir.empty ~1);
   641           let
   558 
   642             fun maximal_set tail seq envthms =
   559           val thmenvs' =
   643               Seq.make (fn () =>
   560             Seq.EVERY (map (fn e => fn (thms, env) =>
   644                 (case Seq.pull seq of
   561               Seq.append (Seq.map (fn (thm, env') => (thm :: thms, env')) (e env))
   645                   SOME ((thm, env'), seq') =>
   562                 (Seq.single (thms, env))) thmenvs) empty;
   646                     let
   563         in
   647                       val (result, envthms') =
   564           Seq.map_filter (fn (fact, env') =>
   648                         Seq_retrieve envthms (fn (env, _) => eq_env (env, env'));
   565             if not (null fact) then SOME ((pat, fact) :: pats, env') else NONE) thmenvs'
   649                     in
   566         end
   650                       (case result of
   567       else
   651                         SOME (_,thms) => SOME ((env', thm :: thms), maximal_set tail seq' envthms')
   568         fold (fn e => Seq.append (Seq.map (fn (thm, env') =>
   652                       | NONE => Seq.pull (maximal_set (tail @ [(env', [thm])]) seq' envthms'))
   569           ((pat, [thm]) :: pats, env')) (e env))) thmenvs Seq.empty;
   653                     end
       
   654                  | NONE => Seq.pull (Seq.append envthms (Seq.of_list tail))));
       
   655 
       
   656             val maximal_sets = fold (maximal_set []) thmenvs Seq.empty;
       
   657           in
       
   658             maximal_sets
       
   659             |> Seq.map swap
       
   660             |> Seq.filter (fn (thms, _) => not (null thms))
       
   661             |> Seq.map_filter (fn (thms, env') =>
       
   662               (case try_merge (env, env') of
       
   663                 SOME env'' => SOME ((pat, thms) :: pats, env'')
       
   664               | NONE => NONE))
       
   665           end
       
   666         else
       
   667           let
       
   668             fun just_one (thm, env') =
       
   669               (case try_merge (env,env') of
       
   670                 SOME env'' => SOME ((pat,[thm]) :: pats, env'')
       
   671               | NONE => NONE);
       
   672           in fold (fn seq => Seq.append (Seq.map_filter just_one seq)) thmenvs Seq.empty end);
   570 
   673 
   571     val all_matches =
   674     val all_matches =
   572       Seq.EVERY (map proc_multi_match all_matches) ([], Envir.empty ~1)
   675       Seq.EVERY (map proc_multi_match all_matches) ([], Envir.empty ~1);
   573   in
   676   in
   574     map_handle all_matches
   677     all_matches
       
   678     |> Seq.map (apsnd (morphism_env morphism))
   575   end;
   679   end;
   576 
   680 
   577 fun real_match using ctxt fixes m text pats goal =
   681 fun real_match using ctxt fixes m text pats goal =
   578   let
   682   let
   579     fun make_fact_matches ctxt get =
   683     fun make_fact_matches ctxt get =
   609         if Thm.no_prems goal then Seq.empty
   713         if Thm.no_prems goal then Seq.empty
   610         else
   714         else
   611           let
   715           let
   612             fun focus_cases f g =
   716             fun focus_cases f g =
   613               (case match_kind of
   717               (case match_kind of
   614                 Match_Prems => f
   718                 Match_Prems b => f b
   615               | Match_Concl => g
   719               | Match_Concl => g
   616               | _ => raise Fail "Match kind fell through");
   720               | _ => raise Fail "Match kind fell through");
   617 
   721 
   618             val ({context = focus_ctxt, params, asms, concl, ...}, focused_goal) =
   722             val (goal_thins,goal) = get_thinned_prems goal;
   619               focus_cases (Subgoal.focus_prems) (focus_concl) ctxt 1 goal
   723 
       
   724             val ((local_premids,{context = focus_ctxt, params, asms, concl, ...}), focused_goal) =
       
   725               focus_cases (K Subgoal.focus_prems) (focus_concl) ctxt 1 goal
   620               |>> augment_focus;
   726               |>> augment_focus;
   621 
   727 
   622             val texts =
   728             val texts =
   623               focus_cases
   729               focus_cases
   624                 (fn _ =>
   730                 (fn is_local => fn _ =>
   625                   make_fact_matches focus_ctxt
   731                   make_fact_matches focus_ctxt
   626                     (Item_Net.retrieve (focus_prems focus_ctxt |> snd) #>
   732                     (Item_Net.retrieve (focus_prems focus_ctxt |> snd)
   627                   order_list))
   733                      #> filter_out (member (eq_fst (op =)) goal_thins)
       
   734                      #> is_local ? filter (fn (p,_) => exists (fn id' => id' = p) local_premids)
       
   735                      #> order_list))
   628                 (fn _ =>
   736                 (fn _ =>
   629                   make_term_matches focus_ctxt (fn _ => [Logic.strip_imp_concl (Thm.term_of concl)]))
   737                   make_term_matches focus_ctxt (fn _ => [Logic.strip_imp_concl (Thm.term_of concl)]))
   630                 ();
   738                 ();
   631 
   739 
   632             (*TODO: How to handle cases? *)
   740             (*TODO: How to handle cases? *)
   633 
   741 
   634             fun do_retrofit inner_ctxt goal' =
   742             fun do_retrofit inner_ctxt goal' =
   635               let
   743               let
   636                 val cleared_prems =
   744                 val (goal'_thins,goal') = get_thinned_prems goal';
   637                   subtract (eq_fst (op =))
   745 
       
   746                 val thinned_prems =
       
   747                   ((subtract (eq_fst (op =))
   638                     (focus_prems inner_ctxt |> snd |> Item_Net.content)
   748                     (focus_prems inner_ctxt |> snd |> Item_Net.content)
   639                     (focus_prems focus_ctxt |> snd |> Item_Net.content)
   749                     (focus_prems focus_ctxt |> snd |> Item_Net.content))
   640                   |> map (fn (_, thm) =>
   750                     |> map (fn (id, thm) =>
   641                     Thm.hyps_of thm
   751                         #hyps (Thm.crep_thm thm)
   642                     |> (fn [hyp] => hyp | _ => error "Prem should have only one hyp"));
   752                         |> (fn [chyp] => (id, (SOME chyp, NONE))
       
   753                              | _ => error "Prem should have only one hyp")));
       
   754 
       
   755                 val all_thinned_prems =
       
   756                   thinned_prems @
       
   757                   map (fn (id, prem) => (id, (NONE, SOME prem))) (goal'_thins @ goal_thins);
       
   758 
       
   759                 val (thinned_local_prems,thinned_extra_prems) =
       
   760                   List.partition (fn (id, _) => member (op =) local_premids id) all_thinned_prems;
       
   761 
       
   762                 val local_thins =
       
   763                   thinned_local_prems
       
   764                   |> map (fn (_, (SOME t, _)) => Thm.term_of t
       
   765                            | (_, (_, SOME pt)) => Thm.term_of pt |> Logic.dest_term);
       
   766 
       
   767                 val extra_thins =
       
   768                   thinned_extra_prems
       
   769                   |> map (fn (id, (SOME ct, _)) => (id, Drule.mk_term ct |> Thm.cprop_of)
       
   770                            | (id, (_, SOME pt)) => (id, pt))
       
   771                   |> map (hyp_from_ctermid inner_ctxt);
   643 
   772 
   644                 val n_subgoals = Thm.nprems_of goal';
   773                 val n_subgoals = Thm.nprems_of goal';
   645                 fun prep_filter t =
   774                 fun prep_filter t =
   646                   Term.subst_bounds (map (Thm.term_of o snd) params |> rev, Term.strip_all_body t);
   775                   Term.subst_bounds (map (Thm.term_of o snd) params |> rev, Term.strip_all_body t);
   647                 fun filter_test prems t =
   776                 fun filter_test prems t =
   648                   if member (op =) prems t then SOME (remove1 (op aconv) t prems) else NONE;
   777                   if member (op =) prems t then SOME (remove1 (op aconv) t prems) else NONE;
   649               in
   778               in
   650                 Subgoal.retrofit inner_ctxt ctxt params asms 1 goal' goal |>
   779                 Subgoal.retrofit inner_ctxt ctxt params asms 1 goal' goal |>
   651                 (if n_subgoals = 0 orelse null cleared_prems then I
   780                 (if n_subgoals = 0 orelse null local_thins then I
   652                  else
   781                  else
   653                   Seq.map (Goal.restrict 1 n_subgoals)
   782                   Seq.map (Goal.restrict 1 n_subgoals)
   654                   #> Seq.maps (ALLGOALS (fn i =>
   783                   #> Seq.maps (ALLGOALS (fn i =>
   655                       DETERM (filter_prems_tac' ctxt prep_filter filter_test cleared_prems i)))
   784                       DETERM (filter_prems_tac' ctxt prep_filter filter_test local_thins i)))
   656                   #> Seq.map (Goal.unrestrict 1))
   785                   #> Seq.map (Goal.unrestrict 1))
       
   786                   |> Seq.map (fold Thm.weaken extra_thins)
   657               end;
   787               end;
   658 
   788 
   659             fun apply_text (text, ctxt') =
   789             fun apply_text (text, ctxt') =
   660               let
   790               let
   661                 val goal' =
   791                 val goal' =
   662                   DROP_CASES (Method_Closure.method_evaluate text ctxt' using) focused_goal
   792                   DROP_CASES (Method_Closure.method_evaluate text ctxt' using) focused_goal
   663                   |> Seq.maps (DETERM (do_retrofit ctxt'))
   793                   |> Seq.maps (DETERM (do_retrofit ctxt'))
   664                   |> Seq.map (fn goal => ([]: cases, goal))
   794                   |> Seq.map (fn goal => ([]: cases, goal));
   665               in goal' end;
   795               in goal' end;
   666           in
   796           in
   667             Seq.map apply_text texts
   797             Seq.map apply_text texts
   668           end)
   798           end)
   669   end;
   799   end;