src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML
changeset 55440 721b4561007a
parent 55399 5c8e91f884af
parent 55437 3fd63b92ea3b
child 56239 17df7145a871
equal deleted inserted replaced
55428:0ab52bf7b5e6 55440:721b4561007a
    88   val mk_not : compilation_funs -> term -> term
    88   val mk_not : compilation_funs -> term -> term
    89   val mk_map : compilation_funs -> typ -> typ -> term -> term -> term
    89   val mk_map : compilation_funs -> typ -> typ -> term -> term -> term
    90   val funT_of : compilation_funs -> mode -> typ -> typ
    90   val funT_of : compilation_funs -> mode -> typ -> typ
    91   (* Different compilations *)
    91   (* Different compilations *)
    92   datatype compilation = Pred | Depth_Limited | Random | Depth_Limited_Random | DSeq | Annotated
    92   datatype compilation = Pred | Depth_Limited | Random | Depth_Limited_Random | DSeq | Annotated
    93     | Pos_Random_DSeq | Neg_Random_DSeq | New_Pos_Random_DSeq | New_Neg_Random_DSeq 
    93     | Pos_Random_DSeq | Neg_Random_DSeq | New_Pos_Random_DSeq | New_Neg_Random_DSeq
    94     | Pos_Generator_DSeq | Neg_Generator_DSeq | Pos_Generator_CPS | Neg_Generator_CPS
    94     | Pos_Generator_DSeq | Neg_Generator_DSeq | Pos_Generator_CPS | Neg_Generator_CPS
    95   val negative_compilation_of : compilation -> compilation
    95   val negative_compilation_of : compilation -> compilation
    96   val compilation_for_polarity : bool -> compilation -> compilation
    96   val compilation_for_polarity : bool -> compilation -> compilation
    97   val is_depth_limited_compilation : compilation -> bool 
    97   val is_depth_limited_compilation : compilation -> bool
    98   val string_of_compilation : compilation -> string
    98   val string_of_compilation : compilation -> string
    99   val compilation_names : (string * compilation) list
    99   val compilation_names : (string * compilation) list
   100   val non_random_compilations : compilation list
   100   val non_random_compilations : compilation list
   101   val random_compilations : compilation list
   101   val random_compilations : compilation list
   102   (* Different options for compiler *)
   102   (* Different options for compiler *)
   103   datatype options = Options of {  
   103   datatype options = Options of {
   104     expected_modes : (string * mode list) option,
   104     expected_modes : (string * mode list) option,
   105     proposed_modes : (string * mode list) list,
   105     proposed_modes : (string * mode list) list,
   106     proposed_names : ((string * mode) * string) list,
   106     proposed_names : ((string * mode) * string) list,
   107     show_steps : bool,
   107     show_steps : bool,
   108     show_proof_trace : bool,
   108     show_proof_trace : bool,
   160   val peephole_optimisation : theory -> thm -> thm option
   160   val peephole_optimisation : theory -> thm -> thm option
   161   (* auxillary *)
   161   (* auxillary *)
   162   val unify_consts : theory -> term list -> term list -> (term list * term list)
   162   val unify_consts : theory -> term list -> term list -> (term list * term list)
   163   val mk_casesrule : Proof.context -> term -> thm list -> term
   163   val mk_casesrule : Proof.context -> term -> thm list -> term
   164   val preprocess_intro : theory -> thm -> thm
   164   val preprocess_intro : theory -> thm -> thm
   165   
   165 
   166   val define_quickcheck_predicate :
   166   val define_quickcheck_predicate :
   167     term -> theory -> (((string * typ) * (string * typ) list) * thm) * theory
   167     term -> theory -> (((string * typ) * (string * typ) list) * thm) * theory
   168 end;
   168 end
   169 
   169 
   170 structure Predicate_Compile_Aux : PREDICATE_COMPILE_AUX =
   170 structure Predicate_Compile_Aux : PREDICATE_COMPILE_AUX =
   171 struct
   171 struct
   172 
   172 
   173 (* general functions *)
   173 (* general functions *)
   210   | mode_ord (Input, Input) = EQUAL
   210   | mode_ord (Input, Input) = EQUAL
   211   | mode_ord (Output, Output) = EQUAL
   211   | mode_ord (Output, Output) = EQUAL
   212   | mode_ord (Bool, Bool) = EQUAL
   212   | mode_ord (Bool, Bool) = EQUAL
   213   | mode_ord (Pair (m1, m2), Pair (m3, m4)) = prod_ord mode_ord mode_ord ((m1, m2), (m3, m4))
   213   | mode_ord (Pair (m1, m2), Pair (m3, m4)) = prod_ord mode_ord mode_ord ((m1, m2), (m3, m4))
   214   | mode_ord (Fun (m1, m2), Fun (m3, m4)) = prod_ord mode_ord mode_ord ((m1, m2), (m3, m4))
   214   | mode_ord (Fun (m1, m2), Fun (m3, m4)) = prod_ord mode_ord mode_ord ((m1, m2), (m3, m4))
   215  
   215 
   216 fun list_fun_mode [] = Bool
   216 fun list_fun_mode [] = Bool
   217   | list_fun_mode (m :: ms) = Fun (m, list_fun_mode ms)
   217   | list_fun_mode (m :: ms) = Fun (m, list_fun_mode ms)
   218 
   218 
   219 (* name: binder_modes? *)
   219 (* name: binder_modes? *)
   220 fun strip_fun_mode (Fun (mode, mode')) = mode :: strip_fun_mode mode'
   220 fun strip_fun_mode (Fun (mode, mode')) = mode :: strip_fun_mode mode'
   226   | dest_fun_mode mode = [mode]
   226   | dest_fun_mode mode = [mode]
   227 
   227 
   228 fun dest_tuple_mode (Pair (mode, mode')) = mode :: dest_tuple_mode mode'
   228 fun dest_tuple_mode (Pair (mode, mode')) = mode :: dest_tuple_mode mode'
   229   | dest_tuple_mode _ = []
   229   | dest_tuple_mode _ = []
   230 
   230 
   231 fun all_modes_of_typ' (T as Type ("fun", _)) = 
   231 fun all_modes_of_typ' (T as Type ("fun", _)) =
   232   let
   232   let
   233     val (S, U) = strip_type T
   233     val (S, U) = strip_type T
   234   in
   234   in
   235     if U = HOLogic.boolT then
   235     if U = HOLogic.boolT then
   236       fold_rev (fn m1 => fn m2 => map_product (curry Fun) m1 m2)
   236       fold_rev (fn m1 => fn m2 => map_product (curry Fun) m1 m2)
   237         (map all_modes_of_typ' S) [Bool]
   237         (map all_modes_of_typ' S) [Bool]
   238     else
   238     else
   239       [Input, Output]
   239       [Input, Output]
   240   end
   240   end
   241   | all_modes_of_typ' (Type (@{type_name Product_Type.prod}, [T1, T2])) = 
   241   | all_modes_of_typ' (Type (@{type_name Product_Type.prod}, [T1, T2])) =
   242     map_product (curry Pair) (all_modes_of_typ' T1) (all_modes_of_typ' T2)
   242     map_product (curry Pair) (all_modes_of_typ' T1) (all_modes_of_typ' T2)
   243   | all_modes_of_typ' _ = [Input, Output]
   243   | all_modes_of_typ' _ = [Input, Output]
   244 
   244 
   245 fun all_modes_of_typ (T as Type ("fun", _)) =
   245 fun all_modes_of_typ (T as Type ("fun", _)) =
   246     let
   246     let
   257     raise Fail "Invocation of all_modes_of_typ with a non-predicate type"
   257     raise Fail "Invocation of all_modes_of_typ with a non-predicate type"
   258 
   258 
   259 fun all_smodes_of_typ (T as Type ("fun", _)) =
   259 fun all_smodes_of_typ (T as Type ("fun", _)) =
   260   let
   260   let
   261     val (S, U) = strip_type T
   261     val (S, U) = strip_type T
   262     fun all_smodes (Type (@{type_name Product_Type.prod}, [T1, T2])) = 
   262     fun all_smodes (Type (@{type_name Product_Type.prod}, [T1, T2])) =
   263       map_product (curry Pair) (all_smodes T1) (all_smodes T2)
   263       map_product (curry Pair) (all_smodes T1) (all_smodes T2)
   264       | all_smodes _ = [Input, Output]
   264       | all_smodes _ = [Input, Output]
   265   in
   265   in
   266     if U = HOLogic.boolT then
   266     if U = HOLogic.boolT then
   267       fold_rev (fn m1 => fn m2 => map_product (curry Fun) m1 m2) (map all_smodes S) [Bool]
   267       fold_rev (fn m1 => fn m2 => map_product (curry Fun) m1 m2) (map all_smodes S) [Bool]
   290     flat (map2_optional ho_arg (strip_fun_mode mode) ts)
   290     flat (map2_optional ho_arg (strip_fun_mode mode) ts)
   291   end
   291   end
   292 
   292 
   293 fun ho_args_of_typ T ts =
   293 fun ho_args_of_typ T ts =
   294   let
   294   let
   295     fun ho_arg (T as Type("fun", [_,_])) (SOME t) = if body_type T = @{typ bool} then [t] else []
   295     fun ho_arg (T as Type ("fun", [_, _])) (SOME t) =
   296       | ho_arg (Type("fun", [_,_])) NONE = raise Fail "mode and term do not match"
   296           if body_type T = @{typ bool} then [t] else []
       
   297       | ho_arg (Type ("fun", [_, _])) NONE = raise Fail "mode and term do not match"
   297       | ho_arg (Type(@{type_name "Product_Type.prod"}, [T1, T2]))
   298       | ho_arg (Type(@{type_name "Product_Type.prod"}, [T1, T2]))
   298          (SOME (Const (@{const_name Pair}, _) $ t1 $ t2)) =
   299          (SOME (Const (@{const_name Pair}, _) $ t1 $ t2)) =
   299           ho_arg T1 (SOME t1) @ ho_arg T2 (SOME t2)
   300           ho_arg T1 (SOME t1) @ ho_arg T2 (SOME t2)
   300       | ho_arg (Type(@{type_name "Product_Type.prod"}, [T1, T2])) NONE =
   301       | ho_arg (Type(@{type_name "Product_Type.prod"}, [T1, T2])) NONE =
   301           ho_arg T1 NONE @ ho_arg T2 NONE
   302           ho_arg T1 NONE @ ho_arg T2 NONE
   305   end
   306   end
   306 
   307 
   307 fun ho_argsT_of_typ Ts =
   308 fun ho_argsT_of_typ Ts =
   308   let
   309   let
   309     fun ho_arg (T as Type("fun", [_,_])) = if body_type T = @{typ bool} then [T] else []
   310     fun ho_arg (T as Type("fun", [_,_])) = if body_type T = @{typ bool} then [T] else []
   310       | ho_arg (Type(@{type_name "Product_Type.prod"}, [T1, T2])) =
   311       | ho_arg (Type (@{type_name "Product_Type.prod"}, [T1, T2])) =
   311           ho_arg T1 @ ho_arg T2
   312           ho_arg T1 @ ho_arg T2
   312       | ho_arg _ = []
   313       | ho_arg _ = []
   313   in
   314   in
   314     maps ho_arg Ts
   315     maps ho_arg Ts
   315   end
   316   end
   316   
   317 
   317 
   318 
   318 (* temporary function should be replaced by unsplit_input or so? *)
   319 (* temporary function should be replaced by unsplit_input or so? *)
   319 fun replace_ho_args mode hoargs ts =
   320 fun replace_ho_args mode hoargs ts =
   320   let
   321   let
   321     fun replace (Fun _, _) (arg' :: hoargs') = (arg', hoargs')
   322     fun replace (Fun _, _) (arg' :: hoargs') = (arg', hoargs')
   322       | replace (Pair (m1, m2), Const (@{const_name Pair}, T) $ t1 $ t2) hoargs =
   323       | replace (Pair (m1, m2), Const (@{const_name Pair}, T) $ t1 $ t2) hoargs =
   323         let
   324           let
   324           val (t1', hoargs') = replace (m1, t1) hoargs
   325             val (t1', hoargs') = replace (m1, t1) hoargs
   325           val (t2', hoargs'') = replace (m2, t2) hoargs'
   326             val (t2', hoargs'') = replace (m2, t2) hoargs'
   326         in
   327           in
   327           (Const (@{const_name Pair}, T) $ t1' $ t2', hoargs'')
   328             (Const (@{const_name Pair}, T) $ t1' $ t2', hoargs'')
   328         end
   329           end
   329       | replace (_, t) hoargs = (t, hoargs)
   330       | replace (_, t) hoargs = (t, hoargs)
   330   in
   331   in
   331     fst (fold_map replace (strip_fun_mode mode ~~ ts) hoargs)
   332     fst (fold_map replace (strip_fun_mode mode ~~ ts) hoargs)
   332   end
   333   end
   333 
   334 
   334 fun ho_argsT_of mode Ts =
   335 fun ho_argsT_of mode Ts =
   335   let
   336   let
   336     fun ho_arg (Fun _) T = [T]
   337     fun ho_arg (Fun _) T = [T]
   337       | ho_arg (Pair (m1, m2)) (Type (@{type_name Product_Type.prod}, [T1, T2])) = ho_arg m1 T1 @ ho_arg m2 T2
   338       | ho_arg (Pair (m1, m2)) (Type (@{type_name Product_Type.prod}, [T1, T2])) =
       
   339           ho_arg m1 T1 @ ho_arg m2 T2
   338       | ho_arg _ _ = []
   340       | ho_arg _ _ = []
   339   in
   341   in
   340     flat (map2 ho_arg (strip_fun_mode mode) Ts)
   342     flat (map2 ho_arg (strip_fun_mode mode) Ts)
   341   end
   343   end
   342 
   344 
   378   end
   380   end
   379 
   381 
   380 fun split_mode mode ts = split_map_mode (fn _ => fn _ => (NONE, NONE)) mode ts
   382 fun split_mode mode ts = split_map_mode (fn _ => fn _ => (NONE, NONE)) mode ts
   381 
   383 
   382 fun fold_map_aterms_prodT comb f (Type (@{type_name Product_Type.prod}, [T1, T2])) s =
   384 fun fold_map_aterms_prodT comb f (Type (@{type_name Product_Type.prod}, [T1, T2])) s =
   383   let
   385       let
   384     val (x1, s') = fold_map_aterms_prodT comb f T1 s
   386         val (x1, s') = fold_map_aterms_prodT comb f T1 s
   385     val (x2, s'') = fold_map_aterms_prodT comb f T2 s'
   387         val (x2, s'') = fold_map_aterms_prodT comb f T2 s'
   386   in
   388       in
   387     (comb x1 x2, s'')
   389         (comb x1 x2, s'')
   388   end
   390       end
   389   | fold_map_aterms_prodT comb f T s = f T s
   391   | fold_map_aterms_prodT _ f T s = f T s
   390 
   392 
   391 fun map_filter_prod f (Const (@{const_name Pair}, _) $ t1 $ t2) =
   393 fun map_filter_prod f (Const (@{const_name Pair}, _) $ t1 $ t2) =
   392   comb_option HOLogic.mk_prod (map_filter_prod f t1, map_filter_prod f t2)
   394       comb_option HOLogic.mk_prod (map_filter_prod f t1, map_filter_prod f t2)
   393   | map_filter_prod f t = f t
   395   | map_filter_prod f t = f t
   394   
   396 
   395 fun split_modeT mode Ts =
   397 fun split_modeT mode Ts =
   396   let
   398   let
   397     fun split_arg_mode (Fun _) _ = ([], [])
   399     fun split_arg_mode (Fun _) _ = ([], [])
   398       | split_arg_mode (Pair (m1, m2)) (Type (@{type_name Product_Type.prod}, [T1, T2])) =
   400       | split_arg_mode (Pair (m1, m2)) (Type (@{type_name Product_Type.prod}, [T1, T2])) =
   399         let
   401           let
   400           val (i1, o1) = split_arg_mode m1 T1
   402             val (i1, o1) = split_arg_mode m1 T1
   401           val (i2, o2) = split_arg_mode m2 T2
   403             val (i2, o2) = split_arg_mode m2 T2
   402         in
   404           in
   403           (i1 @ i2, o1 @ o2)
   405             (i1 @ i2, o1 @ o2)
   404         end
   406           end
   405       | split_arg_mode Input T = ([T], [])
   407       | split_arg_mode Input T = ([T], [])
   406       | split_arg_mode Output T = ([], [T])
   408       | split_arg_mode Output T = ([], [T])
   407       | split_arg_mode _ _ = raise Fail "split_modeT: mode and type do not match"
   409       | split_arg_mode _ _ = raise Fail "split_modeT: mode and type do not match"
   408   in
   410   in
   409     (pairself flat o split_list) (map2 split_arg_mode (strip_fun_mode mode) Ts)
   411     (pairself flat o split_list) (map2 split_arg_mode (strip_fun_mode mode) Ts)
   426     fun ascii_string_of_mode' Input = "i"
   428     fun ascii_string_of_mode' Input = "i"
   427       | ascii_string_of_mode' Output = "o"
   429       | ascii_string_of_mode' Output = "o"
   428       | ascii_string_of_mode' Bool = "b"
   430       | ascii_string_of_mode' Bool = "b"
   429       | ascii_string_of_mode' (Pair (m1, m2)) =
   431       | ascii_string_of_mode' (Pair (m1, m2)) =
   430           "P" ^ ascii_string_of_mode' m1 ^ ascii_string_of_mode'_Pair m2
   432           "P" ^ ascii_string_of_mode' m1 ^ ascii_string_of_mode'_Pair m2
   431       | ascii_string_of_mode' (Fun (m1, m2)) = 
   433       | ascii_string_of_mode' (Fun (m1, m2)) =
   432           "F" ^ ascii_string_of_mode' m1 ^ ascii_string_of_mode'_Fun m2 ^ "B"
   434           "F" ^ ascii_string_of_mode' m1 ^ ascii_string_of_mode'_Fun m2 ^ "B"
   433     and ascii_string_of_mode'_Fun (Fun (m1, m2)) =
   435     and ascii_string_of_mode'_Fun (Fun (m1, m2)) =
   434           ascii_string_of_mode' m1 ^ (if m2 = Bool then "" else "_" ^ ascii_string_of_mode'_Fun m2)
   436           ascii_string_of_mode' m1 ^ (if m2 = Bool then "" else "_" ^ ascii_string_of_mode'_Fun m2)
   435       | ascii_string_of_mode'_Fun Bool = "B"
   437       | ascii_string_of_mode'_Fun Bool = "B"
   436       | ascii_string_of_mode'_Fun m = ascii_string_of_mode' m
   438       | ascii_string_of_mode'_Fun m = ascii_string_of_mode' m
   437     and ascii_string_of_mode'_Pair (Pair (m1, m2)) =
   439     and ascii_string_of_mode'_Pair (Pair (m1, m2)) =
   438           ascii_string_of_mode' m1 ^ ascii_string_of_mode'_Pair m2
   440           ascii_string_of_mode' m1 ^ ascii_string_of_mode'_Pair m2
   439       | ascii_string_of_mode'_Pair m = ascii_string_of_mode' m
   441       | ascii_string_of_mode'_Pair m = ascii_string_of_mode' m
   440   in ascii_string_of_mode'_Fun mode' end
   442   in ascii_string_of_mode'_Fun mode' end
   441 
   443 
       
   444 
   442 (* premises *)
   445 (* premises *)
   443 
   446 
   444 datatype indprem = Prem of term | Negprem of term | Sidecond of term
   447 datatype indprem =
   445   | Generator of (string * typ);
   448   Prem of term | Negprem of term | Sidecond of term | Generator of (string * typ)
   446 
   449 
   447 fun dest_indprem (Prem t) = t
   450 fun dest_indprem (Prem t) = t
   448   | dest_indprem (Negprem t) = t
   451   | dest_indprem (Negprem t) = t
   449   | dest_indprem (Sidecond t) = t
   452   | dest_indprem (Sidecond t) = t
   450   | dest_indprem (Generator _) = raise Fail "cannot destruct generator"
   453   | dest_indprem (Generator _) = raise Fail "cannot destruct generator"
   452 fun map_indprem f (Prem t) = Prem (f t)
   455 fun map_indprem f (Prem t) = Prem (f t)
   453   | map_indprem f (Negprem t) = Negprem (f t)
   456   | map_indprem f (Negprem t) = Negprem (f t)
   454   | map_indprem f (Sidecond t) = Sidecond (f t)
   457   | map_indprem f (Sidecond t) = Sidecond (f t)
   455   | map_indprem f (Generator (v, T)) = Generator (dest_Free (f (Free (v, T))))
   458   | map_indprem f (Generator (v, T)) = Generator (dest_Free (f (Free (v, T))))
   456 
   459 
       
   460 
   457 (* general syntactic functions *)
   461 (* general syntactic functions *)
   458 
   462 
   459 fun is_equationlike_term (Const ("==", _) $ _ $ _) = true
   463 fun is_equationlike_term (Const ("==", _) $ _ $ _) = true
   460   | is_equationlike_term (Const (@{const_name Trueprop}, _) $ (Const (@{const_name HOL.eq}, _) $ _ $ _)) = true
   464   | is_equationlike_term
       
   465       (Const (@{const_name Trueprop}, _) $ (Const (@{const_name HOL.eq}, _) $ _ $ _)) = true
   461   | is_equationlike_term _ = false
   466   | is_equationlike_term _ = false
   462   
   467 
   463 val is_equationlike = is_equationlike_term o prop_of 
   468 val is_equationlike = is_equationlike_term o prop_of
   464 
   469 
   465 fun is_pred_equation_term (Const ("==", _) $ u $ v) =
   470 fun is_pred_equation_term (Const ("==", _) $ u $ v) =
   466   (fastype_of u = @{typ bool}) andalso (fastype_of v = @{typ bool})
   471       (fastype_of u = @{typ bool}) andalso (fastype_of v = @{typ bool})
   467   | is_pred_equation_term _ = false
   472   | is_pred_equation_term _ = false
   468   
   473 
   469 val is_pred_equation = is_pred_equation_term o prop_of 
   474 val is_pred_equation = is_pred_equation_term o prop_of
   470 
   475 
   471 fun is_intro_term constname t =
   476 fun is_intro_term constname t =
   472   the_default false (try (fn t => case fst (strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl t))) of
   477   the_default false (try (fn t =>
   473     Const (c, _) => c = constname
   478     case fst (strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl t))) of
   474   | _ => false) t)
   479       Const (c, _) => c = constname
   475   
   480     | _ => false) t)
       
   481 
   476 fun is_intro constname t = is_intro_term constname (prop_of t)
   482 fun is_intro constname t = is_intro_term constname (prop_of t)
   477 
   483 
   478 fun is_predT (T as Type("fun", [_, _])) = (body_type T = @{typ bool})
   484 fun is_predT (T as Type("fun", [_, _])) = (body_type T = @{typ bool})
   479   | is_predT _ = false
   485   | is_predT _ = false
   480 
   486 
   492 (* FIXME: constructor terms are supposed to be seen in the way the code generator
   498 (* FIXME: constructor terms are supposed to be seen in the way the code generator
   493   sees constructors.*)
   499   sees constructors.*)
   494 fun is_constrt thy =
   500 fun is_constrt thy =
   495   let
   501   let
   496     val cnstrs = get_constrs thy
   502     val cnstrs = get_constrs thy
   497     fun check t = (case strip_comb t of
   503     fun check t =
       
   504       (case strip_comb t of
   498         (Var _, []) => true
   505         (Var _, []) => true
   499       | (Free _, []) => true
   506       | (Free _, []) => true
   500       | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of
   507       | (Const (s, T), ts) =>
   501             (SOME (i, Tname), Type (Tname', _)) => length ts = i andalso Tname = Tname' andalso forall check ts
   508           (case (AList.lookup (op =) cnstrs s, body_type T) of
       
   509             (SOME (i, Tname), Type (Tname', _)) =>
       
   510               length ts = i andalso Tname = Tname' andalso forall check ts
   502           | _ => false)
   511           | _ => false)
   503       | _ => false)
   512       | _ => false)
   504   in check end;
   513   in check end
   505 
   514 
   506 val is_constr = Code.is_constr o Proof_Context.theory_of;
   515 val is_constr = Code.is_constr o Proof_Context.theory_of
   507 
   516 
   508 fun strip_all t = (Term.strip_all_vars t, Term.strip_all_body t)
   517 fun strip_all t = (Term.strip_all_vars t, Term.strip_all_body t)
   509 
   518 
   510 fun strip_ex (Const (@{const_name Ex}, _) $ Abs (x, T, t)) =
   519 fun strip_ex (Const (@{const_name Ex}, _) $ Abs (x, T, t)) =
   511   let
   520       let
   512     val (xTs, t') = strip_ex t
   521         val (xTs, t') = strip_ex t
   513   in
   522       in
   514     ((x, T) :: xTs, t')
   523         ((x, T) :: xTs, t')
   515   end
   524       end
   516   | strip_ex t = ([], t)
   525   | strip_ex t = ([], t)
   517 
   526 
   518 fun focus_ex t nctxt =
   527 fun focus_ex t nctxt =
   519   let
   528   let
   520     val ((xs, Ts), t') = apfst split_list (strip_ex t) 
   529     val ((xs, Ts), t') = apfst split_list (strip_ex t)
   521     val (xs', nctxt') = fold_map Name.variant xs nctxt;
   530     val (xs', nctxt') = fold_map Name.variant xs nctxt;
   522     val ps' = xs' ~~ Ts;
   531     val ps' = xs' ~~ Ts;
   523     val vs = map Free ps';
   532     val vs = map Free ps';
   524     val t'' = Term.subst_bounds (rev vs, t');
   533     val t'' = Term.subst_bounds (rev vs, t');
   525   in ((ps', t''), nctxt') end;
   534   in ((ps', t''), nctxt') end
   526 
   535 
   527 val strip_intro_concl = (strip_comb o HOLogic.dest_Trueprop o Logic.strip_imp_concl o prop_of)
   536 val strip_intro_concl = strip_comb o HOLogic.dest_Trueprop o Logic.strip_imp_concl o prop_of
   528   
   537 
       
   538 
   529 (* introduction rule combinators *)
   539 (* introduction rule combinators *)
   530 
   540 
   531 fun map_atoms f intro = 
   541 fun map_atoms f intro =
   532   let
   542   let
   533     val (literals, head) = Logic.strip_horn intro
   543     val (literals, head) = Logic.strip_horn intro
   534     fun appl t = (case t of
   544     fun appl t =
       
   545       (case t of
   535         (@{term Not} $ t') => HOLogic.mk_not (f t')
   546         (@{term Not} $ t') => HOLogic.mk_not (f t')
   536       | _ => f t)
   547       | _ => f t)
   537   in
   548   in
   538     Logic.list_implies
   549     Logic.list_implies
   539       (map (HOLogic.mk_Trueprop o appl o HOLogic.dest_Trueprop) literals, head)
   550       (map (HOLogic.mk_Trueprop o appl o HOLogic.dest_Trueprop) literals, head)
   540   end
   551   end
   541 
   552 
   542 fun fold_atoms f intro s =
   553 fun fold_atoms f intro s =
   543   let
   554   let
   544     val (literals, _) = Logic.strip_horn intro
   555     val (literals, _) = Logic.strip_horn intro
   545     fun appl t s = (case t of
   556     fun appl t s =
   546       (@{term Not} $ t') => f t' s
   557       (case t of
       
   558         (@{term Not} $ t') => f t' s
   547       | _ => f t s)
   559       | _ => f t s)
   548   in fold appl (map HOLogic.dest_Trueprop literals) s end
   560   in fold appl (map HOLogic.dest_Trueprop literals) s end
   549 
   561 
   550 fun fold_map_atoms f intro s =
   562 fun fold_map_atoms f intro s =
   551   let
   563   let
   552     val (literals, head) = Logic.strip_horn intro
   564     val (literals, head) = Logic.strip_horn intro
   553     fun appl t s = (case t of
   565     fun appl t s =
   554       (@{term Not} $ t') => apfst HOLogic.mk_not (f t' s)
   566       (case t of
       
   567         (@{term Not} $ t') => apfst HOLogic.mk_not (f t' s)
   555       | _ => f t s)
   568       | _ => f t s)
   556     val (literals', s') = fold_map appl (map HOLogic.dest_Trueprop literals) s
   569     val (literals', s') = fold_map appl (map HOLogic.dest_Trueprop literals) s
   557   in
   570   in
   558     (Logic.list_implies (map HOLogic.mk_Trueprop literals', head), s')
   571     (Logic.list_implies (map HOLogic.mk_Trueprop literals', head), s')
   559   end;
   572   end;
   576   let
   589   let
   577     val (premises, head) = Logic.strip_horn intro
   590     val (premises, head) = Logic.strip_horn intro
   578   in
   591   in
   579     Logic.list_implies (premises, f head)
   592     Logic.list_implies (premises, f head)
   580   end
   593   end
       
   594 
   581 
   595 
   582 (* combinators to apply a function to all basic parts of nested products *)
   596 (* combinators to apply a function to all basic parts of nested products *)
   583 
   597 
   584 fun map_products f (Const (@{const_name Pair}, T) $ t1 $ t2) =
   598 fun map_products f (Const (@{const_name Pair}, T) $ t1 $ t2) =
   585   Const (@{const_name Pair}, T) $ map_products f t1 $ map_products f t2
   599   Const (@{const_name Pair}, T) $ map_products f t1 $ map_products f t2
   586   | map_products f t = f t
   600   | map_products f t = f t
       
   601 
   587 
   602 
   588 (* split theorems of case expressions *)
   603 (* split theorems of case expressions *)
   589 
   604 
   590 fun prepare_split_thm ctxt split_thm =
   605 fun prepare_split_thm ctxt split_thm =
   591     (split_thm RS @{thm iffD2})
   606     (split_thm RS @{thm iffD2})
   592     |> Local_Defs.unfold ctxt [@{thm atomize_conjL[symmetric]},
   607     |> Local_Defs.unfold ctxt [@{thm atomize_conjL[symmetric]},
   593       @{thm atomize_all[symmetric]}, @{thm atomize_imp[symmetric]}]
   608       @{thm atomize_all[symmetric]}, @{thm atomize_imp[symmetric]}]
   594 
   609 
   595 fun find_split_thm thy (Const (name, _)) =
   610 fun find_split_thm thy (Const (name, _)) =
   596     Option.map #split (Ctr_Sugar.ctr_sugar_of_case (Proof_Context.init_global thy) name)
   611     Option.map #split (Ctr_Sugar.ctr_sugar_of_case (Proof_Context.init_global thy) name)
   597   | find_split_thm thy _ = NONE
   612   | find_split_thm _ _ = NONE
       
   613 
   598 
   614 
   599 (* lifting term operations to theorems *)
   615 (* lifting term operations to theorems *)
   600 
   616 
   601 fun map_term thy f th =
   617 fun map_term thy f th =
   602   Skip_Proof.make_thm thy (f (prop_of th))
   618   Skip_Proof.make_thm thy (f (prop_of th))
   603 
   619 
   604 (*
   620 (*
   605 fun equals_conv lhs_cv rhs_cv ct =
   621 fun equals_conv lhs_cv rhs_cv ct =
   606   case Thm.term_of ct of
   622   case Thm.term_of ct of
   607     Const ("==", _) $ _ $ _ => Conv.arg_conv cv ct  
   623     Const ("==", _) $ _ $ _ => Conv.arg_conv cv ct
   608   | _ => error "equals_conv"  
   624   | _ => error "equals_conv"
   609 *)
   625 *)
       
   626 
   610 
   627 
   611 (* Different compilations *)
   628 (* Different compilations *)
   612 
   629 
   613 datatype compilation = Pred | Depth_Limited | Random | Depth_Limited_Random | DSeq | Annotated
   630 datatype compilation = Pred | Depth_Limited | Random | Depth_Limited_Random | DSeq | Annotated
   614   | Pos_Random_DSeq | Neg_Random_DSeq | New_Pos_Random_DSeq | New_Neg_Random_DSeq |
   631   | Pos_Random_DSeq | Neg_Random_DSeq | New_Pos_Random_DSeq | New_Neg_Random_DSeq |
   619   | negative_compilation_of New_Pos_Random_DSeq = New_Neg_Random_DSeq
   636   | negative_compilation_of New_Pos_Random_DSeq = New_Neg_Random_DSeq
   620   | negative_compilation_of New_Neg_Random_DSeq = New_Pos_Random_DSeq
   637   | negative_compilation_of New_Neg_Random_DSeq = New_Pos_Random_DSeq
   621   | negative_compilation_of Pos_Generator_DSeq = Neg_Generator_DSeq
   638   | negative_compilation_of Pos_Generator_DSeq = Neg_Generator_DSeq
   622   | negative_compilation_of Neg_Generator_DSeq = Pos_Generator_DSeq
   639   | negative_compilation_of Neg_Generator_DSeq = Pos_Generator_DSeq
   623   | negative_compilation_of Pos_Generator_CPS = Neg_Generator_CPS
   640   | negative_compilation_of Pos_Generator_CPS = Neg_Generator_CPS
   624   | negative_compilation_of Neg_Generator_CPS = Pos_Generator_CPS  
   641   | negative_compilation_of Neg_Generator_CPS = Pos_Generator_CPS
   625   | negative_compilation_of c = c
   642   | negative_compilation_of c = c
   626   
   643 
   627 fun compilation_for_polarity false Pos_Random_DSeq = Neg_Random_DSeq
   644 fun compilation_for_polarity false Pos_Random_DSeq = Neg_Random_DSeq
   628   | compilation_for_polarity false New_Pos_Random_DSeq = New_Neg_Random_DSeq
   645   | compilation_for_polarity false New_Pos_Random_DSeq = New_Neg_Random_DSeq
   629   | compilation_for_polarity _ c = c
   646   | compilation_for_polarity _ c = c
   630 
   647 
   631 fun is_depth_limited_compilation c =
   648 fun is_depth_limited_compilation c =
   632   (c = New_Pos_Random_DSeq) orelse (c = New_Neg_Random_DSeq) orelse
   649   (c = New_Pos_Random_DSeq) orelse (c = New_Neg_Random_DSeq) orelse
   633   (c = Pos_Generator_DSeq) orelse (c = Pos_Generator_DSeq)
   650   (c = Pos_Generator_DSeq) orelse (c = Pos_Generator_DSeq)
   634 
   651 
   635 fun string_of_compilation c =
   652 fun string_of_compilation c =
   636   case c of
   653   (case c of
   637     Pred => ""
   654     Pred => ""
   638   | Random => "random"
   655   | Random => "random"
   639   | Depth_Limited => "depth limited"
   656   | Depth_Limited => "depth limited"
   640   | Depth_Limited_Random => "depth limited random"
   657   | Depth_Limited_Random => "depth limited random"
   641   | DSeq => "dseq"
   658   | DSeq => "dseq"
   645   | New_Pos_Random_DSeq => "new_pos_random dseq"
   662   | New_Pos_Random_DSeq => "new_pos_random dseq"
   646   | New_Neg_Random_DSeq => "new_neg_random_dseq"
   663   | New_Neg_Random_DSeq => "new_neg_random_dseq"
   647   | Pos_Generator_DSeq => "pos_generator_dseq"
   664   | Pos_Generator_DSeq => "pos_generator_dseq"
   648   | Neg_Generator_DSeq => "neg_generator_dseq"
   665   | Neg_Generator_DSeq => "neg_generator_dseq"
   649   | Pos_Generator_CPS => "pos_generator_cps"
   666   | Pos_Generator_CPS => "pos_generator_cps"
   650   | Neg_Generator_CPS => "neg_generator_cps"
   667   | Neg_Generator_CPS => "neg_generator_cps")
   651   
   668 
   652 val compilation_names = [("pred", Pred),
   669 val compilation_names =
       
   670  [("pred", Pred),
   653   ("random", Random),
   671   ("random", Random),
   654   ("depth_limited", Depth_Limited),
   672   ("depth_limited", Depth_Limited),
   655   ("depth_limited_random", Depth_Limited_Random),
   673   ("depth_limited_random", Depth_Limited_Random),
   656   (*("annotated", Annotated),*)
   674   (*("annotated", Annotated),*)
   657   ("dseq", DSeq),
   675   ("dseq", DSeq),
   664 
   682 
   665 
   683 
   666 val random_compilations = [Random, Depth_Limited_Random,
   684 val random_compilations = [Random, Depth_Limited_Random,
   667   Pos_Random_DSeq, Neg_Random_DSeq, New_Pos_Random_DSeq, New_Neg_Random_DSeq,
   685   Pos_Random_DSeq, Neg_Random_DSeq, New_Pos_Random_DSeq, New_Neg_Random_DSeq,
   668   Pos_Generator_CPS, Neg_Generator_CPS]
   686   Pos_Generator_CPS, Neg_Generator_CPS]
       
   687 
   669 
   688 
   670 (* datastructures and setup for generic compilation *)
   689 (* datastructures and setup for generic compilation *)
   671 
   690 
   672 datatype compilation_funs = CompilationFuns of {
   691 datatype compilation_funs = CompilationFuns of {
   673   mk_monadT : typ -> typ,
   692   mk_monadT : typ -> typ,
   678   mk_plus : term * term -> term,
   697   mk_plus : term * term -> term,
   679   mk_if : term -> term,
   698   mk_if : term -> term,
   680   mk_iterate_upto : typ -> term * term * term -> term,
   699   mk_iterate_upto : typ -> term * term * term -> term,
   681   mk_not : term -> term,
   700   mk_not : term -> term,
   682   mk_map : typ -> typ -> term -> term -> term
   701   mk_map : typ -> typ -> term -> term -> term
   683 };
   702 }
   684 
   703 
   685 fun mk_monadT (CompilationFuns funs) = #mk_monadT funs
   704 fun mk_monadT (CompilationFuns funs) = #mk_monadT funs
   686 fun dest_monadT (CompilationFuns funs) = #dest_monadT funs
   705 fun dest_monadT (CompilationFuns funs) = #dest_monadT funs
   687 fun mk_empty (CompilationFuns funs) = #mk_empty funs
   706 fun mk_empty (CompilationFuns funs) = #mk_empty funs
   688 fun mk_single (CompilationFuns funs) = #mk_single funs
   707 fun mk_single (CompilationFuns funs) = #mk_single funs
   691 fun mk_if (CompilationFuns funs) = #mk_if funs
   710 fun mk_if (CompilationFuns funs) = #mk_if funs
   692 fun mk_iterate_upto (CompilationFuns funs) = #mk_iterate_upto funs
   711 fun mk_iterate_upto (CompilationFuns funs) = #mk_iterate_upto funs
   693 fun mk_not (CompilationFuns funs) = #mk_not funs
   712 fun mk_not (CompilationFuns funs) = #mk_not funs
   694 fun mk_map (CompilationFuns funs) = #mk_map funs
   713 fun mk_map (CompilationFuns funs) = #mk_map funs
   695 
   714 
       
   715 
   696 (** function types and names of different compilations **)
   716 (** function types and names of different compilations **)
   697 
   717 
   698 fun funT_of compfuns mode T =
   718 fun funT_of compfuns mode T =
   699   let
   719   let
   700     val Ts = binder_types T
   720     val Ts = binder_types T
   701     val (inTs, outTs) = split_map_modeT (fn m => fn T => (SOME (funT_of compfuns m T), NONE)) mode Ts
   721     val (inTs, outTs) =
       
   722       split_map_modeT (fn m => fn T => (SOME (funT_of compfuns m T), NONE)) mode Ts
   702   in
   723   in
   703     inTs ---> (mk_monadT compfuns (HOLogic.mk_tupleT outTs))
   724     inTs ---> (mk_monadT compfuns (HOLogic.mk_tupleT outTs))
   704   end;
   725   end
       
   726 
   705 
   727 
   706 (* Different options for compiler *)
   728 (* Different options for compiler *)
   707 
   729 
   708 datatype options = Options of {  
   730 datatype options = Options of {
   709   expected_modes : (string * mode list) option,
   731   expected_modes : (string * mode list) option,
   710   proposed_modes : (string * mode list) list,
   732   proposed_modes : (string * mode list) list,
   711   proposed_names : ((string * mode) * string) list,
   733   proposed_names : ((string * mode) * string) list,
   712   show_steps : bool,
   734   show_steps : bool,
   713   show_proof_trace : bool,
   735   show_proof_trace : bool,
   725   no_higher_order_predicate : string list,
   747   no_higher_order_predicate : string list,
   726   inductify : bool,
   748   inductify : bool,
   727   detect_switches : bool,
   749   detect_switches : bool,
   728   smart_depth_limiting : bool,
   750   smart_depth_limiting : bool,
   729   compilation : compilation
   751   compilation : compilation
   730 };
   752 }
   731 
   753 
   732 fun expected_modes (Options opt) = #expected_modes opt
   754 fun expected_modes (Options opt) = #expected_modes opt
   733 fun proposed_modes (Options opt) = AList.lookup (op =) (#proposed_modes opt)
   755 fun proposed_modes (Options opt) = AList.lookup (op =) (#proposed_modes opt)
   734 fun proposed_names (Options opt) name mode = AList.lookup (eq_pair (op =) eq_mode)
   756 fun proposed_names (Options opt) name mode = AList.lookup (eq_pair (op =) eq_mode)
   735   (#proposed_names opt) (name, mode)
   757   (#proposed_names opt) (name, mode)
   788   "smart_depth_limiting"]
   810   "smart_depth_limiting"]
   789 
   811 
   790 fun print_step options s =
   812 fun print_step options s =
   791   if show_steps options then tracing s else ()
   813   if show_steps options then tracing s else ()
   792 
   814 
       
   815 
   793 (* simple transformations *)
   816 (* simple transformations *)
   794 
   817 
   795 (** tuple processing **)
   818 (** tuple processing **)
   796 
   819 
   797 fun rewrite_args [] (pats, intro_t, ctxt) = (pats, intro_t, ctxt)
   820 fun rewrite_args [] (pats, intro_t, ctxt) = (pats, intro_t, ctxt)
   798   | rewrite_args (arg::args) (pats, intro_t, ctxt) = 
   821   | rewrite_args (arg::args) (pats, intro_t, ctxt) =
   799     (case HOLogic.strip_tupleT (fastype_of arg) of
   822       (case HOLogic.strip_tupleT (fastype_of arg) of
   800       (_ :: _ :: _) =>
   823         (_ :: _ :: _) =>
   801       let
   824         let
   802         fun rewrite_arg' (Const (@{const_name Pair}, _) $ _ $ t2, Type (@{type_name Product_Type.prod}, [_, T2]))
   825           fun rewrite_arg'
   803           (args, (pats, intro_t, ctxt)) = rewrite_arg' (t2, T2) (args, (pats, intro_t, ctxt))
   826                 (Const (@{const_name Pair}, _) $ _ $ t2, Type (@{type_name Product_Type.prod}, [_, T2]))
   804           | rewrite_arg' (t, Type (@{type_name Product_Type.prod}, [T1, T2])) (args, (pats, intro_t, ctxt)) =
   827                 (args, (pats, intro_t, ctxt)) =
   805             let
   828                 rewrite_arg' (t2, T2) (args, (pats, intro_t, ctxt))
   806               val thy = Proof_Context.theory_of ctxt
   829             | rewrite_arg'
   807               val ([x, y], ctxt') = Variable.variant_fixes ["x", "y"] ctxt
   830                 (t, Type (@{type_name Product_Type.prod}, [T1, T2])) (args, (pats, intro_t, ctxt)) =
   808               val pat = (t, HOLogic.mk_prod (Free (x, T1), Free (y, T2)))
   831                 let
   809               val intro_t' = Pattern.rewrite_term thy [pat] [] intro_t
   832                   val thy = Proof_Context.theory_of ctxt
   810               val args' = map (Pattern.rewrite_term thy [pat] []) args
   833                   val ([x, y], ctxt') = Variable.variant_fixes ["x", "y"] ctxt
   811             in
   834                   val pat = (t, HOLogic.mk_prod (Free (x, T1), Free (y, T2)))
   812               rewrite_arg' (Free (y, T2), T2) (args', (pat::pats, intro_t', ctxt'))
   835                   val intro_t' = Pattern.rewrite_term thy [pat] [] intro_t
   813             end
   836                   val args' = map (Pattern.rewrite_term thy [pat] []) args
   814           | rewrite_arg' _ (args, (pats, intro_t, ctxt)) = (args, (pats, intro_t, ctxt))
   837                 in
   815         val (args', (pats, intro_t', ctxt')) = rewrite_arg' (arg, fastype_of arg)
   838                   rewrite_arg' (Free (y, T2), T2) (args', (pat::pats, intro_t', ctxt'))
   816           (args, (pats, intro_t, ctxt))
   839                 end
   817       in
   840             | rewrite_arg' _ (args, (pats, intro_t, ctxt)) = (args, (pats, intro_t, ctxt))
   818         rewrite_args args' (pats, intro_t', ctxt')
   841           val (args', (pats, intro_t', ctxt')) =
   819       end
   842             rewrite_arg' (arg, fastype_of arg) (args, (pats, intro_t, ctxt))
       
   843         in
       
   844           rewrite_args args' (pats, intro_t', ctxt')
       
   845         end
   820   | _ => rewrite_args args (pats, intro_t, ctxt))
   846   | _ => rewrite_args args (pats, intro_t, ctxt))
   821 
   847 
   822 fun rewrite_prem atom =
   848 fun rewrite_prem atom =
   823   let
   849   let
   824     val (_, args) = strip_comb atom
   850     val (_, args) = strip_comb atom
   825   in rewrite_args args end
   851   in rewrite_args args end
   826 
   852 
   827 fun split_conjuncts_in_assms ctxt th =
   853 fun split_conjuncts_in_assms ctxt th =
   828   let
   854   let
   829     val ((_, [fixed_th]), ctxt') = Variable.import false [th] ctxt 
   855     val ((_, [fixed_th]), ctxt') = Variable.import false [th] ctxt
   830     fun split_conjs i nprems th =
   856     fun split_conjs i nprems th =
   831       if i > nprems then th
   857       if i > nprems then th
   832       else
   858       else
   833         case try Drule.RSN (@{thm conjI}, (i, th)) of
   859         (case try Drule.RSN (@{thm conjI}, (i, th)) of
   834           SOME th' => split_conjs i (nprems+1) th'
   860           SOME th' => split_conjs i (nprems + 1) th'
   835         | NONE => split_conjs (i+1) nprems th
   861         | NONE => split_conjs (i + 1) nprems th)
   836   in
   862   in
   837     singleton (Variable.export ctxt' ctxt) (split_conjs 1 (Thm.nprems_of fixed_th) fixed_th)
   863     singleton (Variable.export ctxt' ctxt)
       
   864       (split_conjs 1 (Thm.nprems_of fixed_th) fixed_th)
   838   end
   865   end
   839 
   866 
   840 fun dest_conjunct_prem th =
   867 fun dest_conjunct_prem th =
   841   case HOLogic.dest_Trueprop (prop_of th) of
   868   (case HOLogic.dest_Trueprop (prop_of th) of
   842     (Const (@{const_name HOL.conj}, _) $ _ $ _) =>
   869     (Const (@{const_name HOL.conj}, _) $ _ $ _) =>
   843       dest_conjunct_prem (th RS @{thm conjunct1})
   870       dest_conjunct_prem (th RS @{thm conjunct1})
   844         @ dest_conjunct_prem (th RS @{thm conjunct2})
   871         @ dest_conjunct_prem (th RS @{thm conjunct2})
   845     | _ => [th]
   872    | _ => [th])
   846 
   873 
   847 fun expand_tuples thy intro =
   874 fun expand_tuples thy intro =
   848   let
   875   let
   849     val ctxt = Proof_Context.init_global thy
   876     val ctxt = Proof_Context.init_global thy
   850     val (((T_insts, t_insts), [intro']), ctxt1) = Variable.import false [intro] ctxt
   877     val (((T_insts, t_insts), [intro']), ctxt1) = Variable.import false [intro] ctxt
   867     val intro''''' = split_conjuncts_in_assms ctxt intro''''
   894     val intro''''' = split_conjuncts_in_assms ctxt intro''''
   868   in
   895   in
   869     intro'''''
   896     intro'''''
   870   end
   897   end
   871 
   898 
       
   899 
   872 (** making case distributivity rules **)
   900 (** making case distributivity rules **)
   873 (*** this should be part of the datatype package ***)
   901 (*** this should be part of the datatype package ***)
   874 
   902 
   875 fun datatype_name_of_case_name thy =
   903 fun datatype_name_of_case_name thy =
   876   Ctr_Sugar.ctr_sugar_of_case (Proof_Context.init_global thy)
   904   Ctr_Sugar.ctr_sugar_of_case (Proof_Context.init_global thy)
   938     val th = instantiated_case_rewrite thy Tcon
   966     val th = instantiated_case_rewrite thy Tcon
   939   in
   967   in
   940     Raw_Simplifier.rewrite_term thy [th RS @{thm eq_reflection}] [] t
   968     Raw_Simplifier.rewrite_term thy [th RS @{thm eq_reflection}] [] t
   941   end
   969   end
   942 
   970 
       
   971 
   943 (*** conversions ***)
   972 (*** conversions ***)
   944 
   973 
   945 fun imp_prems_conv cv ct =
   974 fun imp_prems_conv cv ct =
   946   case Thm.term_of ct of
   975   (case Thm.term_of ct of
   947     Const ("==>", _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv) (imp_prems_conv cv) ct
   976     Const ("==>", _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv) (imp_prems_conv cv) ct
   948   | _ => Conv.all_conv ct
   977   | _ => Conv.all_conv ct)
       
   978 
   949 
   979 
   950 (** eta contract higher-order arguments **)
   980 (** eta contract higher-order arguments **)
   951 
   981 
   952 fun eta_contract_ho_arguments thy intro =
   982 fun eta_contract_ho_arguments thy intro =
   953   let
   983   let
   954     fun f atom = list_comb (apsnd ((map o map_products) Envir.eta_contract) (strip_comb atom))
   984     fun f atom = list_comb (apsnd ((map o map_products) Envir.eta_contract) (strip_comb atom))
   955   in
   985   in
   956     map_term thy (map_concl f o map_atoms f) intro
   986     map_term thy (map_concl f o map_atoms f) intro
   957   end
   987   end
       
   988 
   958 
   989 
   959 (** remove equalities **)
   990 (** remove equalities **)
   960 
   991 
   961 fun remove_equalities thy intro =
   992 fun remove_equalities thy intro =
   962   let
   993   let
   964       let
   995       let
   965         val (prems, concl) = Logic.strip_horn intro_t
   996         val (prems, concl) = Logic.strip_horn intro_t
   966         fun remove_eq (prems, concl) =
   997         fun remove_eq (prems, concl) =
   967           let
   998           let
   968             fun removable_eq prem =
   999             fun removable_eq prem =
   969               case try (HOLogic.dest_eq o HOLogic.dest_Trueprop) prem of
  1000               (case try (HOLogic.dest_eq o HOLogic.dest_Trueprop) prem of
   970                 SOME (lhs, rhs) => (case lhs of
  1001                 SOME (lhs, rhs) =>
   971                   Var _ => true
  1002                   (case lhs of
       
  1003                     Var _ => true
   972                   | _ => (case rhs of Var _ => true | _ => false))
  1004                   | _ => (case rhs of Var _ => true | _ => false))
   973               | NONE => false
  1005               | NONE => false)
   974           in
  1006           in
   975             case find_first removable_eq prems of
  1007             (case find_first removable_eq prems of
   976               NONE => (prems, concl)
  1008               NONE => (prems, concl)
   977             | SOME eq =>
  1009             | SOME eq =>
   978               let
  1010                 let
   979                 val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
  1011                   val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
   980                 val prems' = remove (op =) eq prems
  1012                   val prems' = remove (op =) eq prems
   981                 val subst = (case lhs of
  1013                   val subst =
   982                   (v as Var _) =>
  1014                     (case lhs of
   983                     (fn t => if t = v then rhs else t)
  1015                       (v as Var _) =>
   984                 | _ => (case rhs of
  1016                         (fn t => if t = v then rhs else t)
   985                    (v as Var _) => (fn t => if t = v then lhs else t)))
  1017                     | _ => (case rhs of (v as Var _) => (fn t => if t = v then lhs else t)))
   986               in
  1018                 in
   987                 remove_eq (map (map_aterms subst) prems', map_aterms subst concl)
  1019                   remove_eq (map (map_aterms subst) prems', map_aterms subst concl)
   988               end
  1020                 end)
   989           end
  1021           end
   990       in
  1022       in
   991         Logic.list_implies (remove_eq (prems, concl))
  1023         Logic.list_implies (remove_eq (prems, concl))
   992       end
  1024       end
   993   in
  1025   in
   994     map_term thy remove_eqs intro
  1026     map_term thy remove_eqs intro
   995   end
  1027   end
   996 
  1028 
       
  1029 
   997 (* Some last processing *)
  1030 (* Some last processing *)
   998 
  1031 
   999 fun remove_pointless_clauses intro =
  1032 fun remove_pointless_clauses intro =
  1000   if Logic.strip_imp_prems (prop_of intro) = [@{prop "False"}] then
  1033   if Logic.strip_imp_prems (prop_of intro) = [@{prop "False"}] then
  1001     []
  1034     []
  1002   else [intro]
  1035   else [intro]
  1003 
  1036 
       
  1037 
  1004 (* some peephole optimisations *)
  1038 (* some peephole optimisations *)
  1005 
  1039 
  1006 fun peephole_optimisation thy intro =
  1040 fun peephole_optimisation thy intro =
  1007   let
  1041   let
  1008     val ctxt = Proof_Context.init_global thy  (* FIXME proper context!? *)
  1042     val ctxt = Proof_Context.init_global thy  (* FIXME proper context!? *)
  1009     val process =
  1043     val process =
  1010       rewrite_rule ctxt (Predicate_Compile_Simps.get ctxt)
  1044       rewrite_rule ctxt (Predicate_Compile_Simps.get ctxt)
  1011     fun process_False intro_t =
  1045     fun process_False intro_t =
  1012       if member (op =) (Logic.strip_imp_prems intro_t) @{prop "False"} then NONE else SOME intro_t
  1046       if member (op =) (Logic.strip_imp_prems intro_t) @{prop "False"}
       
  1047       then NONE else SOME intro_t
  1013     fun process_True intro_t =
  1048     fun process_True intro_t =
  1014       map_filter_premises (fn p => if p = @{prop True} then NONE else SOME p) intro_t
  1049       map_filter_premises (fn p => if p = @{prop True} then NONE else SOME p) intro_t
  1015   in
  1050   in
  1016     Option.map (Skip_Proof.make_thm thy)
  1051     Option.map (Skip_Proof.make_thm thy)
  1017       (process_False (process_True (prop_of (process intro))))
  1052       (process_False (process_True (prop_of (process intro))))
  1019 
  1054 
  1020 
  1055 
  1021 (* importing introduction rules *)
  1056 (* importing introduction rules *)
  1022 
  1057 
  1023 fun import_intros inp_pred [] ctxt =
  1058 fun import_intros inp_pred [] ctxt =
  1024   let
  1059       let
  1025     val ([outp_pred], ctxt') = Variable.import_terms true [inp_pred] ctxt
  1060         val ([outp_pred], ctxt') = Variable.import_terms true [inp_pred] ctxt
  1026     val T = fastype_of outp_pred
  1061         val T = fastype_of outp_pred
  1027     val paramTs = ho_argsT_of_typ (binder_types T)
  1062         val paramTs = ho_argsT_of_typ (binder_types T)
  1028     val (param_names, _) = Variable.variant_fixes
  1063         val (param_names, _) = Variable.variant_fixes
  1029       (map (fn i => "p" ^ (string_of_int i)) (1 upto (length paramTs))) ctxt'
  1064           (map (fn i => "p" ^ (string_of_int i)) (1 upto (length paramTs))) ctxt'
  1030     val params = map2 (curry Free) param_names paramTs
  1065         val params = map2 (curry Free) param_names paramTs
  1031   in
  1066       in
  1032     (((outp_pred, params), []), ctxt')
  1067         (((outp_pred, params), []), ctxt')
  1033   end
  1068       end
  1034   | import_intros inp_pred (th :: ths) ctxt =
  1069   | import_intros inp_pred (th :: ths) ctxt =
  1035     let
  1070       let
  1036       val ((_, [th']), ctxt') = Variable.import true [th] ctxt
  1071         val ((_, [th']), ctxt') = Variable.import true [th] ctxt
  1037       val thy = Proof_Context.theory_of ctxt'
  1072         val thy = Proof_Context.theory_of ctxt'
  1038       val (pred, args) = strip_intro_concl th'
  1073         val (pred, args) = strip_intro_concl th'
  1039       val T = fastype_of pred
  1074         val T = fastype_of pred
  1040       val ho_args = ho_args_of_typ T args
  1075         val ho_args = ho_args_of_typ T args
  1041       fun subst_of (pred', pred) =
  1076         fun subst_of (pred', pred) =
  1042         let
  1077           let
  1043           val subst = Sign.typ_match thy (fastype_of pred', fastype_of pred) Vartab.empty
  1078             val subst = Sign.typ_match thy (fastype_of pred', fastype_of pred) Vartab.empty
  1044             handle Type.TYPE_MATCH => error ("Type mismatch of predicate " ^ fst (dest_Const pred)
  1079               handle Type.TYPE_MATCH =>
  1045             ^ " (trying to match " ^ Syntax.string_of_typ ctxt (fastype_of pred')
  1080                 error ("Type mismatch of predicate " ^ fst (dest_Const pred) ^
  1046             ^ " and " ^ Syntax.string_of_typ ctxt (fastype_of pred) ^ ")"
  1081                   " (trying to match " ^ Syntax.string_of_typ ctxt (fastype_of pred') ^
  1047             ^ " in " ^ Display.string_of_thm ctxt th)
  1082                   " and " ^ Syntax.string_of_typ ctxt (fastype_of pred) ^ ")" ^
  1048         in map (fn (indexname, (s, T)) => ((indexname, s), T)) (Vartab.dest subst) end
  1083                   " in " ^ Display.string_of_thm ctxt th)
  1049       fun instantiate_typ th =
  1084           in map (fn (indexname, (s, T)) => ((indexname, s), T)) (Vartab.dest subst) end
  1050         let
  1085         fun instantiate_typ th =
  1051           val (pred', _) = strip_intro_concl th
  1086           let
  1052           val _ = if not (fst (dest_Const pred) = fst (dest_Const pred')) then
  1087             val (pred', _) = strip_intro_concl th
  1053             raise Fail "Trying to instantiate another predicate" else ()
  1088             val _ =
  1054         in Thm.certify_instantiate (subst_of (pred', pred), []) th end;
  1089               if not (fst (dest_Const pred) = fst (dest_Const pred')) then
  1055       fun instantiate_ho_args th =
  1090                 raise Fail "Trying to instantiate another predicate"
  1056         let
  1091               else ()
  1057           val (_, args') = (strip_comb o HOLogic.dest_Trueprop o Logic.strip_imp_concl o prop_of) th
  1092           in Thm.certify_instantiate (subst_of (pred', pred), []) th end
  1058           val ho_args' = map dest_Var (ho_args_of_typ T args')
  1093         fun instantiate_ho_args th =
  1059         in Thm.certify_instantiate ([], ho_args' ~~ ho_args) th end
  1094           let
  1060       val outp_pred =
  1095             val (_, args') =
  1061         Term_Subst.instantiate (subst_of (inp_pred, pred), []) inp_pred
  1096               (strip_comb o HOLogic.dest_Trueprop o Logic.strip_imp_concl o prop_of) th
  1062       val ((_, ths'), ctxt1) =
  1097             val ho_args' = map dest_Var (ho_args_of_typ T args')
  1063         Variable.import false (map (instantiate_typ #> instantiate_ho_args) ths) ctxt'
  1098           in Thm.certify_instantiate ([], ho_args' ~~ ho_args) th end
  1064     in
  1099         val outp_pred =
  1065       (((outp_pred, ho_args), th' :: ths'), ctxt1)
  1100           Term_Subst.instantiate (subst_of (inp_pred, pred), []) inp_pred
  1066     end
  1101         val ((_, ths'), ctxt1) =
  1067   
  1102           Variable.import false (map (instantiate_typ #> instantiate_ho_args) ths) ctxt'
       
  1103       in
       
  1104         (((outp_pred, ho_args), th' :: ths'), ctxt1)
       
  1105       end
       
  1106 
       
  1107 
  1068 (* generation of case rules from user-given introduction rules *)
  1108 (* generation of case rules from user-given introduction rules *)
  1069 
  1109 
  1070 fun mk_args2 (Type (@{type_name Product_Type.prod}, [T1, T2])) st =
  1110 fun mk_args2 (Type (@{type_name Product_Type.prod}, [T1, T2])) st =
  1071     let
  1111       let
  1072       val (t1, st') = mk_args2 T1 st
  1112         val (t1, st') = mk_args2 T1 st
  1073       val (t2, st'') = mk_args2 T2 st'
  1113         val (t2, st'') = mk_args2 T2 st'
  1074     in
  1114       in
  1075       (HOLogic.mk_prod (t1, t2), st'')
  1115         (HOLogic.mk_prod (t1, t2), st'')
  1076     end
  1116       end
  1077   (*| mk_args2 (T as Type ("fun", _)) (params, ctxt) = 
  1117   (*| mk_args2 (T as Type ("fun", _)) (params, ctxt) =
  1078     let
  1118     let
  1079       val (S, U) = strip_type T
  1119       val (S, U) = strip_type T
  1080     in
  1120     in
  1081       if U = HOLogic.boolT then
  1121       if U = HOLogic.boolT then
  1082         (hd params, (tl params, ctxt))
  1122         (hd params, (tl params, ctxt))
  1086         in
  1126         in
  1087           (Free (x, T), (params, ctxt'))
  1127           (Free (x, T), (params, ctxt'))
  1088         end
  1128         end
  1089     end*)
  1129     end*)
  1090   | mk_args2 T (params, ctxt) =
  1130   | mk_args2 T (params, ctxt) =
  1091     let
  1131       let
  1092       val ([x], ctxt') = Variable.variant_fixes ["x"] ctxt
  1132         val ([x], ctxt') = Variable.variant_fixes ["x"] ctxt
  1093     in
  1133       in
  1094       (Free (x, T), (params, ctxt'))
  1134         (Free (x, T), (params, ctxt'))
  1095     end
  1135       end
  1096 
  1136 
  1097 fun mk_casesrule ctxt pred introrules =
  1137 fun mk_casesrule ctxt pred introrules =
  1098   let
  1138   let
  1099     (* TODO: can be simplified if parameters are not treated specially ? *)
  1139     (* TODO: can be simplified if parameters are not treated specially ? *)
  1100     val (((pred, params), intros_th), ctxt1) = import_intros pred introrules ctxt
  1140     val (((pred, params), intros_th), ctxt1) = import_intros pred introrules ctxt
  1115         val frees = map Free (fold Term.add_frees (args @ prems) [])
  1155         val frees = map Free (fold Term.add_frees (args @ prems) [])
  1116       in fold Logic.all frees (Logic.list_implies (eqprems @ prems, prop)) end
  1156       in fold Logic.all frees (Logic.list_implies (eqprems @ prems, prop)) end
  1117     val assm = HOLogic.mk_Trueprop (list_comb (pred, argvs))
  1157     val assm = HOLogic.mk_Trueprop (list_comb (pred, argvs))
  1118     val cases = map mk_case intros
  1158     val cases = map mk_case intros
  1119   in Logic.list_implies (assm :: cases, prop) end;
  1159   in Logic.list_implies (assm :: cases, prop) end;
  1120   
  1160 
  1121 
  1161 
  1122 (* unifying constants to have the same type variables *)
  1162 (* unifying constants to have the same type variables *)
  1123 
  1163 
  1124 fun unify_consts thy cs intr_ts =
  1164 fun unify_consts thy cs intr_ts =
  1125   (let
  1165   let
  1126      val add_term_consts_2 = fold_aterms (fn Const c => insert (op =) c | _ => I);
  1166      val add_term_consts_2 = fold_aterms (fn Const c => insert (op =) c | _ => I);
  1127      fun varify (t, (i, ts)) =
  1167      fun varify (t, (i, ts)) =
  1128        let val t' = map_types (Logic.incr_tvar (i + 1)) (#2 (Type.varify_global [] t))
  1168        let val t' = map_types (Logic.incr_tvar (i + 1)) (#2 (Type.varify_global [] t))
  1129        in (maxidx_of_term t', t'::ts) end;
  1169        in (maxidx_of_term t', t' :: ts) end
  1130      val (i, cs') = List.foldr varify (~1, []) cs;
  1170      val (i, cs') = List.foldr varify (~1, []) cs
  1131      val (i', intr_ts') = List.foldr varify (i, []) intr_ts;
  1171      val (i', intr_ts') = List.foldr varify (i, []) intr_ts
  1132      val rec_consts = fold add_term_consts_2 cs' [];
  1172      val rec_consts = fold add_term_consts_2 cs' []
  1133      val intr_consts = fold add_term_consts_2 intr_ts' [];
  1173      val intr_consts = fold add_term_consts_2 intr_ts' []
  1134      fun unify (cname, cT) =
  1174      fun unify (cname, cT) =
  1135        let val consts = map snd (filter (fn c => fst c = cname) intr_consts)
  1175        let val consts = map snd (filter (fn c => fst c = cname) intr_consts)
  1136        in fold (Sign.typ_unify thy) ((replicate (length consts) cT) ~~ consts) end;
  1176        in fold (Sign.typ_unify thy) ((replicate (length consts) cT) ~~ consts) end
  1137      val (env, _) = fold unify rec_consts (Vartab.empty, i');
  1177      val (env, _) = fold unify rec_consts (Vartab.empty, i')
  1138      val subst = map_types (Envir.norm_type env)
  1178      val subst = map_types (Envir.norm_type env)
  1139    in (map subst cs', map subst intr_ts')
  1179    in (map subst cs', map subst intr_ts')
  1140    end) handle Type.TUNIFY =>
  1180    end handle Type.TUNIFY =>
  1141      (warning "Occurrences of recursive constant have non-unifiable types"; (cs, intr_ts));
  1181      (warning "Occurrences of recursive constant have non-unifiable types"; (cs, intr_ts))
       
  1182 
  1142 
  1183 
  1143 (* preprocessing rules *)
  1184 (* preprocessing rules *)
  1144 
  1185 
  1145 fun preprocess_equality thy rule =
  1186 fun preprocess_equality thy rule =
  1146   Conv.fconv_rule
  1187   Conv.fconv_rule
  1149         (Conv.try_conv (Conv.rewr_conv (Thm.symmetric @{thm Predicate.eq_is_eq})))))
  1190         (Conv.try_conv (Conv.rewr_conv (Thm.symmetric @{thm Predicate.eq_is_eq})))))
  1150     (Thm.transfer thy rule)
  1191     (Thm.transfer thy rule)
  1151 
  1192 
  1152 fun preprocess_intro thy = expand_tuples thy #> preprocess_equality thy
  1193 fun preprocess_intro thy = expand_tuples thy #> preprocess_equality thy
  1153 
  1194 
       
  1195 
  1154 (* defining a quickcheck predicate *)
  1196 (* defining a quickcheck predicate *)
  1155 
  1197 
  1156 fun strip_imp_prems (Const(@{const_name HOL.implies}, _) $ A $ B) = A :: strip_imp_prems B
  1198 fun strip_imp_prems (Const(@{const_name HOL.implies}, _) $ A $ B) = A :: strip_imp_prems B
  1157   | strip_imp_prems _ = [];
  1199   | strip_imp_prems _ = [];
  1158 
  1200 
  1159 fun strip_imp_concl (Const(@{const_name HOL.implies}, _) $ _ $ B) = strip_imp_concl B
  1201 fun strip_imp_concl (Const(@{const_name HOL.implies}, _) $ _ $ B) = strip_imp_concl B
  1160   | strip_imp_concl A = A;
  1202   | strip_imp_concl A = A;
  1161 
  1203 
  1162 fun strip_horn A = (strip_imp_prems A, strip_imp_concl A);
  1204 fun strip_horn A = (strip_imp_prems A, strip_imp_concl A)
  1163 
  1205 
  1164 fun define_quickcheck_predicate t thy =
  1206 fun define_quickcheck_predicate t thy =
  1165   let
  1207   let
  1166     val (vs, t') = strip_abs t
  1208     val (vs, t') = strip_abs t
  1167     val vs' = Variable.variant_frees (Proof_Context.init_global thy) [] vs (* FIXME proper context!? *)
  1209     val vs' = Variable.variant_frees (Proof_Context.init_global thy) [] vs (* FIXME proper context!? *)
  1170     val constname = "quickcheck"
  1212     val constname = "quickcheck"
  1171     val full_constname = Sign.full_bname thy constname
  1213     val full_constname = Sign.full_bname thy constname
  1172     val constT = map snd vs' ---> @{typ bool}
  1214     val constT = map snd vs' ---> @{typ bool}
  1173     val thy1 = Sign.add_consts_i [(Binding.name constname, constT, NoSyn)] thy
  1215     val thy1 = Sign.add_consts_i [(Binding.name constname, constT, NoSyn)] thy
  1174     val const = Const (full_constname, constT)
  1216     val const = Const (full_constname, constT)
  1175     val t = Logic.list_implies
  1217     val t =
  1176       (map HOLogic.mk_Trueprop (prems @ [HOLogic.mk_not concl]),
  1218       Logic.list_implies
  1177        HOLogic.mk_Trueprop (list_comb (const, map Free vs')))
  1219         (map HOLogic.mk_Trueprop (prems @ [HOLogic.mk_not concl]),
       
  1220           HOLogic.mk_Trueprop (list_comb (const, map Free vs')))
  1178     val intro =
  1221     val intro =
  1179       Goal.prove (Proof_Context.init_global thy1) (map fst vs') [] t
  1222       Goal.prove (Proof_Context.init_global thy1) (map fst vs') [] t
  1180         (fn _ => ALLGOALS Skip_Proof.cheat_tac)
  1223         (fn _ => ALLGOALS Skip_Proof.cheat_tac)
  1181   in
  1224   in
  1182     ((((full_constname, constT), vs'), intro), thy1)
  1225     ((((full_constname, constT), vs'), intro), thy1)
  1183   end
  1226   end
  1184 
  1227 
  1185 end;
  1228 end