src/HOL/Tools/Datatype/datatype_realizer.ML
author haftmann
Fri Aug 27 10:56:46 2010 +0200 (2010-08-27)
changeset 38795 848be46708dc
parent 37236 739d8b9c59da
child 38977 e71e2c56479c
permissions -rw-r--r--
formerly unnamed infix conjunction and disjunction now named HOL.conj and HOL.disj
haftmann@33968
     1
(*  Title:      HOL/Tools/Datatype/datatype_realizer.ML
berghofe@13467
     2
    Author:     Stefan Berghofer, TU Muenchen
berghofe@13467
     3
haftmann@33968
     4
Program extraction from proofs involving datatypes:
haftmann@33968
     5
realizers for induction and case analysis.
berghofe@13467
     6
*)
berghofe@13467
     7
berghofe@13467
     8
signature DATATYPE_REALIZER =
berghofe@13467
     9
sig
haftmann@31737
    10
  val add_dt_realizers: Datatype.config -> string list -> theory -> theory
haftmann@24699
    11
  val setup: theory -> theory
berghofe@13467
    12
end;
berghofe@13467
    13
haftmann@33968
    14
structure Datatype_Realizer : DATATYPE_REALIZER =
berghofe@13467
    15
struct
berghofe@13467
    16
haftmann@33968
    17
open Datatype_Aux;
berghofe@13467
    18
wenzelm@35364
    19
fun subsets i j =
wenzelm@35364
    20
  if i <= j then
wenzelm@35364
    21
    let val is = subsets (i+1) j
wenzelm@35364
    22
    in map (fn ks => i::ks) is @ is end
wenzelm@35364
    23
  else [[]];
berghofe@13467
    24
berghofe@13467
    25
fun prf_of thm =
wenzelm@28814
    26
  Reconstruct.reconstruct_proof (Thm.theory_of_thm thm) (Thm.prop_of thm) (Thm.proof_of thm);
berghofe@13467
    27
berghofe@13467
    28
fun is_unit t = snd (strip_type (fastype_of t)) = HOLogic.unitT;
berghofe@13467
    29
berghofe@13725
    30
fun tname_of (Type (s, _)) = s
berghofe@13725
    31
  | tname_of _ = "";
berghofe@13725
    32
haftmann@32727
    33
fun make_ind sorts ({descr, rec_names, rec_rewrites, induct, ...} : info) is thy =
berghofe@13467
    34
  let
berghofe@13467
    35
    val recTs = get_rec_types descr sorts;
berghofe@13467
    36
    val pnames = if length descr = 1 then ["P"]
berghofe@13467
    37
      else map (fn i => "P" ^ string_of_int i) (1 upto length descr);
berghofe@13467
    38
berghofe@13467
    39
    val rec_result_Ts = map (fn ((i, _), P) =>
haftmann@36692
    40
      if member (op =) is i then TFree ("'" ^ P, HOLogic.typeS) else HOLogic.unitT)
berghofe@13467
    41
        (descr ~~ pnames);
berghofe@13467
    42
berghofe@13467
    43
    fun make_pred i T U r x =
haftmann@36692
    44
      if member (op =) is i then
skalberg@15570
    45
        Free (List.nth (pnames, i), T --> U --> HOLogic.boolT) $ r $ x
skalberg@15570
    46
      else Free (List.nth (pnames, i), U --> HOLogic.boolT) $ x;
berghofe@13467
    47
berghofe@13467
    48
    fun mk_all i s T t =
haftmann@36692
    49
      if member (op =) is i then list_all_free ([(s, T)], t) else t;
berghofe@13467
    50
haftmann@31458
    51
    val (prems, rec_fns) = split_list (flat (fst (fold_map
haftmann@31458
    52
      (fn ((i, (_, _, constrs)), T) => fold_map (fn (cname, cargs) => fn j =>
berghofe@13467
    53
        let
berghofe@13467
    54
          val Ts = map (typ_of_dtyp descr sorts) cargs;
haftmann@33968
    55
          val tnames = Name.variant_list pnames (Datatype_Prop.make_tnames Ts);
haftmann@31458
    56
          val recs = filter (is_rec_type o fst o fst) (cargs ~~ tnames ~~ Ts);
berghofe@13467
    57
          val frees = tnames ~~ Ts;
berghofe@13467
    58
berghofe@13467
    59
          fun mk_prems vs [] = 
berghofe@13467
    60
                let
haftmann@31458
    61
                  val rT = nth (rec_result_Ts) i;
berghofe@13467
    62
                  val vs' = filter_out is_unit vs;
berghofe@13467
    63
                  val f = mk_Free "f" (map fastype_of vs' ---> rT) j;
wenzelm@18929
    64
                  val f' = Envir.eta_contract (list_abs_free
haftmann@36692
    65
                    (map dest_Free vs, if member (op =) is i then list_comb (f, vs')
berghofe@13467
    66
                      else HOLogic.unit));
berghofe@13467
    67
                in (HOLogic.mk_Trueprop (make_pred i rT T (list_comb (f, vs'))
berghofe@13467
    68
                  (list_comb (Const (cname, Ts ---> T), map Free frees))), f')
berghofe@13467
    69
                end
berghofe@13641
    70
            | mk_prems vs (((dt, s), T) :: ds) = 
berghofe@13467
    71
                let
berghofe@13641
    72
                  val k = body_index dt;
berghofe@13641
    73
                  val (Us, U) = strip_type T;
berghofe@13641
    74
                  val i = length Us;
haftmann@31458
    75
                  val rT = nth (rec_result_Ts) k;
berghofe@13641
    76
                  val r = Free ("r" ^ s, Us ---> rT);
berghofe@13467
    77
                  val (p, f) = mk_prems (vs @ [r]) ds
berghofe@13641
    78
                in (mk_all k ("r" ^ s) (Us ---> rT) (Logic.mk_implies
berghofe@13641
    79
                  (list_all (map (pair "x") Us, HOLogic.mk_Trueprop
berghofe@13641
    80
                    (make_pred k rT U (app_bnds r i)
berghofe@13641
    81
                      (app_bnds (Free (s, T)) i))), p)), f)
berghofe@13467
    82
                end
berghofe@13467
    83
haftmann@31458
    84
        in (apfst (curry list_all_free frees) (mk_prems (map Free frees) recs), j + 1) end)
haftmann@31458
    85
          constrs) (descr ~~ recTs) 1)));
berghofe@13467
    86
 
berghofe@13467
    87
    fun mk_proj j [] t = t
berghofe@13467
    88
      | mk_proj j (i :: is) t = if null is then t else
wenzelm@23577
    89
          if (j: int) = i then HOLogic.mk_fst t
berghofe@13467
    90
          else mk_proj j is (HOLogic.mk_snd t);
berghofe@13467
    91
haftmann@33968
    92
    val tnames = Datatype_Prop.make_tnames recTs;
berghofe@13467
    93
    val fTs = map fastype_of rec_fns;
berghofe@13467
    94
    val ps = map (fn ((((i, _), T), U), s) => Abs ("x", T, make_pred i U T
berghofe@13467
    95
      (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Bound 0) (Bound 0)))
berghofe@13467
    96
        (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names);
berghofe@13467
    97
    val r = if null is then Extraction.nullt else
wenzelm@32952
    98
      foldr1 HOLogic.mk_prod (map_filter (fn (((((i, _), T), U), s), tname) =>
haftmann@36692
    99
        if member (op =) is i then SOME
berghofe@13467
   100
          (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Free (tname, T))
skalberg@15531
   101
        else NONE) (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names ~~ tnames));
haftmann@38795
   102
    val concl = HOLogic.mk_Trueprop (foldr1 (HOLogic.mk_binop @{const_name HOL.conj})
berghofe@13467
   103
      (map (fn ((((i, _), T), U), tname) =>
berghofe@13467
   104
        make_pred i U T (mk_proj i is r) (Free (tname, T)))
berghofe@13467
   105
          (descr ~~ recTs ~~ rec_result_Ts ~~ tnames)));
wenzelm@22578
   106
    val cert = cterm_of thy;
berghofe@13467
   107
    val inst = map (pairself cert) (map head_of (HOLogic.dest_conj
haftmann@32712
   108
      (HOLogic.dest_Trueprop (concl_of induct))) ~~ ps);
berghofe@13467
   109
wenzelm@17959
   110
    val thm = OldGoals.simple_prove_goal_cterm (cert (Logic.list_implies (prems, concl)))
berghofe@13467
   111
      (fn prems =>
haftmann@37136
   112
         [rewrite_goals_tac (map mk_meta_eq [@{thm fst_conv}, @{thm snd_conv}]),
haftmann@32712
   113
          rtac (cterm_instantiate inst induct) 1,
wenzelm@35625
   114
          ALLGOALS Object_Logic.atomize_prems_tac,
haftmann@26359
   115
          rewrite_goals_tac (@{thm o_def} :: map mk_meta_eq rec_rewrites),
berghofe@13467
   116
          REPEAT ((resolve_tac prems THEN_ALL_NEW (fn i =>
berghofe@13467
   117
            REPEAT (etac allE i) THEN atac i)) 1)]);
berghofe@13467
   118
wenzelm@36744
   119
    val ind_name = Thm.derivation_name induct;
skalberg@15570
   120
    val vs = map (fn i => List.nth (pnames, i)) is;
haftmann@18358
   121
    val (thm', thy') = thy
wenzelm@30435
   122
      |> Sign.root_path
wenzelm@30435
   123
      |> PureThy.store_thm
wenzelm@30435
   124
        (Binding.qualified_name (space_implode "_" (ind_name :: vs @ ["correctness"])), thm)
wenzelm@24712
   125
      ||> Sign.restore_naming thy;
berghofe@13467
   126
wenzelm@35845
   127
    val ivs = rev (Term.add_vars (Logic.varify_global (Datatype_Prop.make_ind [descr] sorts)) []);
wenzelm@22691
   128
    val rvs = rev (Thm.fold_terms Term.add_vars thm' []);
berghofe@37233
   129
    val ivs1 = map Var (filter_out (fn (_, T) =>
berghofe@37233
   130
      @{type_name bool} = tname_of (body_type T)) ivs);
wenzelm@33035
   131
    val ivs2 = map (fn (ixn, _) => Var (ixn, the (AList.lookup (op =) rvs ixn))) ivs;
berghofe@13467
   132
wenzelm@33338
   133
    val prf =
berghofe@37233
   134
      Extraction.abs_corr_shyps thy' induct vs ivs2
wenzelm@33338
   135
        (fold_rev (fn (f, p) => fn prf =>
wenzelm@33338
   136
          (case head_of (strip_abs_body f) of
wenzelm@33338
   137
             Free (s, T) =>
wenzelm@35845
   138
               let val T' = Logic.varifyT_global T
wenzelm@33338
   139
               in Abst (s, SOME T', Proofterm.prf_abstract_over
wenzelm@33338
   140
                 (Var ((s, 0), T')) (AbsP ("H", SOME p, prf)))
wenzelm@33338
   141
               end
wenzelm@33338
   142
          | _ => AbsP ("H", SOME p, prf)))
wenzelm@33338
   143
            (rec_fns ~~ prems_of thm)
berghofe@37233
   144
            (Proofterm.proof_combP (prf_of thm', map PBound (length prems - 1 downto 0))));
berghofe@13467
   145
wenzelm@33338
   146
    val r' =
wenzelm@33338
   147
      if null is then r
wenzelm@35845
   148
      else Logic.varify_global (fold_rev lambda
wenzelm@35845
   149
        (map Logic.unvarify_global ivs1 @ filter_out is_unit
wenzelm@33338
   150
            (map (head_of o strip_abs_body) rec_fns)) r);
berghofe@13467
   151
berghofe@13467
   152
  in Extraction.add_realizers_i [(ind_name, (vs, r', prf))] thy' end;
berghofe@13467
   153
berghofe@13467
   154
haftmann@32712
   155
fun make_casedists sorts ({index, descr, case_name, case_rewrites, exhaust, ...} : info) thy =
berghofe@13467
   156
  let
wenzelm@19806
   157
    val cert = cterm_of thy;
berghofe@13467
   158
    val rT = TFree ("'P", HOLogic.typeS);
berghofe@13467
   159
    val rT' = TVar (("'P", 0), HOLogic.typeS);
berghofe@13467
   160
berghofe@13467
   161
    fun make_casedist_prem T (cname, cargs) =
berghofe@13467
   162
      let
berghofe@13467
   163
        val Ts = map (typ_of_dtyp descr sorts) cargs;
haftmann@33968
   164
        val frees = Name.variant_list ["P", "y"] (Datatype_Prop.make_tnames Ts) ~~ Ts;
berghofe@13467
   165
        val free_ts = map Free frees;
wenzelm@30364
   166
        val r = Free ("r" ^ Long_Name.base_name cname, Ts ---> rT)
berghofe@13467
   167
      in (r, list_all_free (frees, Logic.mk_implies (HOLogic.mk_Trueprop
berghofe@13467
   168
        (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))),
berghofe@13467
   169
          HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
berghofe@13467
   170
            list_comb (r, free_ts)))))
berghofe@13467
   171
      end;
berghofe@13467
   172
haftmann@17521
   173
    val SOME (_, _, constrs) = AList.lookup (op =) descr index;
skalberg@15570
   174
    val T = List.nth (get_rec_types descr sorts, index);
berghofe@13467
   175
    val (rs, prems) = split_list (map (make_casedist_prem T) constrs);
berghofe@13467
   176
    val r = Const (case_name, map fastype_of rs ---> T --> rT);
berghofe@13467
   177
wenzelm@35845
   178
    val y = Var (("y", 0), Logic.varifyT_global T);
berghofe@13467
   179
    val y' = Free ("y", T);
berghofe@13467
   180
wenzelm@17959
   181
    val thm = OldGoals.prove_goalw_cterm [] (cert (Logic.list_implies (prems,
berghofe@13467
   182
      HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
berghofe@13467
   183
        list_comb (r, rs @ [y'])))))
berghofe@13467
   184
      (fn prems =>
haftmann@32712
   185
         [rtac (cterm_instantiate [(cert y, cert y')] exhaust) 1,
berghofe@13467
   186
          ALLGOALS (EVERY'
berghofe@13467
   187
            [asm_simp_tac (HOL_basic_ss addsimps case_rewrites),
berghofe@13467
   188
             resolve_tac prems, asm_simp_tac HOL_basic_ss])]);
berghofe@13467
   189
wenzelm@36744
   190
    val exh_name = Thm.derivation_name exhaust;
haftmann@18358
   191
    val (thm', thy') = thy
wenzelm@30435
   192
      |> Sign.root_path
wenzelm@30435
   193
      |> PureThy.store_thm (Binding.qualified_name (exh_name ^ "_P_correctness"), thm)
wenzelm@24712
   194
      ||> Sign.restore_naming thy;
berghofe@13467
   195
berghofe@13725
   196
    val P = Var (("P", 0), rT' --> HOLogic.boolT);
berghofe@37233
   197
    val prf = Extraction.abs_corr_shyps thy' exhaust ["P"] [y, P]
berghofe@37233
   198
      (fold_rev (fn (p, r) => fn prf =>
berghofe@37233
   199
          Proofterm.forall_intr_proof' (Logic.varify_global r)
berghofe@37233
   200
            (AbsP ("H", SOME (Logic.varify_global p), prf)))
wenzelm@33338
   201
        (prems ~~ rs)
berghofe@37233
   202
        (Proofterm.proof_combP (prf_of thm', map PBound (length prems - 1 downto 0))));
berghofe@37233
   203
    val prf' = Extraction.abs_corr_shyps thy' exhaust []
berghofe@37233
   204
      (map Var (Term.add_vars (prop_of exhaust) [])) (prf_of exhaust);
wenzelm@35845
   205
    val r' = Logic.varify_global (Abs ("y", T,
berghofe@13725
   206
      list_abs (map dest_Free rs, list_comb (r,
berghofe@13725
   207
        map Bound ((length rs - 1 downto 0) @ [length rs])))));
berghofe@13467
   208
berghofe@13467
   209
  in Extraction.add_realizers_i
berghofe@13467
   210
    [(exh_name, (["P"], r', prf)),
berghofe@37233
   211
     (exh_name, ([], Extraction.nullt, prf'))] thy'
berghofe@13467
   212
  end;
berghofe@13467
   213
haftmann@31668
   214
fun add_dt_realizers config names thy =
wenzelm@25223
   215
  if ! Proofterm.proofs < 2 then thy
haftmann@24699
   216
  else let
haftmann@31668
   217
    val _ = message config "Adding realizers for induction and case analysis ..."
haftmann@31784
   218
    val infos = map (Datatype.the_info thy) names;
haftmann@24699
   219
    val info :: _ = infos;
haftmann@24699
   220
  in
haftmann@24699
   221
    thy
haftmann@24699
   222
    |> fold_rev (make_ind (#sorts info) info) (subsets 0 (length (#descr info) - 1))
haftmann@24699
   223
    |> fold_rev (make_casedists (#sorts info)) infos
haftmann@24699
   224
  end;
berghofe@13467
   225
haftmann@31723
   226
val setup = Datatype.interpretation add_dt_realizers;
berghofe@13467
   227
berghofe@13467
   228
end;