src/HOL/Eisbach/match_method.ML
changeset 60248 f7e4294216d2
parent 60209 022ca2799c73
child 60285 b4f1a0a701ae
--- a/src/HOL/Eisbach/match_method.ML	Sun May 03 18:45:58 2015 +0200
+++ b/src/HOL/Eisbach/match_method.ML	Sun May 03 18:51:26 2015 +0200
@@ -1,4 +1,4 @@
-(*  Title:      match_method.ML
+(*  Title:      HOL/Eisbach/match_method.ML
     Author:     Daniel Matichuk, NICTA/UNSW
 
 Setup for "match" proof method. It provides basic fact/term matching in
@@ -61,8 +61,8 @@
 
 val fixes =
   Parse.and_list1 (Scan.repeat1  (Parse.position bound_term) --
-    Scan.option (@{keyword "::"} |-- Parse.!!! Parse.typ) >> (fn (xs, T) => map (fn (nm,pos) => ((nm,T),pos)) xs))
-  >> flat;
+    Scan.option (@{keyword "::"} |-- Parse.!!! Parse.typ)
+    >> (fn (xs, T) => map (fn (nm, pos) => ((nm, T), pos)) xs)) >> flat;
 
 val for_fixes = Scan.optional (@{keyword "for"} |-- fixes) [];
 
@@ -89,7 +89,6 @@
         | "cut" => {multi = multi, cut = true}))
       ss {multi = false, cut = false});
 
-(*TODO: Shape holes in thms *)
 fun parse_named_pats match_kind =
   Args.context :|-- (fn ctxt =>
     Scan.lift (Parse.and_list1 (Scan.option (dynamic_fact ctxt --| Args.colon) :--
@@ -105,20 +104,18 @@
       SOME (Token.Source src) =>
         let
           val text = Method_Closure.read_inner_method ctxt src
-          (*TODO: Correct parse context for attributes?*)
           val ts' =
             map
               (fn (b, (Parse_Tools.Real_Val v, match_args)) =>
                 ((Option.map (fn (b, att) =>
-                  (Parse_Tools.the_real_val b,
-                    map (Attrib.attribute ctxt) att)) b, match_args), v)
+                  (Parse_Tools.the_real_val b, att)) b, match_args), v)
                 | _ => raise Fail "Expected closed term") ts
-          val fixes' = map (fn ((p, _),_) => Parse_Tools.the_real_val p) fixes
+          val fixes' = map (fn ((p, _), _) => Parse_Tools.the_real_val p) fixes
         in (ts', fixes', text) end
     | SOME _ => error "Unexpected token value in match cartouche"
     | NONE =>
         let
-          val fixes' = map (fn ((pb, otyp),_) => (Parse_Tools.the_parse_val pb, otyp, NoSyn)) fixes;
+          val fixes' = map (fn ((pb, otyp), _) => (Parse_Tools.the_parse_val pb, otyp, NoSyn)) fixes;
           val (fixes'', ctxt1) = Proof_Context.read_vars fixes' ctxt;
           val (fix_nms, ctxt2) = Proof_Context.add_fixes fixes'' ctxt1;
 
@@ -135,9 +132,11 @@
 
           val pat_fixes = fold (Term.add_frees) pats [] |> map fst;
 
-          val _ = map2 (fn nm => fn (_,pos) => member (op =) pat_fixes nm orelse
-            error ("For-fixed variable must be bound in some pattern." ^ Position.here pos))
-            fix_nms fixes;
+          val _ =
+            map2 (fn nm => fn (_, pos) =>
+                member (op =) pat_fixes nm orelse
+                error ("For-fixed variable must be bound in some pattern" ^ Position.here pos))
+              fix_nms fixes;
 
           val _ = map (Term.map_types Type.no_tvars) pats
 
@@ -164,20 +163,27 @@
 
                   val param_thm = map (Drule.mk_term o Thm.cterm_of ctxt' o Free) abs_nms
                     |> Conjunction.intr_balanced
-                    |> Drule.generalize ([], map fst abs_nms);
+                    |> Drule.generalize ([], map fst abs_nms)
+                    |> Method_Closure.tag_free_thm;
 
-                  val thm =
+                  val atts = map (Attrib.attribute ctxt') att;
+                  val (param_thm', ctxt'') = Thm.proof_attributes atts param_thm ctxt';
+
+                  fun label_thm thm =
                     Thm.cterm_of ctxt' (Free (nm, propT))
                     |> Drule.mk_term
-                    |> not (null abs_nms) ? Conjunction.intr param_thm
-                    |> Drule.zero_var_indexes
-                    |> Method_Closure.tag_free_thm;
+                    |> not (null abs_nms) ? Conjunction.intr thm
+
+                  val [head_thm, body_thm] =
+                    Drule.zero_var_indexes_list (map label_thm [param_thm, param_thm'])
+                    |> map Method_Closure.tag_free_thm;
 
-                  (*TODO: Preprocess attributes here?*)
-
-                  val (_, ctxt'') = Proof_Context.note_thmss "" [((b, []), [([thm], [])])] ctxt';
+                  val ctxt''' =
+                    Attrib.local_notes "" [((b, []), [([body_thm], [])])] ctxt''
+                    |> snd
+                    |> Variable.declare_maxidx (Thm.maxidx_of head_thm);
                 in
-                  (SOME (Thm.prop_of thm, map (Attrib.attribute ctxt) att) :: tms, ctxt'')
+                  (SOME (Thm.prop_of head_thm, att) :: tms, ctxt''')
                 end
             | upd_ctxt NONE _ (tms, ctxt) = (NONE :: tms, ctxt);
 
@@ -196,7 +202,20 @@
           val pats' = map (Term.map_types Type_Infer.paramify_vars #> Morphism.term morphism) pats;
           val _ = ListPair.app (fn ((_, (Parse_Tools.Parse_Val (_, f), _)), t) => f t) (ts, pats');
 
-          val binds' = map (Option.map (fn (t, atts) => (Morphism.term morphism t, atts))) binds;
+          fun close_src src =
+            let
+              val src' = Token.closure_src src |> Token.transform_src morphism;
+              val _ =
+                map2 (fn tok1 => fn tok2 =>
+                  (case (Token.get_value tok2) of
+                    SOME value => Token.assign (SOME value) tok1
+                  | NONE => ()))
+                  (Token.args_of_src src)
+                  (Token.args_of_src src');
+            in src' end;
+
+          val binds' =
+            map (Option.map (fn (t, atts) => (Morphism.term morphism t, map close_src atts))) binds;
 
           val _ =
             ListPair.app
@@ -207,7 +226,8 @@
 
           val real_fixes' = map (Morphism.term morphism) real_fixes;
           val _ =
-            ListPair.app (fn (( (Parse_Tools.Parse_Val (_, f),_) , _), t) => f t) (fixes, real_fixes');
+            ListPair.app (fn (( (Parse_Tools.Parse_Val (_, f), _) , _), t) => f t)
+              (fixes, real_fixes');
 
           val match_args = map (fn (_, (_, match_args)) => match_args) ts;
           val binds'' = (binds' ~~ match_args) ~~ pats';
@@ -247,36 +267,30 @@
         (fn ((((SOME ((_, head), att), _), _), _), thms) => SOME (head, (thms, att))
           | _ => NONE) fact_insts';
 
-    fun apply_attribute thm att ctxt =
-      let
-        val (opt_context', thm') = att (Context.Proof ctxt, thm)
-      in
-        (case thm' of
-          SOME _ => error "Rule attributes cannot be applied here"
-        | _ => the_default ctxt (Option.map Context.proof_of opt_context'))
-      end;
-
-    fun apply_attributes atts thm = fold (apply_attribute thm) atts;
-
-     (*TODO: What to do about attributes that raise errors?*)
-    val (fact_insts, ctxt') =
-      fold_map (fn (head, (thms, atts : attribute list)) => fn ctxt =>
-        ((head, thms), fold (apply_attributes atts) thms ctxt)) fact_insts ctxt;
-
     fun try_dest_term thm = try (Thm.prop_of #> dest_internal_fact #> snd) thm;
 
-    fun expand_fact thm =
+    fun expand_fact fact_insts thm =
       the_default [thm]
         (case try_dest_term thm of
           SOME t_ident => AList.lookup (op aconv) fact_insts t_ident
         | NONE => NONE);
 
-    val morphism =
+    fun fact_morphism fact_insts =
       Morphism.term_morphism "do_inst.term" (Envir.norm_term env) $>
       Morphism.typ_morphism "do_inst.type" (Envir.norm_type (Envir.type_env env)) $>
-      Morphism.fact_morphism "do_inst.fact" (maps expand_fact);
+      Morphism.fact_morphism "do_inst.fact" (maps (expand_fact fact_insts));
 
-    val text' = Method.map_source (Token.transform_src morphism) text;
+    fun apply_attribute (head, (fact, atts)) (fact_insts, ctxt) =
+      let
+        val morphism = fact_morphism fact_insts;
+        val atts' = map (Attrib.attribute ctxt o Token.transform_src morphism) atts;
+        val (fact'', ctxt') = fold_map (Thm.proof_attributes atts') fact ctxt;
+      in ((head, fact'') :: fact_insts, ctxt') end;
+
+     (*TODO: What to do about attributes that raise errors?*)
+    val (fact_insts', ctxt') = fold_rev (apply_attribute) fact_insts ([], ctxt);
+
+    val text' = Method.map_source (Token.transform_src (fact_morphism fact_insts')) text;
   in
     (text', ctxt')
   end;
@@ -304,27 +318,27 @@
   let
     val tenv = Envir.term_env env;
     val tyenv = Envir.type_env env;
-    val max_tidx = Vartab.fold (fn (_,(_,t)) => curry Int.max (maxidx_of_term t)) tenv ~1;
-    val max_Tidx = Vartab.fold (fn (_,(_,T)) => curry Int.max (maxidx_of_typ T)) tyenv ~1;
+    val max_tidx = Vartab.fold (fn (_, (_, t)) => curry Int.max (maxidx_of_term t)) tenv ~1;
+    val max_Tidx = Vartab.fold (fn (_, (_, T)) => curry Int.max (maxidx_of_typ T)) tyenv ~1;
   in
     Envir.Envir
-      {maxidx = Int.max (Int.max (max_tidx,max_Tidx),Envir.maxidx_of env),
+      {maxidx = Int.max (Int.max (max_tidx, max_Tidx), Envir.maxidx_of env),
         tenv = tenv, tyenv = tyenv}
   end
 
 fun match_filter_env inner_ctxt morphism pat_vars fixes (ts, params) thm inner_env =
   let
     val tenv = Envir.term_env inner_env
-      |> Vartab.map (K (fn (T,t) => (Morphism.typ morphism T,Morphism.term morphism t)));
+      |> Vartab.map (K (fn (T, t) => (Morphism.typ morphism T, Morphism.term morphism t)));
 
     val tyenv = Envir.type_env inner_env
-      |> Vartab.map (K (fn (S,T) => (S,Morphism.typ morphism T)));
+      |> Vartab.map (K (fn (S, T) => (S, Morphism.typ morphism T)));
 
     val outer_env = Envir.Envir {maxidx = Envir.maxidx_of inner_env, tenv = tenv, tyenv = tyenv};
 
     val param_vars = map Term.dest_Var params;
 
-    val params' = map (fn (xi,_) => Vartab.lookup (Envir.term_env outer_env) xi) param_vars;
+    val params' = map (fn (xi, _) => Vartab.lookup (Envir.term_env outer_env) xi) param_vars;
 
     val fixes_vars = map Term.dest_Var fixes;
 
@@ -339,11 +353,11 @@
       Envir.Envir {maxidx = Envir.maxidx_of outer_env, tenv = tenv', tyenv = tyenv}
       |> recalculate_maxidx;
 
-    val all_params_bound = forall (fn SOME (_,(Var _)) => true | _ => false) params';
+    val all_params_bound = forall (fn SOME (_, Var _) => true | _ => false) params';
 
     val pat_fixes = inter (eq_fst (op =)) fixes_vars pat_vars;
 
-    val all_pat_fixes_bound = forall (fn (xi,_) => is_some (Vartab.lookup tenv' xi)) pat_fixes;
+    val all_pat_fixes_bound = forall (fn (xi, _) => is_some (Vartab.lookup tenv' xi)) pat_fixes;
 
     val thm' = Morphism.thm morphism thm;
 
@@ -469,7 +483,7 @@
       SOME (x, seq') =>
         if member eq prev x
         then Seq.pull (deduplicate eq prev seq')
-        else SOME (x,deduplicate eq (x :: prev) seq')
+        else SOME (x, deduplicate eq (x :: prev) seq')
     | NONE => NONE));
 
 fun consistent_env env =
@@ -480,7 +494,7 @@
     forall (fn (_, (T, t)) => Envir.norm_type tyenv T = fastype_of t) (Vartab.dest tenv)
   end;
 
-fun eq_env (env1,env2) =
+fun eq_env (env1, env2) =
   let
     val tyenv1 = Envir.type_env env1;
     val tyenv2 = Envir.type_env env2;
@@ -492,7 +506,7 @@
           Envir.eta_contract (Envir.norm_term env2 t')))
       (apply2 Vartab.dest (Envir.term_env env1, Envir.term_env env2))
     andalso
-    ListPair.allEq (fn ((var, (_, T)), (var', (_,T'))) =>
+    ListPair.allEq (fn ((var, (_, T)), (var', (_, T'))) =>
         var = var' andalso Envir.norm_type tyenv1 T = Envir.norm_type tyenv2 T')
       (apply2 Vartab.dest (Envir.type_env env1, Envir.type_env env2))
   end;
@@ -505,15 +519,13 @@
 
     fun match_thm (((x, params), pat), thm) env  =
       let
-        fun try_dest_term term = the_default term (try Logic.dest_term term);
-
         val pat_vars = Term.add_vars pat [];
 
-        val pat' = pat |> Envir.norm_term env |> try_dest_term;
+        val pat' = pat |> Envir.norm_term env;
 
-        val (((Tinsts', insts),[thm']), inner_ctxt) = Variable.import false [thm] ctxt
+        val (((Tinsts', insts), [thm']), inner_ctxt) = Variable.import false [thm] ctxt;
 
-        val item' = Thm.prop_of thm' |> try_dest_term;
+        val item' = Thm.prop_of thm';
 
         val ts = Option.map (fst o fst) (fst x);
 
@@ -572,7 +584,6 @@
         |> Seq.map (fn (fact_insts, env) => do_inst fact_insts env text ctxt')
       end;
 
-    (*TODO: Slightly hacky re-use of fact match implementation in plain term matching *)
     fun make_term_matches ctxt get =
       let
         val pats' =