src/HOL/Tools/Quotient/quotient_term.ML
changeset 35990 5ceedb86aa9d
parent 35843 23908b4dbc2f
child 36692 54b64d4ad524
--- a/src/HOL/Tools/Quotient/quotient_term.ML	Sat Mar 27 02:10:00 2010 +0100
+++ b/src/HOL/Tools/Quotient/quotient_term.ML	Sat Mar 27 14:48:46 2010 +0100
@@ -26,8 +26,8 @@
   val inj_repabs_trm: Proof.context -> term * term -> term
   val inj_repabs_trm_chk: Proof.context -> term * term -> term
 
-  val quotient_lift_const: string * term -> local_theory -> term
-  val quotient_lift_all: Proof.context -> term -> term
+  val quotient_lift_const: typ list -> string * term -> local_theory -> term
+  val quotient_lift_all: typ list -> Proof.context -> term -> term
 end;
 
 structure Quotient_Term: QUOTIENT_TERM =
@@ -720,23 +720,28 @@
 (* prepares type and term substitution pairs to be used by above
    functions that let replace all raw constructs by appropriate
    lifted counterparts. *)
-fun get_ty_trm_substs ctxt =
+fun get_ty_trm_substs qtys ctxt =
 let
   val thy = ProofContext.theory_of ctxt
   val quot_infos  = Quotient_Info.quotdata_dest ctxt
   val const_infos = Quotient_Info.qconsts_dest ctxt
-  val ty_substs = map (fn ri => (#rtyp ri, #qtyp ri)) quot_infos
+  val all_ty_substs = map (fn ri => (#rtyp ri, #qtyp ri)) quot_infos
+  val ty_substs =
+    if qtys = [] then all_ty_substs else
+    filter (fn (_, qty) => qty mem qtys) all_ty_substs
   val const_substs = map (fn ci => (#rconst ci, #qconst ci)) const_infos
   fun rel_eq rel = HOLogic.eq_const (subst_tys thy ty_substs (domain_type (fastype_of rel)))
   val rel_substs = map (fn ri => (#equiv_rel ri, rel_eq (#equiv_rel ri))) quot_infos
+  fun valid_trm_subst (rt, qt) = (subst_tys thy ty_substs (fastype_of rt) = fastype_of qt)
+  val all_trm_substs = const_substs @ rel_substs
 in
-  (ty_substs, const_substs @ rel_substs)
+  (ty_substs, filter valid_trm_subst all_trm_substs)
 end
 
-fun quotient_lift_const (b, t) ctxt =
+fun quotient_lift_const qtys (b, t) ctxt =
 let
   val thy = ProofContext.theory_of ctxt
-  val (ty_substs, _) = get_ty_trm_substs ctxt;
+  val (ty_substs, _) = get_ty_trm_substs qtys ctxt;
   val (_, ty) = dest_Const t;
   val nty = subst_tys thy ty_substs ty;
 in
@@ -754,10 +759,10 @@
 
 *)
 
-fun quotient_lift_all ctxt t =
+fun quotient_lift_all qtys ctxt t =
 let
   val thy = ProofContext.theory_of ctxt
-  val (ty_substs, substs) = get_ty_trm_substs ctxt
+  val (ty_substs, substs) = get_ty_trm_substs qtys ctxt
   fun lift_aux t =
     case subst_trms thy substs t of
       SOME x => x