src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
changeset 33126 bb8806eb5da7
parent 33124 5378e61add1a
child 33127 eb91ec1ef6f0
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Sat Oct 24 16:55:42 2009 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML	Sat Oct 24 16:55:42 2009 +0200
@@ -256,7 +256,24 @@
   (if null param_modes then "" else
     "; " ^ "params: " ^ commas (map (the_default "NONE" o Option.map string_of_tmode) param_modes))
 
-(* generation of case rules from user-given introduction rules *)
+fun unify_consts thy cs intr_ts =
+  (let
+     val add_term_consts_2 = fold_aterms (fn Const c => insert (op =) c | _ => I);
+     fun varify (t, (i, ts)) =
+       let val t' = map_types (Logic.incr_tvar (i + 1)) (#2 (Type.varify [] t))
+       in (maxidx_of_term t', t'::ts) end;
+     val (i, cs') = foldr varify (~1, []) cs;
+     val (i', intr_ts') = foldr varify (i, []) intr_ts;
+     val rec_consts = fold add_term_consts_2 cs' [];
+     val intr_consts = fold add_term_consts_2 intr_ts' [];
+     fun unify (cname, cT) =
+       let val consts = map snd (List.filter (fn c => fst c = cname) intr_consts)
+       in fold (Sign.typ_unify thy) ((replicate (length consts) cT) ~~ consts) end;
+     val (env, _) = fold unify rec_consts (Vartab.empty, i');
+     val subst = map_types (Envir.norm_type env)
+   in (map subst cs', map subst intr_ts')
+   end) handle Type.TUNIFY =>
+     (warning "Occurrences of recursive constant have non-unifiable types"; (cs, intr_ts));
 
 (* how to detect polymorphic type dependencies in mutual recursive inductive predicates? *)
 fun import_intros [] ctxt = ([], ctxt)
@@ -289,9 +306,12 @@
       (th' :: ths', ctxt1)
     end
 
+
+(* generation of case rules from user-given introduction rules *)
+
 fun mk_casesrule ctxt nparams introrules =
   let
-    val ((_, intros_th), ctxt1) = Variable.import false introrules ctxt
+    val (intros_th, ctxt1) = import_intros introrules ctxt
     val intros = map prop_of intros_th
     val (pred, (params, args)) = strip_intro_concl nparams (hd intros)
     val ([propname], ctxt2) = Variable.variant_fixes ["thesis"] ctxt1
@@ -2178,11 +2198,13 @@
     
 fun prepare_intrs thy prednames intros =
   let
-    val ((_, intrs), _) = Variable.import false intros (ProofContext.init thy)
-    val intrs = map prop_of intrs
+    val intrs = map prop_of intros
     val nparams = nparams_of thy (hd prednames)
+    val preds = distinct (fn ((c1, _), (c2, _)) => c1 = c2) (map (dest_Const o fst o (strip_intro_concl nparams)) intrs)
+    val (preds, intrs) = unify_consts thy (map Const preds) intrs
+    val ([preds, intrs], _) = fold_burrow (Variable.import_terms false) [preds, intrs] (ProofContext.init thy)
+    val preds = map dest_Const preds
     val extra_modes = all_modes_of thy |> filter_out (fn (name, _) => member (op =) prednames name)
-    val preds = distinct (op =) (map (dest_Const o fst o (strip_intro_concl nparams)) intrs)
     val _ $ u = Logic.strip_imp_concl (hd intrs);
     val params = List.take (snd (strip_comb u), nparams);
     val param_vs = maps term_vs params