src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
changeset 33116 b379ee2cddb1
parent 33115 f765c3234059
child 33117 1413c62db675
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Sat Oct 24 16:55:42 2009 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Sat Oct 24 16:55:42 2009 +0200
@@ -258,9 +258,30 @@
 
 (* generation of case rules from user-given introduction rules *)
 
+fun import_intros [] ctxt = ([], ctxt)
+  | import_intros (th :: ths) ctxt =
+    let
+      val ((_, [th']), ctxt') = Variable.import false [th] ctxt
+      val (pred, _) = strip_intro_concl 0 (prop_of th')
+      fun instantiate_typ th =
+        let
+          val (pred', _) = strip_intro_concl 0 (prop_of th)
+          val subst = Sign.typ_match (ProofContext.theory_of ctxt')
+            (fastype_of pred', fastype_of pred) Vartab.empty
+          val _ = Output.tracing (commas (map (fn ((x, i), (s, T)) => x ^ " instantiate to " ^ (Syntax.string_of_typ ctxt' T))
+          (Vartab.dest subst)))
+          val subst' = map (fn (indexname, (s, T)) => ((indexname, s), T))
+            (Vartab.dest subst)
+        in Thm.certify_instantiate (subst', []) th end;
+      val ((_, ths'), ctxt1) =
+        Variable.import false (map instantiate_typ ths) ctxt'
+    in
+      (th' :: ths', ctxt1)
+    end
+
 fun mk_casesrule ctxt nparams introrules =
   let
-    val ((_, intros_th), ctxt1) = Variable.import false introrules ctxt
+    val (intros_th, ctxt1) = import_intros introrules ctxt
     val intros = map prop_of intros_th
     val (pred, (params, args)) = strip_intro_concl nparams (hd intros)
     val ([propname], ctxt2) = Variable.variant_fixes ["thesis"] ctxt1
@@ -633,7 +654,7 @@
    fun cons_intro gr =
      case try (Graph.get_node gr) name of
        SOME pred_data => Graph.map_node name (map_pred_data
-         (apfst (fn (intro, elim, nparams) => (thm::intro, elim, nparams)))) gr
+         (apfst (fn (intros, elim, nparams) => (thm::intros, elim, nparams)))) gr
      | NONE =>
        let
          val nparams = the_default (guess_nparams T)  (try (#nparams o rep_pred_data o (fetch_pred_data thy)) name)
@@ -965,7 +986,7 @@
           val prfx = map (rpair NONE) (1 upto k)
         in
           if not (is_prefix op = prfx is) then [] else
-          let val is' = List.drop (is, k)
+          let val is' = map (fn (i, t) => (i - k, t)) (List.drop (is, k))
           in map (fn x => Mode (m, is', x)) (cprods (map
             (fn (NONE, _) => [NONE]
               | (SOME js, arg) => map SOME (filter
@@ -2143,7 +2164,7 @@
     
 fun prepare_intrs thy prednames intros =
   let
-    val ((_, intrs), _) = Variable.import false intros (ProofContext.init thy)
+    val (intrs, _) =  import_intros intros (ProofContext.init thy)
     val intrs = map prop_of intrs
     val nparams = nparams_of thy (hd prednames)
     val extra_modes = all_modes_of thy |> filter_out (fn (name, _) => member (op =) prednames name)