src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML
changeset 40101 f7fc517e21c6
parent 40054 cd7b1fa20bce
child 40139 6a53d57fa902
--- a/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Fri Oct 22 18:38:59 2010 +0200
+++ b/src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML	Fri Oct 22 18:38:59 2010 +0200
@@ -53,6 +53,7 @@
   val is_constr : Proof.context -> string -> bool
   val focus_ex : term -> Name.context -> ((string * typ) list * term) * Name.context
   val strip_all : term -> (string * typ) list * term
+  val strip_intro_concl : thm -> term * term list
   (* introduction rule combinators *)
   val map_atoms : (term -> term) -> term -> term
   val fold_atoms : (term -> 'a -> 'a) -> term -> 'a -> 'a
@@ -157,6 +158,11 @@
   val remove_equalities : theory -> thm -> thm
   val remove_pointless_clauses : thm -> thm list
   val peephole_optimisation : theory -> thm -> thm option
+  (* auxillary *)
+  val unify_consts : theory -> term list -> term list -> (term list * term list)
+  val mk_casesrule : Proof.context -> term -> thm list -> term
+  val preprocess_intro : theory -> thm -> thm
+  
   val define_quickcheck_predicate :
     term -> theory -> (((string * typ) * (string * typ) list) * thm) * theory
 end;
@@ -546,6 +552,8 @@
     val t'' = Term.subst_bounds (rev vs, t');
   in ((ps', t''), nctxt') end;
 
+val strip_intro_concl = (strip_comb o HOLogic.dest_Trueprop o Logic.strip_imp_concl o prop_of)
+  
 (* introduction rule combinators *)
 
 fun map_atoms f intro = 
@@ -1048,6 +1056,144 @@
       (process_False (process_True (prop_of (process intro))))
   end
 
+
+(* importing introduction rules *)
+
+fun import_intros inp_pred [] ctxt =
+  let
+    val ([outp_pred], ctxt') = Variable.import_terms true [inp_pred] ctxt
+    val T = fastype_of outp_pred
+    val paramTs = ho_argsT_of_typ (binder_types T)
+    val (param_names, ctxt'') = Variable.variant_fixes
+      (map (fn i => "p" ^ (string_of_int i)) (1 upto (length paramTs))) ctxt'
+    val params = map2 (curry Free) param_names paramTs
+  in
+    (((outp_pred, params), []), ctxt')
+  end
+  | import_intros inp_pred (th :: ths) ctxt =
+    let
+      val ((_, [th']), ctxt') = Variable.import true [th] ctxt
+      val thy = ProofContext.theory_of ctxt'
+      val (pred, args) = strip_intro_concl th'
+      val T = fastype_of pred
+      val ho_args = ho_args_of_typ T args
+      fun subst_of (pred', pred) =
+        let
+          val subst = Sign.typ_match thy (fastype_of pred', fastype_of pred) Vartab.empty
+            handle Type.TYPE_MATCH => error ("Type mismatch of predicate " ^ fst (dest_Const pred)
+            ^ " (trying to match " ^ Syntax.string_of_typ ctxt (fastype_of pred')
+            ^ " and " ^ Syntax.string_of_typ ctxt (fastype_of pred) ^ ")"
+            ^ " in " ^ Display.string_of_thm ctxt th)
+        in map (fn (indexname, (s, T)) => ((indexname, s), T)) (Vartab.dest subst) end
+      fun instantiate_typ th =
+        let
+          val (pred', _) = strip_intro_concl th
+          val _ = if not (fst (dest_Const pred) = fst (dest_Const pred')) then
+            raise Fail "Trying to instantiate another predicate" else ()
+        in Thm.certify_instantiate (subst_of (pred', pred), []) th end;
+      fun instantiate_ho_args th =
+        let
+          val (_, args') = (strip_comb o HOLogic.dest_Trueprop o Logic.strip_imp_concl o prop_of) th
+          val ho_args' = map dest_Var (ho_args_of_typ T args')
+        in Thm.certify_instantiate ([], ho_args' ~~ ho_args) th end
+      val outp_pred =
+        Term_Subst.instantiate (subst_of (inp_pred, pred), []) inp_pred
+      val ((_, ths'), ctxt1) =
+        Variable.import false (map (instantiate_typ #> instantiate_ho_args) ths) ctxt'
+    in
+      (((outp_pred, ho_args), th' :: ths'), ctxt1)
+    end
+  
+(* generation of case rules from user-given introduction rules *)
+
+fun mk_args2 (Type (@{type_name Product_Type.prod}, [T1, T2])) st =
+    let
+      val (t1, st') = mk_args2 T1 st
+      val (t2, st'') = mk_args2 T2 st'
+    in
+      (HOLogic.mk_prod (t1, t2), st'')
+    end
+  (*| mk_args2 (T as Type ("fun", _)) (params, ctxt) = 
+    let
+      val (S, U) = strip_type T
+    in
+      if U = HOLogic.boolT then
+        (hd params, (tl params, ctxt))
+      else
+        let
+          val ([x], ctxt') = Variable.variant_fixes ["x"] ctxt
+        in
+          (Free (x, T), (params, ctxt'))
+        end
+    end*)
+  | mk_args2 T (params, ctxt) =
+    let
+      val ([x], ctxt') = Variable.variant_fixes ["x"] ctxt
+    in
+      (Free (x, T), (params, ctxt'))
+    end
+
+fun mk_casesrule ctxt pred introrules =
+  let
+    (* TODO: can be simplified if parameters are not treated specially ? *)
+    val (((pred, params), intros_th), ctxt1) = import_intros pred introrules ctxt
+    (* TODO: distinct required ? -- test case with more than one parameter! *)
+    val params = distinct (op aconv) params
+    val intros = map prop_of intros_th
+    val ([propname], ctxt2) = Variable.variant_fixes ["thesis"] ctxt1
+    val prop = HOLogic.mk_Trueprop (Free (propname, HOLogic.boolT))
+    val argsT = binder_types (fastype_of pred)
+    (* TODO: can be simplified if parameters are not treated specially ? <-- see uncommented code! *)
+    val (argvs, _) = fold_map mk_args2 argsT (params, ctxt2)
+    fun mk_case intro =
+      let
+        val (_, args) = (strip_comb o HOLogic.dest_Trueprop o Logic.strip_imp_concl) intro
+        val prems = Logic.strip_imp_prems intro
+        val eqprems =
+          map2 (HOLogic.mk_Trueprop oo (curry HOLogic.mk_eq)) argvs args
+        val frees = map Free (fold Term.add_frees (args @ prems) [])
+      in fold Logic.all frees (Logic.list_implies (eqprems @ prems, prop)) end
+    val assm = HOLogic.mk_Trueprop (list_comb (pred, argvs))
+    val cases = map mk_case intros
+  in Logic.list_implies (assm :: cases, prop) end;
+  
+
+(* unifying constants to have the same type variables *)
+
+fun unify_consts thy cs intr_ts =
+  (let
+     val add_term_consts_2 = fold_aterms (fn Const c => insert (op =) c | _ => I);
+     fun varify (t, (i, ts)) =
+       let val t' = map_types (Logic.incr_tvar (i + 1)) (#2 (Type.varify_global [] t))
+       in (maxidx_of_term t', t'::ts) end;
+     val (i, cs') = List.foldr varify (~1, []) cs;
+     val (i', intr_ts') = List.foldr varify (i, []) intr_ts;
+     val rec_consts = fold add_term_consts_2 cs' [];
+     val intr_consts = fold add_term_consts_2 intr_ts' [];
+     fun unify (cname, cT) =
+       let val consts = map snd (filter (fn c => fst c = cname) intr_consts)
+       in fold (Sign.typ_unify thy) ((replicate (length consts) cT) ~~ consts) end;
+     val (env, _) = fold unify rec_consts (Vartab.empty, i');
+     val subst = map_types (Envir.norm_type env)
+   in (map subst cs', map subst intr_ts')
+   end) handle Type.TUNIFY =>
+     (warning "Occurrences of recursive constant have non-unifiable types"; (cs, intr_ts));
+
+(* preprocessing rules *)
+
+fun Trueprop_conv cv ct =
+  case Thm.term_of ct of
+    Const (@{const_name Trueprop}, _) $ _ => Conv.arg_conv cv ct  
+  | _ => raise Fail "Trueprop_conv"
+
+fun preprocess_equality thy rule =
+  Conv.fconv_rule
+    (imp_prems_conv
+      (Trueprop_conv (Conv.try_conv (Conv.rewr_conv (Thm.symmetric @{thm Predicate.eq_is_eq})))))
+    (Thm.transfer thy rule)
+
+fun preprocess_intro thy = expand_tuples thy #> preprocess_equality thy
+
 (* defining a quickcheck predicate *)
 
 fun strip_imp_prems (Const(@{const_name HOL.implies}, _) $ A $ B) = A :: strip_imp_prems B