src/HOL/Eisbach/match_method.ML
changeset 60285 b4f1a0a701ae
parent 60248 f7e4294216d2
child 60287 adde5ce1e0a7
--- a/src/HOL/Eisbach/match_method.ML	Wed May 13 19:12:59 2015 +0200
+++ b/src/HOL/Eisbach/match_method.ML	Sat May 16 12:05:52 2015 +0200
@@ -40,20 +40,20 @@
     Match_Term of term Item_Net.T
   | Match_Fact of thm Item_Net.T
   | Match_Concl
-  | Match_Prems;
+  | Match_Prems of bool;
 
 
 val aconv_net = Item_Net.init (op aconv) single;
 
 val parse_match_kind =
   Scan.lift @{keyword "conclusion"} >> K Match_Concl ||
-  Scan.lift @{keyword "premises"} >> K Match_Prems ||
+  Scan.lift (@{keyword "premises"} |-- Args.mode "local") >> Match_Prems ||
   Scan.lift (@{keyword "("}) |-- Args.term --| Scan.lift (@{keyword ")"}) >>
     (fn t => Match_Term (Item_Net.update t aconv_net)) ||
   Attrib.thms >> (fn thms => Match_Fact (fold Item_Net.update thms Thm.full_rules));
 
 
-fun nameable_match m = (case m of Match_Fact _ => true | Match_Prems => true | _ => false);
+fun nameable_match m = (case m of Match_Fact _ => true | Match_Prems _ => true | _ => false);
 fun prop_match m = (case m of Match_Term _ => false | _ => true);
 
 val bound_term : (term, binding) Parse_Tools.parse_val parser =
@@ -66,28 +66,24 @@
 
 val for_fixes = Scan.optional (@{keyword "for"} |-- fixes) [];
 
-fun pos_of dyn =
-  (case dyn of
-    Parse_Tools.Parse_Val (b, _) => Binding.pos_of b
-  | _ => raise Fail "Not a parse value");
-
+fun pos_of dyn = Parse_Tools.the_parse_val dyn |> Binding.pos_of;
 
 (*FIXME: Dynamic facts modify the background theory, so we have to resort
   to token replacement for matched facts. *)
 fun dynamic_fact ctxt =
   bound_term -- Args.opt_attribs (Attrib.check_name ctxt);
 
-type match_args = {multi : bool, cut : bool};
+type match_args = {multi : bool, cut : int};
 
 val parse_match_args =
   Scan.optional (Args.parens (Parse.enum1 ","
-    (Args.$$$ "multi" || Args.$$$ "cut"))) [] >>
+    (Args.$$$ "multi" -- Scan.succeed ~1 || Args.$$$ "cut" -- Scan.optional Parse.int 1))) [] >>
     (fn ss =>
       fold (fn s => fn {multi, cut} =>
         (case s of
-         "multi" => {multi = true, cut = cut}
-        | "cut" => {multi = multi, cut = true}))
-      ss {multi = false, cut = false});
+          ("multi", _) => {multi = true, cut = cut}
+        | ("cut", n) => {multi = multi, cut = n}))
+      ss {multi = false, cut = ~1});
 
 fun parse_named_pats match_kind =
   Args.context :|-- (fn ctxt =>
@@ -126,8 +122,22 @@
             then Syntax.parse_prop ctxt3 term
             else Syntax.parse_term ctxt3 term;
 
+          fun drop_Trueprop_dummy t =
+            (case t of
+              Const (@{const_name Trueprop}, _) $
+                (Const (@{syntax_const "_type_constraint_"}, T) $
+                  Const (@{const_name Pure.dummy_pattern}, _)) =>
+                    Const (@{syntax_const "_type_constraint_"}, T) $
+                      Const (@{const_name Pure.dummy_pattern}, propT)
+            | t1 $ t2 => drop_Trueprop_dummy t1 $ drop_Trueprop_dummy t2
+            | Abs (a, T, b) => Abs (a, T, drop_Trueprop_dummy b)
+            | _ => t);
+
           val pats =
             map (fn (_, (term, _)) => parse_term (Parse_Tools.the_parse_val term)) ts
+            |> map drop_Trueprop_dummy
+            |> (fn ts => fold_map Term.replace_dummy_patterns ts (Variable.maxidx_of ctxt3 + 1))
+            |> fst
             |> Syntax.check_terms ctxt3;
 
           val pat_fixes = fold (Term.add_frees) pats [] |> map fst;
@@ -138,7 +148,7 @@
                 error ("For-fixed variable must be bound in some pattern" ^ Position.here pos))
               fix_nms fixes;
 
-          val _ = map (Term.map_types Type.no_tvars) pats
+          val _ = map (Term.map_types Type.no_tvars) pats;
 
           val ctxt4 = fold Variable.declare_term pats ctxt3;
 
@@ -200,14 +210,14 @@
                 |> Variable.declare_maxidx (Variable.maxidx_of ctxt6));
 
           val pats' = map (Term.map_types Type_Infer.paramify_vars #> Morphism.term morphism) pats;
-          val _ = ListPair.app (fn ((_, (Parse_Tools.Parse_Val (_, f), _)), t) => f t) (ts, pats');
+          val _ = ListPair.app (fn ((_, (v, _)), t) => Parse_Tools.the_parse_fun v t) (ts, pats');
 
           fun close_src src =
             let
               val src' = Token.closure_src src |> Token.transform_src morphism;
               val _ =
                 map2 (fn tok1 => fn tok2 =>
-                  (case (Token.get_value tok2) of
+                  (case Token.get_value tok2 of
                     SOME value => Token.assign (SOME value) tok1
                   | NONE => ()))
                   (Token.args_of_src src)
@@ -219,14 +229,14 @@
 
           val _ =
             ListPair.app
-              (fn ((SOME ((Parse_Tools.Parse_Val (_, f), _)), _), SOME (t, _)) => f t
+              (fn ((SOME ((v, _)), _), SOME (t, _)) => Parse_Tools.the_parse_fun v t
                 | ((NONE, _), NONE) => ()
                 | _ => error "Mismatch between real and parsed bound variables")
               (ts, binds');
 
           val real_fixes' = map (Morphism.term morphism) real_fixes;
           val _ =
-            ListPair.app (fn (( (Parse_Tools.Parse_Val (_, f), _) , _), t) => f t)
+            ListPair.app (fn (((v, _) , _), t) => Parse_Tools.the_parse_fun v t)
               (fixes, real_fixes');
 
           val match_args = map (fn (_, (_, match_args)) => match_args) ts;
@@ -255,7 +265,6 @@
   let
     val ts' = map (Envir.norm_term env) ts;
     val insts = map (Thm.cterm_of ctxt) ts' ~~ map (Thm.cterm_of ctxt) params;
-
   in
     Drule.cterm_instantiate insts thm
   end;
@@ -309,7 +318,6 @@
     val pat'' :: params'' = map (Morphism.term morphism) (pat' :: params');
 
     fun prep_head (t, att) = (dest_internal_fact t, att);
-
   in
     ((((Option.map prep_head x, args), params''), pat''), ctxt')
   end;
@@ -326,19 +334,28 @@
         tenv = tenv, tyenv = tyenv}
   end
 
-fun match_filter_env inner_ctxt morphism pat_vars fixes (ts, params) thm inner_env =
+fun morphism_env morphism env =
   let
-    val tenv = Envir.term_env inner_env
+    val tenv = Envir.term_env env
       |> Vartab.map (K (fn (T, t) => (Morphism.typ morphism T, Morphism.term morphism t)));
+    val tyenv = Envir.type_env env
+      |> Vartab.map (K (fn (S, T) => (S, Morphism.typ morphism T)));
+   in Envir.Envir {maxidx = Envir.maxidx_of env, tenv = tenv, tyenv = tyenv} end;
 
-    val tyenv = Envir.type_env inner_env
-      |> Vartab.map (K (fn (S, T) => (S, Morphism.typ morphism T)));
+fun export_with_params ctxt morphism (SOME ts, params) thm env =
+      let
+        val outer_env = morphism_env morphism env;
+        val thm' = Morphism.thm morphism thm;
+      in inst_thm ctxt outer_env params ts thm' end
+  | export_with_params _ morphism (NONE,_) thm _ = Morphism.thm morphism thm;
 
-    val outer_env = Envir.Envir {maxidx = Envir.maxidx_of inner_env, tenv = tenv, tyenv = tyenv};
-
+fun match_filter_env is_newly_fixed pat_vars fixes params env =
+  let
     val param_vars = map Term.dest_Var params;
 
-    val params' = map (fn (xi, _) => Vartab.lookup (Envir.term_env outer_env) xi) param_vars;
+    val tenv = Envir.term_env env;
+
+    val params' = map (fn (xi, _) => Vartab.lookup tenv xi) param_vars;
 
     val fixes_vars = map Term.dest_Var fixes;
 
@@ -346,24 +363,21 @@
 
     val extra_vars = subtract (fn ((xi, _), xi') => xi = xi') fixes_vars all_vars;
 
-    val tenv' = tenv
-      |> fold (Vartab.delete_safe) extra_vars;
+    val tenv' = tenv |> fold (Vartab.delete_safe) extra_vars;
 
     val env' =
-      Envir.Envir {maxidx = Envir.maxidx_of outer_env, tenv = tenv', tyenv = tyenv}
-      |> recalculate_maxidx;
+      Envir.Envir {maxidx = Envir.maxidx_of env, tenv = tenv', tyenv = Envir.type_env env}
 
-    val all_params_bound = forall (fn SOME (_, Var _) => true | _ => false) params';
+    val all_params_bound = forall (fn SOME (_, Free (x,_)) => is_newly_fixed x | _ => false) params';
+
+    val all_params_distinct = not (has_duplicates (op =) params');
 
     val pat_fixes = inter (eq_fst (op =)) fixes_vars pat_vars;
 
     val all_pat_fixes_bound = forall (fn (xi, _) => is_some (Vartab.lookup tenv' xi)) pat_fixes;
-
-    val thm' = Morphism.thm morphism thm;
-
   in
-    if all_params_bound andalso all_pat_fixes_bound then
-      SOME (case ts of SOME ts => inst_thm inner_ctxt outer_env params ts thm' | _ => thm', env')
+    if all_params_bound andalso all_pat_fixes_bound andalso all_params_distinct
+    then SOME env'
     else NONE
   end;
 
@@ -374,7 +388,7 @@
 fun prem_id_eq ((id, _ : thm), (id', _ : thm)) = id = id';
 
 val prem_rules : (int * thm) Item_Net.T =
-   Item_Net.init prem_id_eq (single o Thm.full_prop_of o snd);
+  Item_Net.init prem_id_eq (single o Thm.full_prop_of o snd);
 
 fun raw_thm_to_id thm =
   (case Properties.get (Thm.get_tags thm) prem_idN of NONE => NONE | SOME id => Int.fromString id)
@@ -394,13 +408,34 @@
 
 val focus_prems = #1 o Focus_Data.get;
 
+fun hyp_from_premid ctxt (ident, prem) =
+  let
+    val ident = Thm.cterm_of ctxt (HOLogic.mk_number @{typ nat} ident |> Logic.mk_term);
+    val hyp =
+      (case #hyps (Thm.crep_thm prem) of
+        [hyp] => hyp
+      | _ => error "Prem should have exactly one hyp");  (* FIXME error vs. raise Fail !? *)
+    val ct = Drule.mk_term (hyp) |> Thm.cprop_of;
+  in Drule.protect (Conjunction.mk_conjunction (ident, ct)) end;
+
+fun hyp_from_ctermid ctxt (ident,cterm) =
+  let
+    val ident = Thm.cterm_of ctxt (HOLogic.mk_number @{typ nat} ident |> Logic.mk_term);
+  in Drule.protect (Conjunction.mk_conjunction (ident, cterm)) end;
+
+fun add_premid_hyp premid ctxt =
+  Thm.declare_hyps (hyp_from_premid ctxt premid) ctxt;
+
 fun add_focus_prem prem =
+  `(Focus_Data.get #> #1 #> #1) ##>
   (Focus_Data.map o @{apply 3(1)}) (fn (next, net) =>
     (next + 1, Item_Net.update (next, Thm.tag_rule (prem_idN, string_of_int next) prem) net));
 
-fun remove_focus_prem thm =
+fun remove_focus_prem' (ident, thm) =
   (Focus_Data.map o @{apply 3(1)} o apsnd)
-    (Item_Net.remove (raw_thm_to_id thm, thm));
+    (Item_Net.remove (ident, thm));
+
+fun remove_focus_prem thm = remove_focus_prem' (raw_thm_to_id thm, thm);
 
 (*TODO: Preliminary analysis to see if we're trying to clear in a non-focus match?*)
 val _ =
@@ -429,22 +464,48 @@
   (Focus_Data.map o @{apply 3(3)})
     (append (map (fn (_, ct) => Thm.term_of ct) params));
 
+fun solve_term ct = Thm.trivial ct OF [Drule.termI];
+
+fun get_thinned_prems goal =
+  let
+    val chyps = Thm.crep_thm goal |> #hyps;
+
+    fun prem_from_hyp hyp goal =
+    let
+      val asm = Thm.assume hyp;
+      val (identt,ct) = asm |> Goal.conclude |> Thm.cprop_of |> Conjunction.dest_conjunction;
+      val ident = HOLogic.dest_number (Thm.term_of identt |> Logic.dest_term) |> snd;
+      val thm = Conjunction.intr (solve_term identt) (solve_term ct) |> Goal.protect 0
+      val goal' = Thm.implies_elim (Thm.implies_intr hyp goal) thm;
+    in
+      (SOME (ident,ct),goal')
+    end handle TERM _ => (NONE,goal) | THM _ => (NONE,goal);
+  in
+    fold_map prem_from_hyp chyps goal
+    |>> map_filter I
+  end;
+
 
 (* Add focus elements as proof data *)
-fun augment_focus
-    ({context, params, prems, asms, concl, schematics} : Subgoal.focus) : Subgoal.focus =
+fun augment_focus (focus: Subgoal.focus) : (int list * Subgoal.focus) =
   let
-    val context' = context
+    val {context, params, prems, asms, concl, schematics} = focus;
+
+    val (prem_ids,ctxt') = context
       |> add_focus_params params
       |> add_focus_schematics (snd schematics)
-      |> fold add_focus_prem (rev prems);
+      |> fold_map add_focus_prem (rev prems)
+
+    val local_prems = map2 pair prem_ids (rev prems);
+
+    val ctxt'' = fold add_premid_hyp local_prems ctxt';
   in
-    {context = context',
+    (prem_ids,{context = ctxt'',
      params = params,
      prems = prems,
      concl = concl,
      schematics = schematics,
-     asms = asms}
+     asms = asms})
   end;
 
 
@@ -467,25 +528,17 @@
       schematics = schematics', asms = asms} : Subgoal.focus, goal'')
   end;
 
-exception MATCH_CUT;
-
-val raise_match : (thm * Envir.env) Seq.seq = Seq.make (fn () => raise MATCH_CUT);
-
-fun map_handle seq =
-  Seq.make (fn () =>
-    (case (Seq.pull seq handle MATCH_CUT => NONE) of
-      SOME (x, seq') => SOME (x, map_handle seq')
-    | NONE => NONE));
 
 fun deduplicate eq prev seq =
   Seq.make (fn () =>
-    (case (Seq.pull seq) of
+    (case Seq.pull seq of
       SOME (x, seq') =>
         if member eq prev x
         then Seq.pull (deduplicate eq prev seq')
         else SOME (x, deduplicate eq (x :: prev) seq')
     | NONE => NONE));
 
+
 fun consistent_env env =
   let
     val tenv = Envir.term_env env;
@@ -494,84 +547,135 @@
     forall (fn (_, (T, t)) => Envir.norm_type tyenv T = fastype_of t) (Vartab.dest tenv)
   end;
 
+fun term_eq_wrt (env1,env2) (t1,t2) =
+  Envir.eta_contract (Envir.norm_term env1 t1) aconv
+  Envir.eta_contract (Envir.norm_term env2 t2);
+
+fun type_eq_wrt (env1,env2) (T1,T2) =
+  Envir.norm_type (Envir.type_env env1) T1 = Envir.norm_type (Envir.type_env env2) T2
+
+
 fun eq_env (env1, env2) =
-  let
-    val tyenv1 = Envir.type_env env1;
-    val tyenv2 = Envir.type_env env2;
-  in
     Envir.maxidx_of env1 = Envir.maxidx_of env1 andalso
     ListPair.allEq (fn ((var, (_, t)), (var', (_, t'))) =>
-        (var = var' andalso
-          Envir.eta_contract (Envir.norm_term env1 t) aconv
-          Envir.eta_contract (Envir.norm_term env2 t')))
+        (var = var' andalso term_eq_wrt (env1,env2) (t,t')))
       (apply2 Vartab.dest (Envir.term_env env1, Envir.term_env env2))
     andalso
     ListPair.allEq (fn ((var, (_, T)), (var', (_, T'))) =>
-        var = var' andalso Envir.norm_type tyenv1 T = Envir.norm_type tyenv2 T')
-      (apply2 Vartab.dest (Envir.type_env env1, Envir.type_env env2))
-  end;
+        var = var' andalso type_eq_wrt (env1,env2) (T,T'))
+      (apply2 Vartab.dest (Envir.type_env env1, Envir.type_env env2));
+
+
+fun merge_env (env1,env2) =
+  let
+    val tenv =
+      Vartab.merge (eq_snd (term_eq_wrt (env1, env2))) (Envir.term_env env1, Envir.term_env env2);
+    val tyenv =
+      Vartab.merge (eq_snd (type_eq_wrt (env1, env2)) andf eq_fst (op =))
+        (Envir.type_env env1,Envir.type_env env2);
+    val maxidx = Int.max (Envir.maxidx_of env1, Envir.maxidx_of env2);
+  in Envir.Envir {maxidx = maxidx, tenv = tenv, tyenv = tyenv} end;
+
 
+fun import_with_tags thms ctxt =
+  let
+    val ((_, thms'), ctxt') = Variable.import false thms ctxt;
+    val thms'' = map2 (fn thm => Thm.map_tags (K (Thm.get_tags thm))) thms thms';
+  in (thms'', ctxt') end;
+
+
+fun try_merge (env, env') = SOME (merge_env (env, env')) handle Vartab.DUP _ => NONE
+
+
+fun Seq_retrieve seq f =
+  let
+    fun retrieve' (list, seq) f =
+      (case Seq.pull seq of
+        SOME (x, seq') =>
+          if f x then (SOME x, (list, seq'))
+          else retrieve' (list @ [x], seq') f
+      | NONE => (NONE, (list, seq)));
+
+    val (result, (list, seq)) = retrieve' ([], seq) f;
+  in (result, Seq.append (Seq.of_list list) seq) end;
 
 fun match_facts ctxt fixes prop_pats get =
   let
     fun is_multi (((_, x : match_args), _), _) = #multi x;
-    fun is_cut (_, x : match_args) = #cut x;
+    fun get_cut (((_, x : match_args), _), _) = #cut x;
+    fun do_cut n = if n = ~1 then I else Seq.take n;
+
+    val raw_thmss = map (get o snd) prop_pats;
+    val (thmss,ctxt') = fold_burrow import_with_tags raw_thmss ctxt;
 
-    fun match_thm (((x, params), pat), thm) env  =
+    val newly_fixed = Variable.is_newly_fixed ctxt' ctxt;
+
+    val morphism = Variable.export_morphism ctxt' ctxt;
+
+    fun match_thm (((x, params), pat), thm)  =
       let
         val pat_vars = Term.add_vars pat [];
 
-        val pat' = pat |> Envir.norm_term env;
-
-        val (((Tinsts', insts), [thm']), inner_ctxt) = Variable.import false [thm] ctxt;
-
-        val item' = Thm.prop_of thm';
-
         val ts = Option.map (fst o fst) (fst x);
 
-        val outer_ctxt = ctxt |> Variable.declare_maxidx (Envir.maxidx_of env);
-
-        val morphism = Variable.export_morphism inner_ctxt outer_ctxt;
+        val item' = Thm.prop_of thm;
 
         val matches =
-          (Unify.matchers (Context.Proof ctxt) [(pat', item')])
+          (Unify.matchers (Context.Proof ctxt) [(pat, item')])
           |> Seq.filter consistent_env
           |> Seq.map_filter (fn env' =>
-              match_filter_env inner_ctxt morphism pat_vars fixes
-                (ts, params) thm' (Envir.merge (env, env')))
+              (case match_filter_env newly_fixed pat_vars fixes params env' of
+                SOME env'' => SOME (export_with_params ctxt morphism (ts,params) thm env',env'')
+              | NONE => NONE))
           |> Seq.map (apfst (Thm.map_tags (K (Thm.get_tags thm))))
-          |> deduplicate (eq_snd eq_env) []
-          |> is_cut x ? (fn t => Seq.make (fn () =>
-            Option.map (fn (x, _) => (x, raise_match)) (Seq.pull t)));
-      in
-        matches
-      end;
+          |> deduplicate (eq_pair Thm.eq_thm_prop eq_env) []
+      in matches end;
 
     val all_matches =
-      map (fn pat => (pat, get (snd pat))) prop_pats
+      map2 pair prop_pats thmss
       |> map (fn (pat, matches) => (pat, map (fn thm => match_thm (pat, thm)) matches));
 
     fun proc_multi_match (pat, thmenvs) (pats, env) =
-      if is_multi pat then
-        let
-          val empty = ([], Envir.empty ~1);
+      do_cut (get_cut pat)
+        (if is_multi pat then
+          let
+            fun maximal_set tail seq envthms =
+              Seq.make (fn () =>
+                (case Seq.pull seq of
+                  SOME ((thm, env'), seq') =>
+                    let
+                      val (result, envthms') =
+                        Seq_retrieve envthms (fn (env, _) => eq_env (env, env'));
+                    in
+                      (case result of
+                        SOME (_,thms) => SOME ((env', thm :: thms), maximal_set tail seq' envthms')
+                      | NONE => Seq.pull (maximal_set (tail @ [(env', [thm])]) seq' envthms'))
+                    end
+                 | NONE => Seq.pull (Seq.append envthms (Seq.of_list tail))));
 
-          val thmenvs' =
-            Seq.EVERY (map (fn e => fn (thms, env) =>
-              Seq.append (Seq.map (fn (thm, env') => (thm :: thms, env')) (e env))
-                (Seq.single (thms, env))) thmenvs) empty;
-        in
-          Seq.map_filter (fn (fact, env') =>
-            if not (null fact) then SOME ((pat, fact) :: pats, env') else NONE) thmenvs'
-        end
-      else
-        fold (fn e => Seq.append (Seq.map (fn (thm, env') =>
-          ((pat, [thm]) :: pats, env')) (e env))) thmenvs Seq.empty;
+            val maximal_sets = fold (maximal_set []) thmenvs Seq.empty;
+          in
+            maximal_sets
+            |> Seq.map swap
+            |> Seq.filter (fn (thms, _) => not (null thms))
+            |> Seq.map_filter (fn (thms, env') =>
+              (case try_merge (env, env') of
+                SOME env'' => SOME ((pat, thms) :: pats, env'')
+              | NONE => NONE))
+          end
+        else
+          let
+            fun just_one (thm, env') =
+              (case try_merge (env,env') of
+                SOME env'' => SOME ((pat,[thm]) :: pats, env'')
+              | NONE => NONE);
+          in fold (fn seq => Seq.append (Seq.map_filter just_one seq)) thmenvs Seq.empty end);
 
     val all_matches =
-      Seq.EVERY (map proc_multi_match all_matches) ([], Envir.empty ~1)
+      Seq.EVERY (map proc_multi_match all_matches) ([], Envir.empty ~1);
   in
-    map_handle all_matches
+    all_matches
+    |> Seq.map (apsnd (morphism_env morphism))
   end;
 
 fun real_match using ctxt fixes m text pats goal =
@@ -611,20 +715,24 @@
           let
             fun focus_cases f g =
               (case match_kind of
-                Match_Prems => f
+                Match_Prems b => f b
               | Match_Concl => g
               | _ => raise Fail "Match kind fell through");
 
-            val ({context = focus_ctxt, params, asms, concl, ...}, focused_goal) =
-              focus_cases (Subgoal.focus_prems) (focus_concl) ctxt 1 goal
+            val (goal_thins,goal) = get_thinned_prems goal;
+
+            val ((local_premids,{context = focus_ctxt, params, asms, concl, ...}), focused_goal) =
+              focus_cases (K Subgoal.focus_prems) (focus_concl) ctxt 1 goal
               |>> augment_focus;
 
             val texts =
               focus_cases
-                (fn _ =>
+                (fn is_local => fn _ =>
                   make_fact_matches focus_ctxt
-                    (Item_Net.retrieve (focus_prems focus_ctxt |> snd) #>
-                  order_list))
+                    (Item_Net.retrieve (focus_prems focus_ctxt |> snd)
+                     #> filter_out (member (eq_fst (op =)) goal_thins)
+                     #> is_local ? filter (fn (p,_) => exists (fn id' => id' = p) local_premids)
+                     #> order_list))
                 (fn _ =>
                   make_term_matches focus_ctxt (fn _ => [Logic.strip_imp_concl (Thm.term_of concl)]))
                 ();
@@ -633,13 +741,34 @@
 
             fun do_retrofit inner_ctxt goal' =
               let
-                val cleared_prems =
-                  subtract (eq_fst (op =))
+                val (goal'_thins,goal') = get_thinned_prems goal';
+
+                val thinned_prems =
+                  ((subtract (eq_fst (op =))
                     (focus_prems inner_ctxt |> snd |> Item_Net.content)
-                    (focus_prems focus_ctxt |> snd |> Item_Net.content)
-                  |> map (fn (_, thm) =>
-                    Thm.hyps_of thm
-                    |> (fn [hyp] => hyp | _ => error "Prem should have only one hyp"));
+                    (focus_prems focus_ctxt |> snd |> Item_Net.content))
+                    |> map (fn (id, thm) =>
+                        #hyps (Thm.crep_thm thm)
+                        |> (fn [chyp] => (id, (SOME chyp, NONE))
+                             | _ => error "Prem should have only one hyp")));
+
+                val all_thinned_prems =
+                  thinned_prems @
+                  map (fn (id, prem) => (id, (NONE, SOME prem))) (goal'_thins @ goal_thins);
+
+                val (thinned_local_prems,thinned_extra_prems) =
+                  List.partition (fn (id, _) => member (op =) local_premids id) all_thinned_prems;
+
+                val local_thins =
+                  thinned_local_prems
+                  |> map (fn (_, (SOME t, _)) => Thm.term_of t
+                           | (_, (_, SOME pt)) => Thm.term_of pt |> Logic.dest_term);
+
+                val extra_thins =
+                  thinned_extra_prems
+                  |> map (fn (id, (SOME ct, _)) => (id, Drule.mk_term ct |> Thm.cprop_of)
+                           | (id, (_, SOME pt)) => (id, pt))
+                  |> map (hyp_from_ctermid inner_ctxt);
 
                 val n_subgoals = Thm.nprems_of goal';
                 fun prep_filter t =
@@ -648,12 +777,13 @@
                   if member (op =) prems t then SOME (remove1 (op aconv) t prems) else NONE;
               in
                 Subgoal.retrofit inner_ctxt ctxt params asms 1 goal' goal |>
-                (if n_subgoals = 0 orelse null cleared_prems then I
+                (if n_subgoals = 0 orelse null local_thins then I
                  else
                   Seq.map (Goal.restrict 1 n_subgoals)
                   #> Seq.maps (ALLGOALS (fn i =>
-                      DETERM (filter_prems_tac' ctxt prep_filter filter_test cleared_prems i)))
+                      DETERM (filter_prems_tac' ctxt prep_filter filter_test local_thins i)))
                   #> Seq.map (Goal.unrestrict 1))
+                  |> Seq.map (fold Thm.weaken extra_thins)
               end;
 
             fun apply_text (text, ctxt') =
@@ -661,7 +791,7 @@
                 val goal' =
                   DROP_CASES (Method_Closure.method_evaluate text ctxt' using) focused_goal
                   |> Seq.maps (DETERM (do_retrofit ctxt'))
-                  |> Seq.map (fn goal => ([]: cases, goal))
+                  |> Seq.map (fn goal => ([]: cases, goal));
               in goal' end;
           in
             Seq.map apply_text texts