src/Pure/Tools/codegen_theorems.ML
changeset 20404 1a29e6c3ab04
parent 20394 21227c43ba26
child 20456 42be3a46dcd8
equal deleted inserted replaced
20403:14d5f6ed5602 20404:1a29e6c3ab04
    19   val del_unfold: thm -> theory -> theory;
    19   val del_unfold: thm -> theory -> theory;
    20   val purge_defs: string * typ -> theory -> theory;
    20   val purge_defs: string * typ -> theory -> theory;
    21   val notify_dirty: theory -> theory;
    21   val notify_dirty: theory -> theory;
    22 
    22 
    23   val extr_typ: theory -> thm -> typ;
    23   val extr_typ: theory -> thm -> typ;
       
    24   val rewrite_fun: thm list -> thm -> thm;
    24   val common_typ: theory -> (thm -> typ) -> thm list -> thm list;
    25   val common_typ: theory -> (thm -> typ) -> thm list -> thm list;
    25   val preprocess: theory -> thm list -> thm list;
    26   val preprocess: theory -> thm list -> thm list;
    26 
    27 
    27   val prove_freeness: theory -> tactic -> string
    28   val prove_freeness: theory -> tactic -> string
    28     -> (string * sort) list * (string * typ list) list -> thm list;
    29     -> (string * sort) list * (string * typ list) list -> thm list;
    49 (** preliminaries **)
    50 (** preliminaries **)
    50 
    51 
    51 (* diagnostics *)
    52 (* diagnostics *)
    52 
    53 
    53 val debug = ref false;
    54 val debug = ref false;
    54 fun debug_msg f x = (if !debug then Output.debug (f x) else (); x);
    55 fun debug_msg f x = (if !debug then Output.tracing (f x) else (); x);
    55 
    56 
    56 
    57 
    57 (* auxiliary *)
    58 (* auxiliary *)
    58 
    59 
    59 fun getf_first [] _ = NONE
    60 fun getf_first [] _ = NONE
   182 (* theorem purification *)
   183 (* theorem purification *)
   183 
   184 
   184 fun err_thm msg thm =
   185 fun err_thm msg thm =
   185   error (msg ^ ": " ^ string_of_thm thm);
   186   error (msg ^ ": " ^ string_of_thm thm);
   186 
   187 
       
   188 val mk_rule =
       
   189   #mk o #mk_rews o snd o MetaSimplifier.rep_ss o Simplifier.simpset_of;
       
   190 
   187 fun abs_norm thy thm =
   191 fun abs_norm thy thm =
   188   let
   192   let
   189     fun expvars t =
   193     fun expvars t =
   190       let
   194       let
   191         val lhs = (fst o Logic.dest_equals) t;
   195         val lhs = (fst o Logic.dest_equals) t;
   253     fun drop eqs [] = eqs
   257     fun drop eqs [] = eqs
   254       | drop eqs (eq::eqs') =
   258       | drop eqs (eq::eqs') =
   255           drop (eq::eqs) (filter_out (matches eq) eqs')
   259           drop (eq::eqs) (filter_out (matches eq) eqs')
   256   in drop [] eqs end;
   260   in drop [] eqs end;
   257 
   261 
       
   262 fun drop_refl thy = filter_out (is_equal o Term.fast_term_ord o Logic.dest_equals
       
   263   o ObjectLogic.drop_judgment thy o Drule.plain_prop_of);
       
   264 
   258 fun make_eq thy =
   265 fun make_eq thy =
   259   let
   266   let
   260     val ((_, atomize), _) = get_obj thy;
   267     val ((_, atomize), _) = get_obj thy;
   261   in rewrite_rule [atomize] end;
   268   in rewrite_rule [atomize] end;
   262 
   269 
   399       val unfolds = (fn Preproc { unfolds, ... } => unfolds) preproc;
   406       val unfolds = (fn Preproc { unfolds, ... } => unfolds) preproc;
   400     in
   407     in
   401       (Pretty.writeln o Pretty.block o Pretty.fbreaks) ([
   408       (Pretty.writeln o Pretty.block o Pretty.fbreaks) ([
   402         Pretty.str "code generation theorems:",
   409         Pretty.str "code generation theorems:",
   403         Pretty.str "function theorems:" ] @
   410         Pretty.str "function theorems:" ] @
   404         (*Pretty.fbreaks ( *)
       
   405           map (fn (c, thms) =>
   411           map (fn (c, thms) =>
   406             (Pretty.block o Pretty.fbreaks) (
   412             (Pretty.block o Pretty.fbreaks) (
   407               (Pretty.str o CodegenConsts.string_of_const thy) c  :: map pretty_thm (rev thms)
   413               (Pretty.str o CodegenConsts.string_of_const thy) c  :: map pretty_thm (rev thms)
   408             )
   414             )
   409           ) funs
   415           ) funs @ [
   410         (*) *) @ [
       
   411         Pretty.fbrk,
       
   412         Pretty.block (
   416         Pretty.block (
   413           Pretty.str "inlined theorems:"
   417           Pretty.str "inlined theorems:"
   414           :: Pretty.fbrk
   418           :: Pretty.fbrk
   415           :: (Pretty.fbreaks o map pretty_thm) unfolds
   419           :: (Pretty.fbreaks o map pretty_thm) unfolds
   416       )])
   420       )])
   541 (* preprocessing *)
   545 (* preprocessing *)
   542 
   546 
   543 fun extr_typ thy thm = case dest_fun thy thm
   547 fun extr_typ thy thm = case dest_fun thy thm
   544  of (_, (ty, _)) => ty;
   548  of (_, (ty, _)) => ty;
   545 
   549 
   546 fun rewrite_rhs conv thm = (case (Drule.strip_comb o cprop_of) thm
   550 fun rewrite_fun rewrites thm =
   547  of (ct', [ct1, ct2]) => (case term_of ct'
   551   let
   548      of Const ("==", _) =>
   552     val rewrite = Tactic.rewrite true rewrites;
   549           Thm.equal_elim (combination (combination (reflexive ct') (reflexive ct1))
   553     val (ct_eq, [ct_lhs, ct_rhs]) = (Drule.strip_comb o cprop_of) thm;
   550             (conv ct2)) thm
   554     val Const ("==", _) = term_of ct_eq;
   551       | _ => raise ERROR "rewrite_rhs")
   555     val (ct_f, ct_args) = Drule.strip_comb ct_lhs;
   552   | _ => raise ERROR "rewrite_rhs");
   556     val rhs' = rewrite ct_rhs;
       
   557     val args' = map rewrite ct_args;
       
   558     val lhs' = Thm.symmetric (fold (fn th1 => fn th2 => Thm.combination th2 th1)
       
   559       args' (Thm.reflexive ct_f));
       
   560   in
       
   561     Thm.transitive (Thm.transitive lhs' thm) rhs'
       
   562   end handle Bind => raise ERROR "rewrite_fun"
   553 
   563 
   554 fun common_typ thy _ [] = []
   564 fun common_typ thy _ [] = []
   555   | common_typ thy _ [thm] = [thm]
   565   | common_typ thy _ [thm] = [thm]
   556   | common_typ thy extract_typ thms =
   566   | common_typ thy extract_typ thms =
   557       let
   567       let
   583           |> Conjunction.elim_list;
   593           |> Conjunction.elim_list;
   584     fun cmp_thms (thm1, thm2) =
   594     fun cmp_thms (thm1, thm2) =
   585       not (Sign.typ_instance thy (extr_typ thy thm1, extr_typ thy thm2));
   595       not (Sign.typ_instance thy (extr_typ thy thm1, extr_typ thy thm2));
   586     fun unvarify thms =
   596     fun unvarify thms =
   587       #2 (#1 (Variable.import true thms (ProofContext.init thy)));
   597       #2 (#1 (Variable.import true thms (ProofContext.init thy)));
   588     val unfold_thms = Tactic.rewrite true (map (make_eq thy) (the_unfolds thy));
   598     val unfold_thms = map (make_eq thy) (the_unfolds thy);
   589   in
   599   in
   590     thms
   600     thms
   591     |> map (make_eq thy)
   601     |> map (make_eq thy)
   592     |> map (Thm.transfer thy)
   602     |> map (Thm.transfer thy)
   593     |> fold (fn f => f thy) (the_preprocs thy)
   603     |> fold (fn f => f thy) (the_preprocs thy)
   594     |> map (rewrite_rhs unfold_thms)
   604     |> map (rewrite_fun unfold_thms)
   595     |> debug_msg (fn _ => "[cg_thm] sorting")
   605     |> debug_msg (fn _ => "[cg_thm] sorting")
   596     |> debug_msg (commas o map string_of_thm)
   606     |> debug_msg (commas o map string_of_thm)
   597     |> sort (make_ord cmp_thms)
   607     |> sort (make_ord cmp_thms)
   598     |> debug_msg (fn _ => "[cg_thm] common_typ")
   608     |> debug_msg (fn _ => "[cg_thm] common_typ")
   599     |> debug_msg (commas o map string_of_thm)
   609     |> debug_msg (commas o map string_of_thm)
   600     |> common_typ thy (extr_typ thy)
   610     |> common_typ thy (extr_typ thy)
   601     |> debug_msg (fn _ => "[cg_thm] abs_norm")
   611     |> debug_msg (fn _ => "[cg_thm] abs_norm")
   602     |> debug_msg (commas o map string_of_thm)
   612     |> debug_msg (commas o map string_of_thm)
   603     |> map (abs_norm thy)
   613     |> map (abs_norm thy)
       
   614     |> drop_refl thy
   604     |> burrow_thms (
   615     |> burrow_thms (
   605         debug_msg (fn _ => "[cg_thm] canonical tvars")
   616         debug_msg (fn _ => "[cg_thm] canonical tvars")
   606         #> debug_msg (string_of_thm)
   617         #> debug_msg (string_of_thm)
   607         #> canonical_tvars thy
   618         #> canonical_tvars thy
   608         #> debug_msg (fn _ => "[cg_thm] canonical vars")
   619         #> debug_msg (fn _ => "[cg_thm] canonical vars")
   682           val (_, lhs) = mk_lhs vs args;
   693           val (_, lhs) = mk_lhs vs args;
   683         in (inj, mk_func thy (lhs, fals) :: dist) end;
   694         in (inj, mk_func thy (lhs, fals) :: dist) end;
   684     fun mk_eqs (vs, cos) =
   695     fun mk_eqs (vs, cos) =
   685       let val cos' = rev cos
   696       let val cos' = rev cos
   686       in (op @) (fold (mk_eq vs) (product cos' cos') ([], [])) end;
   697       in (op @) (fold (mk_eq vs) (product cos' cos') ([], [])) end;
   687   in
   698     val ts = (map (ObjectLogic.ensure_propT thy) o mk_eqs) vs_cos;
   688     map (fn t => Goal.prove_global thy [] []
   699     fun prove t = if !quick_and_dirty then SkipProof.make_thm thy (Logic.varify t)
   689         (ObjectLogic.ensure_propT thy t) (K tac)) (mk_eqs vs_cos)
   700       else Goal.prove_global thy [] [] t (K tac);
   690   end;
   701   in map prove ts end;
   691 
   702 
   692 fun get_datatypes thy dtco =
   703 fun get_datatypes thy dtco =
   693   let
   704   let
   694     val _ = debug_msg (fn _ => "[cg_thm] datatype " ^ dtco) ()
   705     val _ = debug_msg (fn _ => "[cg_thm] datatype " ^ dtco) ()
   695   in
   706   in