src/HOL/Tools/Predicate_Compile/predicate_compile_data.ML
author wenzelm
Fri Mar 21 20:33:56 2014 +0100 (2014-03-21)
changeset 56245 84fc7dfa3cd4
parent 55440 721b4561007a
child 57962 0284a7d083be
permissions -rw-r--r--
more qualified names;
     1 (*  Title:      HOL/Tools/Predicate_Compile/predicate_compile_data.ML
     2     Author:     Lukas Bulwahn, TU Muenchen
     3 
     4 Book-keeping datastructure for the predicate compiler.
     5 *)
     6 
     7 signature PREDICATE_COMPILE_DATA =
     8 sig
     9   val ignore_consts : string list -> theory -> theory
    10   val keep_functions : string list -> theory -> theory
    11   val keep_function : theory -> string -> bool
    12   val processed_specs : theory -> string -> (string * thm list) list option
    13   val store_processed_specs : (string * (string * thm list) list) -> theory -> theory
    14 
    15   val get_specification : Predicate_Compile_Aux.options -> theory -> term -> thm list
    16   val obtain_specification_graph :
    17     Predicate_Compile_Aux.options -> theory -> term -> thm list Term_Graph.T
    18 
    19   val present_graph : thm list Term_Graph.T -> unit
    20   val normalize_equation : theory -> thm -> thm
    21 end;
    22 
    23 structure Predicate_Compile_Data : PREDICATE_COMPILE_DATA =
    24 struct
    25 
    26 open Predicate_Compile_Aux;
    27 
    28 structure Data = Theory_Data
    29 (
    30   type T =
    31     {ignore_consts : unit Symtab.table,
    32      keep_functions : unit Symtab.table,
    33      processed_specs : ((string * thm list) list) Symtab.table};
    34   val empty =
    35     {ignore_consts = Symtab.empty,
    36      keep_functions = Symtab.empty,
    37      processed_specs =  Symtab.empty};
    38   val extend = I;
    39   fun merge
    40     ({ignore_consts = c1, keep_functions = k1, processed_specs = s1},
    41      {ignore_consts = c2, keep_functions = k2, processed_specs = s2}) =
    42      {ignore_consts = Symtab.merge (K true) (c1, c2),
    43       keep_functions = Symtab.merge (K true) (k1, k2),
    44       processed_specs = Symtab.merge (K true) (s1, s2)}
    45 );
    46 
    47 
    48 
    49 fun mk_data (c, k, s) = {ignore_consts = c, keep_functions = k, processed_specs = s}
    50 fun map_data f {ignore_consts = c, keep_functions = k, processed_specs = s} = mk_data (f (c, k, s))
    51 
    52 fun ignore_consts cs = Data.map (map_data (apfst3 (fold (fn c => Symtab.insert (op =) (c, ())) cs)))
    53 
    54 fun keep_functions cs = Data.map (map_data (apsnd3 (fold (fn c => Symtab.insert (op =) (c, ())) cs)))
    55 
    56 fun keep_function thy = Symtab.defined (#keep_functions (Data.get thy))
    57 
    58 fun processed_specs thy = Symtab.lookup (#processed_specs (Data.get thy))
    59 
    60 fun store_processed_specs (constname, specs) =
    61   Data.map (map_data (aptrd3 (Symtab.update_new (constname, specs))))
    62 (* *)
    63 
    64 
    65 fun defining_term_of_introrule_term t =
    66   let
    67     val _ $ u = Logic.strip_imp_concl t
    68   in fst (strip_comb u) end
    69 (*
    70   in case pred of
    71     Const (c, T) => c
    72     | _ => raise TERM ("defining_const_of_introrule_term failed: Not a constant", [t])
    73   end
    74 *)
    75 val defining_term_of_introrule = defining_term_of_introrule_term o prop_of
    76 
    77 fun defining_const_of_introrule th =
    78   (case defining_term_of_introrule th of
    79     Const (c, _) => c
    80   | _ => raise TERM ("defining_const_of_introrule failed: Not a constant", [prop_of th]))
    81 
    82 (*TODO*)
    83 fun is_introlike_term _ = true
    84 
    85 val is_introlike = is_introlike_term o prop_of
    86 
    87 fun check_equation_format_term (t as (Const (@{const_name Pure.eq}, _) $ u $ _)) =
    88       (case strip_comb u of
    89         (Const (_, T), args) =>
    90           if (length (binder_types T) = length args) then
    91             true
    92           else
    93             raise TERM ("check_equation_format_term failed: Number of arguments mismatch", [t])
    94       | _ => raise TERM ("check_equation_format_term failed: Not a constant", [t]))
    95   | check_equation_format_term t =
    96       raise TERM ("check_equation_format_term failed: Not an equation", [t])
    97 
    98 val check_equation_format = check_equation_format_term o prop_of
    99 
   100 
   101 fun defining_term_of_equation_term (Const (@{const_name Pure.eq}, _) $ u $ _) = fst (strip_comb u)
   102   | defining_term_of_equation_term t =
   103       raise TERM ("defining_const_of_equation_term failed: Not an equation", [t])
   104 
   105 val defining_term_of_equation = defining_term_of_equation_term o prop_of
   106 
   107 fun defining_const_of_equation th =
   108   (case defining_term_of_equation th of
   109     Const (c, _) => c
   110   | _ => raise TERM ("defining_const_of_equation failed: Not a constant", [prop_of th]))
   111 
   112 
   113 
   114 
   115 (* Normalizing equations *)
   116 
   117 fun mk_meta_equation th =
   118   (case prop_of th of
   119     Const (@{const_name Trueprop}, _) $ (Const (@{const_name HOL.eq}, _) $ _ $ _) =>
   120       th RS @{thm eq_reflection}
   121   | _ => th)
   122 
   123 val meta_fun_cong = @{lemma "f == g ==> f x == g x" by simp}
   124 
   125 fun full_fun_cong_expand th =
   126   let
   127     val (f, args) = strip_comb (fst (Logic.dest_equals (prop_of th)))
   128     val i = length (binder_types (fastype_of f)) - length args
   129   in funpow i (fn th => th RS meta_fun_cong) th end;
   130 
   131 fun declare_names s xs ctxt =
   132   let
   133     val res = Name.invent_names ctxt s xs
   134   in (res, fold Name.declare (map fst res) ctxt) end
   135 
   136 fun split_all_pairs thy th =
   137   let
   138     val ctxt = Proof_Context.init_global thy  (* FIXME proper context!? *)
   139     val ((_, [th']), _) = Variable.import true [th] ctxt
   140     val t = prop_of th'
   141     val frees = Term.add_frees t []
   142     val freenames = Term.add_free_names t []
   143     val nctxt = Name.make_context freenames
   144     fun mk_tuple_rewrites (x, T) nctxt =
   145       let
   146         val Ts = HOLogic.flatten_tupleT T
   147         val (xTs, nctxt') = declare_names x Ts nctxt
   148         val paths = HOLogic.flat_tupleT_paths T
   149       in ((Free (x, T), HOLogic.mk_ptuple paths T (map Free xTs)), nctxt') end
   150     val (rewr, _) = fold_map mk_tuple_rewrites frees nctxt
   151     val t' = Pattern.rewrite_term thy rewr [] t
   152     val th'' =
   153       Goal.prove ctxt (Term.add_free_names t' []) [] t'
   154         (fn _ => ALLGOALS Skip_Proof.cheat_tac)
   155     val th''' = Local_Defs.unfold ctxt [@{thm split_conv}, @{thm fst_conv}, @{thm snd_conv}] th''
   156   in
   157     th'''
   158   end;
   159 
   160 
   161 fun inline_equations thy th =
   162   let
   163     val ctxt = Proof_Context.init_global thy
   164     val inline_defs = Predicate_Compile_Inline_Defs.get ctxt
   165     val th' = (Simplifier.full_simplify (put_simpset HOL_basic_ss ctxt addsimps inline_defs)) th
   166     (*val _ = print_step options
   167       ("Inlining " ^ (Syntax.string_of_term_global thy (prop_of th))
   168        ^ "with " ^ (commas (map ((Syntax.string_of_term_global thy) o prop_of) inline_defs))
   169        ^" to " ^ (Syntax.string_of_term_global thy (prop_of th')))*)
   170   in
   171     th'
   172   end
   173 
   174 fun normalize_equation thy th =
   175   mk_meta_equation th
   176   |> full_fun_cong_expand
   177   |> split_all_pairs thy
   178   |> tap check_equation_format
   179   |> inline_equations thy
   180 
   181 fun normalize_intros thy th =
   182   split_all_pairs thy th
   183   |> inline_equations thy
   184 
   185 fun normalize thy th =
   186   if is_equationlike th then
   187     normalize_equation thy th
   188   else
   189     normalize_intros thy th
   190 
   191 fun get_specification options thy t =
   192   let
   193     (*val (c, T) = dest_Const t
   194     val t = Const (Axclass.unoverload_const thy (c, T), T)*)
   195     val _ = if show_steps options then
   196         tracing ("getting specification of " ^ Syntax.string_of_term_global thy t ^
   197           " with type " ^ Syntax.string_of_typ_global thy (fastype_of t))
   198       else ()
   199     val ctxt = Proof_Context.init_global thy
   200     fun filtering th =
   201       if is_equationlike th andalso
   202         defining_const_of_equation (normalize_equation thy th) = fst (dest_Const t) then
   203         SOME (normalize_equation thy th)
   204       else
   205         if is_introlike th andalso defining_const_of_introrule th = fst (dest_Const t) then
   206           SOME th
   207         else
   208           NONE
   209     fun filter_defs ths = map_filter filtering (map (normalize thy o Thm.transfer thy) ths)
   210     val spec =
   211       (case filter_defs (Predicate_Compile_Alternative_Defs.get ctxt) of
   212         [] =>
   213           (case Spec_Rules.retrieve ctxt t of
   214             [] => error ("No specification for " ^ Syntax.string_of_term_global thy t)
   215           | ((_, (_, ths)) :: _) => filter_defs ths)
   216       | ths => rev ths)
   217     val _ =
   218       if show_intermediate_results options then
   219         tracing ("Specification for " ^ (Syntax.string_of_term_global thy t) ^ ":\n" ^
   220           commas (map (Display.string_of_thm_global thy) spec))
   221       else ()
   222   in
   223     spec
   224   end
   225 
   226 val logic_operator_names =
   227   [@{const_name Pure.eq},
   228    @{const_name Pure.imp},
   229    @{const_name Trueprop},
   230    @{const_name Not},
   231    @{const_name HOL.eq},
   232    @{const_name HOL.implies},
   233    @{const_name All},
   234    @{const_name Ex},
   235    @{const_name HOL.conj},
   236    @{const_name HOL.disj}]
   237 
   238 fun special_cases (c, _) =
   239   member (op =)
   240    [@{const_name Product_Type.Unity},
   241     @{const_name False},
   242     @{const_name Suc}, @{const_name Nat.zero_nat_inst.zero_nat},
   243     @{const_name Nat.one_nat_inst.one_nat},
   244     @{const_name Orderings.less}, @{const_name Orderings.less_eq},
   245     @{const_name Groups.zero},
   246     @{const_name Groups.one},  @{const_name Groups.plus},
   247     @{const_name Nat.ord_nat_inst.less_eq_nat},
   248     @{const_name Nat.ord_nat_inst.less_nat},
   249   (* FIXME
   250     @{const_name number_nat_inst.number_of_nat},
   251   *)
   252     @{const_name Num.Bit0},
   253     @{const_name Num.Bit1},
   254     @{const_name Num.One},
   255     @{const_name Int.zero_int_inst.zero_int},
   256     @{const_name List.filter},
   257     @{const_name HOL.If},
   258     @{const_name Groups.minus}] c
   259 
   260 
   261 fun obtain_specification_graph options thy t =
   262   let
   263     val ctxt = Proof_Context.init_global thy
   264     fun is_nondefining_const (c, _) = member (op =) logic_operator_names c
   265     fun has_code_pred_intros (c, _) = can (Core_Data.intros_of ctxt) c
   266     fun case_consts (c, _) = is_some (Ctr_Sugar.ctr_sugar_of_case ctxt c)
   267     fun is_datatype_constructor (x as (_, T)) =
   268       (case body_type T of
   269         Type (Tcon, _) => can (Ctr_Sugar.dest_ctr ctxt Tcon) (Const x)
   270       | _ => false)
   271     fun defiants_of specs =
   272       fold (Term.add_consts o prop_of) specs []
   273       |> filter_out is_datatype_constructor
   274       |> filter_out is_nondefining_const
   275       |> filter_out has_code_pred_intros
   276       |> filter_out case_consts
   277       |> filter_out special_cases
   278       |> filter_out (fn (c, _) => Symtab.defined (#ignore_consts (Data.get thy)) c)
   279       |> map (fn (c, _) => (c, Sign.the_const_constraint thy c))
   280       |> map Const
   281       (*
   282       |> filter is_defining_constname*)
   283     fun extend t gr =
   284       if can (Term_Graph.get_node gr) t then gr
   285       else
   286         let
   287           val specs = get_specification options thy t
   288           (*val _ = print_specification options thy constname specs*)
   289           val us = defiants_of specs
   290         in
   291           gr
   292           |> Term_Graph.new_node (t, specs)
   293           |> fold extend us
   294           |> fold (fn u => Term_Graph.add_edge (t, u)) us
   295         end
   296   in
   297     extend t Term_Graph.empty
   298   end;
   299 
   300 
   301 fun present_graph gr =
   302   let
   303     fun eq_cname (Const (c1, _), Const (c2, _)) = (c1 = c2)
   304     fun string_of_const (Const (c, _)) = c
   305       | string_of_const _ = error "string_of_const: unexpected term"
   306     val constss = Term_Graph.strong_conn gr;
   307     val mapping = Termtab.empty |> fold (fn consts => fold (fn const =>
   308       Termtab.update (const, consts)) consts) constss;
   309     fun succs consts = consts
   310       |> maps (Term_Graph.immediate_succs gr)
   311       |> subtract eq_cname consts
   312       |> map (the o Termtab.lookup mapping)
   313       |> distinct (eq_list eq_cname);
   314     val conn = [] |> fold (fn consts => cons (consts, succs consts)) constss;
   315 
   316     fun namify consts = map string_of_const consts
   317       |> commas;
   318     val prgr = map (fn (consts, constss) =>
   319       {name = namify consts, ID = namify consts, dir = "", unfold = true,
   320        path = "", parents = map namify constss, content = [] }) conn
   321   in Graph_Display.display_graph prgr end
   322 
   323 end