src/HOL/Eisbach/match_method.ML
changeset 60209 022ca2799c73
parent 60119 54bea620e54f
child 60248 f7e4294216d2
--- a/src/HOL/Eisbach/match_method.ML	Thu Apr 30 17:00:50 2015 +0200
+++ b/src/HOL/Eisbach/match_method.ML	Thu Apr 30 17:02:57 2015 +0200
@@ -46,8 +46,8 @@
 val aconv_net = Item_Net.init (op aconv) single;
 
 val parse_match_kind =
-  Scan.lift @{keyword "concl"} >> K Match_Concl ||
-  Scan.lift @{keyword "prems"} >> K Match_Prems ||
+  Scan.lift @{keyword "conclusion"} >> K Match_Concl ||
+  Scan.lift @{keyword "premises"} >> K Match_Prems ||
   Scan.lift (@{keyword "("}) |-- Args.term --| Scan.lift (@{keyword ")"}) >>
     (fn t => Match_Term (Item_Net.update t aconv_net)) ||
   Attrib.thms >> (fn thms => Match_Fact (fold Item_Net.update thms Thm.full_rules));
@@ -60,8 +60,8 @@
   Parse_Tools.parse_term_val Parse.binding;
 
 val fixes =
-  Parse.and_list1 (Scan.repeat1 bound_term --
-    Scan.option (@{keyword "::"} |-- Parse.!!! Parse.typ) >> (fn (xs, T) => map (rpair T) xs))
+  Parse.and_list1 (Scan.repeat1  (Parse.position bound_term) --
+    Scan.option (@{keyword "::"} |-- Parse.!!! Parse.typ) >> (fn (xs, T) => map (fn (nm,pos) => ((nm,T),pos)) xs))
   >> flat;
 
 val for_fixes = Scan.optional (@{keyword "for"} |-- fixes) [];
@@ -77,18 +77,17 @@
 fun dynamic_fact ctxt =
   bound_term -- Args.opt_attribs (Attrib.check_name ctxt);
 
-type match_args = {unify : bool, multi : bool, cut : bool};
+type match_args = {multi : bool, cut : bool};
 
 val parse_match_args =
   Scan.optional (Args.parens (Parse.enum1 ","
-    (Args.$$$ "unify" || Args.$$$ "multi" || Args.$$$ "cut"))) [] >>
+    (Args.$$$ "multi" || Args.$$$ "cut"))) [] >>
     (fn ss =>
-      fold (fn s => fn {unify, multi, cut} =>
+      fold (fn s => fn {multi, cut} =>
         (case s of
-          "unify" => {unify = true, multi = multi, cut = cut}
-        | "multi" => {unify = unify, multi = true, cut = cut}
-        | "cut" => {unify = unify, multi = multi, cut = true}))
-      ss {unify = false, multi = false, cut = false});
+         "multi" => {multi = true, cut = cut}
+        | "cut" => {multi = multi, cut = true}))
+      ss {multi = false, cut = false});
 
 (*TODO: Shape holes in thms *)
 fun parse_named_pats match_kind =
@@ -114,12 +113,12 @@
                   (Parse_Tools.the_real_val b,
                     map (Attrib.attribute ctxt) att)) b, match_args), v)
                 | _ => raise Fail "Expected closed term") ts
-          val fixes' = map (fn (p, _) => Parse_Tools.the_real_val p) fixes
+          val fixes' = map (fn ((p, _),_) => Parse_Tools.the_real_val p) fixes
         in (ts', fixes', text) end
     | SOME _ => error "Unexpected token value in match cartouche"
     | NONE =>
         let
-          val fixes' = map (fn (pb, otyp) => (Parse_Tools.the_parse_val pb, otyp, NoSyn)) fixes;
+          val fixes' = map (fn ((pb, otyp),_) => (Parse_Tools.the_parse_val pb, otyp, NoSyn)) fixes;
           val (fixes'', ctxt1) = Proof_Context.read_vars fixes' ctxt;
           val (fix_nms, ctxt2) = Proof_Context.add_fixes fixes'' ctxt1;
 
@@ -134,6 +133,14 @@
             map (fn (_, (term, _)) => parse_term (Parse_Tools.the_parse_val term)) ts
             |> Syntax.check_terms ctxt3;
 
+          val pat_fixes = fold (Term.add_frees) pats [] |> map fst;
+
+          val _ = map2 (fn nm => fn (_,pos) => member (op =) pat_fixes nm orelse
+            error ("For-fixed variable must be bound in some pattern." ^ Position.here pos))
+            fix_nms fixes;
+
+          val _ = map (Term.map_types Type.no_tvars) pats
+
           val ctxt4 = fold Variable.declare_term pats ctxt3;
 
           val (Ts, ctxt5) = ctxt4 |> fold_map Proof_Context.inferred_param fix_nms;
@@ -146,12 +153,6 @@
             | reject_extra_free _ () = ();
           val _ = (fold o fold_aterms) reject_extra_free pats ();
 
-          (*fun test_multi_bind {multi = multi, ...} pat = multi andalso
-           not (null (inter (op =) (map Free (Term.add_frees pat [])) real_fixes)) andalso
-           error "Cannot fix terms in multi-match. Use a schematic instead."
-
-          val _ = map2 (fn pat => fn (_, (_, match_args)) => test_multi_bind match_args pat) pats ts*)
-
           val binds =
             map (fn (b, _) => Option.map (fn (b, att) => (Parse_Tools.the_parse_val b, att)) b) ts;
 
@@ -184,7 +185,7 @@
             |> (fn ctxt => fold2 upd_ctxt binds pats ([], ctxt) |> apfst rev)
             ||> Proof_Context.restore_mode ctxt;
 
-          val (src, text) = Method_Closure.read_text_closure ctxt6 (Token.input_of cartouche);
+          val (src, text) = Method_Closure.read_inner_text_closure ctxt6 (Token.input_of cartouche);
 
           val morphism =
             Variable.export_morphism ctxt6
@@ -206,7 +207,7 @@
 
           val real_fixes' = map (Morphism.term morphism) real_fixes;
           val _ =
-            ListPair.app (fn ((Parse_Tools.Parse_Val (_, f), _), t) => f t) (fixes, real_fixes');
+            ListPair.app (fn (( (Parse_Tools.Parse_Val (_, f),_) , _), t) => f t) (fixes, real_fixes');
 
           val match_args = map (fn (_, (_, match_args)) => match_args) ts;
           val binds'' = (binds' ~~ match_args) ~~ pats';
@@ -234,19 +235,9 @@
   let
     val ts' = map (Envir.norm_term env) ts;
     val insts = map (Thm.cterm_of ctxt) ts' ~~ map (Thm.cterm_of ctxt) params;
-    val tags = Thm.get_tags thm;
 
-   (*
-    val Tinsts = Type.raw_matches ((map (fastype_of) params), (map (fastype_of) ts')) Vartab.empty
-    |> Vartab.dest
-    |> map (fn (xi, (S, typ)) => (certT (TVar (xi, S)), certT typ))
-   *)
-
-    val thm' = Drule.cterm_instantiate insts thm
-    (*|> Thm.instantiate (Tinsts, [])*)
-      |> Thm.map_tags (K tags);
   in
-    thm'
+    Drule.cterm_instantiate insts thm
   end;
 
 fun do_inst fact_insts' env text ctxt =
@@ -282,6 +273,7 @@
 
     val morphism =
       Morphism.term_morphism "do_inst.term" (Envir.norm_term env) $>
+      Morphism.typ_morphism "do_inst.type" (Envir.norm_type (Envir.type_env env)) $>
       Morphism.fact_morphism "do_inst.fact" (maps expand_fact);
 
     val text' = Method.map_source (Token.transform_src morphism) text;
@@ -303,32 +295,61 @@
     val pat'' :: params'' = map (Morphism.term morphism) (pat' :: params');
 
     fun prep_head (t, att) = (dest_internal_fact t, att);
+
   in
     ((((Option.map prep_head x, args), params''), pat''), ctxt')
   end;
 
-fun match_filter_env ctxt fixes (ts, params) thm env =
+fun recalculate_maxidx env =
   let
+    val tenv = Envir.term_env env;
+    val tyenv = Envir.type_env env;
+    val max_tidx = Vartab.fold (fn (_,(_,t)) => curry Int.max (maxidx_of_term t)) tenv ~1;
+    val max_Tidx = Vartab.fold (fn (_,(_,T)) => curry Int.max (maxidx_of_typ T)) tyenv ~1;
+  in
+    Envir.Envir
+      {maxidx = Int.max (Int.max (max_tidx,max_Tidx),Envir.maxidx_of env),
+        tenv = tenv, tyenv = tyenv}
+  end
+
+fun match_filter_env inner_ctxt morphism pat_vars fixes (ts, params) thm inner_env =
+  let
+    val tenv = Envir.term_env inner_env
+      |> Vartab.map (K (fn (T,t) => (Morphism.typ morphism T,Morphism.term morphism t)));
+
+    val tyenv = Envir.type_env inner_env
+      |> Vartab.map (K (fn (S,T) => (S,Morphism.typ morphism T)));
+
+    val outer_env = Envir.Envir {maxidx = Envir.maxidx_of inner_env, tenv = tenv, tyenv = tyenv};
+
     val param_vars = map Term.dest_Var params;
-    val params' = map (Envir.lookup env) param_vars;
+
+    val params' = map (fn (xi,_) => Vartab.lookup (Envir.term_env outer_env) xi) param_vars;
 
     val fixes_vars = map Term.dest_Var fixes;
 
-    val tenv = Envir.term_env env;
     val all_vars = Vartab.keys tenv;
 
     val extra_vars = subtract (fn ((xi, _), xi') => xi = xi') fixes_vars all_vars;
 
-    val tenv' = Envir.term_env env
+    val tenv' = tenv
       |> fold (Vartab.delete_safe) extra_vars;
 
     val env' =
-      Envir.Envir {maxidx = Envir.maxidx_of env, tenv = tenv', tyenv = Envir.type_env env};
+      Envir.Envir {maxidx = Envir.maxidx_of outer_env, tenv = tenv', tyenv = tyenv}
+      |> recalculate_maxidx;
+
+    val all_params_bound = forall (fn SOME (_,(Var _)) => true | _ => false) params';
+
+    val pat_fixes = inter (eq_fst (op =)) fixes_vars pat_vars;
 
-    val all_params_bound = forall (fn SOME (Var _) => true | _ => false) params';
+    val all_pat_fixes_bound = forall (fn (xi,_) => is_some (Vartab.lookup tenv' xi)) pat_fixes;
+
+    val thm' = Morphism.thm morphism thm;
+
   in
-    if all_params_bound
-    then SOME (case ts of SOME ts => inst_thm ctxt env params ts thm | _ => thm, env')
+    if all_params_bound andalso all_pat_fixes_bound then
+      SOME (case ts of SOME ts => inst_thm inner_ctxt outer_env params ts thm' | _ => thm', env')
     else NONE
   end;
 
@@ -436,28 +457,78 @@
 
 val raise_match : (thm * Envir.env) Seq.seq = Seq.make (fn () => raise MATCH_CUT);
 
+fun map_handle seq =
+  Seq.make (fn () =>
+    (case (Seq.pull seq handle MATCH_CUT => NONE) of
+      SOME (x, seq') => SOME (x, map_handle seq')
+    | NONE => NONE));
+
+fun deduplicate eq prev seq =
+  Seq.make (fn () =>
+    (case (Seq.pull seq) of
+      SOME (x, seq') =>
+        if member eq prev x
+        then Seq.pull (deduplicate eq prev seq')
+        else SOME (x,deduplicate eq (x :: prev) seq')
+    | NONE => NONE));
+
+fun consistent_env env =
+  let
+    val tenv = Envir.term_env env;
+    val tyenv = Envir.type_env env;
+  in
+    forall (fn (_, (T, t)) => Envir.norm_type tyenv T = fastype_of t) (Vartab.dest tenv)
+  end;
+
+fun eq_env (env1,env2) =
+  let
+    val tyenv1 = Envir.type_env env1;
+    val tyenv2 = Envir.type_env env2;
+  in
+    Envir.maxidx_of env1 = Envir.maxidx_of env1 andalso
+    ListPair.allEq (fn ((var, (_, t)), (var', (_, t'))) =>
+        (var = var' andalso
+          Envir.eta_contract (Envir.norm_term env1 t) aconv
+          Envir.eta_contract (Envir.norm_term env2 t')))
+      (apply2 Vartab.dest (Envir.term_env env1, Envir.term_env env2))
+    andalso
+    ListPair.allEq (fn ((var, (_, T)), (var', (_,T'))) =>
+        var = var' andalso Envir.norm_type tyenv1 T = Envir.norm_type tyenv2 T')
+      (apply2 Vartab.dest (Envir.type_env env1, Envir.type_env env2))
+  end;
+
+
 fun match_facts ctxt fixes prop_pats get =
   let
     fun is_multi (((_, x : match_args), _), _) = #multi x;
-    fun is_unify (_, x : match_args) = #unify x;
     fun is_cut (_, x : match_args) = #cut x;
 
     fun match_thm (((x, params), pat), thm) env  =
       let
         fun try_dest_term term = the_default term (try Logic.dest_term term);
 
+        val pat_vars = Term.add_vars pat [];
+
         val pat' = pat |> Envir.norm_term env |> try_dest_term;
 
-        val item' = Thm.prop_of thm |> try_dest_term;
+        val (((Tinsts', insts),[thm']), inner_ctxt) = Variable.import false [thm] ctxt
+
+        val item' = Thm.prop_of thm' |> try_dest_term;
+
         val ts = Option.map (fst o fst) (fst x);
-        (*FIXME: Do we need to move one of these patterns above the other?*)
+
+        val outer_ctxt = ctxt |> Variable.declare_maxidx (Envir.maxidx_of env);
+
+        val morphism = Variable.export_morphism inner_ctxt outer_ctxt;
 
         val matches =
-          (if is_unify x
-           then Unify.smash_unifiers (Context.Proof ctxt) [(pat', item') ] env
-           else Unify.matchers (Context.Proof ctxt) [(pat', item')])
+          (Unify.matchers (Context.Proof ctxt) [(pat', item')])
+          |> Seq.filter consistent_env
           |> Seq.map_filter (fn env' =>
-              match_filter_env ctxt fixes (ts, params) thm (Envir.merge (env, env')))
+              match_filter_env inner_ctxt morphism pat_vars fixes
+                (ts, params) thm' (Envir.merge (env, env')))
+          |> Seq.map (apfst (Thm.map_tags (K (Thm.get_tags thm))))
+          |> deduplicate (eq_snd eq_env) []
           |> is_cut x ? (fn t => Seq.make (fn () =>
             Option.map (fn (x, _) => (x, raise_match)) (Seq.pull t)));
       in
@@ -487,12 +558,6 @@
 
     val all_matches =
       Seq.EVERY (map proc_multi_match all_matches) ([], Envir.empty ~1)
-      |> Seq.filter (fn (_, e) => forall (is_some o Envir.lookup e o Term.dest_Var) fixes);
-
-    fun map_handle seq = Seq.make (fn () =>
-      (case (Seq.pull seq handle MATCH_CUT => NONE) of
-        SOME (x, seq') => SOME (x, map_handle seq')
-      | NONE => NONE));
   in
     map_handle all_matches
   end;