src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
changeset 33135 422cac7d6e31
parent 33134 88c9c3460fe7
child 33137 0d16c07f8d24
equal deleted inserted replaced
33134:88c9c3460fe7 33135:422cac7d6e31
    35   val set_nparams : string -> int -> theory -> theory
    35   val set_nparams : string -> int -> theory -> theory
    36   val print_stored_rules: theory -> unit
    36   val print_stored_rules: theory -> unit
    37   val print_all_modes: theory -> unit
    37   val print_all_modes: theory -> unit
    38   val do_proofs: bool Unsynchronized.ref
    38   val do_proofs: bool Unsynchronized.ref
    39   val mk_casesrule : Proof.context -> int -> thm list -> term
    39   val mk_casesrule : Proof.context -> int -> thm list -> term
    40   val analyze_compr: theory -> term -> term
    40   val analyze_compr: theory -> int option -> term -> term
    41   val eval_ref: (unit -> term Predicate.pred) option Unsynchronized.ref
    41   val eval_ref: (unit -> term Predicate.pred) option Unsynchronized.ref
    42   val add_equations : Predicate_Compile_Aux.options -> string list -> theory -> theory
    42   val add_equations : Predicate_Compile_Aux.options -> string list -> theory -> theory
    43   val code_pred_intros_attrib : attribute
    43   val code_pred_intros_attrib : attribute
    44   (* used by Quickcheck_Generator *) 
    44   (* used by Quickcheck_Generator *) 
    45   (*val funT_of : mode -> typ -> typ
    45   (*val funT_of : mode -> typ -> typ
  1276       (Name.variant_list names (replicate (length Ts) "x") ~~ Ts)
  1276       (Name.variant_list names (replicate (length Ts) "x") ~~ Ts)
  1277   in
  1277   in
  1278     fold_rev lambda vs (f (list_comb (t, vs)))
  1278     fold_rev lambda vs (f (list_comb (t, vs)))
  1279   end;
  1279   end;
  1280 
  1280 
  1281 fun compile_param depth thy compfuns (NONE, t) = t
  1281 fun compile_param depth_limited thy compfuns (NONE, t) = t
  1282   | compile_param depth thy compfuns (m as SOME (Mode ((iss, is'), is, ms)), t) =
  1282   | compile_param depth_limited thy compfuns (m as SOME (Mode ((iss, is'), is, ms)), t) =
  1283    let
  1283    let
  1284      val (f, args) = strip_comb (Envir.eta_contract t)
  1284      val (f, args) = strip_comb (Envir.eta_contract t)
  1285      val (params, args') = chop (length ms) args
  1285      val (params, args') = chop (length ms) args
  1286      val params' = map (compile_param depth thy compfuns) (ms ~~ params)
  1286      val params' = map (compile_param depth_limited thy compfuns) (ms ~~ params)
  1287      val mk_fun_of = case depth of NONE => mk_fun_of | SOME _ => mk_depth_limited_fun_of
  1287      val mk_fun_of = if depth_limited then mk_depth_limited_fun_of else mk_fun_of
  1288      val funT_of = case depth of NONE => funT_of | SOME _ => depth_limited_funT_of
  1288      val funT_of = if depth_limited then depth_limited_funT_of else funT_of
  1289      val f' =
  1289      val f' =
  1290        case f of
  1290        case f of
  1291          Const (name, T) => mk_fun_of compfuns thy (name, T) (iss, is')
  1291          Const (name, T) => mk_fun_of compfuns thy (name, T) (iss, is')
  1292        | Free (name, T) => Free (name, funT_of compfuns (iss, is') T)
  1292        | Free (name, T) => Free (name, funT_of compfuns (iss, is') T)
  1293        | _ => error ("PredicateCompiler: illegal parameter term")
  1293        | _ => error ("PredicateCompiler: illegal parameter term")
  1294    in
  1294    in
  1295      list_comb (f', params' @ args')
  1295      list_comb (f', params' @ args')
  1296    end
  1296    end
  1297 
  1297 
  1298 fun compile_expr depth thy ((Mode (mode, is, ms)), t) =
  1298 fun compile_expr depth_limited thy ((Mode (mode, is, ms)), t) =
  1299   case strip_comb t of
  1299   case strip_comb t of
  1300     (Const (name, T), params) =>
  1300     (Const (name, T), params) =>
  1301        let
  1301        let
  1302          val params' = map (compile_param depth thy PredicateCompFuns.compfuns) (ms ~~ params)
  1302          val params' = map (compile_param depth_limited thy PredicateCompFuns.compfuns) (ms ~~ params)
  1303          val mk_fun_of = case depth of NONE => mk_fun_of | SOME _ => mk_depth_limited_fun_of
  1303          val mk_fun_of = if depth_limited then mk_depth_limited_fun_of else mk_fun_of
  1304        in
  1304        in
  1305          list_comb (mk_fun_of PredicateCompFuns.compfuns thy (name, T) mode, params')
  1305          list_comb (mk_fun_of PredicateCompFuns.compfuns thy (name, T) mode, params')
  1306        end
  1306        end
  1307   | (Free (name, T), args) =>
  1307   | (Free (name, T), args) =>
  1308        let 
  1308        let 
  1309          val funT_of = case depth of NONE => funT_of | SOME _ => depth_limited_funT_of 
  1309          val funT_of = if depth_limited then depth_limited_funT_of else funT_of
  1310        in
  1310        in
  1311          list_comb (Free (name, funT_of PredicateCompFuns.compfuns ([], is) T), args)
  1311          list_comb (Free (name, funT_of PredicateCompFuns.compfuns ([], is) T), args)
  1312        end;
  1312        end;
  1313        
  1313        
  1314 fun compile_gen_expr depth thy compfuns ((Mode (mode, is, ms)), t) inargs =
  1314 fun compile_gen_expr depth thy compfuns ((Mode (mode, is, ms)), t) inargs =
  1429                    val in_ts = map (compile_arg depth thy param_vs iss) in_ts
  1429                    val in_ts = map (compile_arg depth thy param_vs iss) in_ts
  1430                    val args = case depth of
  1430                    val args = case depth of
  1431                      NONE => in_ts
  1431                      NONE => in_ts
  1432                    | SOME (polarity, depth_t) => in_ts @ [polarity, depth_t]
  1432                    | SOME (polarity, depth_t) => in_ts @ [polarity, depth_t]
  1433                    val u = lift_pred compfuns
  1433                    val u = lift_pred compfuns
  1434                      (list_comb (compile_expr depth thy (mode, t), args))
  1434                      (list_comb (compile_expr (is_some depth) thy (mode, t), args))
  1435                    val rest = compile_prems out_ts''' vs' names'' ps
  1435                    val rest = compile_prems out_ts''' vs' names'' ps
  1436                  in
  1436                  in
  1437                    (u, rest)
  1437                    (u, rest)
  1438                  end
  1438                  end
  1439              | Negprem (us, t) =>
  1439              | Negprem (us, t) =>
  1440                  let
  1440                  let
  1441                    val (in_ts, out_ts''') = split_smode is us
  1441                    val (in_ts, out_ts''') = split_smode is us
  1442                    val depth' = Option.map (apfst HOLogic.mk_not) depth
  1442                    val args = case depth of
  1443                    val args = case depth' of
       
  1444                      NONE => in_ts
  1443                      NONE => in_ts
  1445                    | SOME (polarity, depth_t) => in_ts @ [polarity, depth_t]
  1444                    | SOME (polarity, depth_t) => in_ts @ [HOLogic.mk_not polarity, depth_t]
  1446                    val u = lift_pred compfuns (mk_not PredicateCompFuns.compfuns
  1445                    val u = lift_pred compfuns (mk_not PredicateCompFuns.compfuns
  1447                      (list_comb (compile_expr depth' thy (mode, t), args)))
  1446                    (list_comb (compile_expr (is_some depth) thy (mode, t), args)))
  1448                    val rest = compile_prems out_ts''' vs' names'' ps
  1447                    val rest = compile_prems out_ts''' vs' names'' ps
  1449                  in
  1448                  in
  1450                    (u, rest)
  1449                    (u, rest)
  1451                  end
  1450                  end
  1452              | Sidecond t =>
  1451              | Sidecond t =>
  1459                  let
  1458                  let
  1460                    val (in_ts, out_ts''') = split_smode is us;
  1459                    val (in_ts, out_ts''') = split_smode is us;
  1461                    val args = case depth of
  1460                    val args = case depth of
  1462                      NONE => in_ts
  1461                      NONE => in_ts
  1463                      | SOME (polarity, depth_t) => in_ts @ [polarity, depth_t]
  1462                      | SOME (polarity, depth_t) => in_ts @ [polarity, depth_t]
  1464                    val u = compile_gen_expr depth thy compfuns (mode, t) args
  1463                    val u = compile_gen_expr (is_some depth) thy compfuns (mode, t) args
  1465                    val rest = compile_prems out_ts''' vs' names'' ps
  1464                    val rest = compile_prems out_ts''' vs' names'' ps
  1466                  in
  1465                  in
  1467                    (u, rest)
  1466                    (u, rest)
  1468                  end
  1467                  end
  1469              | Generator (v, T) =>
  1468              | Generator (v, T) =>
  2436 (* transformation for code generation *)
  2435 (* transformation for code generation *)
  2437 
  2436 
  2438 val eval_ref = Unsynchronized.ref (NONE : (unit -> term Predicate.pred) option);
  2437 val eval_ref = Unsynchronized.ref (NONE : (unit -> term Predicate.pred) option);
  2439 
  2438 
  2440 (*FIXME turn this into an LCF-guarded preprocessor for comprehensions*)
  2439 (*FIXME turn this into an LCF-guarded preprocessor for comprehensions*)
  2441 fun analyze_compr thy t_compr =
  2440 fun analyze_compr thy depth_limit t_compr =
  2442   let
  2441   let
  2443     val split = case t_compr of (Const (@{const_name Collect}, _) $ t) => t
  2442     val split = case t_compr of (Const (@{const_name Collect}, _) $ t) => t
  2444       | _ => error ("Not a set comprehension: " ^ Syntax.string_of_term_global thy t_compr);
  2443       | _ => error ("Not a set comprehension: " ^ Syntax.string_of_term_global thy t_compr);
  2445     val (body, Ts, fp) = HOLogic.strip_psplits split;
  2444     val (body, Ts, fp) = HOLogic.strip_psplits split;
  2446     val (pred as Const (name, T), all_args) = strip_comb body;
  2445     val (pred as Const (name, T), all_args) = strip_comb body;
  2456                 ^ Syntax.string_of_term_global thy t_compr)
  2455                 ^ Syntax.string_of_term_global thy t_compr)
  2457       | [m] => m
  2456       | [m] => m
  2458       | m :: _ :: _ => (warning ("Multiple modes possible for comprehension "
  2457       | m :: _ :: _ => (warning ("Multiple modes possible for comprehension "
  2459                 ^ Syntax.string_of_term_global thy t_compr); m);
  2458                 ^ Syntax.string_of_term_global thy t_compr); m);
  2460     val (inargs, outargs) = split_smode user_mode' args;
  2459     val (inargs, outargs) = split_smode user_mode' args;
  2461     val t_pred = list_comb (compile_expr NONE thy (m, list_comb (pred, params)), inargs);
  2460     val inargs' = case depth_limit of NONE => inargs
       
  2461       | SOME d => inargs @ [@{term "True"}, HOLogic.mk_number @{typ "code_numeral"} d]
       
  2462     val t_pred = list_comb (compile_expr (is_some depth_limit) thy
       
  2463       (m, list_comb (pred, params)), inargs');
  2462     val t_eval = if null outargs then t_pred else
  2464     val t_eval = if null outargs then t_pred else
  2463       let
  2465       let
  2464         val outargs_bounds = map (fn Bound i => i) outargs;
  2466         val outargs_bounds = map (fn Bound i => i) outargs;
  2465         val outargsTs = map (nth Ts) outargs_bounds;
  2467         val outargsTs = map (nth Ts) outargs_bounds;
  2466         val T_pred = HOLogic.mk_tupleT outargsTs;
  2468         val T_pred = HOLogic.mk_tupleT outargsTs;
  2472           (Term.list_abs (map (pair "") outargsTs,
  2474           (Term.list_abs (map (pair "") outargsTs,
  2473             HOLogic.mk_ptuple fp T_compr (map Bound arrange_bounds)))
  2475             HOLogic.mk_ptuple fp T_compr (map Bound arrange_bounds)))
  2474       in mk_map PredicateCompFuns.compfuns T_pred T_compr arrange t_pred end
  2476       in mk_map PredicateCompFuns.compfuns T_pred T_compr arrange t_pred end
  2475   in t_eval end;
  2477   in t_eval end;
  2476 
  2478 
  2477 fun eval thy t_compr =
  2479 fun eval thy depth_limit t_compr =
  2478   let
  2480   let
  2479     val t = analyze_compr thy t_compr;
  2481     val t = analyze_compr thy depth_limit t_compr;
  2480     val T = dest_predT PredicateCompFuns.compfuns (fastype_of t);
  2482     val T = dest_predT PredicateCompFuns.compfuns (fastype_of t);
  2481     val t' = mk_map PredicateCompFuns.compfuns T HOLogic.termT (HOLogic.term_of_const T) t;
  2483     val t' = mk_map PredicateCompFuns.compfuns T HOLogic.termT (HOLogic.term_of_const T) t;
  2482   in (T, Code_ML.eval NONE ("Predicate_Compile_Core.eval_ref", eval_ref) Predicate.map thy t' []) end;
  2484   in (T, Code_ML.eval NONE ("Predicate_Compile_Core.eval_ref", eval_ref) Predicate.map thy t' []) end;
  2483 
  2485 
  2484 fun values ctxt k t_compr =
  2486 fun values ctxt depth_limit k t_compr =
  2485   let
  2487   let
  2486     val thy = ProofContext.theory_of ctxt;
  2488     val thy = ProofContext.theory_of ctxt;
  2487     val (T, t) = eval thy t_compr;
  2489     val (T, t) = eval thy depth_limit t_compr;
  2488     val setT = HOLogic.mk_setT T;
  2490     val setT = HOLogic.mk_setT T;
  2489     val (ts, _) = Predicate.yieldn k t;
  2491     val (ts, _) = Predicate.yieldn k t;
  2490     val elemsT = HOLogic.mk_set T ts;
  2492     val elemsT = HOLogic.mk_set T ts;
  2491   in if k = ~1 orelse length ts < k then elemsT
  2493   in if k = ~1 orelse length ts < k then elemsT
  2492     else Const (@{const_name Set.union}, setT --> setT --> setT) $ elemsT $ t_compr
  2494     else Const (@{const_name Set.union}, setT --> setT --> setT) $ elemsT $ t_compr
  2497     val thy = ProofContext.theory_of ctxt
  2499     val thy = ProofContext.theory_of ctxt
  2498     val _ = 
  2500     val _ = 
  2499   in
  2501   in
  2500   end;
  2502   end;
  2501   *)
  2503   *)
  2502 fun values_cmd modes k raw_t state =
  2504 fun values_cmd modes depth_limit k raw_t state =
  2503   let
  2505   let
  2504     val ctxt = Toplevel.context_of state;
  2506     val ctxt = Toplevel.context_of state;
  2505     val t = Syntax.read_term ctxt raw_t;
  2507     val t = Syntax.read_term ctxt raw_t;
  2506     val t' = values ctxt k t;
  2508     val t' = values ctxt depth_limit k t;
  2507     val ty' = Term.type_of t';
  2509     val ty' = Term.type_of t';
  2508     val ctxt' = Variable.auto_fixes t' ctxt;
  2510     val ctxt' = Variable.auto_fixes t' ctxt;
  2509     val p = PrintMode.with_modes modes (fn () =>
  2511     val p = PrintMode.with_modes modes (fn () =>
  2510       Pretty.block [Pretty.quote (Syntax.pretty_term ctxt' t'), Pretty.fbrk,
  2512       Pretty.block [Pretty.quote (Syntax.pretty_term ctxt' t'), Pretty.fbrk,
  2511         Pretty.str "::", Pretty.brk 1, Pretty.quote (Syntax.pretty_typ ctxt' ty')]) ();
  2513         Pretty.str "::", Pretty.brk 1, Pretty.quote (Syntax.pretty_typ ctxt' ty')]) ();
  2513 
  2515 
  2514 local structure P = OuterParse in
  2516 local structure P = OuterParse in
  2515 
  2517 
  2516 val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) [];
  2518 val opt_modes = Scan.optional (P.$$$ "(" |-- P.!!! (Scan.repeat1 P.xname --| P.$$$ ")")) [];
  2517 
  2519 
       
  2520 val _ = List.app OuterKeyword.keyword ["depth_limit"]
       
  2521 
       
  2522 val opt_depth_limit =
       
  2523   Scan.optional (P.$$$ "[" |-- P.$$$ "depth_limit" |-- P.$$$ "=" |-- P.nat --| P.$$$ "]" >> SOME) NONE
       
  2524 
  2518 val _ = OuterSyntax.improper_command "values" "enumerate and print comprehensions" OuterKeyword.diag
  2525 val _ = OuterSyntax.improper_command "values" "enumerate and print comprehensions" OuterKeyword.diag
  2519   (opt_modes -- Scan.optional P.nat ~1 -- P.term
  2526   (opt_modes -- opt_depth_limit -- Scan.optional P.nat ~1 -- P.term
  2520     >> (fn ((modes, k), t) => Toplevel.no_timing o Toplevel.keep
  2527     >> (fn (((modes, depth_limit), k), t) => Toplevel.no_timing o Toplevel.keep
  2521         (values_cmd modes k t)));
  2528         (values_cmd modes depth_limit k t)));
  2522 
  2529 
  2523 end;
  2530 end;
  2524 
  2531 
  2525 end;
  2532 end;