src/HOL/Tools/Predicate_Compile/predicate_compile_data.ML
author haftmann
Sat Aug 28 16:14:32 2010 +0200 (2010-08-28)
changeset 38864 4abe644fcea5
parent 38795 848be46708dc
child 40053 3fa49ea76cbb
permissions -rw-r--r--
formerly unnamed infix equality now named HOL.eq
     1 (*  Title:      HOL/Tools/Predicate_Compile/predicate_compile_data.ML
     2     Author:     Lukas Bulwahn, TU Muenchen
     3 
     4 Book-keeping datastructure for the predicate compiler.
     5 *)
     6 
     7 signature PREDICATE_COMPILE_DATA =
     8 sig
     9   val ignore_consts : string list -> theory -> theory
    10   val keep_functions : string list -> theory -> theory
    11   val keep_function : theory -> string -> bool
    12   val processed_specs : theory -> string -> (string * thm list) list option
    13   val store_processed_specs : (string * (string * thm list) list) -> theory -> theory
    14   
    15   val get_specification : Predicate_Compile_Aux.options -> theory -> term -> thm list
    16   val obtain_specification_graph :
    17     Predicate_Compile_Aux.options -> theory -> term -> thm list Term_Graph.T
    18     
    19   val present_graph : thm list Term_Graph.T -> unit
    20   val normalize_equation : theory -> thm -> thm
    21 end;
    22 
    23 structure Predicate_Compile_Data : PREDICATE_COMPILE_DATA =
    24 struct
    25 
    26 open Predicate_Compile_Aux;
    27 
    28 structure Data = Theory_Data
    29 (
    30   type T =
    31     {ignore_consts : unit Symtab.table,
    32      keep_functions : unit Symtab.table,
    33      processed_specs : ((string * thm list) list) Symtab.table};
    34   val empty =
    35     {ignore_consts = Symtab.empty,
    36      keep_functions = Symtab.empty,
    37      processed_specs =  Symtab.empty};
    38   val extend = I;
    39   fun merge
    40     ({ignore_consts = c1, keep_functions = k1, processed_specs = s1},
    41      {ignore_consts = c2, keep_functions = k2, processed_specs = s2}) =
    42      {ignore_consts = Symtab.merge (K true) (c1, c2),
    43       keep_functions = Symtab.merge (K true) (k1, k2),
    44       processed_specs = Symtab.merge (K true) (s1, s2)}
    45 );
    46 
    47 
    48 
    49 fun mk_data (c, k, s) = {ignore_consts = c, keep_functions = k, processed_specs = s}
    50 fun map_data f {ignore_consts = c, keep_functions = k, processed_specs = s} = mk_data (f (c, k, s))
    51 
    52 fun ignore_consts cs = Data.map (map_data (apfst3 (fold (fn c => Symtab.insert (op =) (c, ())) cs)))
    53 
    54 fun keep_functions cs = Data.map (map_data (apsnd3 (fold (fn c => Symtab.insert (op =) (c, ())) cs)))
    55 
    56 fun keep_function thy = Symtab.defined (#keep_functions (Data.get thy))
    57 
    58 fun processed_specs thy = Symtab.lookup (#processed_specs (Data.get thy))
    59 
    60 fun store_processed_specs (constname, specs) =
    61   Data.map (map_data (aptrd3 (Symtab.update_new (constname, specs))))
    62 (* *)
    63 
    64 
    65 fun defining_term_of_introrule_term t =
    66   let
    67     val _ $ u = Logic.strip_imp_concl t
    68   in fst (strip_comb u) end
    69 (*  
    70   in case pred of
    71     Const (c, T) => c
    72     | _ => raise TERM ("defining_const_of_introrule_term failed: Not a constant", [t])
    73   end
    74 *)
    75 val defining_term_of_introrule = defining_term_of_introrule_term o prop_of
    76 
    77 (*TODO*)
    78 fun is_introlike_term t = true
    79 
    80 val is_introlike = is_introlike_term o prop_of
    81 
    82 fun check_equation_format_term (t as (Const ("==", _) $ u $ v)) =
    83   (case strip_comb u of
    84     (Const (c, T), args) =>
    85       if (length (binder_types T) = length args) then
    86         true
    87       else
    88         raise TERM ("check_equation_format_term failed: Number of arguments mismatch", [t])
    89   | _ => raise TERM ("check_equation_format_term failed: Not a constant", [t]))
    90   | check_equation_format_term t =
    91     raise TERM ("check_equation_format_term failed: Not an equation", [t])
    92 
    93 val check_equation_format = check_equation_format_term o prop_of
    94 
    95 
    96 fun defining_term_of_equation_term (t as (Const ("==", _) $ u $ v)) = fst (strip_comb u)
    97   | defining_term_of_equation_term t =
    98     raise TERM ("defining_const_of_equation_term failed: Not an equation", [t])
    99 
   100 val defining_term_of_equation = defining_term_of_equation_term o prop_of
   101 
   102 fun defining_const_of_equation th =
   103   case defining_term_of_equation th
   104    of Const (c, _) => c
   105     | _ => raise TERM ("defining_const_of_equation failed: Not a constant", [prop_of th])
   106 
   107 
   108 
   109 
   110 (* Normalizing equations *)
   111 
   112 fun mk_meta_equation th =
   113   case prop_of th of
   114     Const (@{const_name Trueprop}, _) $ (Const (@{const_name HOL.eq}, _) $ _ $ _) => th RS @{thm eq_reflection}
   115   | _ => th
   116 
   117 val meta_fun_cong = @{lemma "f == g ==> f x == g x" by simp}
   118 
   119 fun full_fun_cong_expand th =
   120   let
   121     val (f, args) = strip_comb (fst (Logic.dest_equals (prop_of th)))
   122     val i = length (binder_types (fastype_of f)) - length args
   123   in funpow i (fn th => th RS meta_fun_cong) th end;
   124 
   125 fun declare_names s xs ctxt =
   126   let
   127     val res = Name.names ctxt s xs
   128   in (res, fold Name.declare (map fst res) ctxt) end
   129   
   130 fun split_all_pairs thy th =
   131   let
   132     val ctxt = ProofContext.init_global thy
   133     val ((_, [th']), ctxt') = Variable.import true [th] ctxt
   134     val t = prop_of th'
   135     val frees = Term.add_frees t [] 
   136     val freenames = Term.add_free_names t []
   137     val nctxt = Name.make_context freenames
   138     fun mk_tuple_rewrites (x, T) nctxt =
   139       let
   140         val Ts = HOLogic.flatten_tupleT T
   141         val (xTs, nctxt') = declare_names x Ts nctxt
   142         val paths = HOLogic.flat_tupleT_paths T
   143       in ((Free (x, T), HOLogic.mk_ptuple paths T (map Free xTs)), nctxt') end
   144     val (rewr, _) = fold_map mk_tuple_rewrites frees nctxt 
   145     val t' = Pattern.rewrite_term thy rewr [] t
   146     val tac = Skip_Proof.cheat_tac thy
   147     val th'' = Goal.prove ctxt (Term.add_free_names t' []) [] t' (fn _ => tac)
   148     val th''' = Local_Defs.unfold ctxt [@{thm split_conv}, @{thm fst_conv}, @{thm snd_conv}] th''
   149   in
   150     th'''
   151   end;
   152 
   153 
   154 fun inline_equations thy th =
   155   let
   156     val inline_defs = Predicate_Compile_Inline_Defs.get (ProofContext.init_global thy)
   157     val th' = (Simplifier.full_simplify (HOL_basic_ss addsimps inline_defs)) th
   158     (*val _ = print_step options 
   159       ("Inlining " ^ (Syntax.string_of_term_global thy (prop_of th))
   160        ^ "with " ^ (commas (map ((Syntax.string_of_term_global thy) o prop_of) inline_defs))
   161        ^" to " ^ (Syntax.string_of_term_global thy (prop_of th')))*)
   162   in
   163     th'
   164   end
   165 
   166 fun normalize_equation thy th =
   167   mk_meta_equation th
   168   |> full_fun_cong_expand
   169   |> split_all_pairs thy
   170   |> tap check_equation_format
   171   |> inline_equations thy
   172 
   173 fun normalize_intros thy th =
   174   split_all_pairs thy th
   175   |> inline_equations thy
   176 
   177 fun normalize thy th =
   178   if is_equationlike th then
   179     normalize_equation thy th
   180   else
   181     normalize_intros thy th
   182 
   183 fun get_specification options thy t =
   184   let
   185     (*val (c, T) = dest_Const t
   186     val t = Const (AxClass.unoverload_const thy (c, T), T)*)
   187     val _ = if show_steps options then
   188         tracing ("getting specification of " ^ Syntax.string_of_term_global thy t ^
   189           " with type " ^ Syntax.string_of_typ_global thy (fastype_of t))
   190       else ()
   191     val ctxt = ProofContext.init_global thy
   192     fun filtering th =
   193       if is_equationlike th andalso
   194         defining_const_of_equation (normalize_equation thy th) = fst (dest_Const t) then
   195         SOME (normalize_equation thy th)
   196       else
   197         if is_introlike th andalso defining_term_of_introrule th = t then
   198           SOME th
   199         else
   200           NONE
   201     fun filter_defs ths = map_filter filtering (map (normalize thy o Thm.transfer thy) ths)
   202     val spec = case filter_defs (Predicate_Compile_Alternative_Defs.get ctxt) of
   203       [] => (case Spec_Rules.retrieve ctxt t of
   204           [] => error ("No specification for " ^ (Syntax.string_of_term_global thy t))
   205         | ((_, (_, ths)) :: _) => filter_defs ths)
   206     | ths => rev ths
   207     val _ =
   208       if show_intermediate_results options then
   209         Output.tracing (commas (map (Display.string_of_thm_global thy) spec))
   210       else ()
   211   in
   212     spec
   213   end
   214 
   215 val logic_operator_names =
   216   [@{const_name "=="}, 
   217    @{const_name "==>"},
   218    @{const_name Trueprop},
   219    @{const_name Not},
   220    @{const_name HOL.eq},
   221    @{const_name HOL.implies},
   222    @{const_name All},
   223    @{const_name Ex}, 
   224    @{const_name HOL.conj},
   225    @{const_name HOL.disj}]
   226 
   227 fun special_cases (c, T) = member (op =) [
   228   @{const_name Product_Type.Unity},
   229   @{const_name False},
   230   @{const_name Suc}, @{const_name Nat.zero_nat_inst.zero_nat},
   231   @{const_name Nat.one_nat_inst.one_nat},
   232   @{const_name Orderings.less}, @{const_name Orderings.less_eq},
   233   @{const_name Groups.zero},
   234   @{const_name Groups.one},  @{const_name Groups.plus},
   235   @{const_name Nat.ord_nat_inst.less_eq_nat},
   236   @{const_name Nat.ord_nat_inst.less_nat},
   237   @{const_name number_nat_inst.number_of_nat},
   238   @{const_name Int.Bit0},
   239   @{const_name Int.Bit1},
   240   @{const_name Int.Pls},
   241   @{const_name Int.zero_int_inst.zero_int},
   242   @{const_name List.filter},
   243   @{const_name HOL.If},
   244   @{const_name Groups.minus}
   245   ] c
   246 
   247 
   248 fun print_specification options thy constname specs = 
   249   if show_intermediate_results options then
   250     tracing ("Specification of " ^ constname ^ ":\n" ^
   251       cat_lines (map (Display.string_of_thm_global thy) specs))
   252   else ()
   253 
   254 fun obtain_specification_graph options thy t =
   255   let
   256     val ctxt = ProofContext.init_global thy
   257     fun is_nondefining_const (c, T) = member (op =) logic_operator_names c
   258     fun has_code_pred_intros (c, T) = can (Predicate_Compile_Core.intros_of ctxt) c
   259     fun case_consts (c, T) = is_some (Datatype.info_of_case thy c)
   260     fun is_datatype_constructor (c, T) = is_some (Datatype.info_of_constr thy (c, T))
   261     fun defiants_of specs =
   262       fold (Term.add_consts o prop_of) specs []
   263       |> filter_out is_datatype_constructor
   264       |> filter_out is_nondefining_const
   265       |> filter_out has_code_pred_intros
   266       |> filter_out case_consts
   267       |> filter_out special_cases
   268       |> filter_out (fn (c, _) => Symtab.defined (#ignore_consts (Data.get thy)) c)
   269       |> map (fn (c, _) => (c, Sign.the_const_constraint thy c))
   270       |> map Const
   271       (*
   272       |> filter is_defining_constname*)
   273     fun extend t gr =
   274       if can (Term_Graph.get_node gr) t then gr
   275       else
   276         let
   277           val specs = get_specification options thy t
   278           (*val _ = print_specification options thy constname specs*)
   279           val us = defiants_of specs
   280         in
   281           gr
   282           |> Term_Graph.new_node (t, specs)
   283           |> fold extend us
   284           |> fold (fn u => Term_Graph.add_edge (t, u)) us
   285         end
   286   in
   287     extend t Term_Graph.empty
   288   end;
   289 
   290 
   291 fun present_graph gr =
   292   let
   293     fun eq_cname (Const (c1, _), Const (c2, _)) = (c1 = c2)
   294     fun string_of_const (Const (c, _)) = c
   295       | string_of_const _ = error "string_of_const: unexpected term"
   296     val constss = Term_Graph.strong_conn gr;
   297     val mapping = Termtab.empty |> fold (fn consts => fold (fn const =>
   298       Termtab.update (const, consts)) consts) constss;
   299     fun succs consts = consts
   300       |> maps (Term_Graph.imm_succs gr)
   301       |> subtract eq_cname consts
   302       |> map (the o Termtab.lookup mapping)
   303       |> distinct (eq_list eq_cname);
   304     val conn = [] |> fold (fn consts => cons (consts, succs consts)) constss;
   305     
   306     fun namify consts = map string_of_const consts
   307       |> commas;
   308     val prgr = map (fn (consts, constss) =>
   309       { name = namify consts, ID = namify consts, dir = "", unfold = true,
   310         path = "", parents = map namify constss }) conn;
   311   in Present.display_graph prgr end;
   312 
   313 
   314 end;