src/HOL/Tools/Old_Datatype/old_datatype_realizer.ML
author blanchet
Tue Sep 09 20:51:36 2014 +0200 (2014-09-09)
changeset 58274 4a84e94e58a2
parent 58182 82478e6c60cb
permissions -rw-r--r--
made datatype realizer plugin work for new-style datatypes with no nesting
blanchet@58112
     1
(*  Title:      HOL/Tools/Old_Datatype/old_datatype_realizer.ML
berghofe@13467
     2
    Author:     Stefan Berghofer, TU Muenchen
berghofe@13467
     3
haftmann@33968
     4
Program extraction from proofs involving datatypes:
haftmann@33968
     5
realizers for induction and case analysis.
berghofe@13467
     6
*)
berghofe@13467
     7
blanchet@58112
     8
signature OLD_DATATYPE_REALIZER =
berghofe@13467
     9
sig
blanchet@58274
    10
  val realizer_plugin: string
blanchet@58112
    11
  val add_dt_realizers: Old_Datatype_Aux.config -> string list -> theory -> theory
berghofe@13467
    12
end;
berghofe@13467
    13
blanchet@58112
    14
structure Old_Datatype_Realizer : OLD_DATATYPE_REALIZER =
berghofe@13467
    15
struct
berghofe@13467
    16
blanchet@58274
    17
val realizer_plugin = "realizer";
blanchet@58274
    18
wenzelm@35364
    19
fun subsets i j =
wenzelm@35364
    20
  if i <= j then
wenzelm@35364
    21
    let val is = subsets (i+1) j
wenzelm@35364
    22
    in map (fn ks => i::ks) is @ is end
wenzelm@35364
    23
  else [[]];
berghofe@13467
    24
wenzelm@40844
    25
fun is_unit t = body_type (fastype_of t) = HOLogic.unitT;
berghofe@13467
    26
berghofe@13725
    27
fun tname_of (Type (s, _)) = s
berghofe@13725
    28
  | tname_of _ = "";
berghofe@13725
    29
blanchet@58112
    30
fun make_ind ({descr, rec_names, rec_rewrites, induct, ...} : Old_Datatype_Aux.info) is thy =
berghofe@13467
    31
  let
wenzelm@54742
    32
    val ctxt = Proof_Context.init_global thy;
wenzelm@54742
    33
    val cert = cterm_of thy;
wenzelm@54742
    34
blanchet@58112
    35
    val recTs = Old_Datatype_Aux.get_rec_types descr;
wenzelm@45700
    36
    val pnames =
wenzelm@45700
    37
      if length descr = 1 then ["P"]
berghofe@13467
    38
      else map (fn i => "P" ^ string_of_int i) (1 upto length descr);
berghofe@13467
    39
berghofe@13467
    40
    val rec_result_Ts = map (fn ((i, _), P) =>
wenzelm@56254
    41
        if member (op =) is i then TFree ("'" ^ P, @{sort type}) else HOLogic.unitT)
wenzelm@45700
    42
      (descr ~~ pnames);
berghofe@13467
    43
berghofe@13467
    44
    fun make_pred i T U r x =
haftmann@36692
    45
      if member (op =) is i then
wenzelm@42364
    46
        Free (nth pnames i, T --> U --> HOLogic.boolT) $ r $ x
wenzelm@42364
    47
      else Free (nth pnames i, U --> HOLogic.boolT) $ x;
berghofe@13467
    48
berghofe@13467
    49
    fun mk_all i s T t =
wenzelm@46215
    50
      if member (op =) is i then Logic.all (Free (s, T)) t else t;
berghofe@13467
    51
haftmann@31458
    52
    val (prems, rec_fns) = split_list (flat (fst (fold_map
haftmann@31458
    53
      (fn ((i, (_, _, constrs)), T) => fold_map (fn (cname, cargs) => fn j =>
berghofe@13467
    54
        let
blanchet@58112
    55
          val Ts = map (Old_Datatype_Aux.typ_of_dtyp descr) cargs;
blanchet@58112
    56
          val tnames = Name.variant_list pnames (Old_Datatype_Prop.make_tnames Ts);
blanchet@58112
    57
          val recs = filter (Old_Datatype_Aux.is_rec_type o fst o fst) (cargs ~~ tnames ~~ Ts);
berghofe@13467
    58
          val frees = tnames ~~ Ts;
berghofe@13467
    59
wenzelm@38977
    60
          fun mk_prems vs [] =
berghofe@13467
    61
                let
haftmann@31458
    62
                  val rT = nth (rec_result_Ts) i;
berghofe@13467
    63
                  val vs' = filter_out is_unit vs;
blanchet@58112
    64
                  val f = Old_Datatype_Aux.mk_Free "f" (map fastype_of vs' ---> rT) j;
wenzelm@45700
    65
                  val f' =
wenzelm@45700
    66
                    Envir.eta_contract (fold_rev (absfree o dest_Free) vs
wenzelm@45700
    67
                      (if member (op =) is i then list_comb (f, vs') else HOLogic.unit));
wenzelm@45700
    68
                in
wenzelm@45700
    69
                  (HOLogic.mk_Trueprop (make_pred i rT T (list_comb (f, vs'))
wenzelm@45700
    70
                    (list_comb (Const (cname, Ts ---> T), map Free frees))), f')
berghofe@13467
    71
                end
wenzelm@38977
    72
            | mk_prems vs (((dt, s), T) :: ds) =
berghofe@13467
    73
                let
blanchet@58112
    74
                  val k = Old_Datatype_Aux.body_index dt;
berghofe@13641
    75
                  val (Us, U) = strip_type T;
berghofe@13641
    76
                  val i = length Us;
haftmann@31458
    77
                  val rT = nth (rec_result_Ts) k;
berghofe@13641
    78
                  val r = Free ("r" ^ s, Us ---> rT);
wenzelm@45700
    79
                  val (p, f) = mk_prems (vs @ [r]) ds;
wenzelm@45700
    80
                in
wenzelm@45700
    81
                  (mk_all k ("r" ^ s) (Us ---> rT) (Logic.mk_implies
wenzelm@46218
    82
                    (Logic.list_all (map (pair "x") Us, HOLogic.mk_Trueprop
blanchet@58112
    83
                      (make_pred k rT U (Old_Datatype_Aux.app_bnds r i)
blanchet@58112
    84
                        (Old_Datatype_Aux.app_bnds (Free (s, T)) i))), p)), f)
wenzelm@45700
    85
                end;
wenzelm@46215
    86
        in (apfst (fold_rev (Logic.all o Free) frees) (mk_prems (map Free frees) recs), j + 1) end)
haftmann@31458
    87
          constrs) (descr ~~ recTs) 1)));
wenzelm@38977
    88
blanchet@58111
    89
    fun mk_proj _ [] t = t
wenzelm@45700
    90
      | mk_proj j (i :: is) t =
wenzelm@45700
    91
          if null is then t
wenzelm@45700
    92
          else if (j: int) = i then HOLogic.mk_fst t
berghofe@13467
    93
          else mk_proj j is (HOLogic.mk_snd t);
berghofe@13467
    94
blanchet@58112
    95
    val tnames = Old_Datatype_Prop.make_tnames recTs;
berghofe@13467
    96
    val fTs = map fastype_of rec_fns;
berghofe@13467
    97
    val ps = map (fn ((((i, _), T), U), s) => Abs ("x", T, make_pred i U T
berghofe@13467
    98
      (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Bound 0) (Bound 0)))
berghofe@13467
    99
        (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names);
wenzelm@45700
   100
    val r =
wenzelm@45700
   101
      if null is then Extraction.nullt
wenzelm@45700
   102
      else
wenzelm@45700
   103
        foldr1 HOLogic.mk_prod (map_filter (fn (((((i, _), T), U), s), tname) =>
wenzelm@45700
   104
          if member (op =) is i then SOME
wenzelm@45700
   105
            (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Free (tname, T))
wenzelm@45700
   106
          else NONE) (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names ~~ tnames));
wenzelm@45700
   107
    val concl =
wenzelm@45700
   108
      HOLogic.mk_Trueprop (foldr1 (HOLogic.mk_binop @{const_name HOL.conj})
wenzelm@45700
   109
        (map (fn ((((i, _), T), U), tname) =>
wenzelm@45700
   110
          make_pred i U T (mk_proj i is r) (Free (tname, T)))
wenzelm@45700
   111
            (descr ~~ recTs ~~ rec_result_Ts ~~ tnames)));
berghofe@13467
   112
    val inst = map (pairself cert) (map head_of (HOLogic.dest_conj
haftmann@32712
   113
      (HOLogic.dest_Trueprop (concl_of induct))) ~~ ps);
berghofe@13467
   114
wenzelm@45700
   115
    val thm =
wenzelm@54883
   116
      Goal.prove_internal ctxt (map cert prems) (cert concl)
wenzelm@45700
   117
        (fn prems =>
wenzelm@45700
   118
           EVERY [
wenzelm@54742
   119
            rewrite_goals_tac ctxt (map mk_meta_eq [@{thm fst_conv}, @{thm snd_conv}]),
wenzelm@45700
   120
            rtac (cterm_instantiate inst induct) 1,
wenzelm@54742
   121
            ALLGOALS (Object_Logic.atomize_prems_tac ctxt),
wenzelm@54742
   122
            rewrite_goals_tac ctxt (@{thm o_def} :: map mk_meta_eq rec_rewrites),
wenzelm@45700
   123
            REPEAT ((resolve_tac prems THEN_ALL_NEW (fn i =>
wenzelm@45700
   124
              REPEAT (etac allE i) THEN atac i)) 1)])
wenzelm@38977
   125
      |> Drule.export_without_context;
berghofe@13467
   126
wenzelm@36744
   127
    val ind_name = Thm.derivation_name induct;
wenzelm@42364
   128
    val vs = map (nth pnames) is;
haftmann@18358
   129
    val (thm', thy') = thy
wenzelm@30435
   130
      |> Sign.root_path
wenzelm@39557
   131
      |> Global_Theory.store_thm
wenzelm@30435
   132
        (Binding.qualified_name (space_implode "_" (ind_name :: vs @ ["correctness"])), thm)
wenzelm@24712
   133
      ||> Sign.restore_naming thy;
berghofe@13467
   134
blanchet@58112
   135
    val ivs = rev (Term.add_vars (Logic.varify_global (Old_Datatype_Prop.make_ind [descr])) []);
wenzelm@22691
   136
    val rvs = rev (Thm.fold_terms Term.add_vars thm' []);
wenzelm@45700
   137
    val ivs1 = map Var (filter_out (fn (_, T) => @{type_name bool} = tname_of (body_type T)) ivs);
wenzelm@33035
   138
    val ivs2 = map (fn (ixn, _) => Var (ixn, the (AList.lookup (op =) rvs ixn))) ivs;
berghofe@13467
   139
wenzelm@33338
   140
    val prf =
berghofe@37233
   141
      Extraction.abs_corr_shyps thy' induct vs ivs2
wenzelm@33338
   142
        (fold_rev (fn (f, p) => fn prf =>
wenzelm@45700
   143
            (case head_of (strip_abs_body f) of
wenzelm@45700
   144
              Free (s, T) =>
wenzelm@45700
   145
                let val T' = Logic.varifyT_global T in
wenzelm@45700
   146
                  Abst (s, SOME T', Proofterm.prf_abstract_over
wenzelm@45700
   147
                    (Var ((s, 0), T')) (AbsP ("H", SOME p, prf)))
wenzelm@45700
   148
                end
wenzelm@45700
   149
            | _ => AbsP ("H", SOME p, prf)))
wenzelm@45700
   150
          (rec_fns ~~ prems_of thm)
wenzelm@45700
   151
          (Proofterm.proof_combP
wenzelm@45700
   152
            (Reconstruct.proof_of thm', map PBound (length prems - 1 downto 0))));
berghofe@13467
   153
wenzelm@33338
   154
    val r' =
wenzelm@33338
   155
      if null is then r
wenzelm@45700
   156
      else
wenzelm@45700
   157
        Logic.varify_global (fold_rev lambda
wenzelm@45700
   158
          (map Logic.unvarify_global ivs1 @ filter_out is_unit
wenzelm@45700
   159
              (map (head_of o strip_abs_body) rec_fns)) r);
berghofe@13467
   160
blanchet@58274
   161
  in Extraction.add_realizers_i [(ind_name, (vs, r', prf))] thy' end
blanchet@58274
   162
  (* Nested new-style datatypes are not supported (unless they are registered via
blanchet@58274
   163
     "datatype_compat"). *)
blanchet@58274
   164
  handle Old_Datatype_Aux.Datatype => thy;
berghofe@13467
   165
blanchet@58182
   166
fun make_casedists ({index, descr, case_name, case_rewrites, exhaust, ...} : Old_Datatype_Aux.info) thy =
berghofe@13467
   167
  let
wenzelm@54883
   168
    val ctxt = Proof_Context.init_global thy;
wenzelm@19806
   169
    val cert = cterm_of thy;
wenzelm@56254
   170
    val rT = TFree ("'P", @{sort type});
wenzelm@56254
   171
    val rT' = TVar (("'P", 0), @{sort type});
berghofe@13467
   172
berghofe@13467
   173
    fun make_casedist_prem T (cname, cargs) =
berghofe@13467
   174
      let
blanchet@58112
   175
        val Ts = map (Old_Datatype_Aux.typ_of_dtyp descr) cargs;
blanchet@58112
   176
        val frees = Name.variant_list ["P", "y"] (Old_Datatype_Prop.make_tnames Ts) ~~ Ts;
berghofe@13467
   177
        val free_ts = map Free frees;
wenzelm@30364
   178
        val r = Free ("r" ^ Long_Name.base_name cname, Ts ---> rT)
wenzelm@45700
   179
      in
wenzelm@46215
   180
        (r, fold_rev Logic.all free_ts
wenzelm@46215
   181
          (Logic.mk_implies (HOLogic.mk_Trueprop
wenzelm@46215
   182
            (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))),
wenzelm@46215
   183
              HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
wenzelm@46215
   184
                list_comb (r, free_ts)))))
berghofe@13467
   185
      end;
berghofe@13467
   186
haftmann@17521
   187
    val SOME (_, _, constrs) = AList.lookup (op =) descr index;
blanchet@58112
   188
    val T = nth (Old_Datatype_Aux.get_rec_types descr) index;
berghofe@13467
   189
    val (rs, prems) = split_list (map (make_casedist_prem T) constrs);
berghofe@13467
   190
    val r = Const (case_name, map fastype_of rs ---> T --> rT);
berghofe@13467
   191
wenzelm@35845
   192
    val y = Var (("y", 0), Logic.varifyT_global T);
berghofe@13467
   193
    val y' = Free ("y", T);
berghofe@13467
   194
wenzelm@45700
   195
    val thm =
wenzelm@54883
   196
      Goal.prove_internal ctxt (map cert prems)
wenzelm@45700
   197
        (cert (HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $ list_comb (r, rs @ [y']))))
wenzelm@45700
   198
        (fn prems =>
wenzelm@45700
   199
           EVERY [
wenzelm@45700
   200
            rtac (cterm_instantiate [(cert y, cert y')] exhaust) 1,
wenzelm@45700
   201
            ALLGOALS (EVERY'
wenzelm@54895
   202
              [asm_simp_tac (put_simpset HOL_basic_ss ctxt addsimps case_rewrites),
wenzelm@54895
   203
               resolve_tac prems, asm_simp_tac (put_simpset HOL_basic_ss ctxt)])])
wenzelm@38977
   204
      |> Drule.export_without_context;
berghofe@13467
   205
wenzelm@36744
   206
    val exh_name = Thm.derivation_name exhaust;
haftmann@18358
   207
    val (thm', thy') = thy
wenzelm@30435
   208
      |> Sign.root_path
wenzelm@39557
   209
      |> Global_Theory.store_thm (Binding.qualified_name (exh_name ^ "_P_correctness"), thm)
wenzelm@24712
   210
      ||> Sign.restore_naming thy;
berghofe@13467
   211
berghofe@13725
   212
    val P = Var (("P", 0), rT' --> HOLogic.boolT);
wenzelm@45700
   213
    val prf =
wenzelm@45700
   214
      Extraction.abs_corr_shyps thy' exhaust ["P"] [y, P]
wenzelm@45700
   215
        (fold_rev (fn (p, r) => fn prf =>
wenzelm@45700
   216
            Proofterm.forall_intr_proof' (Logic.varify_global r)
wenzelm@45700
   217
              (AbsP ("H", SOME (Logic.varify_global p), prf)))
wenzelm@45700
   218
          (prems ~~ rs)
wenzelm@45700
   219
          (Proofterm.proof_combP
wenzelm@45700
   220
            (Reconstruct.proof_of thm', map PBound (length prems - 1 downto 0))));
wenzelm@45700
   221
    val prf' =
wenzelm@45700
   222
      Extraction.abs_corr_shyps thy' exhaust []
wenzelm@45700
   223
        (map Var (Term.add_vars (prop_of exhaust) [])) (Reconstruct.proof_of exhaust);
wenzelm@45700
   224
    val r' =
wenzelm@45700
   225
      Logic.varify_global (Abs ("y", T,
wenzelm@46219
   226
        (fold_rev (Term.abs o dest_Free) rs
wenzelm@46219
   227
          (list_comb (r, map Bound ((length rs - 1 downto 0) @ [length rs]))))));
wenzelm@45700
   228
  in
wenzelm@45700
   229
    Extraction.add_realizers_i
wenzelm@45700
   230
      [(exh_name, (["P"], r', prf)),
wenzelm@45700
   231
       (exh_name, ([], Extraction.nullt, prf'))] thy'
berghofe@13467
   232
  end;
berghofe@13467
   233
haftmann@31668
   234
fun add_dt_realizers config names thy =
wenzelm@52487
   235
  if not (Proofterm.proofs_enabled ()) then thy
wenzelm@38977
   236
  else
wenzelm@38977
   237
    let
blanchet@58112
   238
      val _ = Old_Datatype_Aux.message config "Adding realizers for induction and case analysis ...";
blanchet@58274
   239
      val infos = map (BNF_LFP_Compat.the_info thy BNF_LFP_Compat.Unfold_Nesting) names;
wenzelm@38977
   240
      val info :: _ = infos;
wenzelm@38977
   241
    in
wenzelm@38977
   242
      thy
wenzelm@45822
   243
      |> fold_rev (make_ind info) (subsets 0 (length (#descr info) - 1))
wenzelm@45822
   244
      |> fold_rev make_casedists infos
wenzelm@38977
   245
    end;
berghofe@13467
   246
blanchet@58274
   247
val _ = Theory.setup (BNF_LFP_Compat.interpretation realizer_plugin BNF_LFP_Compat.Unfold_Nesting
blanchet@58274
   248
  add_dt_realizers);
berghofe@13467
   249
berghofe@13467
   250
end;