src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
changeset 45461 130c90bb80b4
parent 45452 414732ebf891
child 45506 4cc83e901acf
equal deleted inserted replaced
45452:414732ebf891 45461:130c90bb80b4
   310   wrap_compilation =
   310   wrap_compilation =
   311     fn compfuns => fn s => fn T => fn mode => fn additional_arguments => fn compilation =>
   311     fn compfuns => fn s => fn T => fn mode => fn additional_arguments => fn compilation =>
   312     let
   312     let
   313       val [depth] = additional_arguments
   313       val [depth] = additional_arguments
   314       val (_, Ts) = split_modeT mode (binder_types T)
   314       val (_, Ts) = split_modeT mode (binder_types T)
   315       val T' = mk_predT compfuns (HOLogic.mk_tupleT Ts)
   315       val T' = mk_monadT compfuns (HOLogic.mk_tupleT Ts)
   316       val if_const = Const (@{const_name "If"}, @{typ bool} --> T' --> T' --> T')
   316       val if_const = Const (@{const_name "If"}, @{typ bool} --> T' --> T' --> T')
   317     in
   317     in
   318       if_const $ HOLogic.mk_eq (depth, @{term "0 :: code_numeral"})
   318       if_const $ HOLogic.mk_eq (depth, @{term "0 :: code_numeral"})
   319         $ mk_bot compfuns (dest_predT compfuns T')
   319         $ mk_empty compfuns (dest_monadT compfuns T')
   320         $ compilation
   320         $ compilation
   321     end,
   321     end,
   322   transform_additional_arguments =
   322   transform_additional_arguments =
   323     fn prem => fn additional_arguments =>
   323     fn prem => fn additional_arguments =>
   324     let
   324     let
   335   function_name_prefix = "random_",
   335   function_name_prefix = "random_",
   336   compfuns = Predicate_Comp_Funs.compfuns,
   336   compfuns = Predicate_Comp_Funs.compfuns,
   337   mk_random = (fn T => fn additional_arguments =>
   337   mk_random = (fn T => fn additional_arguments =>
   338   list_comb (Const(@{const_name Quickcheck.iter},
   338   list_comb (Const(@{const_name Quickcheck.iter},
   339   [@{typ code_numeral}, @{typ code_numeral}, @{typ Random.seed}] ---> 
   339   [@{typ code_numeral}, @{typ code_numeral}, @{typ Random.seed}] ---> 
   340     Predicate_Comp_Funs.mk_predT T), additional_arguments)),
   340     Predicate_Comp_Funs.mk_monadT T), additional_arguments)),
   341   modify_funT = (fn T =>
   341   modify_funT = (fn T =>
   342     let
   342     let
   343       val (Ts, U) = strip_type T
   343       val (Ts, U) = strip_type T
   344       val Ts' = [@{typ code_numeral}, @{typ code_numeral}, @{typ "code_numeral * code_numeral"}]
   344       val Ts' = [@{typ code_numeral}, @{typ code_numeral}, @{typ "code_numeral * code_numeral"}]
   345     in (Ts @ Ts') ---> U end),
   345     in (Ts @ Ts') ---> U end),
   361   function_name_prefix = "depth_limited_random_",
   361   function_name_prefix = "depth_limited_random_",
   362   compfuns = Predicate_Comp_Funs.compfuns,
   362   compfuns = Predicate_Comp_Funs.compfuns,
   363   mk_random = (fn T => fn additional_arguments =>
   363   mk_random = (fn T => fn additional_arguments =>
   364   list_comb (Const(@{const_name Quickcheck.iter},
   364   list_comb (Const(@{const_name Quickcheck.iter},
   365   [@{typ code_numeral}, @{typ code_numeral}, @{typ Random.seed}] ---> 
   365   [@{typ code_numeral}, @{typ code_numeral}, @{typ Random.seed}] ---> 
   366     Predicate_Comp_Funs.mk_predT T), tl additional_arguments)),
   366     Predicate_Comp_Funs.mk_monadT T), tl additional_arguments)),
   367   modify_funT = (fn T =>
   367   modify_funT = (fn T =>
   368     let
   368     let
   369       val (Ts, U) = strip_type T
   369       val (Ts, U) = strip_type T
   370       val Ts' = [@{typ code_numeral}, @{typ code_numeral}, @{typ code_numeral},
   370       val Ts' = [@{typ code_numeral}, @{typ code_numeral}, @{typ code_numeral},
   371         @{typ "code_numeral * code_numeral"}]
   371         @{typ "code_numeral * code_numeral"}]
   381   fn compfuns => fn s => fn T => fn mode => fn additional_arguments => fn compilation =>
   381   fn compfuns => fn s => fn T => fn mode => fn additional_arguments => fn compilation =>
   382     let
   382     let
   383       val depth = hd (additional_arguments)
   383       val depth = hd (additional_arguments)
   384       val (_, Ts) = split_map_modeT (fn m => fn T => (SOME (funT_of compfuns m T), NONE))
   384       val (_, Ts) = split_map_modeT (fn m => fn T => (SOME (funT_of compfuns m T), NONE))
   385         mode (binder_types T)
   385         mode (binder_types T)
   386       val T' = mk_predT compfuns (HOLogic.mk_tupleT Ts)
   386       val T' = mk_monadT compfuns (HOLogic.mk_tupleT Ts)
   387       val if_const = Const (@{const_name "If"}, @{typ bool} --> T' --> T' --> T')
   387       val if_const = Const (@{const_name "If"}, @{typ bool} --> T' --> T' --> T')
   388     in
   388     in
   389       if_const $ HOLogic.mk_eq (depth, @{term "0 :: code_numeral"})
   389       if_const $ HOLogic.mk_eq (depth, @{term "0 :: code_numeral"})
   390         $ mk_bot compfuns (dest_predT compfuns T')
   390         $ mk_empty compfuns (dest_monadT compfuns T')
   391         $ compilation
   391         $ compilation
   392     end,
   392     end,
   393   transform_additional_arguments =
   393   transform_additional_arguments =
   394     fn prem => fn additional_arguments =>
   394     fn prem => fn additional_arguments =>
   395     let
   395     let
   656     val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) [];
   656     val names = fold Term.add_free_names (success_t :: eqs'' @ out_ts) [];
   657     val name = singleton (Name.variant_list names) "x";
   657     val name = singleton (Name.variant_list names) "x";
   658     val name' = singleton (Name.variant_list (name :: names)) "y";
   658     val name' = singleton (Name.variant_list (name :: names)) "y";
   659     val T = HOLogic.mk_tupleT (map fastype_of out_ts);
   659     val T = HOLogic.mk_tupleT (map fastype_of out_ts);
   660     val U = fastype_of success_t;
   660     val U = fastype_of success_t;
   661     val U' = dest_predT compfuns U;
   661     val U' = dest_monadT compfuns U;
   662     val v = Free (name, T);
   662     val v = Free (name, T);
   663     val v' = Free (name', T);
   663     val v' = Free (name', T);
   664   in
   664   in
   665     lambda v (Datatype.make_case ctxt Datatype_Case.Quiet [] v
   665     lambda v (Datatype.make_case ctxt Datatype_Case.Quiet [] v
   666       [(HOLogic.mk_tuple out_ts,
   666       [(HOLogic.mk_tuple out_ts,
   667         if null eqs'' then success_t
   667         if null eqs'' then success_t
   668         else Const (@{const_name HOL.If}, HOLogic.boolT --> U --> U --> U) $
   668         else Const (@{const_name HOL.If}, HOLogic.boolT --> U --> U --> U) $
   669           foldr1 HOLogic.mk_conj eqs'' $ success_t $
   669           foldr1 HOLogic.mk_conj eqs'' $ success_t $
   670             mk_bot compfuns U'),
   670             mk_empty compfuns U'),
   671        (v', mk_bot compfuns U')])
   671        (v', mk_empty compfuns U')])
   672   end;
   672   end;
   673 
   673 
   674 fun string_of_tderiv ctxt (t, deriv) = 
   674 fun string_of_tderiv ctxt (t, deriv) = 
   675   (case (t, deriv) of
   675   (case (t, deriv) of
   676     (t1 $ t2, Mode_App (deriv1, deriv2)) =>
   676     (t1 $ t2, Mode_App (deriv1, deriv2)) =>
   926               val out_ts' = map (Pattern.rewrite_term thy (map swap fsubst) []) out_ts
   926               val out_ts' = map (Pattern.rewrite_term thy (map swap fsubst) []) out_ts
   927             in
   927             in
   928               compile_clause compilation_modifiers ctxt all_vs param_modes additional_arguments
   928               compile_clause compilation_modifiers ctxt all_vs param_modes additional_arguments
   929                 inp (in_ts', out_ts') moded_ps'
   929                 inp (in_ts', out_ts') moded_ps'
   930             end
   930             end
   931         in SOME (foldr1 (mk_sup compfuns) (map compile_clause' moded_clauses)) end
   931         in SOME (foldr1 (mk_plus compfuns) (map compile_clause' moded_clauses)) end
   932     | compile_switch_tree all_vs ctxt_eqs (Node ((position, switched_clauses), left_clauses)) =
   932     | compile_switch_tree all_vs ctxt_eqs (Node ((position, switched_clauses), left_clauses)) =
   933       let
   933       let
   934         val (i, is) = argument_position_of mode position
   934         val (i, is) = argument_position_of mode position
   935         val inp_var = nth_pair is (nth in_ts' i)
   935         val inp_var = nth_pair is (nth in_ts' i)
   936         val x = singleton (Name.variant_list all_vs) "x"
   936         val x = singleton (Name.variant_list all_vs) "x"
   941             val argnames = Name.variant_list (x :: all_vs)
   941             val argnames = Name.variant_list (x :: all_vs)
   942               (map (fn i => "c" ^ string_of_int i) (1 upto length Ts))
   942               (map (fn i => "c" ^ string_of_int i) (1 upto length Ts))
   943             val args = map2 (curry Free) argnames Ts
   943             val args = map2 (curry Free) argnames Ts
   944             val pattern = list_comb (Const (c, T), args)
   944             val pattern = list_comb (Const (c, T), args)
   945             val ctxt_eqs' = (inp_var, pattern) :: ctxt_eqs
   945             val ctxt_eqs' = (inp_var, pattern) :: ctxt_eqs
   946             val compilation = the_default (mk_bot compfuns (HOLogic.mk_tupleT outTs))
   946             val compilation = the_default (mk_empty compfuns (HOLogic.mk_tupleT outTs))
   947               (compile_switch_tree (argnames @ x :: all_vs) ctxt_eqs' switched)
   947               (compile_switch_tree (argnames @ x :: all_vs) ctxt_eqs' switched)
   948         in
   948         in
   949           (pattern, compilation)
   949           (pattern, compilation)
   950         end
   950         end
   951         val switch = Datatype.make_case ctxt Datatype_Case.Quiet [] inp_var
   951         val switch = Datatype.make_case ctxt Datatype_Case.Quiet [] inp_var
   952           ((map compile_single_case switched_clauses) @
   952           ((map compile_single_case switched_clauses) @
   953             [(xt, mk_bot compfuns (HOLogic.mk_tupleT outTs))])
   953             [(xt, mk_empty compfuns (HOLogic.mk_tupleT outTs))])
   954       in
   954       in
   955         case compile_switch_tree all_vs ctxt_eqs left_clauses of
   955         case compile_switch_tree all_vs ctxt_eqs left_clauses of
   956           NONE => SOME switch
   956           NONE => SOME switch
   957         | SOME left_comp => SOME (mk_sup compfuns (switch, left_comp))
   957         | SOME left_comp => SOME (mk_plus compfuns (switch, left_comp))
   958       end
   958       end
   959   in
   959   in
   960     compile_switch_tree all_vs [] switch_tree
   960     compile_switch_tree all_vs [] switch_tree
   961   end
   961   end
   962 
   962 
   976          else I)
   976          else I)
   977     val additional_arguments = Comp_Mod.additional_arguments compilation_modifiers
   977     val additional_arguments = Comp_Mod.additional_arguments compilation_modifiers
   978       (all_vs @ param_vs)
   978       (all_vs @ param_vs)
   979     val compfuns = Comp_Mod.compfuns compilation_modifiers
   979     val compfuns = Comp_Mod.compfuns compilation_modifiers
   980     fun is_param_type (T as Type ("fun",[_ , T'])) =
   980     fun is_param_type (T as Type ("fun",[_ , T'])) =
   981       is_some (try (dest_predT compfuns) T) orelse is_param_type T'
   981       is_some (try (dest_monadT compfuns) T) orelse is_param_type T'
   982       | is_param_type T = is_some (try (dest_predT compfuns) T)
   982       | is_param_type T = is_some (try (dest_monadT compfuns) T)
   983     val (inpTs, outTs) = split_map_modeT (fn m => fn T => (SOME (funT_of compfuns m T), NONE)) mode
   983     val (inpTs, outTs) = split_map_modeT (fn m => fn T => (SOME (funT_of compfuns m T), NONE)) mode
   984       (binder_types T)
   984       (binder_types T)
   985     val predT = mk_predT compfuns (HOLogic.mk_tupleT outTs)
   985     val predT = mk_monadT compfuns (HOLogic.mk_tupleT outTs)
   986     val funT = Comp_Mod.funT_of compilation_modifiers mode T
   986     val funT = Comp_Mod.funT_of compilation_modifiers mode T
   987     val (in_ts, _) = fold_map (fold_map_aterms_prodT (curry HOLogic.mk_prod)
   987     val (in_ts, _) = fold_map (fold_map_aterms_prodT (curry HOLogic.mk_prod)
   988       (fn T => fn (param_vs, names) =>
   988       (fn T => fn (param_vs, names) =>
   989         if is_param_type T then
   989         if is_param_type T then
   990           (Free (hd param_vs, T), (tl param_vs, names))
   990           (Free (hd param_vs, T), (tl param_vs, names))
   996     val in_ts' = map_filter (map_filter_prod
   996     val in_ts' = map_filter (map_filter_prod
   997       (fn t as Free (x, _) => if member (op =) param_vs x then NONE else SOME t | t => SOME t)) in_ts
   997       (fn t as Free (x, _) => if member (op =) param_vs x then NONE else SOME t | t => SOME t)) in_ts
   998     val param_modes = param_vs ~~ ho_arg_modes_of mode
   998     val param_modes = param_vs ~~ ho_arg_modes_of mode
   999     val compilation =
   999     val compilation =
  1000       if detect_switches options then
  1000       if detect_switches options then
  1001         the_default (mk_bot compfuns (HOLogic.mk_tupleT outTs))
  1001         the_default (mk_empty compfuns (HOLogic.mk_tupleT outTs))
  1002           (compile_switch compilation_modifiers ctxt all_vs param_modes additional_arguments mode
  1002           (compile_switch compilation_modifiers ctxt all_vs param_modes additional_arguments mode
  1003             in_ts' outTs (mk_switch_tree ctxt mode moded_cls))
  1003             in_ts' outTs (mk_switch_tree ctxt mode moded_cls))
  1004       else
  1004       else
  1005         let
  1005         let
  1006           val cl_ts =
  1006           val cl_ts =
  1008               compile_clause compilation_modifiers ctxt all_vs param_modes additional_arguments
  1008               compile_clause compilation_modifiers ctxt all_vs param_modes additional_arguments
  1009                 (HOLogic.mk_tuple in_ts') (split_mode mode ts) moded_prems) moded_cls;
  1009                 (HOLogic.mk_tuple in_ts') (split_mode mode ts) moded_prems) moded_cls;
  1010         in
  1010         in
  1011           Comp_Mod.wrap_compilation compilation_modifiers compfuns s T mode additional_arguments
  1011           Comp_Mod.wrap_compilation compilation_modifiers compfuns s T mode additional_arguments
  1012             (if null cl_ts then
  1012             (if null cl_ts then
  1013               mk_bot compfuns (HOLogic.mk_tupleT outTs)
  1013               mk_empty compfuns (HOLogic.mk_tupleT outTs)
  1014             else
  1014             else
  1015               foldr1 (mk_sup compfuns) cl_ts)
  1015               foldr1 (mk_plus compfuns) cl_ts)
  1016         end
  1016         end
  1017     val fun_const =
  1017     val fun_const =
  1018       Const (function_name_of (Comp_Mod.compilation compilation_modifiers)
  1018       Const (function_name_of (Comp_Mod.compilation compilation_modifiers)
  1019       ctxt s mode, funT)
  1019       ctxt s mode, funT)
  1020   in
  1020   in
  1339             val Ts = binder_types T
  1339             val Ts = binder_types T
  1340             val arg_names = Name.variant_list []
  1340             val arg_names = Name.variant_list []
  1341               (map (fn i => "x" ^ string_of_int i) (1 upto length Ts))
  1341               (map (fn i => "x" ^ string_of_int i) (1 upto length Ts))
  1342             val args = map2 (curry Free) arg_names Ts
  1342             val args = map2 (curry Free) arg_names Ts
  1343             val predfun = Const (function_name_of Pred ctxt predname full_mode,
  1343             val predfun = Const (function_name_of Pred ctxt predname full_mode,
  1344               Ts ---> Predicate_Comp_Funs.mk_predT @{typ unit})
  1344               Ts ---> Predicate_Comp_Funs.mk_monadT @{typ unit})
  1345             val rhs = @{term Predicate.holds} $ (list_comb (predfun, args))
  1345             val rhs = @{term Predicate.holds} $ (list_comb (predfun, args))
  1346             val eq_term = HOLogic.mk_Trueprop
  1346             val eq_term = HOLogic.mk_Trueprop
  1347               (HOLogic.mk_eq (list_comb (Const (predname, T), args), rhs))
  1347               (HOLogic.mk_eq (list_comb (Const (predname, T), args), rhs))
  1348             val def = predfun_definition_of ctxt predname full_mode
  1348             val def = predfun_definition_of ctxt predname full_mode
  1349             val tac = fn _ => Simplifier.simp_tac
  1349             val tac = fn _ => Simplifier.simp_tac
  1831           | d :: _ :: _ => (warning ("Multiple modes possible for comprehension "
  1831           | d :: _ :: _ => (warning ("Multiple modes possible for comprehension "
  1832                     ^ Syntax.string_of_term ctxt t_compr); d);
  1832                     ^ Syntax.string_of_term ctxt t_compr); d);
  1833         val (_, outargs) = split_mode (head_mode_of deriv) all_args
  1833         val (_, outargs) = split_mode (head_mode_of deriv) all_args
  1834         val t_pred = compile_expr comp_modifiers ctxt
  1834         val t_pred = compile_expr comp_modifiers ctxt
  1835           (body, deriv) [] additional_arguments;
  1835           (body, deriv) [] additional_arguments;
  1836         val T_pred = dest_predT compfuns (fastype_of t_pred)
  1836         val T_pred = dest_monadT compfuns (fastype_of t_pred)
  1837         val arrange = HOLogic.tupled_lambda (HOLogic.mk_tuple outargs) output
  1837         val arrange = HOLogic.tupled_lambda (HOLogic.mk_tuple outargs) output
  1838       in
  1838       in
  1839         if null outargs then t_pred else mk_map compfuns T_pred T_compr arrange t_pred
  1839         if null outargs then t_pred else mk_map compfuns T_pred T_compr arrange t_pred
  1840       end
  1840       end
  1841     else
  1841     else
  1874       | DSeq => []
  1874       | DSeq => []
  1875       | Pos_Random_DSeq => []
  1875       | Pos_Random_DSeq => []
  1876       | New_Pos_Random_DSeq => []
  1876       | New_Pos_Random_DSeq => []
  1877       | Pos_Generator_DSeq => []
  1877       | Pos_Generator_DSeq => []
  1878     val t = analyze_compr ctxt (comp_modifiers, additional_arguments) param_user_modes options t_compr;
  1878     val t = analyze_compr ctxt (comp_modifiers, additional_arguments) param_user_modes options t_compr;
  1879     val T = dest_predT compfuns (fastype_of t);
  1879     val T = dest_monadT compfuns (fastype_of t);
  1880     val t' =
  1880     val t' =
  1881       if stats andalso compilation = New_Pos_Random_DSeq then
  1881       if stats andalso compilation = New_Pos_Random_DSeq then
  1882         mk_map compfuns T (HOLogic.mk_prodT (HOLogic.termT, @{typ code_numeral}))
  1882         mk_map compfuns T (HOLogic.mk_prodT (HOLogic.termT, @{typ code_numeral}))
  1883           (absdummy T (HOLogic.mk_prod (HOLogic.term_of_const T $ Bound 0,
  1883           (absdummy T (HOLogic.mk_prod (HOLogic.term_of_const T $ Bound 0,
  1884             @{term Code_Numeral.of_nat} $ (HOLogic.size_const T $ Bound 0)))) t
  1884             @{term Code_Numeral.of_nat} $ (HOLogic.size_const T $ Bound 0)))) t