src/HOL/Tools/Predicate_Compile/predicate_compile_aux.ML
changeset 55437 3fd63b92ea3b
parent 54742 7a86358a3c0b
child 55440 721b4561007a
equal deleted inserted replaced
55436:9781e17dcc23 55437:3fd63b92ea3b
    87   val mk_not : compilation_funs -> term -> term
    87   val mk_not : compilation_funs -> term -> term
    88   val mk_map : compilation_funs -> typ -> typ -> term -> term -> term
    88   val mk_map : compilation_funs -> typ -> typ -> term -> term -> term
    89   val funT_of : compilation_funs -> mode -> typ -> typ
    89   val funT_of : compilation_funs -> mode -> typ -> typ
    90   (* Different compilations *)
    90   (* Different compilations *)
    91   datatype compilation = Pred | Depth_Limited | Random | Depth_Limited_Random | DSeq | Annotated
    91   datatype compilation = Pred | Depth_Limited | Random | Depth_Limited_Random | DSeq | Annotated
    92     | Pos_Random_DSeq | Neg_Random_DSeq | New_Pos_Random_DSeq | New_Neg_Random_DSeq 
    92     | Pos_Random_DSeq | Neg_Random_DSeq | New_Pos_Random_DSeq | New_Neg_Random_DSeq
    93     | Pos_Generator_DSeq | Neg_Generator_DSeq | Pos_Generator_CPS | Neg_Generator_CPS
    93     | Pos_Generator_DSeq | Neg_Generator_DSeq | Pos_Generator_CPS | Neg_Generator_CPS
    94   val negative_compilation_of : compilation -> compilation
    94   val negative_compilation_of : compilation -> compilation
    95   val compilation_for_polarity : bool -> compilation -> compilation
    95   val compilation_for_polarity : bool -> compilation -> compilation
    96   val is_depth_limited_compilation : compilation -> bool 
    96   val is_depth_limited_compilation : compilation -> bool
    97   val string_of_compilation : compilation -> string
    97   val string_of_compilation : compilation -> string
    98   val compilation_names : (string * compilation) list
    98   val compilation_names : (string * compilation) list
    99   val non_random_compilations : compilation list
    99   val non_random_compilations : compilation list
   100   val random_compilations : compilation list
   100   val random_compilations : compilation list
   101   (* Different options for compiler *)
   101   (* Different options for compiler *)
   102   datatype options = Options of {  
   102   datatype options = Options of {
   103     expected_modes : (string * mode list) option,
   103     expected_modes : (string * mode list) option,
   104     proposed_modes : (string * mode list) list,
   104     proposed_modes : (string * mode list) list,
   105     proposed_names : ((string * mode) * string) list,
   105     proposed_names : ((string * mode) * string) list,
   106     show_steps : bool,
   106     show_steps : bool,
   107     show_proof_trace : bool,
   107     show_proof_trace : bool,
   159   val peephole_optimisation : theory -> thm -> thm option
   159   val peephole_optimisation : theory -> thm -> thm option
   160   (* auxillary *)
   160   (* auxillary *)
   161   val unify_consts : theory -> term list -> term list -> (term list * term list)
   161   val unify_consts : theory -> term list -> term list -> (term list * term list)
   162   val mk_casesrule : Proof.context -> term -> thm list -> term
   162   val mk_casesrule : Proof.context -> term -> thm list -> term
   163   val preprocess_intro : theory -> thm -> thm
   163   val preprocess_intro : theory -> thm -> thm
   164   
   164 
   165   val define_quickcheck_predicate :
   165   val define_quickcheck_predicate :
   166     term -> theory -> (((string * typ) * (string * typ) list) * thm) * theory
   166     term -> theory -> (((string * typ) * (string * typ) list) * thm) * theory
   167 end;
   167 end;
   168 
   168 
   169 structure Predicate_Compile_Aux : PREDICATE_COMPILE_AUX =
   169 structure Predicate_Compile_Aux : PREDICATE_COMPILE_AUX =
   209   | mode_ord (Input, Input) = EQUAL
   209   | mode_ord (Input, Input) = EQUAL
   210   | mode_ord (Output, Output) = EQUAL
   210   | mode_ord (Output, Output) = EQUAL
   211   | mode_ord (Bool, Bool) = EQUAL
   211   | mode_ord (Bool, Bool) = EQUAL
   212   | mode_ord (Pair (m1, m2), Pair (m3, m4)) = prod_ord mode_ord mode_ord ((m1, m2), (m3, m4))
   212   | mode_ord (Pair (m1, m2), Pair (m3, m4)) = prod_ord mode_ord mode_ord ((m1, m2), (m3, m4))
   213   | mode_ord (Fun (m1, m2), Fun (m3, m4)) = prod_ord mode_ord mode_ord ((m1, m2), (m3, m4))
   213   | mode_ord (Fun (m1, m2), Fun (m3, m4)) = prod_ord mode_ord mode_ord ((m1, m2), (m3, m4))
   214  
   214 
   215 fun list_fun_mode [] = Bool
   215 fun list_fun_mode [] = Bool
   216   | list_fun_mode (m :: ms) = Fun (m, list_fun_mode ms)
   216   | list_fun_mode (m :: ms) = Fun (m, list_fun_mode ms)
   217 
   217 
   218 (* name: binder_modes? *)
   218 (* name: binder_modes? *)
   219 fun strip_fun_mode (Fun (mode, mode')) = mode :: strip_fun_mode mode'
   219 fun strip_fun_mode (Fun (mode, mode')) = mode :: strip_fun_mode mode'
   225   | dest_fun_mode mode = [mode]
   225   | dest_fun_mode mode = [mode]
   226 
   226 
   227 fun dest_tuple_mode (Pair (mode, mode')) = mode :: dest_tuple_mode mode'
   227 fun dest_tuple_mode (Pair (mode, mode')) = mode :: dest_tuple_mode mode'
   228   | dest_tuple_mode _ = []
   228   | dest_tuple_mode _ = []
   229 
   229 
   230 fun all_modes_of_typ' (T as Type ("fun", _)) = 
   230 fun all_modes_of_typ' (T as Type ("fun", _)) =
   231   let
   231   let
   232     val (S, U) = strip_type T
   232     val (S, U) = strip_type T
   233   in
   233   in
   234     if U = HOLogic.boolT then
   234     if U = HOLogic.boolT then
   235       fold_rev (fn m1 => fn m2 => map_product (curry Fun) m1 m2)
   235       fold_rev (fn m1 => fn m2 => map_product (curry Fun) m1 m2)
   236         (map all_modes_of_typ' S) [Bool]
   236         (map all_modes_of_typ' S) [Bool]
   237     else
   237     else
   238       [Input, Output]
   238       [Input, Output]
   239   end
   239   end
   240   | all_modes_of_typ' (Type (@{type_name Product_Type.prod}, [T1, T2])) = 
   240   | all_modes_of_typ' (Type (@{type_name Product_Type.prod}, [T1, T2])) =
   241     map_product (curry Pair) (all_modes_of_typ' T1) (all_modes_of_typ' T2)
   241     map_product (curry Pair) (all_modes_of_typ' T1) (all_modes_of_typ' T2)
   242   | all_modes_of_typ' _ = [Input, Output]
   242   | all_modes_of_typ' _ = [Input, Output]
   243 
   243 
   244 fun all_modes_of_typ (T as Type ("fun", _)) =
   244 fun all_modes_of_typ (T as Type ("fun", _)) =
   245     let
   245     let
   256     raise Fail "Invocation of all_modes_of_typ with a non-predicate type"
   256     raise Fail "Invocation of all_modes_of_typ with a non-predicate type"
   257 
   257 
   258 fun all_smodes_of_typ (T as Type ("fun", _)) =
   258 fun all_smodes_of_typ (T as Type ("fun", _)) =
   259   let
   259   let
   260     val (S, U) = strip_type T
   260     val (S, U) = strip_type T
   261     fun all_smodes (Type (@{type_name Product_Type.prod}, [T1, T2])) = 
   261     fun all_smodes (Type (@{type_name Product_Type.prod}, [T1, T2])) =
   262       map_product (curry Pair) (all_smodes T1) (all_smodes T2)
   262       map_product (curry Pair) (all_smodes T1) (all_smodes T2)
   263       | all_smodes _ = [Input, Output]
   263       | all_smodes _ = [Input, Output]
   264   in
   264   in
   265     if U = HOLogic.boolT then
   265     if U = HOLogic.boolT then
   266       fold_rev (fn m1 => fn m2 => map_product (curry Fun) m1 m2) (map all_smodes S) [Bool]
   266       fold_rev (fn m1 => fn m2 => map_product (curry Fun) m1 m2) (map all_smodes S) [Bool]
   289     flat (map2_optional ho_arg (strip_fun_mode mode) ts)
   289     flat (map2_optional ho_arg (strip_fun_mode mode) ts)
   290   end
   290   end
   291 
   291 
   292 fun ho_args_of_typ T ts =
   292 fun ho_args_of_typ T ts =
   293   let
   293   let
   294     fun ho_arg (T as Type("fun", [_,_])) (SOME t) = if body_type T = @{typ bool} then [t] else []
   294     fun ho_arg (T as Type ("fun", [_, _])) (SOME t) =
   295       | ho_arg (Type("fun", [_,_])) NONE = raise Fail "mode and term do not match"
   295           if body_type T = @{typ bool} then [t] else []
       
   296       | ho_arg (Type ("fun", [_, _])) NONE = raise Fail "mode and term do not match"
   296       | ho_arg (Type(@{type_name "Product_Type.prod"}, [T1, T2]))
   297       | ho_arg (Type(@{type_name "Product_Type.prod"}, [T1, T2]))
   297          (SOME (Const (@{const_name Pair}, _) $ t1 $ t2)) =
   298          (SOME (Const (@{const_name Pair}, _) $ t1 $ t2)) =
   298           ho_arg T1 (SOME t1) @ ho_arg T2 (SOME t2)
   299           ho_arg T1 (SOME t1) @ ho_arg T2 (SOME t2)
   299       | ho_arg (Type(@{type_name "Product_Type.prod"}, [T1, T2])) NONE =
   300       | ho_arg (Type(@{type_name "Product_Type.prod"}, [T1, T2])) NONE =
   300           ho_arg T1 NONE @ ho_arg T2 NONE
   301           ho_arg T1 NONE @ ho_arg T2 NONE
   304   end
   305   end
   305 
   306 
   306 fun ho_argsT_of_typ Ts =
   307 fun ho_argsT_of_typ Ts =
   307   let
   308   let
   308     fun ho_arg (T as Type("fun", [_,_])) = if body_type T = @{typ bool} then [T] else []
   309     fun ho_arg (T as Type("fun", [_,_])) = if body_type T = @{typ bool} then [T] else []
   309       | ho_arg (Type(@{type_name "Product_Type.prod"}, [T1, T2])) =
   310       | ho_arg (Type (@{type_name "Product_Type.prod"}, [T1, T2])) =
   310           ho_arg T1 @ ho_arg T2
   311           ho_arg T1 @ ho_arg T2
   311       | ho_arg _ = []
   312       | ho_arg _ = []
   312   in
   313   in
   313     maps ho_arg Ts
   314     maps ho_arg Ts
   314   end
   315   end
   315   
   316 
   316 
   317 
   317 (* temporary function should be replaced by unsplit_input or so? *)
   318 (* temporary function should be replaced by unsplit_input or so? *)
   318 fun replace_ho_args mode hoargs ts =
   319 fun replace_ho_args mode hoargs ts =
   319   let
   320   let
   320     fun replace (Fun _, _) (arg' :: hoargs') = (arg', hoargs')
   321     fun replace (Fun _, _) (arg' :: hoargs') = (arg', hoargs')
   321       | replace (Pair (m1, m2), Const (@{const_name Pair}, T) $ t1 $ t2) hoargs =
   322       | replace (Pair (m1, m2), Const (@{const_name Pair}, T) $ t1 $ t2) hoargs =
   322         let
   323           let
   323           val (t1', hoargs') = replace (m1, t1) hoargs
   324             val (t1', hoargs') = replace (m1, t1) hoargs
   324           val (t2', hoargs'') = replace (m2, t2) hoargs'
   325             val (t2', hoargs'') = replace (m2, t2) hoargs'
   325         in
   326           in
   326           (Const (@{const_name Pair}, T) $ t1' $ t2', hoargs'')
   327             (Const (@{const_name Pair}, T) $ t1' $ t2', hoargs'')
   327         end
   328           end
   328       | replace (_, t) hoargs = (t, hoargs)
   329       | replace (_, t) hoargs = (t, hoargs)
   329   in
   330   in
   330     fst (fold_map replace (strip_fun_mode mode ~~ ts) hoargs)
   331     fst (fold_map replace (strip_fun_mode mode ~~ ts) hoargs)
   331   end
   332   end
   332 
   333 
   333 fun ho_argsT_of mode Ts =
   334 fun ho_argsT_of mode Ts =
   334   let
   335   let
   335     fun ho_arg (Fun _) T = [T]
   336     fun ho_arg (Fun _) T = [T]
   336       | ho_arg (Pair (m1, m2)) (Type (@{type_name Product_Type.prod}, [T1, T2])) = ho_arg m1 T1 @ ho_arg m2 T2
   337       | ho_arg (Pair (m1, m2)) (Type (@{type_name Product_Type.prod}, [T1, T2])) =
       
   338           ho_arg m1 T1 @ ho_arg m2 T2
   337       | ho_arg _ _ = []
   339       | ho_arg _ _ = []
   338   in
   340   in
   339     flat (map2 ho_arg (strip_fun_mode mode) Ts)
   341     flat (map2 ho_arg (strip_fun_mode mode) Ts)
   340   end
   342   end
   341 
   343 
   377   end
   379   end
   378 
   380 
   379 fun split_mode mode ts = split_map_mode (fn _ => fn _ => (NONE, NONE)) mode ts
   381 fun split_mode mode ts = split_map_mode (fn _ => fn _ => (NONE, NONE)) mode ts
   380 
   382 
   381 fun fold_map_aterms_prodT comb f (Type (@{type_name Product_Type.prod}, [T1, T2])) s =
   383 fun fold_map_aterms_prodT comb f (Type (@{type_name Product_Type.prod}, [T1, T2])) s =
   382   let
   384       let
   383     val (x1, s') = fold_map_aterms_prodT comb f T1 s
   385         val (x1, s') = fold_map_aterms_prodT comb f T1 s
   384     val (x2, s'') = fold_map_aterms_prodT comb f T2 s'
   386         val (x2, s'') = fold_map_aterms_prodT comb f T2 s'
   385   in
   387       in
   386     (comb x1 x2, s'')
   388         (comb x1 x2, s'')
   387   end
   389       end
   388   | fold_map_aterms_prodT comb f T s = f T s
   390   | fold_map_aterms_prodT _ f T s = f T s
   389 
   391 
   390 fun map_filter_prod f (Const (@{const_name Pair}, _) $ t1 $ t2) =
   392 fun map_filter_prod f (Const (@{const_name Pair}, _) $ t1 $ t2) =
   391   comb_option HOLogic.mk_prod (map_filter_prod f t1, map_filter_prod f t2)
   393       comb_option HOLogic.mk_prod (map_filter_prod f t1, map_filter_prod f t2)
   392   | map_filter_prod f t = f t
   394   | map_filter_prod f t = f t
   393   
   395 
   394 fun split_modeT mode Ts =
   396 fun split_modeT mode Ts =
   395   let
   397   let
   396     fun split_arg_mode (Fun _) _ = ([], [])
   398     fun split_arg_mode (Fun _) _ = ([], [])
   397       | split_arg_mode (Pair (m1, m2)) (Type (@{type_name Product_Type.prod}, [T1, T2])) =
   399       | split_arg_mode (Pair (m1, m2)) (Type (@{type_name Product_Type.prod}, [T1, T2])) =
   398         let
   400           let
   399           val (i1, o1) = split_arg_mode m1 T1
   401             val (i1, o1) = split_arg_mode m1 T1
   400           val (i2, o2) = split_arg_mode m2 T2
   402             val (i2, o2) = split_arg_mode m2 T2
   401         in
   403           in
   402           (i1 @ i2, o1 @ o2)
   404             (i1 @ i2, o1 @ o2)
   403         end
   405           end
   404       | split_arg_mode Input T = ([T], [])
   406       | split_arg_mode Input T = ([T], [])
   405       | split_arg_mode Output T = ([], [T])
   407       | split_arg_mode Output T = ([], [T])
   406       | split_arg_mode _ _ = raise Fail "split_modeT: mode and type do not match"
   408       | split_arg_mode _ _ = raise Fail "split_modeT: mode and type do not match"
   407   in
   409   in
   408     (pairself flat o split_list) (map2 split_arg_mode (strip_fun_mode mode) Ts)
   410     (pairself flat o split_list) (map2 split_arg_mode (strip_fun_mode mode) Ts)
   425     fun ascii_string_of_mode' Input = "i"
   427     fun ascii_string_of_mode' Input = "i"
   426       | ascii_string_of_mode' Output = "o"
   428       | ascii_string_of_mode' Output = "o"
   427       | ascii_string_of_mode' Bool = "b"
   429       | ascii_string_of_mode' Bool = "b"
   428       | ascii_string_of_mode' (Pair (m1, m2)) =
   430       | ascii_string_of_mode' (Pair (m1, m2)) =
   429           "P" ^ ascii_string_of_mode' m1 ^ ascii_string_of_mode'_Pair m2
   431           "P" ^ ascii_string_of_mode' m1 ^ ascii_string_of_mode'_Pair m2
   430       | ascii_string_of_mode' (Fun (m1, m2)) = 
   432       | ascii_string_of_mode' (Fun (m1, m2)) =
   431           "F" ^ ascii_string_of_mode' m1 ^ ascii_string_of_mode'_Fun m2 ^ "B"
   433           "F" ^ ascii_string_of_mode' m1 ^ ascii_string_of_mode'_Fun m2 ^ "B"
   432     and ascii_string_of_mode'_Fun (Fun (m1, m2)) =
   434     and ascii_string_of_mode'_Fun (Fun (m1, m2)) =
   433           ascii_string_of_mode' m1 ^ (if m2 = Bool then "" else "_" ^ ascii_string_of_mode'_Fun m2)
   435           ascii_string_of_mode' m1 ^ (if m2 = Bool then "" else "_" ^ ascii_string_of_mode'_Fun m2)
   434       | ascii_string_of_mode'_Fun Bool = "B"
   436       | ascii_string_of_mode'_Fun Bool = "B"
   435       | ascii_string_of_mode'_Fun m = ascii_string_of_mode' m
   437       | ascii_string_of_mode'_Fun m = ascii_string_of_mode' m
   436     and ascii_string_of_mode'_Pair (Pair (m1, m2)) =
   438     and ascii_string_of_mode'_Pair (Pair (m1, m2)) =
   437           ascii_string_of_mode' m1 ^ ascii_string_of_mode'_Pair m2
   439           ascii_string_of_mode' m1 ^ ascii_string_of_mode'_Pair m2
   438       | ascii_string_of_mode'_Pair m = ascii_string_of_mode' m
   440       | ascii_string_of_mode'_Pair m = ascii_string_of_mode' m
   439   in ascii_string_of_mode'_Fun mode' end
   441   in ascii_string_of_mode'_Fun mode' end
   440 
   442 
       
   443 
   441 (* premises *)
   444 (* premises *)
   442 
   445 
   443 datatype indprem = Prem of term | Negprem of term | Sidecond of term
   446 datatype indprem =
   444   | Generator of (string * typ);
   447   Prem of term | Negprem of term | Sidecond of term | Generator of (string * typ)
   445 
   448 
   446 fun dest_indprem (Prem t) = t
   449 fun dest_indprem (Prem t) = t
   447   | dest_indprem (Negprem t) = t
   450   | dest_indprem (Negprem t) = t
   448   | dest_indprem (Sidecond t) = t
   451   | dest_indprem (Sidecond t) = t
   449   | dest_indprem (Generator _) = raise Fail "cannot destruct generator"
   452   | dest_indprem (Generator _) = raise Fail "cannot destruct generator"
   451 fun map_indprem f (Prem t) = Prem (f t)
   454 fun map_indprem f (Prem t) = Prem (f t)
   452   | map_indprem f (Negprem t) = Negprem (f t)
   455   | map_indprem f (Negprem t) = Negprem (f t)
   453   | map_indprem f (Sidecond t) = Sidecond (f t)
   456   | map_indprem f (Sidecond t) = Sidecond (f t)
   454   | map_indprem f (Generator (v, T)) = Generator (dest_Free (f (Free (v, T))))
   457   | map_indprem f (Generator (v, T)) = Generator (dest_Free (f (Free (v, T))))
   455 
   458 
       
   459 
   456 (* general syntactic functions *)
   460 (* general syntactic functions *)
   457 
   461 
   458 fun is_equationlike_term (Const ("==", _) $ _ $ _) = true
   462 fun is_equationlike_term (Const ("==", _) $ _ $ _) = true
   459   | is_equationlike_term (Const (@{const_name Trueprop}, _) $ (Const (@{const_name HOL.eq}, _) $ _ $ _)) = true
   463   | is_equationlike_term
       
   464       (Const (@{const_name Trueprop}, _) $ (Const (@{const_name HOL.eq}, _) $ _ $ _)) = true
   460   | is_equationlike_term _ = false
   465   | is_equationlike_term _ = false
   461   
   466 
   462 val is_equationlike = is_equationlike_term o prop_of 
   467 val is_equationlike = is_equationlike_term o prop_of
   463 
   468 
   464 fun is_pred_equation_term (Const ("==", _) $ u $ v) =
   469 fun is_pred_equation_term (Const ("==", _) $ u $ v) =
   465   (fastype_of u = @{typ bool}) andalso (fastype_of v = @{typ bool})
   470       (fastype_of u = @{typ bool}) andalso (fastype_of v = @{typ bool})
   466   | is_pred_equation_term _ = false
   471   | is_pred_equation_term _ = false
   467   
   472 
   468 val is_pred_equation = is_pred_equation_term o prop_of 
   473 val is_pred_equation = is_pred_equation_term o prop_of
   469 
   474 
   470 fun is_intro_term constname t =
   475 fun is_intro_term constname t =
   471   the_default false (try (fn t => case fst (strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl t))) of
   476   the_default false (try (fn t =>
   472     Const (c, _) => c = constname
   477     case fst (strip_comb (HOLogic.dest_Trueprop (Logic.strip_imp_concl t))) of
   473   | _ => false) t)
   478       Const (c, _) => c = constname
   474   
   479     | _ => false) t)
       
   480 
   475 fun is_intro constname t = is_intro_term constname (prop_of t)
   481 fun is_intro constname t = is_intro_term constname (prop_of t)
   476 
   482 
   477 fun is_predT (T as Type("fun", [_, _])) = (body_type T = @{typ bool})
   483 fun is_predT (T as Type("fun", [_, _])) = (body_type T = @{typ bool})
   478   | is_predT _ = false
   484   | is_predT _ = false
   479 
   485 
   484 fun is_constrt thy =
   490 fun is_constrt thy =
   485   let
   491   let
   486     val cnstrs = flat (maps
   492     val cnstrs = flat (maps
   487       (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd)
   493       (map (fn (_, (Tname, _, cs)) => map (apsnd (rpair Tname o length)) cs) o #descr o snd)
   488       (Symtab.dest (Datatype.get_all thy)));
   494       (Symtab.dest (Datatype.get_all thy)));
   489     fun check t = (case strip_comb t of
   495     fun check t =
       
   496       (case strip_comb t of
   490         (Var _, []) => true
   497         (Var _, []) => true
   491       | (Free _, []) => true
   498       | (Free _, []) => true
   492       | (Const (s, T), ts) => (case (AList.lookup (op =) cnstrs s, body_type T) of
   499       | (Const (s, T), ts) =>
   493             (SOME (i, Tname), Type (Tname', _)) => length ts = i andalso Tname = Tname' andalso forall check ts
   500           (case (AList.lookup (op =) cnstrs s, body_type T) of
       
   501             (SOME (i, Tname), Type (Tname', _)) =>
       
   502               length ts = i andalso Tname = Tname' andalso forall check ts
   494           | _ => false)
   503           | _ => false)
   495       | _ => false)
   504       | _ => false)
   496   in check end;
   505   in check end
   497 
   506 
   498 (* returns true if t is an application of an datatype constructor *)
   507 (* returns true if t is an application of an datatype constructor *)
   499 (* which then consequently would be splitted *)
   508 (* which then consequently would be splitted *)
   500 (* else false *)
   509 (* else false *)
   501 (*
   510 (*
   510         Const (name, _) => name mem_string constr_consts
   519         Const (name, _) => name mem_string constr_consts
   511         | _ => false) end))
   520         | _ => false) end))
   512   else false
   521   else false
   513 *)
   522 *)
   514 
   523 
   515 val is_constr = Code.is_constr o Proof_Context.theory_of;
   524 val is_constr = Code.is_constr o Proof_Context.theory_of
   516 
   525 
   517 fun strip_all t = (Term.strip_all_vars t, Term.strip_all_body t)
   526 fun strip_all t = (Term.strip_all_vars t, Term.strip_all_body t)
   518 
   527 
   519 fun strip_ex (Const (@{const_name Ex}, _) $ Abs (x, T, t)) =
   528 fun strip_ex (Const (@{const_name Ex}, _) $ Abs (x, T, t)) =
   520   let
   529       let
   521     val (xTs, t') = strip_ex t
   530         val (xTs, t') = strip_ex t
   522   in
   531       in
   523     ((x, T) :: xTs, t')
   532         ((x, T) :: xTs, t')
   524   end
   533       end
   525   | strip_ex t = ([], t)
   534   | strip_ex t = ([], t)
   526 
   535 
   527 fun focus_ex t nctxt =
   536 fun focus_ex t nctxt =
   528   let
   537   let
   529     val ((xs, Ts), t') = apfst split_list (strip_ex t) 
   538     val ((xs, Ts), t') = apfst split_list (strip_ex t)
   530     val (xs', nctxt') = fold_map Name.variant xs nctxt;
   539     val (xs', nctxt') = fold_map Name.variant xs nctxt;
   531     val ps' = xs' ~~ Ts;
   540     val ps' = xs' ~~ Ts;
   532     val vs = map Free ps';
   541     val vs = map Free ps';
   533     val t'' = Term.subst_bounds (rev vs, t');
   542     val t'' = Term.subst_bounds (rev vs, t');
   534   in ((ps', t''), nctxt') end;
   543   in ((ps', t''), nctxt') end
   535 
   544 
   536 val strip_intro_concl = (strip_comb o HOLogic.dest_Trueprop o Logic.strip_imp_concl o prop_of)
   545 val strip_intro_concl = strip_comb o HOLogic.dest_Trueprop o Logic.strip_imp_concl o prop_of
   537   
   546 
       
   547 
   538 (* introduction rule combinators *)
   548 (* introduction rule combinators *)
   539 
   549 
   540 fun map_atoms f intro = 
   550 fun map_atoms f intro =
   541   let
   551   let
   542     val (literals, head) = Logic.strip_horn intro
   552     val (literals, head) = Logic.strip_horn intro
   543     fun appl t = (case t of
   553     fun appl t =
       
   554       (case t of
   544         (@{term Not} $ t') => HOLogic.mk_not (f t')
   555         (@{term Not} $ t') => HOLogic.mk_not (f t')
   545       | _ => f t)
   556       | _ => f t)
   546   in
   557   in
   547     Logic.list_implies
   558     Logic.list_implies
   548       (map (HOLogic.mk_Trueprop o appl o HOLogic.dest_Trueprop) literals, head)
   559       (map (HOLogic.mk_Trueprop o appl o HOLogic.dest_Trueprop) literals, head)
   549   end
   560   end
   550 
   561 
   551 fun fold_atoms f intro s =
   562 fun fold_atoms f intro s =
   552   let
   563   let
   553     val (literals, _) = Logic.strip_horn intro
   564     val (literals, _) = Logic.strip_horn intro
   554     fun appl t s = (case t of
   565     fun appl t s =
   555       (@{term Not} $ t') => f t' s
   566       (case t of
       
   567         (@{term Not} $ t') => f t' s
   556       | _ => f t s)
   568       | _ => f t s)
   557   in fold appl (map HOLogic.dest_Trueprop literals) s end
   569   in fold appl (map HOLogic.dest_Trueprop literals) s end
   558 
   570 
   559 fun fold_map_atoms f intro s =
   571 fun fold_map_atoms f intro s =
   560   let
   572   let
   561     val (literals, head) = Logic.strip_horn intro
   573     val (literals, head) = Logic.strip_horn intro
   562     fun appl t s = (case t of
   574     fun appl t s =
   563       (@{term Not} $ t') => apfst HOLogic.mk_not (f t' s)
   575       (case t of
       
   576         (@{term Not} $ t') => apfst HOLogic.mk_not (f t' s)
   564       | _ => f t s)
   577       | _ => f t s)
   565     val (literals', s') = fold_map appl (map HOLogic.dest_Trueprop literals) s
   578     val (literals', s') = fold_map appl (map HOLogic.dest_Trueprop literals) s
   566   in
   579   in
   567     (Logic.list_implies (map HOLogic.mk_Trueprop literals', head), s')
   580     (Logic.list_implies (map HOLogic.mk_Trueprop literals', head), s')
   568   end;
   581   end;
   585   let
   598   let
   586     val (premises, head) = Logic.strip_horn intro
   599     val (premises, head) = Logic.strip_horn intro
   587   in
   600   in
   588     Logic.list_implies (premises, f head)
   601     Logic.list_implies (premises, f head)
   589   end
   602   end
       
   603 
   590 
   604 
   591 (* combinators to apply a function to all basic parts of nested products *)
   605 (* combinators to apply a function to all basic parts of nested products *)
   592 
   606 
   593 fun map_products f (Const (@{const_name Pair}, T) $ t1 $ t2) =
   607 fun map_products f (Const (@{const_name Pair}, T) $ t1 $ t2) =
   594   Const (@{const_name Pair}, T) $ map_products f t1 $ map_products f t2
   608   Const (@{const_name Pair}, T) $ map_products f t1 $ map_products f t2
   595   | map_products f t = f t
   609   | map_products f t = f t
       
   610 
   596 
   611 
   597 (* split theorems of case expressions *)
   612 (* split theorems of case expressions *)
   598 
   613 
   599 fun prepare_split_thm ctxt split_thm =
   614 fun prepare_split_thm ctxt split_thm =
   600     (split_thm RS @{thm iffD2})
   615     (split_thm RS @{thm iffD2})
   601     |> Local_Defs.unfold ctxt [@{thm atomize_conjL[symmetric]},
   616     |> Local_Defs.unfold ctxt [@{thm atomize_conjL[symmetric]},
   602       @{thm atomize_all[symmetric]}, @{thm atomize_imp[symmetric]}]
   617       @{thm atomize_all[symmetric]}, @{thm atomize_imp[symmetric]}]
   603 
   618 
   604 fun find_split_thm thy (Const (name, _)) = Option.map #split (Datatype.info_of_case thy name)
   619 fun find_split_thm thy (Const (name, _)) = Option.map #split (Datatype.info_of_case thy name)
   605   | find_split_thm thy _ = NONE
   620   | find_split_thm _ _ = NONE
       
   621 
   606 
   622 
   607 (* lifting term operations to theorems *)
   623 (* lifting term operations to theorems *)
   608 
   624 
   609 fun map_term thy f th =
   625 fun map_term thy f th =
   610   Skip_Proof.make_thm thy (f (prop_of th))
   626   Skip_Proof.make_thm thy (f (prop_of th))
   611 
   627 
   612 (*
   628 (*
   613 fun equals_conv lhs_cv rhs_cv ct =
   629 fun equals_conv lhs_cv rhs_cv ct =
   614   case Thm.term_of ct of
   630   case Thm.term_of ct of
   615     Const ("==", _) $ _ $ _ => Conv.arg_conv cv ct  
   631     Const ("==", _) $ _ $ _ => Conv.arg_conv cv ct
   616   | _ => error "equals_conv"  
   632   | _ => error "equals_conv"
   617 *)
   633 *)
       
   634 
   618 
   635 
   619 (* Different compilations *)
   636 (* Different compilations *)
   620 
   637 
   621 datatype compilation = Pred | Depth_Limited | Random | Depth_Limited_Random | DSeq | Annotated
   638 datatype compilation = Pred | Depth_Limited | Random | Depth_Limited_Random | DSeq | Annotated
   622   | Pos_Random_DSeq | Neg_Random_DSeq | New_Pos_Random_DSeq | New_Neg_Random_DSeq |
   639   | Pos_Random_DSeq | Neg_Random_DSeq | New_Pos_Random_DSeq | New_Neg_Random_DSeq |
   627   | negative_compilation_of New_Pos_Random_DSeq = New_Neg_Random_DSeq
   644   | negative_compilation_of New_Pos_Random_DSeq = New_Neg_Random_DSeq
   628   | negative_compilation_of New_Neg_Random_DSeq = New_Pos_Random_DSeq
   645   | negative_compilation_of New_Neg_Random_DSeq = New_Pos_Random_DSeq
   629   | negative_compilation_of Pos_Generator_DSeq = Neg_Generator_DSeq
   646   | negative_compilation_of Pos_Generator_DSeq = Neg_Generator_DSeq
   630   | negative_compilation_of Neg_Generator_DSeq = Pos_Generator_DSeq
   647   | negative_compilation_of Neg_Generator_DSeq = Pos_Generator_DSeq
   631   | negative_compilation_of Pos_Generator_CPS = Neg_Generator_CPS
   648   | negative_compilation_of Pos_Generator_CPS = Neg_Generator_CPS
   632   | negative_compilation_of Neg_Generator_CPS = Pos_Generator_CPS  
   649   | negative_compilation_of Neg_Generator_CPS = Pos_Generator_CPS
   633   | negative_compilation_of c = c
   650   | negative_compilation_of c = c
   634   
   651 
   635 fun compilation_for_polarity false Pos_Random_DSeq = Neg_Random_DSeq
   652 fun compilation_for_polarity false Pos_Random_DSeq = Neg_Random_DSeq
   636   | compilation_for_polarity false New_Pos_Random_DSeq = New_Neg_Random_DSeq
   653   | compilation_for_polarity false New_Pos_Random_DSeq = New_Neg_Random_DSeq
   637   | compilation_for_polarity _ c = c
   654   | compilation_for_polarity _ c = c
   638 
   655 
   639 fun is_depth_limited_compilation c =
   656 fun is_depth_limited_compilation c =
   640   (c = New_Pos_Random_DSeq) orelse (c = New_Neg_Random_DSeq) orelse
   657   (c = New_Pos_Random_DSeq) orelse (c = New_Neg_Random_DSeq) orelse
   641   (c = Pos_Generator_DSeq) orelse (c = Pos_Generator_DSeq)
   658   (c = Pos_Generator_DSeq) orelse (c = Pos_Generator_DSeq)
   642 
   659 
   643 fun string_of_compilation c =
   660 fun string_of_compilation c =
   644   case c of
   661   (case c of
   645     Pred => ""
   662     Pred => ""
   646   | Random => "random"
   663   | Random => "random"
   647   | Depth_Limited => "depth limited"
   664   | Depth_Limited => "depth limited"
   648   | Depth_Limited_Random => "depth limited random"
   665   | Depth_Limited_Random => "depth limited random"
   649   | DSeq => "dseq"
   666   | DSeq => "dseq"
   653   | New_Pos_Random_DSeq => "new_pos_random dseq"
   670   | New_Pos_Random_DSeq => "new_pos_random dseq"
   654   | New_Neg_Random_DSeq => "new_neg_random_dseq"
   671   | New_Neg_Random_DSeq => "new_neg_random_dseq"
   655   | Pos_Generator_DSeq => "pos_generator_dseq"
   672   | Pos_Generator_DSeq => "pos_generator_dseq"
   656   | Neg_Generator_DSeq => "neg_generator_dseq"
   673   | Neg_Generator_DSeq => "neg_generator_dseq"
   657   | Pos_Generator_CPS => "pos_generator_cps"
   674   | Pos_Generator_CPS => "pos_generator_cps"
   658   | Neg_Generator_CPS => "neg_generator_cps"
   675   | Neg_Generator_CPS => "neg_generator_cps")
   659   
   676 
   660 val compilation_names = [("pred", Pred),
   677 val compilation_names =
       
   678  [("pred", Pred),
   661   ("random", Random),
   679   ("random", Random),
   662   ("depth_limited", Depth_Limited),
   680   ("depth_limited", Depth_Limited),
   663   ("depth_limited_random", Depth_Limited_Random),
   681   ("depth_limited_random", Depth_Limited_Random),
   664   (*("annotated", Annotated),*)
   682   (*("annotated", Annotated),*)
   665   ("dseq", DSeq),
   683   ("dseq", DSeq),
   672 
   690 
   673 
   691 
   674 val random_compilations = [Random, Depth_Limited_Random,
   692 val random_compilations = [Random, Depth_Limited_Random,
   675   Pos_Random_DSeq, Neg_Random_DSeq, New_Pos_Random_DSeq, New_Neg_Random_DSeq,
   693   Pos_Random_DSeq, Neg_Random_DSeq, New_Pos_Random_DSeq, New_Neg_Random_DSeq,
   676   Pos_Generator_CPS, Neg_Generator_CPS]
   694   Pos_Generator_CPS, Neg_Generator_CPS]
       
   695 
   677 
   696 
   678 (* datastructures and setup for generic compilation *)
   697 (* datastructures and setup for generic compilation *)
   679 
   698 
   680 datatype compilation_funs = CompilationFuns of {
   699 datatype compilation_funs = CompilationFuns of {
   681   mk_monadT : typ -> typ,
   700   mk_monadT : typ -> typ,
   686   mk_plus : term * term -> term,
   705   mk_plus : term * term -> term,
   687   mk_if : term -> term,
   706   mk_if : term -> term,
   688   mk_iterate_upto : typ -> term * term * term -> term,
   707   mk_iterate_upto : typ -> term * term * term -> term,
   689   mk_not : term -> term,
   708   mk_not : term -> term,
   690   mk_map : typ -> typ -> term -> term -> term
   709   mk_map : typ -> typ -> term -> term -> term
   691 };
   710 }
   692 
   711 
   693 fun mk_monadT (CompilationFuns funs) = #mk_monadT funs
   712 fun mk_monadT (CompilationFuns funs) = #mk_monadT funs
   694 fun dest_monadT (CompilationFuns funs) = #dest_monadT funs
   713 fun dest_monadT (CompilationFuns funs) = #dest_monadT funs
   695 fun mk_empty (CompilationFuns funs) = #mk_empty funs
   714 fun mk_empty (CompilationFuns funs) = #mk_empty funs
   696 fun mk_single (CompilationFuns funs) = #mk_single funs
   715 fun mk_single (CompilationFuns funs) = #mk_single funs
   699 fun mk_if (CompilationFuns funs) = #mk_if funs
   718 fun mk_if (CompilationFuns funs) = #mk_if funs
   700 fun mk_iterate_upto (CompilationFuns funs) = #mk_iterate_upto funs
   719 fun mk_iterate_upto (CompilationFuns funs) = #mk_iterate_upto funs
   701 fun mk_not (CompilationFuns funs) = #mk_not funs
   720 fun mk_not (CompilationFuns funs) = #mk_not funs
   702 fun mk_map (CompilationFuns funs) = #mk_map funs
   721 fun mk_map (CompilationFuns funs) = #mk_map funs
   703 
   722 
       
   723 
   704 (** function types and names of different compilations **)
   724 (** function types and names of different compilations **)
   705 
   725 
   706 fun funT_of compfuns mode T =
   726 fun funT_of compfuns mode T =
   707   let
   727   let
   708     val Ts = binder_types T
   728     val Ts = binder_types T
   709     val (inTs, outTs) = split_map_modeT (fn m => fn T => (SOME (funT_of compfuns m T), NONE)) mode Ts
   729     val (inTs, outTs) =
       
   730       split_map_modeT (fn m => fn T => (SOME (funT_of compfuns m T), NONE)) mode Ts
   710   in
   731   in
   711     inTs ---> (mk_monadT compfuns (HOLogic.mk_tupleT outTs))
   732     inTs ---> (mk_monadT compfuns (HOLogic.mk_tupleT outTs))
   712   end;
   733   end
       
   734 
   713 
   735 
   714 (* Different options for compiler *)
   736 (* Different options for compiler *)
   715 
   737 
   716 datatype options = Options of {  
   738 datatype options = Options of {
   717   expected_modes : (string * mode list) option,
   739   expected_modes : (string * mode list) option,
   718   proposed_modes : (string * mode list) list,
   740   proposed_modes : (string * mode list) list,
   719   proposed_names : ((string * mode) * string) list,
   741   proposed_names : ((string * mode) * string) list,
   720   show_steps : bool,
   742   show_steps : bool,
   721   show_proof_trace : bool,
   743   show_proof_trace : bool,
   733   no_higher_order_predicate : string list,
   755   no_higher_order_predicate : string list,
   734   inductify : bool,
   756   inductify : bool,
   735   detect_switches : bool,
   757   detect_switches : bool,
   736   smart_depth_limiting : bool,
   758   smart_depth_limiting : bool,
   737   compilation : compilation
   759   compilation : compilation
   738 };
   760 }
   739 
   761 
   740 fun expected_modes (Options opt) = #expected_modes opt
   762 fun expected_modes (Options opt) = #expected_modes opt
   741 fun proposed_modes (Options opt) = AList.lookup (op =) (#proposed_modes opt)
   763 fun proposed_modes (Options opt) = AList.lookup (op =) (#proposed_modes opt)
   742 fun proposed_names (Options opt) name mode = AList.lookup (eq_pair (op =) eq_mode)
   764 fun proposed_names (Options opt) name mode = AList.lookup (eq_pair (op =) eq_mode)
   743   (#proposed_names opt) (name, mode)
   765   (#proposed_names opt) (name, mode)
   796   "smart_depth_limiting"]
   818   "smart_depth_limiting"]
   797 
   819 
   798 fun print_step options s =
   820 fun print_step options s =
   799   if show_steps options then tracing s else ()
   821   if show_steps options then tracing s else ()
   800 
   822 
       
   823 
   801 (* simple transformations *)
   824 (* simple transformations *)
   802 
   825 
   803 (** tuple processing **)
   826 (** tuple processing **)
   804 
   827 
   805 fun rewrite_args [] (pats, intro_t, ctxt) = (pats, intro_t, ctxt)
   828 fun rewrite_args [] (pats, intro_t, ctxt) = (pats, intro_t, ctxt)
   806   | rewrite_args (arg::args) (pats, intro_t, ctxt) = 
   829   | rewrite_args (arg::args) (pats, intro_t, ctxt) =
   807     (case HOLogic.strip_tupleT (fastype_of arg) of
   830       (case HOLogic.strip_tupleT (fastype_of arg) of
   808       (_ :: _ :: _) =>
   831         (_ :: _ :: _) =>
   809       let
   832         let
   810         fun rewrite_arg' (Const (@{const_name Pair}, _) $ _ $ t2, Type (@{type_name Product_Type.prod}, [_, T2]))
   833           fun rewrite_arg'
   811           (args, (pats, intro_t, ctxt)) = rewrite_arg' (t2, T2) (args, (pats, intro_t, ctxt))
   834                 (Const (@{const_name Pair}, _) $ _ $ t2, Type (@{type_name Product_Type.prod}, [_, T2]))
   812           | rewrite_arg' (t, Type (@{type_name Product_Type.prod}, [T1, T2])) (args, (pats, intro_t, ctxt)) =
   835                 (args, (pats, intro_t, ctxt)) =
   813             let
   836                 rewrite_arg' (t2, T2) (args, (pats, intro_t, ctxt))
   814               val thy = Proof_Context.theory_of ctxt
   837             | rewrite_arg'
   815               val ([x, y], ctxt') = Variable.variant_fixes ["x", "y"] ctxt
   838                 (t, Type (@{type_name Product_Type.prod}, [T1, T2])) (args, (pats, intro_t, ctxt)) =
   816               val pat = (t, HOLogic.mk_prod (Free (x, T1), Free (y, T2)))
   839                 let
   817               val intro_t' = Pattern.rewrite_term thy [pat] [] intro_t
   840                   val thy = Proof_Context.theory_of ctxt
   818               val args' = map (Pattern.rewrite_term thy [pat] []) args
   841                   val ([x, y], ctxt') = Variable.variant_fixes ["x", "y"] ctxt
   819             in
   842                   val pat = (t, HOLogic.mk_prod (Free (x, T1), Free (y, T2)))
   820               rewrite_arg' (Free (y, T2), T2) (args', (pat::pats, intro_t', ctxt'))
   843                   val intro_t' = Pattern.rewrite_term thy [pat] [] intro_t
   821             end
   844                   val args' = map (Pattern.rewrite_term thy [pat] []) args
   822           | rewrite_arg' _ (args, (pats, intro_t, ctxt)) = (args, (pats, intro_t, ctxt))
   845                 in
   823         val (args', (pats, intro_t', ctxt')) = rewrite_arg' (arg, fastype_of arg)
   846                   rewrite_arg' (Free (y, T2), T2) (args', (pat::pats, intro_t', ctxt'))
   824           (args, (pats, intro_t, ctxt))
   847                 end
   825       in
   848             | rewrite_arg' _ (args, (pats, intro_t, ctxt)) = (args, (pats, intro_t, ctxt))
   826         rewrite_args args' (pats, intro_t', ctxt')
   849           val (args', (pats, intro_t', ctxt')) =
   827       end
   850             rewrite_arg' (arg, fastype_of arg) (args, (pats, intro_t, ctxt))
       
   851         in
       
   852           rewrite_args args' (pats, intro_t', ctxt')
       
   853         end
   828   | _ => rewrite_args args (pats, intro_t, ctxt))
   854   | _ => rewrite_args args (pats, intro_t, ctxt))
   829 
   855 
   830 fun rewrite_prem atom =
   856 fun rewrite_prem atom =
   831   let
   857   let
   832     val (_, args) = strip_comb atom
   858     val (_, args) = strip_comb atom
   833   in rewrite_args args end
   859   in rewrite_args args end
   834 
   860 
   835 fun split_conjuncts_in_assms ctxt th =
   861 fun split_conjuncts_in_assms ctxt th =
   836   let
   862   let
   837     val ((_, [fixed_th]), ctxt') = Variable.import false [th] ctxt 
   863     val ((_, [fixed_th]), ctxt') = Variable.import false [th] ctxt
   838     fun split_conjs i nprems th =
   864     fun split_conjs i nprems th =
   839       if i > nprems then th
   865       if i > nprems then th
   840       else
   866       else
   841         case try Drule.RSN (@{thm conjI}, (i, th)) of
   867         (case try Drule.RSN (@{thm conjI}, (i, th)) of
   842           SOME th' => split_conjs i (nprems+1) th'
   868           SOME th' => split_conjs i (nprems + 1) th'
   843         | NONE => split_conjs (i+1) nprems th
   869         | NONE => split_conjs (i + 1) nprems th)
   844   in
   870   in
   845     singleton (Variable.export ctxt' ctxt) (split_conjs 1 (Thm.nprems_of fixed_th) fixed_th)
   871     singleton (Variable.export ctxt' ctxt)
       
   872       (split_conjs 1 (Thm.nprems_of fixed_th) fixed_th)
   846   end
   873   end
   847 
   874 
   848 fun dest_conjunct_prem th =
   875 fun dest_conjunct_prem th =
   849   case HOLogic.dest_Trueprop (prop_of th) of
   876   (case HOLogic.dest_Trueprop (prop_of th) of
   850     (Const (@{const_name HOL.conj}, _) $ _ $ _) =>
   877     (Const (@{const_name HOL.conj}, _) $ _ $ _) =>
   851       dest_conjunct_prem (th RS @{thm conjunct1})
   878       dest_conjunct_prem (th RS @{thm conjunct1})
   852         @ dest_conjunct_prem (th RS @{thm conjunct2})
   879         @ dest_conjunct_prem (th RS @{thm conjunct2})
   853     | _ => [th]
   880    | _ => [th])
   854 
   881 
   855 fun expand_tuples thy intro =
   882 fun expand_tuples thy intro =
   856   let
   883   let
   857     val ctxt = Proof_Context.init_global thy
   884     val ctxt = Proof_Context.init_global thy
   858     val (((T_insts, t_insts), [intro']), ctxt1) = Variable.import false [intro] ctxt
   885     val (((T_insts, t_insts), [intro']), ctxt1) = Variable.import false [intro] ctxt
   875     val intro''''' = split_conjuncts_in_assms ctxt intro''''
   902     val intro''''' = split_conjuncts_in_assms ctxt intro''''
   876   in
   903   in
   877     intro'''''
   904     intro'''''
   878   end
   905   end
   879 
   906 
       
   907 
   880 (** making case distributivity rules **)
   908 (** making case distributivity rules **)
   881 (*** this should be part of the datatype package ***)
   909 (*** this should be part of the datatype package ***)
   882 
   910 
   883 fun datatype_names_of_case_name thy case_name =
   911 fun datatype_names_of_case_name thy case_name =
   884   map (#1 o #2) (#descr (the (Datatype.info_of_case thy case_name)))
   912   map (#1 o #2) (#descr (the (Datatype.info_of_case thy case_name)))
   886 fun make_case_distribs case_names descr thy =
   914 fun make_case_distribs case_names descr thy =
   887   let
   915   let
   888     val case_combs = Datatype_Prop.make_case_combs case_names descr thy "f";
   916     val case_combs = Datatype_Prop.make_case_combs case_names descr thy "f";
   889     fun make comb =
   917     fun make comb =
   890       let
   918       let
   891         val Type ("fun", [T, T']) = fastype_of comb;
   919         val Type ("fun", [T, T']) = fastype_of comb
   892         val (Const (case_name, _), fs) = strip_comb comb
   920         val (Const (case_name, _), fs) = strip_comb comb
   893         val used = Term.add_tfree_names comb []
   921         val used = Term.add_tfree_names comb []
   894         val U = TFree (singleton (Name.variant_list used) "'t", HOLogic.typeS)
   922         val U = TFree (singleton (Name.variant_list used) "'t", HOLogic.typeS)
   895         val x = Free ("x", T)
   923         val x = Free ("x", T)
   896         val f = Free ("f", T' --> U)
   924         val f = Free ("f", T' --> U)
   950   in
   978   in
   951     Raw_Simplifier.rewrite_term thy
   979     Raw_Simplifier.rewrite_term thy
   952       (map (fn th => th RS @{thm eq_reflection}) ths) [] t
   980       (map (fn th => th RS @{thm eq_reflection}) ths) [] t
   953   end
   981   end
   954 
   982 
       
   983 
   955 (*** conversions ***)
   984 (*** conversions ***)
   956 
   985 
   957 fun imp_prems_conv cv ct =
   986 fun imp_prems_conv cv ct =
   958   case Thm.term_of ct of
   987   (case Thm.term_of ct of
   959     Const ("==>", _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv) (imp_prems_conv cv) ct
   988     Const ("==>", _) $ _ $ _ => Conv.combination_conv (Conv.arg_conv cv) (imp_prems_conv cv) ct
   960   | _ => Conv.all_conv ct
   989   | _ => Conv.all_conv ct)
       
   990 
   961 
   991 
   962 (** eta contract higher-order arguments **)
   992 (** eta contract higher-order arguments **)
   963 
   993 
   964 fun eta_contract_ho_arguments thy intro =
   994 fun eta_contract_ho_arguments thy intro =
   965   let
   995   let
   966     fun f atom = list_comb (apsnd ((map o map_products) Envir.eta_contract) (strip_comb atom))
   996     fun f atom = list_comb (apsnd ((map o map_products) Envir.eta_contract) (strip_comb atom))
   967   in
   997   in
   968     map_term thy (map_concl f o map_atoms f) intro
   998     map_term thy (map_concl f o map_atoms f) intro
   969   end
   999   end
       
  1000 
   970 
  1001 
   971 (** remove equalities **)
  1002 (** remove equalities **)
   972 
  1003 
   973 fun remove_equalities thy intro =
  1004 fun remove_equalities thy intro =
   974   let
  1005   let
   976       let
  1007       let
   977         val (prems, concl) = Logic.strip_horn intro_t
  1008         val (prems, concl) = Logic.strip_horn intro_t
   978         fun remove_eq (prems, concl) =
  1009         fun remove_eq (prems, concl) =
   979           let
  1010           let
   980             fun removable_eq prem =
  1011             fun removable_eq prem =
   981               case try (HOLogic.dest_eq o HOLogic.dest_Trueprop) prem of
  1012               (case try (HOLogic.dest_eq o HOLogic.dest_Trueprop) prem of
   982                 SOME (lhs, rhs) => (case lhs of
  1013                 SOME (lhs, rhs) =>
   983                   Var _ => true
  1014                   (case lhs of
       
  1015                     Var _ => true
   984                   | _ => (case rhs of Var _ => true | _ => false))
  1016                   | _ => (case rhs of Var _ => true | _ => false))
   985               | NONE => false
  1017               | NONE => false)
   986           in
  1018           in
   987             case find_first removable_eq prems of
  1019             (case find_first removable_eq prems of
   988               NONE => (prems, concl)
  1020               NONE => (prems, concl)
   989             | SOME eq =>
  1021             | SOME eq =>
   990               let
  1022                 let
   991                 val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
  1023                   val (lhs, rhs) = HOLogic.dest_eq (HOLogic.dest_Trueprop eq)
   992                 val prems' = remove (op =) eq prems
  1024                   val prems' = remove (op =) eq prems
   993                 val subst = (case lhs of
  1025                   val subst =
   994                   (v as Var _) =>
  1026                     (case lhs of
   995                     (fn t => if t = v then rhs else t)
  1027                       (v as Var _) =>
   996                 | _ => (case rhs of
  1028                         (fn t => if t = v then rhs else t)
   997                    (v as Var _) => (fn t => if t = v then lhs else t)))
  1029                     | _ => (case rhs of (v as Var _) => (fn t => if t = v then lhs else t)))
   998               in
  1030                 in
   999                 remove_eq (map (map_aterms subst) prems', map_aterms subst concl)
  1031                   remove_eq (map (map_aterms subst) prems', map_aterms subst concl)
  1000               end
  1032                 end)
  1001           end
  1033           end
  1002       in
  1034       in
  1003         Logic.list_implies (remove_eq (prems, concl))
  1035         Logic.list_implies (remove_eq (prems, concl))
  1004       end
  1036       end
  1005   in
  1037   in
  1006     map_term thy remove_eqs intro
  1038     map_term thy remove_eqs intro
  1007   end
  1039   end
  1008 
  1040 
       
  1041 
  1009 (* Some last processing *)
  1042 (* Some last processing *)
  1010 
  1043 
  1011 fun remove_pointless_clauses intro =
  1044 fun remove_pointless_clauses intro =
  1012   if Logic.strip_imp_prems (prop_of intro) = [@{prop "False"}] then
  1045   if Logic.strip_imp_prems (prop_of intro) = [@{prop "False"}] then
  1013     []
  1046     []
  1014   else [intro]
  1047   else [intro]
  1015 
  1048 
       
  1049 
  1016 (* some peephole optimisations *)
  1050 (* some peephole optimisations *)
  1017 
  1051 
  1018 fun peephole_optimisation thy intro =
  1052 fun peephole_optimisation thy intro =
  1019   let
  1053   let
  1020     val ctxt = Proof_Context.init_global thy  (* FIXME proper context!? *)
  1054     val ctxt = Proof_Context.init_global thy  (* FIXME proper context!? *)
  1021     val process =
  1055     val process =
  1022       rewrite_rule ctxt (Predicate_Compile_Simps.get ctxt)
  1056       rewrite_rule ctxt (Predicate_Compile_Simps.get ctxt)
  1023     fun process_False intro_t =
  1057     fun process_False intro_t =
  1024       if member (op =) (Logic.strip_imp_prems intro_t) @{prop "False"} then NONE else SOME intro_t
  1058       if member (op =) (Logic.strip_imp_prems intro_t) @{prop "False"}
       
  1059       then NONE else SOME intro_t
  1025     fun process_True intro_t =
  1060     fun process_True intro_t =
  1026       map_filter_premises (fn p => if p = @{prop True} then NONE else SOME p) intro_t
  1061       map_filter_premises (fn p => if p = @{prop True} then NONE else SOME p) intro_t
  1027   in
  1062   in
  1028     Option.map (Skip_Proof.make_thm thy)
  1063     Option.map (Skip_Proof.make_thm thy)
  1029       (process_False (process_True (prop_of (process intro))))
  1064       (process_False (process_True (prop_of (process intro))))
  1031 
  1066 
  1032 
  1067 
  1033 (* importing introduction rules *)
  1068 (* importing introduction rules *)
  1034 
  1069 
  1035 fun import_intros inp_pred [] ctxt =
  1070 fun import_intros inp_pred [] ctxt =
  1036   let
  1071       let
  1037     val ([outp_pred], ctxt') = Variable.import_terms true [inp_pred] ctxt
  1072         val ([outp_pred], ctxt') = Variable.import_terms true [inp_pred] ctxt
  1038     val T = fastype_of outp_pred
  1073         val T = fastype_of outp_pred
  1039     val paramTs = ho_argsT_of_typ (binder_types T)
  1074         val paramTs = ho_argsT_of_typ (binder_types T)
  1040     val (param_names, _) = Variable.variant_fixes
  1075         val (param_names, _) = Variable.variant_fixes
  1041       (map (fn i => "p" ^ (string_of_int i)) (1 upto (length paramTs))) ctxt'
  1076           (map (fn i => "p" ^ (string_of_int i)) (1 upto (length paramTs))) ctxt'
  1042     val params = map2 (curry Free) param_names paramTs
  1077         val params = map2 (curry Free) param_names paramTs
  1043   in
  1078       in
  1044     (((outp_pred, params), []), ctxt')
  1079         (((outp_pred, params), []), ctxt')
  1045   end
  1080       end
  1046   | import_intros inp_pred (th :: ths) ctxt =
  1081   | import_intros inp_pred (th :: ths) ctxt =
  1047     let
  1082       let
  1048       val ((_, [th']), ctxt') = Variable.import true [th] ctxt
  1083         val ((_, [th']), ctxt') = Variable.import true [th] ctxt
  1049       val thy = Proof_Context.theory_of ctxt'
  1084         val thy = Proof_Context.theory_of ctxt'
  1050       val (pred, args) = strip_intro_concl th'
  1085         val (pred, args) = strip_intro_concl th'
  1051       val T = fastype_of pred
  1086         val T = fastype_of pred
  1052       val ho_args = ho_args_of_typ T args
  1087         val ho_args = ho_args_of_typ T args
  1053       fun subst_of (pred', pred) =
  1088         fun subst_of (pred', pred) =
  1054         let
  1089           let
  1055           val subst = Sign.typ_match thy (fastype_of pred', fastype_of pred) Vartab.empty
  1090             val subst = Sign.typ_match thy (fastype_of pred', fastype_of pred) Vartab.empty
  1056             handle Type.TYPE_MATCH => error ("Type mismatch of predicate " ^ fst (dest_Const pred)
  1091               handle Type.TYPE_MATCH =>
  1057             ^ " (trying to match " ^ Syntax.string_of_typ ctxt (fastype_of pred')
  1092                 error ("Type mismatch of predicate " ^ fst (dest_Const pred) ^
  1058             ^ " and " ^ Syntax.string_of_typ ctxt (fastype_of pred) ^ ")"
  1093                   " (trying to match " ^ Syntax.string_of_typ ctxt (fastype_of pred') ^
  1059             ^ " in " ^ Display.string_of_thm ctxt th)
  1094                   " and " ^ Syntax.string_of_typ ctxt (fastype_of pred) ^ ")" ^
  1060         in map (fn (indexname, (s, T)) => ((indexname, s), T)) (Vartab.dest subst) end
  1095                   " in " ^ Display.string_of_thm ctxt th)
  1061       fun instantiate_typ th =
  1096           in map (fn (indexname, (s, T)) => ((indexname, s), T)) (Vartab.dest subst) end
  1062         let
  1097         fun instantiate_typ th =
  1063           val (pred', _) = strip_intro_concl th
  1098           let
  1064           val _ = if not (fst (dest_Const pred) = fst (dest_Const pred')) then
  1099             val (pred', _) = strip_intro_concl th
  1065             raise Fail "Trying to instantiate another predicate" else ()
  1100             val _ =
  1066         in Thm.certify_instantiate (subst_of (pred', pred), []) th end;
  1101               if not (fst (dest_Const pred) = fst (dest_Const pred')) then
  1067       fun instantiate_ho_args th =
  1102                 raise Fail "Trying to instantiate another predicate"
  1068         let
  1103               else ()
  1069           val (_, args') = (strip_comb o HOLogic.dest_Trueprop o Logic.strip_imp_concl o prop_of) th
  1104           in Thm.certify_instantiate (subst_of (pred', pred), []) th end
  1070           val ho_args' = map dest_Var (ho_args_of_typ T args')
  1105         fun instantiate_ho_args th =
  1071         in Thm.certify_instantiate ([], ho_args' ~~ ho_args) th end
  1106           let
  1072       val outp_pred =
  1107             val (_, args') =
  1073         Term_Subst.instantiate (subst_of (inp_pred, pred), []) inp_pred
  1108               (strip_comb o HOLogic.dest_Trueprop o Logic.strip_imp_concl o prop_of) th
  1074       val ((_, ths'), ctxt1) =
  1109             val ho_args' = map dest_Var (ho_args_of_typ T args')
  1075         Variable.import false (map (instantiate_typ #> instantiate_ho_args) ths) ctxt'
  1110           in Thm.certify_instantiate ([], ho_args' ~~ ho_args) th end
  1076     in
  1111         val outp_pred =
  1077       (((outp_pred, ho_args), th' :: ths'), ctxt1)
  1112           Term_Subst.instantiate (subst_of (inp_pred, pred), []) inp_pred
  1078     end
  1113         val ((_, ths'), ctxt1) =
  1079   
  1114           Variable.import false (map (instantiate_typ #> instantiate_ho_args) ths) ctxt'
       
  1115       in
       
  1116         (((outp_pred, ho_args), th' :: ths'), ctxt1)
       
  1117       end
       
  1118 
       
  1119 
  1080 (* generation of case rules from user-given introduction rules *)
  1120 (* generation of case rules from user-given introduction rules *)
  1081 
  1121 
  1082 fun mk_args2 (Type (@{type_name Product_Type.prod}, [T1, T2])) st =
  1122 fun mk_args2 (Type (@{type_name Product_Type.prod}, [T1, T2])) st =
  1083     let
  1123       let
  1084       val (t1, st') = mk_args2 T1 st
  1124         val (t1, st') = mk_args2 T1 st
  1085       val (t2, st'') = mk_args2 T2 st'
  1125         val (t2, st'') = mk_args2 T2 st'
  1086     in
  1126       in
  1087       (HOLogic.mk_prod (t1, t2), st'')
  1127         (HOLogic.mk_prod (t1, t2), st'')
  1088     end
  1128       end
  1089   (*| mk_args2 (T as Type ("fun", _)) (params, ctxt) = 
  1129   (*| mk_args2 (T as Type ("fun", _)) (params, ctxt) =
  1090     let
  1130     let
  1091       val (S, U) = strip_type T
  1131       val (S, U) = strip_type T
  1092     in
  1132     in
  1093       if U = HOLogic.boolT then
  1133       if U = HOLogic.boolT then
  1094         (hd params, (tl params, ctxt))
  1134         (hd params, (tl params, ctxt))
  1098         in
  1138         in
  1099           (Free (x, T), (params, ctxt'))
  1139           (Free (x, T), (params, ctxt'))
  1100         end
  1140         end
  1101     end*)
  1141     end*)
  1102   | mk_args2 T (params, ctxt) =
  1142   | mk_args2 T (params, ctxt) =
  1103     let
  1143       let
  1104       val ([x], ctxt') = Variable.variant_fixes ["x"] ctxt
  1144         val ([x], ctxt') = Variable.variant_fixes ["x"] ctxt
  1105     in
  1145       in
  1106       (Free (x, T), (params, ctxt'))
  1146         (Free (x, T), (params, ctxt'))
  1107     end
  1147       end
  1108 
  1148 
  1109 fun mk_casesrule ctxt pred introrules =
  1149 fun mk_casesrule ctxt pred introrules =
  1110   let
  1150   let
  1111     (* TODO: can be simplified if parameters are not treated specially ? *)
  1151     (* TODO: can be simplified if parameters are not treated specially ? *)
  1112     val (((pred, params), intros_th), ctxt1) = import_intros pred introrules ctxt
  1152     val (((pred, params), intros_th), ctxt1) = import_intros pred introrules ctxt
  1127         val frees = map Free (fold Term.add_frees (args @ prems) [])
  1167         val frees = map Free (fold Term.add_frees (args @ prems) [])
  1128       in fold Logic.all frees (Logic.list_implies (eqprems @ prems, prop)) end
  1168       in fold Logic.all frees (Logic.list_implies (eqprems @ prems, prop)) end
  1129     val assm = HOLogic.mk_Trueprop (list_comb (pred, argvs))
  1169     val assm = HOLogic.mk_Trueprop (list_comb (pred, argvs))
  1130     val cases = map mk_case intros
  1170     val cases = map mk_case intros
  1131   in Logic.list_implies (assm :: cases, prop) end;
  1171   in Logic.list_implies (assm :: cases, prop) end;
  1132   
  1172 
  1133 
  1173 
  1134 (* unifying constants to have the same type variables *)
  1174 (* unifying constants to have the same type variables *)
  1135 
  1175 
  1136 fun unify_consts thy cs intr_ts =
  1176 fun unify_consts thy cs intr_ts =
  1137   (let
  1177   let
  1138      val add_term_consts_2 = fold_aterms (fn Const c => insert (op =) c | _ => I);
  1178      val add_term_consts_2 = fold_aterms (fn Const c => insert (op =) c | _ => I);
  1139      fun varify (t, (i, ts)) =
  1179      fun varify (t, (i, ts)) =
  1140        let val t' = map_types (Logic.incr_tvar (i + 1)) (#2 (Type.varify_global [] t))
  1180        let val t' = map_types (Logic.incr_tvar (i + 1)) (#2 (Type.varify_global [] t))
  1141        in (maxidx_of_term t', t'::ts) end;
  1181        in (maxidx_of_term t', t' :: ts) end
  1142      val (i, cs') = List.foldr varify (~1, []) cs;
  1182      val (i, cs') = List.foldr varify (~1, []) cs
  1143      val (i', intr_ts') = List.foldr varify (i, []) intr_ts;
  1183      val (i', intr_ts') = List.foldr varify (i, []) intr_ts
  1144      val rec_consts = fold add_term_consts_2 cs' [];
  1184      val rec_consts = fold add_term_consts_2 cs' []
  1145      val intr_consts = fold add_term_consts_2 intr_ts' [];
  1185      val intr_consts = fold add_term_consts_2 intr_ts' []
  1146      fun unify (cname, cT) =
  1186      fun unify (cname, cT) =
  1147        let val consts = map snd (filter (fn c => fst c = cname) intr_consts)
  1187        let val consts = map snd (filter (fn c => fst c = cname) intr_consts)
  1148        in fold (Sign.typ_unify thy) ((replicate (length consts) cT) ~~ consts) end;
  1188        in fold (Sign.typ_unify thy) ((replicate (length consts) cT) ~~ consts) end
  1149      val (env, _) = fold unify rec_consts (Vartab.empty, i');
  1189      val (env, _) = fold unify rec_consts (Vartab.empty, i')
  1150      val subst = map_types (Envir.norm_type env)
  1190      val subst = map_types (Envir.norm_type env)
  1151    in (map subst cs', map subst intr_ts')
  1191    in (map subst cs', map subst intr_ts')
  1152    end) handle Type.TUNIFY =>
  1192    end handle Type.TUNIFY =>
  1153      (warning "Occurrences of recursive constant have non-unifiable types"; (cs, intr_ts));
  1193      (warning "Occurrences of recursive constant have non-unifiable types"; (cs, intr_ts))
       
  1194 
  1154 
  1195 
  1155 (* preprocessing rules *)
  1196 (* preprocessing rules *)
  1156 
  1197 
  1157 fun preprocess_equality thy rule =
  1198 fun preprocess_equality thy rule =
  1158   Conv.fconv_rule
  1199   Conv.fconv_rule
  1161         (Conv.try_conv (Conv.rewr_conv (Thm.symmetric @{thm Predicate.eq_is_eq})))))
  1202         (Conv.try_conv (Conv.rewr_conv (Thm.symmetric @{thm Predicate.eq_is_eq})))))
  1162     (Thm.transfer thy rule)
  1203     (Thm.transfer thy rule)
  1163 
  1204 
  1164 fun preprocess_intro thy = expand_tuples thy #> preprocess_equality thy
  1205 fun preprocess_intro thy = expand_tuples thy #> preprocess_equality thy
  1165 
  1206 
       
  1207 
  1166 (* defining a quickcheck predicate *)
  1208 (* defining a quickcheck predicate *)
  1167 
  1209 
  1168 fun strip_imp_prems (Const(@{const_name HOL.implies}, _) $ A $ B) = A :: strip_imp_prems B
  1210 fun strip_imp_prems (Const(@{const_name HOL.implies}, _) $ A $ B) = A :: strip_imp_prems B
  1169   | strip_imp_prems _ = [];
  1211   | strip_imp_prems _ = [];
  1170 
  1212 
  1171 fun strip_imp_concl (Const(@{const_name HOL.implies}, _) $ _ $ B) = strip_imp_concl B
  1213 fun strip_imp_concl (Const(@{const_name HOL.implies}, _) $ _ $ B) = strip_imp_concl B
  1172   | strip_imp_concl A = A;
  1214   | strip_imp_concl A = A;
  1173 
  1215 
  1174 fun strip_horn A = (strip_imp_prems A, strip_imp_concl A);
  1216 fun strip_horn A = (strip_imp_prems A, strip_imp_concl A)
  1175 
  1217 
  1176 fun define_quickcheck_predicate t thy =
  1218 fun define_quickcheck_predicate t thy =
  1177   let
  1219   let
  1178     val (vs, t') = strip_abs t
  1220     val (vs, t') = strip_abs t
  1179     val vs' = Variable.variant_frees (Proof_Context.init_global thy) [] vs (* FIXME proper context!? *)
  1221     val vs' = Variable.variant_frees (Proof_Context.init_global thy) [] vs (* FIXME proper context!? *)
  1182     val constname = "quickcheck"
  1224     val constname = "quickcheck"
  1183     val full_constname = Sign.full_bname thy constname
  1225     val full_constname = Sign.full_bname thy constname
  1184     val constT = map snd vs' ---> @{typ bool}
  1226     val constT = map snd vs' ---> @{typ bool}
  1185     val thy1 = Sign.add_consts_i [(Binding.name constname, constT, NoSyn)] thy
  1227     val thy1 = Sign.add_consts_i [(Binding.name constname, constT, NoSyn)] thy
  1186     val const = Const (full_constname, constT)
  1228     val const = Const (full_constname, constT)
  1187     val t = Logic.list_implies
  1229     val t =
  1188       (map HOLogic.mk_Trueprop (prems @ [HOLogic.mk_not concl]),
  1230       Logic.list_implies
  1189        HOLogic.mk_Trueprop (list_comb (const, map Free vs')))
  1231         (map HOLogic.mk_Trueprop (prems @ [HOLogic.mk_not concl]),
       
  1232           HOLogic.mk_Trueprop (list_comb (const, map Free vs')))
  1190     val intro =
  1233     val intro =
  1191       Goal.prove (Proof_Context.init_global thy1) (map fst vs') [] t
  1234       Goal.prove (Proof_Context.init_global thy1) (map fst vs') [] t
  1192         (fn _ => ALLGOALS Skip_Proof.cheat_tac)
  1235         (fn _ => ALLGOALS Skip_Proof.cheat_tac)
  1193   in
  1236   in
  1194     ((((full_constname, constT), vs'), intro), thy1)
  1237     ((((full_constname, constT), vs'), intro), thy1)
  1195   end
  1238   end
  1196 
  1239 
  1197 end;
  1240 end