src/HOL/Tools/Predicate_Compile/predicate_compile_pred.ML
author wenzelm
Sat Dec 14 17:28:05 2013 +0100 (2013-12-14)
changeset 54742 7a86358a3c0b
parent 52230 1105b3b5aa77
child 54895 515630483010
permissions -rw-r--r--
proper context for basic Simplifier operations: rewrite_rule, rewrite_goals_rule, rewrite_goals_tac etc.;
clarified tool context in some boundary cases;
wenzelm@33265
     1
(*  Title:      HOL/Tools/Predicate_Compile/predicate_compile_pred.ML
wenzelm@33265
     2
    Author:     Lukas Bulwahn, TU Muenchen
bulwahn@33250
     3
wenzelm@33265
     4
Preprocessing definitions of predicates to introduction rules.
bulwahn@33250
     5
*)
bulwahn@33250
     6
bulwahn@33250
     7
signature PREDICATE_COMPILE_PRED =
bulwahn@33250
     8
sig
bulwahn@33250
     9
  (* preprocesses an equation to a set of intro rules; defines new constants *)
bulwahn@35324
    10
  val preprocess : Predicate_Compile_Aux.options -> (string * thm list) -> theory
bulwahn@35324
    11
    -> ((string * thm list) list * theory) 
bulwahn@35324
    12
  val flat_higher_order_arguments : ((string * thm list) list * theory)
bulwahn@35324
    13
    -> ((string * thm list) list * ((string * thm list) list * theory))
bulwahn@33250
    14
end;
wenzelm@33265
    15
bulwahn@35324
    16
bulwahn@35324
    17
structure Predicate_Compile_Pred : PREDICATE_COMPILE_PRED =
bulwahn@33250
    18
struct
bulwahn@33250
    19
bulwahn@33250
    20
open Predicate_Compile_Aux
bulwahn@33250
    21
haftmann@50056
    22
fun is_compound ((Const (@{const_name Not}, _)) $ _) =
bulwahn@33250
    23
    error "is_compound: Negation should not occur; preprocessing is defect"
haftmann@38558
    24
  | is_compound ((Const (@{const_name Ex}, _)) $ _) = true
haftmann@38795
    25
  | is_compound ((Const (@{const_name HOL.disj}, _)) $ _ $ _) = true
haftmann@38795
    26
  | is_compound ((Const (@{const_name HOL.conj}, _)) $ _ $ _) =
bulwahn@33250
    27
    error "is_compound: Conjunction should not occur; preprocessing is defect"
bulwahn@33250
    28
  | is_compound _ = false
bulwahn@33250
    29
bulwahn@39723
    30
fun try_destruct_case thy names atom =
bulwahn@39723
    31
  case find_split_thm thy (fst (strip_comb atom)) of
bulwahn@39723
    32
    NONE => NONE
bulwahn@39723
    33
  | SOME raw_split_thm =>
bulwahn@39723
    34
    let
wenzelm@42361
    35
      val split_thm = prepare_split_thm (Proof_Context.init_global thy) raw_split_thm
bulwahn@39723
    36
      (* TODO: contextify things - this line is to unvarify the split_thm *)
bulwahn@39723
    37
      (*val ((_, [isplit_thm]), _) =
wenzelm@42361
    38
        Variable.import true [split_thm] (Proof_Context.init_global thy)*)
bulwahn@39723
    39
      val (assms, concl) = Logic.strip_horn (prop_of split_thm)
haftmann@50056
    40
      val (_, [split_t]) = strip_comb (HOLogic.dest_Trueprop concl) 
bulwahn@39802
    41
      val atom' = case_betapply thy atom
bulwahn@39723
    42
      val subst = Pattern.match thy (split_t, atom') (Vartab.empty, Vartab.empty)
bulwahn@39723
    43
      val names' = Term.add_free_names atom' names
bulwahn@39723
    44
      fun mk_subst_rhs assm =
bulwahn@39723
    45
        let
bulwahn@39723
    46
          val (vTs, assm') = strip_all (Envir.beta_norm (Envir.subst_term subst assm))
bulwahn@39723
    47
          val var_names = Name.variant_list names' (map fst vTs)
bulwahn@39723
    48
          val vars = map Free (var_names ~~ (map snd vTs))
bulwahn@39723
    49
          val (prems', pre_res) = Logic.strip_horn (subst_bounds (rev vars, assm'))
bulwahn@39723
    50
          fun partition_prem_subst prem =
bulwahn@39723
    51
            case HOLogic.dest_eq (HOLogic.dest_Trueprop prem) of
bulwahn@39723
    52
              (Free (x, T), r) => (NONE, SOME ((x, T), r))
bulwahn@39723
    53
            | _ => (SOME prem, NONE)
bulwahn@39723
    54
          fun partition f xs =
bulwahn@39723
    55
            let
bulwahn@39723
    56
              fun partition' acc1 acc2 [] = (rev acc1, rev acc2)
bulwahn@39723
    57
                | partition' acc1 acc2 (x :: xs) =
bulwahn@39723
    58
                  let
bulwahn@39723
    59
                    val (y, z) = f x
bulwahn@39723
    60
                    val acc1' = case y of NONE => acc1 | SOME y' => y' :: acc1
bulwahn@39723
    61
                    val acc2' = case z of NONE => acc2 | SOME z' => z' :: acc2
bulwahn@39723
    62
                  in partition' acc1' acc2' xs end
bulwahn@39723
    63
            in partition' [] [] xs end
bulwahn@39723
    64
          val (prems'', subst) = partition partition_prem_subst prems'
bulwahn@39723
    65
          val (_, [inner_t]) = strip_comb (HOLogic.dest_Trueprop pre_res)
bulwahn@39723
    66
          val pre_rhs =
bulwahn@39723
    67
            fold (curry HOLogic.mk_conj) (map HOLogic.dest_Trueprop prems'') inner_t
bulwahn@39723
    68
          val rhs = Envir.expand_term_frees subst pre_rhs
bulwahn@39723
    69
        in
bulwahn@39723
    70
          case try_destruct_case thy (var_names @ names') rhs of
bulwahn@39723
    71
            NONE => [(subst, rhs)]
bulwahn@39723
    72
          | SOME (_, srs) => map (fn (subst', rhs') => (subst @ subst', rhs')) srs
bulwahn@39723
    73
        end
bulwahn@39723
    74
     in SOME (atom', maps mk_subst_rhs assms) end
bulwahn@39723
    75
     
bulwahn@33250
    76
fun flatten constname atom (defs, thy) =
bulwahn@33250
    77
  if is_compound atom then
bulwahn@33250
    78
    let
wenzelm@52131
    79
      val atom = Envir.beta_norm (Envir.eta_long [] atom)
wenzelm@43324
    80
      val constname = singleton (Name.variant_list (map (Long_Name.base_name o fst) defs))
bulwahn@33250
    81
        ((Long_Name.base_name constname) ^ "_aux")
bulwahn@33250
    82
      val full_constname = Sign.full_bname thy constname
bulwahn@33250
    83
      val (params, args) = List.partition (is_predT o fastype_of)
bulwahn@33250
    84
        (map Free (Term.add_frees atom []))
bulwahn@33250
    85
      val constT = map fastype_of (params @ args) ---> HOLogic.boolT
bulwahn@33250
    86
      val lhs = list_comb (Const (full_constname, constT), params @ args)
bulwahn@33250
    87
      val def = Logic.mk_equals (lhs, atom)
bulwahn@33250
    88
      val ([definition], thy') = thy
bulwahn@33250
    89
        |> Sign.add_consts_i [(Binding.name constname, constT, NoSyn)]
wenzelm@46909
    90
        |> Global_Theory.add_defs false [((Binding.name (Thm.def_name constname), def), [])]
bulwahn@33250
    91
    in
bulwahn@33250
    92
      (lhs, ((full_constname, [definition]) :: defs, thy'))
bulwahn@33250
    93
    end
bulwahn@33250
    94
  else
bulwahn@36029
    95
    case (fst (strip_comb atom)) of
bulwahn@37908
    96
      (Const (@{const_name If}, _)) =>
bulwahn@37908
    97
        let
bulwahn@35324
    98
          val if_beta = @{lemma "(if c then x else y) z = (if c then x z else y z)" by simp}
wenzelm@41228
    99
          val atom' = Raw_Simplifier.rewrite_term thy
bulwahn@35324
   100
            (map (fn th => th RS @{thm eq_reflection}) [@{thm if_bool_eq_disj}, if_beta]) [] atom
wenzelm@42816
   101
          val _ = @{assert} (not (atom = atom'))
bulwahn@35324
   102
        in
bulwahn@35324
   103
          flatten constname atom' (defs, thy)
bulwahn@35324
   104
        end
bulwahn@36029
   105
    | _ =>
bulwahn@39723
   106
      case try_destruct_case thy [] atom of
bulwahn@36029
   107
        NONE => (atom, (defs, thy))
bulwahn@39723
   108
      | SOME (atom', srs) =>
bulwahn@39723
   109
        let      
bulwahn@39723
   110
          val frees = map Free (Term.add_frees atom' [])
wenzelm@43324
   111
          val constname = singleton (Name.variant_list (map (Long_Name.base_name o fst) defs))
bulwahn@39723
   112
           ((Long_Name.base_name constname) ^ "_aux")
bulwahn@35324
   113
          val full_constname = Sign.full_bname thy constname
bulwahn@35324
   114
          val constT = map fastype_of frees ---> HOLogic.boolT
bulwahn@35324
   115
          val lhs = list_comb (Const (full_constname, constT), frees)
bulwahn@39723
   116
          fun mk_def (subst, rhs) =
bulwahn@39723
   117
            Logic.mk_equals (fold Envir.expand_term_frees (map single subst) lhs, rhs)
bulwahn@39723
   118
          val new_defs = map mk_def srs
bulwahn@39723
   119
          val (definition, thy') = thy
bulwahn@35324
   120
          |> Sign.add_consts_i [(Binding.name constname, constT, NoSyn)]
wenzelm@35897
   121
          |> fold_map Specification.axiom (map_index
wenzelm@35897
   122
              (fn (i, t) => ((Binding.name (constname ^ "_def" ^ string_of_int i), []), t)) new_defs)
bulwahn@35324
   123
        in
wenzelm@35897
   124
          (lhs, ((full_constname, map Drule.export_without_context definition) :: defs, thy'))
bulwahn@35324
   125
        end
bulwahn@36029
   126
bulwahn@33250
   127
bulwahn@33250
   128
fun flatten_intros constname intros thy =
bulwahn@33250
   129
  let
wenzelm@51552
   130
    val ctxt = Proof_Context.init_global thy  (* FIXME proper context!? *)
bulwahn@33250
   131
    val ((_, intros), ctxt') = Variable.import true intros ctxt
bulwahn@33250
   132
    val (intros', (local_defs, thy')) = (fold_map o fold_map_atoms)
bulwahn@33250
   133
      (flatten constname) (map prop_of intros) ([], thy)
wenzelm@42361
   134
    val ctxt'' = Proof_Context.transfer thy' ctxt'
wenzelm@51552
   135
    val intros'' =
wenzelm@51552
   136
      map (fn t => Goal.prove ctxt'' [] [] t (fn _ => ALLGOALS Skip_Proof.cheat_tac)) intros'
bulwahn@37908
   137
      |> Variable.export ctxt'' ctxt
bulwahn@33250
   138
  in
bulwahn@33250
   139
    (intros'', (local_defs, thy'))
bulwahn@33250
   140
  end
bulwahn@33250
   141
bulwahn@33250
   142
(* TODO: same function occurs in inductive package *)
bulwahn@33250
   143
fun select_disj 1 1 = []
bulwahn@33250
   144
  | select_disj _ 1 = [rtac @{thm disjI1}]
bulwahn@33250
   145
  | select_disj n i = (rtac @{thm disjI2})::(select_disj (n - 1) (i - 1));
bulwahn@33250
   146
bulwahn@33250
   147
fun introrulify thy ths = 
bulwahn@33250
   148
  let
wenzelm@42361
   149
    val ctxt = Proof_Context.init_global thy
bulwahn@33250
   150
    val ((_, ths'), ctxt') = Variable.import true ths ctxt
bulwahn@33250
   151
    fun introrulify' th =
bulwahn@33250
   152
      let
bulwahn@33250
   153
        val (lhs, rhs) = Logic.dest_equals (prop_of th)
bulwahn@33250
   154
        val frees = Term.add_free_names rhs []
bulwahn@33250
   155
        val disjuncts = HOLogic.dest_disj rhs
bulwahn@33250
   156
        val nctxt = Name.make_context frees
bulwahn@33250
   157
        fun mk_introrule t =
bulwahn@33250
   158
          let
haftmann@50056
   159
            val ((ps, t'), _) = focus_ex t nctxt
bulwahn@33250
   160
            val prems = map HOLogic.mk_Trueprop (HOLogic.dest_conj t')
bulwahn@33250
   161
          in
bulwahn@33250
   162
            (ps, Logic.list_implies (prems, HOLogic.mk_Trueprop lhs))
bulwahn@33250
   163
          end
bulwahn@33250
   164
        val x = ((cterm_of thy) o the_single o snd o strip_comb o HOLogic.dest_Trueprop o fst o
bulwahn@33250
   165
          Logic.dest_implies o prop_of) @{thm exI}
bulwahn@33250
   166
        fun prove_introrule (index, (ps, introrule)) =
bulwahn@33250
   167
          let
wenzelm@51717
   168
            val tac = Simplifier.simp_tac (put_simpset HOL_basic_ss ctxt' addsimps [th]) 1
bulwahn@33250
   169
              THEN EVERY1 (select_disj (length disjuncts) (index + 1)) 
bulwahn@33250
   170
              THEN (EVERY (map (fn y =>
bulwahn@33250
   171
                rtac (Drule.cterm_instantiate [(x, cterm_of thy (Free y))] @{thm exI}) 1) ps))
bulwahn@33250
   172
              THEN REPEAT_DETERM (rtac @{thm conjI} 1 THEN atac 1)
bulwahn@33250
   173
              THEN TRY (atac 1)
bulwahn@33250
   174
          in
wenzelm@33441
   175
            Goal.prove ctxt' (map fst ps) [] introrule (fn _ => tac)
bulwahn@33250
   176
          end
bulwahn@33250
   177
      in
bulwahn@33250
   178
        map_index prove_introrule (map mk_introrule disjuncts)
bulwahn@33250
   179
      end
bulwahn@33250
   180
  in maps introrulify' ths' |> Variable.export ctxt' ctxt end
bulwahn@33250
   181
wenzelm@51717
   182
fun rewrite ctxt =
wenzelm@51717
   183
  Simplifier.simplify (put_simpset HOL_basic_ss ctxt addsimps [@{thm Ball_def}, @{thm Bex_def}])
wenzelm@51717
   184
  #> Simplifier.simplify (put_simpset HOL_basic_ss ctxt addsimps [@{thm all_not_ex}])
wenzelm@51717
   185
  #> Conv.fconv_rule (nnf_conv ctxt)
wenzelm@51717
   186
  #> Simplifier.simplify (put_simpset HOL_basic_ss ctxt addsimps [@{thm ex_disj_distrib}])
bulwahn@33250
   187
bulwahn@35324
   188
fun rewrite_intros thy =
wenzelm@51717
   189
  Simplifier.full_simplify (Simplifier.global_context thy HOL_basic_ss addsimps [@{thm all_not_ex}])
bulwahn@38952
   190
  #> Simplifier.full_simplify
wenzelm@51717
   191
    (Simplifier.global_context thy HOL_basic_ss
wenzelm@51717
   192
      addsimps (tl @{thms bool_simps}) addsimps @{thms nnf_simps})
wenzelm@42361
   193
  #> split_conjuncts_in_assms (Proof_Context.init_global thy)
bulwahn@35324
   194
bulwahn@35324
   195
fun print_specs options thy msg ths =
bulwahn@35324
   196
  if show_intermediate_results options then
bulwahn@35324
   197
    (tracing (msg); tracing (commas (map (Display.string_of_thm_global thy) ths)))
bulwahn@35324
   198
  else
bulwahn@35324
   199
    ()
bulwahn@39787
   200
bulwahn@35324
   201
fun preprocess options (constname, specs) thy =
bulwahn@35324
   202
(*  case Predicate_Compile_Data.processed_specs thy constname of
bulwahn@35324
   203
    SOME specss => (specss, thy)
bulwahn@35324
   204
  | NONE =>*)
bulwahn@35324
   205
    let
wenzelm@54742
   206
      val ctxt = Proof_Context.init_global thy  (* FIXME proper context!? *)
bulwahn@33250
   207
      val intros =
bulwahn@35324
   208
        if forall is_pred_equation specs then 
wenzelm@51717
   209
          map (split_conjuncts_in_assms ctxt) (introrulify thy (map (rewrite ctxt) specs))
bulwahn@35324
   210
        else if forall (is_intro constname) specs then
bulwahn@35324
   211
          map (rewrite_intros thy) specs
bulwahn@35324
   212
        else
bulwahn@35324
   213
          error ("unexpected specification for constant " ^ quote constname ^ ":\n"
bulwahn@35324
   214
            ^ commas (map (quote o Display.string_of_thm_global thy) specs))
bulwahn@37908
   215
      val if_beta = @{lemma "(if c then x else y) z = (if c then x z else y z)" by simp}
wenzelm@54742
   216
      val intros = map (rewrite_rule ctxt [if_beta RS @{thm eq_reflection}]) intros
bulwahn@35324
   217
      val _ = print_specs options thy "normalized intros" intros
bulwahn@35324
   218
      (*val intros = maps (split_cases thy) intros*)
bulwahn@35324
   219
      val (intros', (local_defs, thy')) = flatten_intros constname intros thy
bulwahn@35324
   220
      val (intross, thy'') = fold_map (preprocess options) local_defs thy'
bulwahn@35324
   221
      val full_spec = (constname, intros') :: flat intross
bulwahn@35324
   222
      (*val thy''' = Predicate_Compile_Data.store_processed_specs (constname, full_spec) thy''*)
bulwahn@35324
   223
    in
bulwahn@35324
   224
      (full_spec, thy'')
bulwahn@35324
   225
    end;
bulwahn@33250
   226
bulwahn@33250
   227
fun flat_higher_order_arguments (intross, thy) =
bulwahn@33250
   228
  let
bulwahn@33250
   229
    fun process constname atom (new_defs, thy) =
bulwahn@33250
   230
      let
bulwahn@33250
   231
        val (pred, args) = strip_comb atom
bulwahn@33250
   232
        fun replace_abs_arg (abs_arg as Abs _ ) (new_defs, thy) =
bulwahn@33250
   233
          let
bulwahn@33250
   234
            val vars = map Var (Term.add_vars abs_arg [])
wenzelm@35845
   235
            val abs_arg' = Logic.unvarify_global abs_arg
bulwahn@33250
   236
            val frees = map Free (Term.add_frees abs_arg' [])
wenzelm@43324
   237
            val constname =
wenzelm@43324
   238
              singleton (Name.variant_list (map (Long_Name.base_name o fst) new_defs))
wenzelm@43324
   239
                ((Long_Name.base_name constname) ^ "_hoaux")
bulwahn@33250
   240
            val full_constname = Sign.full_bname thy constname
bulwahn@33250
   241
            val constT = map fastype_of frees ---> (fastype_of abs_arg')
bulwahn@33250
   242
            val const = Const (full_constname, constT)
bulwahn@33250
   243
            val lhs = list_comb (const, frees)
bulwahn@33250
   244
            val def = Logic.mk_equals (lhs, abs_arg')
bulwahn@33250
   245
            val ([definition], thy') = thy
bulwahn@33250
   246
              |> Sign.add_consts_i [(Binding.name constname, constT, NoSyn)]
wenzelm@46909
   247
              |> Global_Theory.add_defs false [((Binding.name (Thm.def_name constname), def), [])]
bulwahn@33250
   248
          in
wenzelm@35845
   249
            (list_comb (Logic.varify_global const, vars),
wenzelm@35845
   250
              ((full_constname, [definition])::new_defs, thy'))
bulwahn@33250
   251
          end
bulwahn@33403
   252
        | replace_abs_arg arg (new_defs, thy) =
bulwahn@39468
   253
          if is_some (try HOLogic.dest_prodT (fastype_of arg)) then
bulwahn@39468
   254
            (case try HOLogic.dest_prod arg of
bulwahn@39468
   255
              SOME (t1, t2) =>
bulwahn@39468
   256
                (new_defs, thy)
bulwahn@39468
   257
                |> process constname t1 
bulwahn@39468
   258
                ||>> process constname t2
bulwahn@39468
   259
                |>> HOLogic.mk_prod
bulwahn@39468
   260
            | NONE => (warning ("Replacing higher order arguments " ^
bulwahn@39468
   261
              "is not applied in an undestructable product type"); (arg, (new_defs, thy))))
bulwahn@39468
   262
          else if (is_predT (fastype_of arg)) then
bulwahn@33403
   263
            process constname arg (new_defs, thy)
bulwahn@33403
   264
          else
bulwahn@33403
   265
            (arg, (new_defs, thy))
bulwahn@39468
   266
bulwahn@35324
   267
        val (args', (new_defs', thy')) = fold_map replace_abs_arg
bulwahn@35324
   268
          (map Envir.beta_eta_contract args) (new_defs, thy)
bulwahn@33250
   269
      in
bulwahn@33250
   270
        (list_comb (pred, args'), (new_defs', thy'))
bulwahn@33250
   271
      end
bulwahn@33250
   272
    fun flat_intro intro (new_defs, thy) =
bulwahn@33250
   273
      let
bulwahn@33250
   274
        val constname = fst (dest_Const (fst (strip_comb
bulwahn@33250
   275
          (HOLogic.dest_Trueprop (Logic.strip_imp_concl (prop_of intro))))))
bulwahn@33250
   276
        val (intro_ts, (new_defs, thy)) = fold_map_atoms (process constname) (prop_of intro) (new_defs, thy)
bulwahn@33250
   277
        val th = Skip_Proof.make_thm thy intro_ts
bulwahn@33250
   278
      in
bulwahn@33250
   279
        (th, (new_defs, thy))
bulwahn@33250
   280
      end
bulwahn@33250
   281
    fun fold_map_spec f [] s = ([], s)
bulwahn@33250
   282
      | fold_map_spec f ((c, ths) :: specs) s =
bulwahn@33250
   283
        let
bulwahn@33250
   284
          val (ths', s') = f ths s
bulwahn@33250
   285
          val (specs', s'') = fold_map_spec f specs s'
bulwahn@33250
   286
        in ((c, ths') :: specs', s'') end
bulwahn@33250
   287
    val (intross', (new_defs, thy')) = fold_map_spec (fold_map flat_intro) intross ([], thy)
bulwahn@33250
   288
  in
bulwahn@33250
   289
    (intross', (new_defs, thy'))
bulwahn@33250
   290
  end
bulwahn@33250
   291
bulwahn@33250
   292
end;