src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
changeset 33132 07efd452a698
parent 33131 cef39362ce56
child 33133 2eb7dfcf3bc3
equal deleted inserted replaced
33131:cef39362ce56 33132:07efd452a698
     5 *)
     5 *)
     6 
     6 
     7 signature PREDICATE_COMPILE_CORE =
     7 signature PREDICATE_COMPILE_CORE =
     8 sig
     8 sig
     9   val setup: theory -> theory
     9   val setup: theory -> theory
    10   val code_pred: Predicate_Compile_Aux.options -> int list list option -> string -> Proof.context -> Proof.state
    10   val code_pred: Predicate_Compile_Aux.options -> string -> Proof.context -> Proof.state
    11   val code_pred_cmd: Predicate_Compile_Aux.options -> int list list option -> string -> Proof.context -> Proof.state
    11   val code_pred_cmd: Predicate_Compile_Aux.options -> string -> Proof.context -> Proof.state
    12   type smode = (int * int list option) list
    12   type smode = (int * int list option) list
    13   type mode = smode option list * smode
    13   type mode = smode option list * smode
    14   datatype tmode = Mode of mode * smode * tmode option list;
    14   datatype tmode = Mode of mode * smode * tmode option list;
    15   (*val add_equations_of: bool -> string list -> theory -> theory *)
    15   (*val add_equations_of: bool -> string list -> theory -> theory *)
    16   val register_predicate : (thm list * thm * int) -> theory -> theory
    16   val register_predicate : (thm list * thm * int) -> theory -> theory
    37   val print_all_modes: theory -> unit
    37   val print_all_modes: theory -> unit
    38   val do_proofs: bool Unsynchronized.ref
    38   val do_proofs: bool Unsynchronized.ref
    39   val mk_casesrule : Proof.context -> int -> thm list -> term
    39   val mk_casesrule : Proof.context -> int -> thm list -> term
    40   val analyze_compr: theory -> term -> term
    40   val analyze_compr: theory -> term -> term
    41   val eval_ref: (unit -> term Predicate.pred) option Unsynchronized.ref
    41   val eval_ref: (unit -> term Predicate.pred) option Unsynchronized.ref
    42   val add_equations : Predicate_Compile_Aux.options -> int list list option -> string list -> theory -> theory
    42   val add_equations : Predicate_Compile_Aux.options -> string list -> theory -> theory
    43   val code_pred_intros_attrib : attribute
    43   val code_pred_intros_attrib : attribute
    44   (* used by Quickcheck_Generator *) 
    44   (* used by Quickcheck_Generator *) 
    45   (*val funT_of : mode -> typ -> typ
    45   (*val funT_of : mode -> typ -> typ
    46   val mk_if_pred : term -> term
    46   val mk_if_pred : term -> term
    47   val mk_Eval : term * term -> term*)
    47   val mk_Eval : term * term -> term*)
    88   (*val rpred_prove_preds : theory -> term pred_mode_table -> thm pred_mode_table*)
    88   (*val rpred_prove_preds : theory -> term pred_mode_table -> thm pred_mode_table*)
    89   val pred_compfuns : compilation_funs
    89   val pred_compfuns : compilation_funs
    90   val rpred_compfuns : compilation_funs
    90   val rpred_compfuns : compilation_funs
    91   val dest_funT : typ -> typ * typ
    91   val dest_funT : typ -> typ * typ
    92  (* val depending_preds_of : theory -> thm list -> string list *)
    92  (* val depending_preds_of : theory -> thm list -> string list *)
    93   val add_quickcheck_equations : Predicate_Compile_Aux.options -> int list list option -> string list -> theory -> theory
    93   val add_quickcheck_equations : Predicate_Compile_Aux.options -> string list -> theory -> theory
    94   val add_sizelim_equations : Predicate_Compile_Aux.options -> int list list option -> string list -> theory -> theory
    94   val add_sizelim_equations : Predicate_Compile_Aux.options -> string list -> theory -> theory
    95   val is_inductive_predicate : theory -> string -> bool
    95   val is_inductive_predicate : theory -> string -> bool
    96   val terms_vs : term list -> string list
    96   val terms_vs : term list -> string list
    97   val subsets : int -> int -> int list list
    97   val subsets : int -> int -> int list list
    98   val check_mode_clause : bool -> theory -> string list ->
    98   val check_mode_clause : bool -> theory -> string list ->
    99     (string * mode list) list -> (string * mode list) list -> mode -> (term list * indprem list)
    99     (string * mode list) list -> (string * mode list) list -> mode -> (term list * indprem list)
   396 
   396 
   397 (*val generator_modes_of = (map fst) o #generators oo the_pred_data*)
   397 (*val generator_modes_of = (map fst) o #generators oo the_pred_data*)
   398      
   398      
   399 (* diagnostic display functions *)
   399 (* diagnostic display functions *)
   400 
   400 
   401 fun print_modes modes = tracing ("Inferred modes:\n" ^
   401 fun print_modes modes =
   402   cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map
   402   tracing ("Inferred modes:\n" ^
   403     string_of_mode ms)) modes));
   403     cat_lines (map (fn (s, ms) => s ^ ": " ^ commas (map
       
   404       string_of_mode ms)) modes));
   404 
   405 
   405 fun print_pred_mode_table string_of_entry thy pred_mode_table =
   406 fun print_pred_mode_table string_of_entry thy pred_mode_table =
   406   let
   407   let
   407     fun print_mode pred (mode, entry) =  "mode : " ^ (string_of_mode mode)
   408     fun print_mode pred (mode, entry) =  "mode : " ^ (string_of_mode mode)
   408       ^ (string_of_entry pred mode entry)  
   409       ^ (string_of_entry pred mode entry)  
   479         val _ = writeln ("modes: " ^ (commas (map string_of_mode modes)))
   480         val _ = writeln ("modes: " ^ (commas (map string_of_mode modes)))
   480       in u end  
   481       in u end  
   481   in
   482   in
   482     fold print (all_modes_of thy) ()
   483     fold print (all_modes_of thy) ()
   483   end
   484   end
       
   485 
       
   486 (* validity checks *)
       
   487 
       
   488 fun check_expected_modes (options : Predicate_Compile_Aux.options) modes =
       
   489   case expected_modes options of
       
   490     SOME (s, ms) => (case AList.lookup (op =) modes s of
       
   491       SOME modes =>
       
   492         if not (eq_set (map (map (rpair NONE)) ms, map snd modes)) then
       
   493           error ("expected modes were not inferred:"
       
   494             ^ "infered modes for " ^ s ^ ": " ^ commas (map (string_of_smode o snd) modes))
       
   495         else ()
       
   496       | NONE => ())
       
   497   | NONE => ()
   484 
   498 
   485 (* importing introduction rules *)   
   499 (* importing introduction rules *)   
   486 
   500 
   487 fun unify_consts thy cs intr_ts =
   501 fun unify_consts thy cs intr_ts =
   488   (let
   502   (let
  1468                    val in_ts = map (compile_arg size thy param_vs iss) in_ts
  1482                    val in_ts = map (compile_arg size thy param_vs iss) in_ts
  1469                    val args = case size of
  1483                    val args = case size of
  1470                      NONE => in_ts
  1484                      NONE => in_ts
  1471                    | SOME size_t => in_ts @ [size_t]
  1485                    | SOME size_t => in_ts @ [size_t]
  1472                    val u = lift_pred compfuns
  1486                    val u = lift_pred compfuns
  1473                      (list_comb (compile_expr NONE size thy (mode, t), args))                     
  1487                      (list_comb (compile_expr NONE size thy (mode, t), args))
  1474                    val rest = compile_prems out_ts''' vs' names'' ps
  1488                    val rest = compile_prems out_ts''' vs' names'' ps
  1475                  in
  1489                  in
  1476                    (u, rest)
  1490                    (u, rest)
  1477                  end
  1491                  end
  1478              | Negprem (us, t) =>
  1492              | Negprem (us, t) =>
  2303 *)
  2317 *)
  2304 
  2318 
  2305 
  2319 
  2306 (** main function of predicate compiler **)
  2320 (** main function of predicate compiler **)
  2307 
  2321 
  2308 fun add_equations_of steps options expected_modes prednames thy =
  2322 fun add_equations_of steps options prednames thy =
  2309   let
  2323   let
  2310     val _ = print_step options ("Starting predicate compiler for predicates " ^ commas prednames ^ "...")
  2324     val _ = print_step options ("Starting predicate compiler for predicates " ^ commas prednames ^ "...")
  2311     val _ = Output.tracing (commas (map (Display.string_of_thm_global thy) (maps (intros_of thy) prednames)))
  2325     val _ = tracing (commas (map (Display.string_of_thm_global thy) (maps (intros_of thy) prednames)))
  2312       (*val _ = check_intros_elim_match thy prednames*)
  2326       (*val _ = check_intros_elim_match thy prednames*)
  2313       (*val _ = map (check_format_of_intro_rule thy) (maps (intros_of thy) prednames)*)
  2327       (*val _ = map (check_format_of_intro_rule thy) (maps (intros_of thy) prednames)*)
  2314     val (preds, nparams, all_vs, param_vs, extra_modes, clauses, all_modes) =
  2328     val (preds, nparams, all_vs, param_vs, extra_modes, clauses, all_modes) =
  2315       prepare_intrs thy prednames (maps (intros_of thy) prednames)
  2329       prepare_intrs thy prednames (maps (intros_of thy) prednames)
  2316     val _ = print_step options "Infering modes..."
  2330     val _ = print_step options "Infering modes..."
  2317     val moded_clauses = #infer_modes steps options thy extra_modes all_modes param_vs clauses 
  2331     val moded_clauses = #infer_modes steps options thy extra_modes all_modes param_vs clauses 
  2318     val modes : (string * ((int * int list option) list option list * (int * int list option) list) list) list = map (fn (p, mps) => (p, map fst mps)) moded_clauses
  2332     val modes = map (fn (p, mps) => (p, map fst mps)) moded_clauses
  2319     val all_smodes : (((int * int list option) list) list) list = map (map snd) (map snd modes)
  2333     val _ = check_expected_modes options modes
  2320     val _ = case expected_modes of
       
  2321       SOME ms => if not (forall (fn smodes => eq_set (map (map (rpair NONE)) ms, smodes)) all_smodes) then
       
  2322       error ("expected modes were not inferred - " ^ commas (map string_of_smode (flat all_smodes))) else ()
       
  2323       | NONE => ()
       
  2324     val _ = print_modes modes
  2334     val _ = print_modes modes
  2325     val _ = print_moded_clauses thy moded_clauses
  2335     val _ = print_moded_clauses thy moded_clauses
  2326     val _ = print_step options "Defining executable functions..."
  2336     val _ = print_step options "Defining executable functions..."
  2327     val thy' = fold (#create_definitions steps preds) modes thy
  2337     val thy' = fold (#create_definitions steps preds) modes thy
  2328       |> Theory.checkpoint
  2338       |> Theory.checkpoint
  2357     (fold (Graph.add_edge o (pair key)) (edges_of (key, v)) G'', visited')
  2367     (fold (Graph.add_edge o (pair key)) (edges_of (key, v)) G'', visited')
  2358   end;
  2368   end;
  2359 
  2369 
  2360 fun extend value_of edges_of key G = fst (extend' value_of edges_of key (G, [])) 
  2370 fun extend value_of edges_of key G = fst (extend' value_of edges_of key (G, [])) 
  2361   
  2371   
  2362 fun gen_add_equations steps options expected_modes names thy =
  2372 fun gen_add_equations steps options names thy =
  2363   let
  2373   let
  2364     val thy' = PredData.map (fold (extend (fetch_pred_data thy) (depending_preds_of thy)) names) thy
  2374     val thy' = PredData.map (fold (extend (fetch_pred_data thy) (depending_preds_of thy)) names) thy
  2365       |> Theory.checkpoint;
  2375       |> Theory.checkpoint;
  2366     fun strong_conn_of gr keys =
  2376     fun strong_conn_of gr keys =
  2367       Graph.strong_conn (Graph.subgraph (member (op =) (Graph.all_succs gr keys)) gr)
  2377       Graph.strong_conn (Graph.subgraph (member (op =) (Graph.all_succs gr keys)) gr)
  2368     val scc = strong_conn_of (PredData.get thy') names
  2378     val scc = strong_conn_of (PredData.get thy') names
  2369     val thy'' = fold_rev
  2379     val thy'' = fold_rev
  2370       (fn preds => fn thy =>
  2380       (fn preds => fn thy =>
  2371         if #are_not_defined steps thy preds then
  2381         if #are_not_defined steps thy preds then
  2372           add_equations_of steps options expected_modes preds thy else thy)
  2382           add_equations_of steps options preds thy else thy)
  2373       scc thy' |> Theory.checkpoint
  2383       scc thy' |> Theory.checkpoint
  2374   in thy'' end
  2384   in thy'' end
  2375 
  2385 
  2376 (* different instantiantions of the predicate compiler *)
  2386 (* different instantiantions of the predicate compiler *)
  2377 
  2387 
  2415 *)
  2425 *)
  2416 
  2426 
  2417 val setup = PredData.put (Graph.empty) #>
  2427 val setup = PredData.put (Graph.empty) #>
  2418   Attrib.setup @{binding code_pred_intros} (Scan.succeed (attrib add_intro))
  2428   Attrib.setup @{binding code_pred_intros} (Scan.succeed (attrib add_intro))
  2419     "adding alternative introduction rules for code generation of inductive predicates"
  2429     "adding alternative introduction rules for code generation of inductive predicates"
  2420 (*  Attrib.setup @{binding code_ind_cases} (Scan.succeed add_elim_attrib)
       
  2421     "adding alternative elimination rules for code generation of inductive predicates";
       
  2422     *)
       
  2423   (*FIXME name discrepancy in attribs and ML code*)
  2430   (*FIXME name discrepancy in attribs and ML code*)
  2424   (*FIXME intros should be better named intro*)
  2431   (*FIXME intros should be better named intro*)
  2425   (*FIXME why distinguished attribute for cases?*)
       
  2426 
  2432 
  2427 (* TODO: make TheoryDataFun to GenericDataFun & remove duplication of local theory and theory *)
  2433 (* TODO: make TheoryDataFun to GenericDataFun & remove duplication of local theory and theory *)
  2428 fun generic_code_pred prep_const options modes raw_const lthy =
  2434 fun generic_code_pred prep_const options raw_const lthy =
  2429   let
  2435   let
  2430     val thy = ProofContext.theory_of lthy
  2436     val thy = ProofContext.theory_of lthy
  2431     val const = prep_const thy raw_const
  2437     val const = prep_const thy raw_const
  2432     val lthy' = LocalTheory.theory (PredData.map
  2438     val lthy' = LocalTheory.theory (PredData.map
  2433         (extend (fetch_pred_data thy) (depending_preds_of thy) const)) lthy
  2439         (extend (fetch_pred_data thy) (depending_preds_of thy) const)) lthy
  2453         val global_thms = ProofContext.export goal_ctxt
  2459         val global_thms = ProofContext.export goal_ctxt
  2454           (ProofContext.init (ProofContext.theory_of goal_ctxt)) (map the_single thms)
  2460           (ProofContext.init (ProofContext.theory_of goal_ctxt)) (map the_single thms)
  2455       in
  2461       in
  2456         goal_ctxt |> LocalTheory.theory (fold set_elim global_thms #>
  2462         goal_ctxt |> LocalTheory.theory (fold set_elim global_thms #>
  2457           (if is_rpred options then
  2463           (if is_rpred options then
  2458             (add_equations options NONE [const] #>
  2464             (add_equations options [const] #>
  2459              add_sizelim_equations options NONE [const] #> add_quickcheck_equations options NONE [const])
  2465              add_sizelim_equations options [const] #> add_quickcheck_equations options [const])
  2460            else add_equations options modes [const]))
  2466            else if is_sizelim options then
       
  2467              add_sizelim_equations options [const]
       
  2468            else
       
  2469              add_equations options [const]))
  2461       end  
  2470       end  
  2462   in
  2471   in
  2463     Proof.theorem_i NONE after_qed (map (single o (rpair [])) cases_rules) lthy''
  2472     Proof.theorem_i NONE after_qed (map (single o (rpair [])) cases_rules) lthy''
  2464   end;
  2473   end;
  2465 
  2474