src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML
changeset 46662 4e258158be38
parent 46638 fc315796794e
child 48221 e0ed7fab0d09
equal deleted inserted replaced
46661:d2ac78ba805e 46662:4e258158be38
   184 fun map2_optional f (x :: xs) (y :: ys) = f x (SOME y) :: (map2_optional f xs ys)
   184 fun map2_optional f (x :: xs) (y :: ys) = f x (SOME y) :: (map2_optional f xs ys)
   185   | map2_optional f (x :: xs) [] = (f x NONE) :: (map2_optional f xs [])
   185   | map2_optional f (x :: xs) [] = (f x NONE) :: (map2_optional f xs [])
   186   | map2_optional f [] [] = []
   186   | map2_optional f [] [] = []
   187 
   187 
   188 fun find_indices f xs =
   188 fun find_indices f xs =
   189   map_filter (fn (i, true) => SOME i | (i, false) => NONE) (map_index (apsnd f) xs)
   189   map_filter (fn (i, true) => SOME i | (_, false) => NONE) (map_index (apsnd f) xs)
   190 
   190 
   191 (* mode *)
   191 (* mode *)
   192 
   192 
   193 datatype mode = Bool | Input | Output | Pair of mode * mode | Fun of mode * mode
   193 datatype mode = Bool | Input | Output | Pair of mode * mode | Fun of mode * mode
   194 
   194 
   251           (map all_modes_of_typ' S) [Bool]
   251           (map all_modes_of_typ' S) [Bool]
   252       else
   252       else
   253         raise Fail "Invocation of all_modes_of_typ with a non-predicate type"
   253         raise Fail "Invocation of all_modes_of_typ with a non-predicate type"
   254     end
   254     end
   255   | all_modes_of_typ @{typ bool} = [Bool]
   255   | all_modes_of_typ @{typ bool} = [Bool]
   256   | all_modes_of_typ T =
   256   | all_modes_of_typ _ =
   257     raise Fail "Invocation of all_modes_of_typ with a non-predicate type"
   257     raise Fail "Invocation of all_modes_of_typ with a non-predicate type"
   258 
   258 
   259 fun all_smodes_of_typ (T as Type ("fun", _)) =
   259 fun all_smodes_of_typ (T as Type ("fun", _)) =
   260   let
   260   let
   261     val (S, U) = strip_type T
   261     val (S, U) = strip_type T
   392   comb_option HOLogic.mk_prod (map_filter_prod f t1, map_filter_prod f t2)
   392   comb_option HOLogic.mk_prod (map_filter_prod f t1, map_filter_prod f t2)
   393   | map_filter_prod f t = f t
   393   | map_filter_prod f t = f t
   394   
   394   
   395 fun split_modeT mode Ts =
   395 fun split_modeT mode Ts =
   396   let
   396   let
   397     fun split_arg_mode (Fun _) T = ([], [])
   397     fun split_arg_mode (Fun _) _ = ([], [])
   398       | split_arg_mode (Pair (m1, m2)) (Type (@{type_name Product_Type.prod}, [T1, T2])) =
   398       | split_arg_mode (Pair (m1, m2)) (Type (@{type_name Product_Type.prod}, [T1, T2])) =
   399         let
   399         let
   400           val (i1, o1) = split_arg_mode m1 T1
   400           val (i1, o1) = split_arg_mode m1 T1
   401           val (i2, o2) = split_arg_mode m2 T2
   401           val (i2, o2) = split_arg_mode m2 T2
   402         in
   402         in
   479     Const (c, _) => c = constname
   479     Const (c, _) => c = constname
   480   | _ => false) t)
   480   | _ => false) t)
   481   
   481   
   482 fun is_intro constname t = is_intro_term constname (prop_of t)
   482 fun is_intro constname t = is_intro_term constname (prop_of t)
   483 
   483 
   484 fun is_pred thy constname = (body_type (Sign.the_const_type thy constname) = HOLogic.boolT);
       
   485 
       
   486 fun is_predT (T as Type("fun", [_, _])) = (body_type T = @{typ bool})
   484 fun is_predT (T as Type("fun", [_, _])) = (body_type T = @{typ bool})
   487   | is_predT _ = false
   485   | is_predT _ = false
   488 
   486 
   489 (*** check if a term contains only constructor functions ***)
   487 (*** check if a term contains only constructor functions ***)
   490 (* TODO: another copy in the core! *)
   488 (* TODO: another copy in the core! *)
   501       | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of
   499       | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of
   502             (SOME (i, Tname), Type (Tname', _)) => length ts = i andalso Tname = Tname' andalso forall check ts
   500             (SOME (i, Tname), Type (Tname', _)) => length ts = i andalso Tname = Tname' andalso forall check ts
   503           | _ => false)
   501           | _ => false)
   504       | _ => false)
   502       | _ => false)
   505   in check end;
   503   in check end;
   506 
       
   507 fun is_funtype (Type ("fun", [_, _])) = true
       
   508   | is_funtype _ = false;
       
   509 
       
   510 fun is_Type (Type _) = true
       
   511   | is_Type _ = false
       
   512 
   504 
   513 (* returns true if t is an application of an datatype constructor *)
   505 (* returns true if t is an application of an datatype constructor *)
   514 (* which then consequently would be splitted *)
   506 (* which then consequently would be splitted *)
   515 (* else false *)
   507 (* else false *)
   516 (*
   508 (*
   563       (map (HOLogic.mk_Trueprop o appl o HOLogic.dest_Trueprop) literals, head)
   555       (map (HOLogic.mk_Trueprop o appl o HOLogic.dest_Trueprop) literals, head)
   564   end
   556   end
   565 
   557 
   566 fun fold_atoms f intro s =
   558 fun fold_atoms f intro s =
   567   let
   559   let
   568     val (literals, head) = Logic.strip_horn intro
   560     val (literals, _) = Logic.strip_horn intro
   569     fun appl t s = (case t of
   561     fun appl t s = (case t of
   570       (@{term Not} $ t') => f t' s
   562       (@{term Not} $ t') => f t' s
   571       | _ => f t s)
   563       | _ => f t s)
   572   in fold appl (map HOLogic.dest_Trueprop literals) s end
   564   in fold appl (map HOLogic.dest_Trueprop literals) s end
   573 
   565 
   580     val (literals', s') = fold_map appl (map HOLogic.dest_Trueprop literals) s
   572     val (literals', s') = fold_map appl (map HOLogic.dest_Trueprop literals) s
   581   in
   573   in
   582     (Logic.list_implies (map HOLogic.mk_Trueprop literals', head), s')
   574     (Logic.list_implies (map HOLogic.mk_Trueprop literals', head), s')
   583   end;
   575   end;
   584 
   576 
   585 fun map_premises f intro =
       
   586   let
       
   587     val (premises, head) = Logic.strip_horn intro
       
   588   in
       
   589     Logic.list_implies (map f premises, head)
       
   590   end
       
   591 
       
   592 fun map_filter_premises f intro =
   577 fun map_filter_premises f intro =
   593   let
   578   let
   594     val (premises, head) = Logic.strip_horn intro
   579     val (premises, head) = Logic.strip_horn intro
   595   in
   580   in
   596     Logic.list_implies (map_filter f premises, head)
   581     Logic.list_implies (map_filter f premises, head)
   621 fun prepare_split_thm ctxt split_thm =
   606 fun prepare_split_thm ctxt split_thm =
   622     (split_thm RS @{thm iffD2})
   607     (split_thm RS @{thm iffD2})
   623     |> Local_Defs.unfold ctxt [@{thm atomize_conjL[symmetric]},
   608     |> Local_Defs.unfold ctxt [@{thm atomize_conjL[symmetric]},
   624       @{thm atomize_all[symmetric]}, @{thm atomize_imp[symmetric]}]
   609       @{thm atomize_all[symmetric]}, @{thm atomize_imp[symmetric]}]
   625 
   610 
   626 fun find_split_thm thy (Const (name, T)) = Option.map #split (Datatype.info_of_case thy name)
   611 fun find_split_thm thy (Const (name, _)) = Option.map #split (Datatype.info_of_case thy name)
   627   | find_split_thm thy _ = NONE
   612   | find_split_thm thy _ = NONE
   628 
   613 
   629 (* lifting term operations to theorems *)
   614 (* lifting term operations to theorems *)
   630 
   615 
   631 fun map_term thy f th =
   616 fun map_term thy f th =
   824 (** tuple processing **)
   809 (** tuple processing **)
   825 
   810 
   826 fun rewrite_args [] (pats, intro_t, ctxt) = (pats, intro_t, ctxt)
   811 fun rewrite_args [] (pats, intro_t, ctxt) = (pats, intro_t, ctxt)
   827   | rewrite_args (arg::args) (pats, intro_t, ctxt) = 
   812   | rewrite_args (arg::args) (pats, intro_t, ctxt) = 
   828     (case HOLogic.strip_tupleT (fastype_of arg) of
   813     (case HOLogic.strip_tupleT (fastype_of arg) of
   829       (Ts as _ :: _ :: _) =>
   814       (_ :: _ :: _) =>
   830       let
   815       let
   831         fun rewrite_arg' (Const (@{const_name Pair}, _) $ _ $ t2, Type (@{type_name Product_Type.prod}, [_, T2]))
   816         fun rewrite_arg' (Const (@{const_name Pair}, _) $ _ $ t2, Type (@{type_name Product_Type.prod}, [_, T2]))
   832           (args, (pats, intro_t, ctxt)) = rewrite_arg' (t2, T2) (args, (pats, intro_t, ctxt))
   817           (args, (pats, intro_t, ctxt)) = rewrite_arg' (t2, T2) (args, (pats, intro_t, ctxt))
   833           | rewrite_arg' (t, Type (@{type_name Product_Type.prod}, [T1, T2])) (args, (pats, intro_t, ctxt)) =
   818           | rewrite_arg' (t, Type (@{type_name Product_Type.prod}, [T1, T2])) (args, (pats, intro_t, ctxt)) =
   834             let
   819             let
   866     singleton (Variable.export ctxt' ctxt) (split_conjs 1 (Thm.nprems_of fixed_th) fixed_th)
   851     singleton (Variable.export ctxt' ctxt) (split_conjs 1 (Thm.nprems_of fixed_th) fixed_th)
   867   end
   852   end
   868 
   853 
   869 fun dest_conjunct_prem th =
   854 fun dest_conjunct_prem th =
   870   case HOLogic.dest_Trueprop (prop_of th) of
   855   case HOLogic.dest_Trueprop (prop_of th) of
   871     (Const (@{const_name HOL.conj}, _) $ t $ t') =>
   856     (Const (@{const_name HOL.conj}, _) $ _ $ _) =>
   872       dest_conjunct_prem (th RS @{thm conjunct1})
   857       dest_conjunct_prem (th RS @{thm conjunct1})
   873         @ dest_conjunct_prem (th RS @{thm conjunct2})
   858         @ dest_conjunct_prem (th RS @{thm conjunct2})
   874     | _ => [th]
   859     | _ => [th]
   875 
   860 
   876 fun expand_tuples thy intro =
   861 fun expand_tuples thy intro =
   877   let
   862   let
   878     val ctxt = Proof_Context.init_global thy
   863     val ctxt = Proof_Context.init_global thy
   879     val (((T_insts, t_insts), [intro']), ctxt1) = Variable.import false [intro] ctxt
   864     val (((T_insts, t_insts), [intro']), ctxt1) = Variable.import false [intro] ctxt
   880     val intro_t = prop_of intro'
   865     val intro_t = prop_of intro'
   881     val concl = Logic.strip_imp_concl intro_t
   866     val concl = Logic.strip_imp_concl intro_t
   882     val (p, args) = strip_comb (HOLogic.dest_Trueprop concl)
   867     val (_, args) = strip_comb (HOLogic.dest_Trueprop concl)
   883     val (pats', intro_t', ctxt2) = rewrite_args args ([], intro_t, ctxt1)
   868     val (pats', intro_t', ctxt2) = rewrite_args args ([], intro_t, ctxt1)
   884     val (pats', intro_t', ctxt3) = 
   869     val (pats', _, ctxt3) = fold_atoms rewrite_prem intro_t' (pats', intro_t', ctxt2)
   885       fold_atoms rewrite_prem intro_t' (pats', intro_t', ctxt2)
       
   886     fun rewrite_pat (ct1, ct2) =
   870     fun rewrite_pat (ct1, ct2) =
   887       (ct1, cterm_of thy (Pattern.rewrite_term thy pats' [] (term_of ct2)))
   871       (ct1, cterm_of thy (Pattern.rewrite_term thy pats' [] (term_of ct2)))
   888     val t_insts' = map rewrite_pat t_insts
   872     val t_insts' = map rewrite_pat t_insts
   889     val intro'' = Thm.instantiate (T_insts, t_insts') intro
   873     val intro'' = Thm.instantiate (T_insts, t_insts') intro
   890     val [intro'''] = Variable.export ctxt3 ctxt [intro'']
   874     val [intro'''] = Variable.export ctxt3 ctxt [intro'']
   945     fun instantiate th =
   929     fun instantiate th =
   946     let
   930     let
   947       val f = (fst (strip_comb (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop (prop_of th))))))
   931       val f = (fst (strip_comb (fst (HOLogic.dest_eq (HOLogic.dest_Trueprop (prop_of th))))))
   948       val Type ("fun", [uninst_T, uninst_T']) = fastype_of f
   932       val Type ("fun", [uninst_T, uninst_T']) = fastype_of f
   949       val ([tname, tname', uname, yname], ctxt') = Variable.add_fixes ["'t", "'t'", "'u", "y"] ctxt
   933       val ([tname, tname', uname, yname], ctxt') = Variable.add_fixes ["'t", "'t'", "'u", "y"] ctxt
   950       val T = TFree (tname, HOLogic.typeS)
       
   951       val T' = TFree (tname', HOLogic.typeS)
   934       val T' = TFree (tname', HOLogic.typeS)
   952       val U = TFree (uname, HOLogic.typeS)
   935       val U = TFree (uname, HOLogic.typeS)
   953       val y = Free (yname, U)
   936       val y = Free (yname, U)
   954       val f' = absdummy (U --> T') (Bound 0 $ y)
   937       val f' = absdummy (U --> T') (Bound 0 $ y)
   955       val th' = Thm.certify_instantiate
   938       val th' = Thm.certify_instantiate
   978 fun imp_prems_conv cv ct =
   961 fun imp_prems_conv cv ct =
   979   case Thm.term_of ct of
   962   case Thm.term_of ct of
   980     Const ("==>", _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv) (imp_prems_conv cv) ct
   963     Const ("==>", _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv) (imp_prems_conv cv) ct
   981   | _ => Conv.all_conv ct
   964   | _ => Conv.all_conv ct
   982 
   965 
   983 fun all_params_conv cv ctxt ct =
       
   984   if Logic.is_all (Thm.term_of ct)
       
   985   then Conv.arg_conv (Conv.abs_conv (all_params_conv cv o #2) ctxt) ct
       
   986   else cv ctxt ct;
       
   987   
       
   988 (** eta contract higher-order arguments **)
   966 (** eta contract higher-order arguments **)
   989 
   967 
   990 fun eta_contract_ho_arguments thy intro =
   968 fun eta_contract_ho_arguments thy intro =
   991   let
   969   let
   992     fun f atom = list_comb (apsnd ((map o map_products) Envir.eta_contract) (strip_comb atom))
   970     fun f atom = list_comb (apsnd ((map o map_products) Envir.eta_contract) (strip_comb atom))
  1060 fun import_intros inp_pred [] ctxt =
  1038 fun import_intros inp_pred [] ctxt =
  1061   let
  1039   let
  1062     val ([outp_pred], ctxt') = Variable.import_terms true [inp_pred] ctxt
  1040     val ([outp_pred], ctxt') = Variable.import_terms true [inp_pred] ctxt
  1063     val T = fastype_of outp_pred
  1041     val T = fastype_of outp_pred
  1064     val paramTs = ho_argsT_of_typ (binder_types T)
  1042     val paramTs = ho_argsT_of_typ (binder_types T)
  1065     val (param_names, ctxt'') = Variable.variant_fixes
  1043     val (param_names, _) = Variable.variant_fixes
  1066       (map (fn i => "p" ^ (string_of_int i)) (1 upto (length paramTs))) ctxt'
  1044       (map (fn i => "p" ^ (string_of_int i)) (1 upto (length paramTs))) ctxt'
  1067     val params = map2 (curry Free) param_names paramTs
  1045     val params = map2 (curry Free) param_names paramTs
  1068   in
  1046   in
  1069     (((outp_pred, params), []), ctxt')
  1047     (((outp_pred, params), []), ctxt')
  1070   end
  1048   end
  1195 (* defining a quickcheck predicate *)
  1173 (* defining a quickcheck predicate *)
  1196 
  1174 
  1197 fun strip_imp_prems (Const(@{const_name HOL.implies}, _) $ A $ B) = A :: strip_imp_prems B
  1175 fun strip_imp_prems (Const(@{const_name HOL.implies}, _) $ A $ B) = A :: strip_imp_prems B
  1198   | strip_imp_prems _ = [];
  1176   | strip_imp_prems _ = [];
  1199 
  1177 
  1200 fun strip_imp_concl (Const(@{const_name HOL.implies}, _) $ A $ B) = strip_imp_concl B
  1178 fun strip_imp_concl (Const(@{const_name HOL.implies}, _) $ _ $ B) = strip_imp_concl B
  1201   | strip_imp_concl A = A : term;
  1179   | strip_imp_concl A = A;
  1202 
  1180 
  1203 fun strip_horn A = (strip_imp_prems A, strip_imp_concl A);
  1181 fun strip_horn A = (strip_imp_prems A, strip_imp_concl A);
  1204 
  1182 
  1205 fun define_quickcheck_predicate t thy =
  1183 fun define_quickcheck_predicate t thy =
  1206   let
  1184   let