src/HOL/Tools/datatype_realizer.ML
author wenzelm
Sat Nov 15 21:31:17 2008 +0100 (2008-11-15)
changeset 28799 ee65e7d043fc
parent 26626 c6231d64d264
child 28814 463c9e9111ae
permissions -rw-r--r--
Thm.proof_of returns proof_body;
berghofe@13467
     1
(*  Title:      HOL/Tools/datatype_realizer.ML
berghofe@13467
     2
    ID:         $Id$
berghofe@13467
     3
    Author:     Stefan Berghofer, TU Muenchen
berghofe@13467
     4
berghofe@13467
     5
Porgram extraction from proofs involving datatypes:
berghofe@13467
     6
Realizers for induction and case analysis
berghofe@13467
     7
*)
berghofe@13467
     8
berghofe@13467
     9
signature DATATYPE_REALIZER =
berghofe@13467
    10
sig
haftmann@24699
    11
  val add_dt_realizers: string list -> theory -> theory
haftmann@24699
    12
  val setup: theory -> theory
berghofe@13467
    13
end;
berghofe@13467
    14
berghofe@13467
    15
structure DatatypeRealizer : DATATYPE_REALIZER =
berghofe@13467
    16
struct
berghofe@13467
    17
berghofe@13467
    18
open DatatypeAux;
berghofe@13467
    19
berghofe@13467
    20
fun subsets i j = if i <= j then
berghofe@13467
    21
       let val is = subsets (i+1) j
berghofe@13467
    22
       in map (fn ks => i::ks) is @ is end
berghofe@13467
    23
     else [[]];
berghofe@13467
    24
berghofe@13467
    25
fun forall_intr_prf (t, prf) =
berghofe@13467
    26
  let val (a, T) = (case t of Var ((a, _), T) => (a, T) | Free p => p)
skalberg@15531
    27
  in Abst (a, SOME T, Proofterm.prf_abstract_over t prf) end;
berghofe@13467
    28
berghofe@13467
    29
fun prf_of thm =
wenzelm@28799
    30
  Reconstruct.reconstruct_proof (Thm.theory_of_thm thm) (Thm.prop_of thm)
wenzelm@28799
    31
    (Proofterm.proof_of (Thm.proof_of thm));
berghofe@13467
    32
berghofe@13467
    33
fun prf_subst_vars inst =
berghofe@13467
    34
  Proofterm.map_proof_terms (subst_vars ([], inst)) I;
berghofe@13467
    35
berghofe@13467
    36
fun is_unit t = snd (strip_type (fastype_of t)) = HOLogic.unitT;
berghofe@13467
    37
berghofe@13725
    38
fun tname_of (Type (s, _)) = s
berghofe@13725
    39
  | tname_of _ = "";
berghofe@13725
    40
berghofe@13467
    41
fun mk_realizes T = Const ("realizes", T --> HOLogic.boolT --> HOLogic.boolT);
berghofe@13467
    42
haftmann@24699
    43
fun make_ind sorts ({descr, rec_names, rec_rewrites, induction, ...} : datatype_info) is thy =
berghofe@13467
    44
  let
berghofe@13467
    45
    val recTs = get_rec_types descr sorts;
berghofe@13467
    46
    val pnames = if length descr = 1 then ["P"]
berghofe@13467
    47
      else map (fn i => "P" ^ string_of_int i) (1 upto length descr);
berghofe@13467
    48
berghofe@13467
    49
    val rec_result_Ts = map (fn ((i, _), P) =>
berghofe@13467
    50
      if i mem is then TFree ("'" ^ P, HOLogic.typeS) else HOLogic.unitT)
berghofe@13467
    51
        (descr ~~ pnames);
berghofe@13467
    52
berghofe@13467
    53
    fun make_pred i T U r x =
berghofe@13467
    54
      if i mem is then
skalberg@15570
    55
        Free (List.nth (pnames, i), T --> U --> HOLogic.boolT) $ r $ x
skalberg@15570
    56
      else Free (List.nth (pnames, i), U --> HOLogic.boolT) $ x;
berghofe@13467
    57
berghofe@13467
    58
    fun mk_all i s T t =
berghofe@13467
    59
      if i mem is then list_all_free ([(s, T)], t) else t;
berghofe@13467
    60
skalberg@15570
    61
    val (prems, rec_fns) = split_list (List.concat (snd (foldl_map
berghofe@13467
    62
      (fn (j, ((i, (_, _, constrs)), T)) => foldl_map (fn (j, (cname, cargs)) =>
berghofe@13467
    63
        let
berghofe@13467
    64
          val Ts = map (typ_of_dtyp descr sorts) cargs;
wenzelm@20071
    65
          val tnames = Name.variant_list pnames (DatatypeProp.make_tnames Ts);
skalberg@15570
    66
          val recs = List.filter (is_rec_type o fst o fst) (cargs ~~ tnames ~~ Ts);
berghofe@13467
    67
          val frees = tnames ~~ Ts;
berghofe@13467
    68
berghofe@13467
    69
          fun mk_prems vs [] = 
berghofe@13467
    70
                let
skalberg@15570
    71
                  val rT = List.nth (rec_result_Ts, i);
berghofe@13467
    72
                  val vs' = filter_out is_unit vs;
berghofe@13467
    73
                  val f = mk_Free "f" (map fastype_of vs' ---> rT) j;
wenzelm@18929
    74
                  val f' = Envir.eta_contract (list_abs_free
berghofe@13467
    75
                    (map dest_Free vs, if i mem is then list_comb (f, vs')
berghofe@13467
    76
                      else HOLogic.unit));
berghofe@13467
    77
                in (HOLogic.mk_Trueprop (make_pred i rT T (list_comb (f, vs'))
berghofe@13467
    78
                  (list_comb (Const (cname, Ts ---> T), map Free frees))), f')
berghofe@13467
    79
                end
berghofe@13641
    80
            | mk_prems vs (((dt, s), T) :: ds) = 
berghofe@13467
    81
                let
berghofe@13641
    82
                  val k = body_index dt;
berghofe@13641
    83
                  val (Us, U) = strip_type T;
berghofe@13641
    84
                  val i = length Us;
skalberg@15570
    85
                  val rT = List.nth (rec_result_Ts, k);
berghofe@13641
    86
                  val r = Free ("r" ^ s, Us ---> rT);
berghofe@13467
    87
                  val (p, f) = mk_prems (vs @ [r]) ds
berghofe@13641
    88
                in (mk_all k ("r" ^ s) (Us ---> rT) (Logic.mk_implies
berghofe@13641
    89
                  (list_all (map (pair "x") Us, HOLogic.mk_Trueprop
berghofe@13641
    90
                    (make_pred k rT U (app_bnds r i)
berghofe@13641
    91
                      (app_bnds (Free (s, T)) i))), p)), f)
berghofe@13467
    92
                end
berghofe@13467
    93
berghofe@13467
    94
        in (j + 1,
berghofe@13467
    95
          apfst (curry list_all_free frees) (mk_prems (map Free frees) recs))
berghofe@13467
    96
        end) (j, constrs)) (1, descr ~~ recTs))));
berghofe@13467
    97
 
berghofe@13467
    98
    fun mk_proj j [] t = t
berghofe@13467
    99
      | mk_proj j (i :: is) t = if null is then t else
wenzelm@23577
   100
          if (j: int) = i then HOLogic.mk_fst t
berghofe@13467
   101
          else mk_proj j is (HOLogic.mk_snd t);
berghofe@13467
   102
berghofe@13467
   103
    val tnames = DatatypeProp.make_tnames recTs;
berghofe@13467
   104
    val fTs = map fastype_of rec_fns;
berghofe@13467
   105
    val ps = map (fn ((((i, _), T), U), s) => Abs ("x", T, make_pred i U T
berghofe@13467
   106
      (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Bound 0) (Bound 0)))
berghofe@13467
   107
        (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names);
berghofe@13467
   108
    val r = if null is then Extraction.nullt else
skalberg@15570
   109
      foldr1 HOLogic.mk_prod (List.mapPartial (fn (((((i, _), T), U), s), tname) =>
skalberg@15531
   110
        if i mem is then SOME
berghofe@13467
   111
          (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Free (tname, T))
skalberg@15531
   112
        else NONE) (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names ~~ tnames));
berghofe@13467
   113
    val concl = HOLogic.mk_Trueprop (foldr1 (HOLogic.mk_binop "op &")
berghofe@13467
   114
      (map (fn ((((i, _), T), U), tname) =>
berghofe@13467
   115
        make_pred i U T (mk_proj i is r) (Free (tname, T)))
berghofe@13467
   116
          (descr ~~ recTs ~~ rec_result_Ts ~~ tnames)));
wenzelm@22578
   117
    val cert = cterm_of thy;
berghofe@13467
   118
    val inst = map (pairself cert) (map head_of (HOLogic.dest_conj
berghofe@13467
   119
      (HOLogic.dest_Trueprop (concl_of induction))) ~~ ps);
berghofe@13467
   120
wenzelm@17959
   121
    val thm = OldGoals.simple_prove_goal_cterm (cert (Logic.list_implies (prems, concl)))
berghofe@13467
   122
      (fn prems =>
berghofe@13467
   123
         [rewrite_goals_tac (map mk_meta_eq [fst_conv, snd_conv]),
berghofe@13467
   124
          rtac (cterm_instantiate inst induction) 1,
wenzelm@23590
   125
          ALLGOALS ObjectLogic.atomize_prems_tac,
haftmann@26359
   126
          rewrite_goals_tac (@{thm o_def} :: map mk_meta_eq rec_rewrites),
berghofe@13467
   127
          REPEAT ((resolve_tac prems THEN_ALL_NEW (fn i =>
berghofe@13467
   128
            REPEAT (etac allE i) THEN atac i)) 1)]);
berghofe@13467
   129
wenzelm@21646
   130
    val ind_name = Thm.get_name induction;
skalberg@15570
   131
    val vs = map (fn i => List.nth (pnames, i)) is;
haftmann@18358
   132
    val (thm', thy') = thy
wenzelm@24712
   133
      |> Sign.absolute_path
wenzelm@26481
   134
      |> PureThy.store_thm (space_implode "_" (ind_name :: vs @ ["correctness"]), thm)
wenzelm@24712
   135
      ||> Sign.restore_naming thy;
berghofe@13467
   136
wenzelm@20286
   137
    val ivs = rev (Term.add_vars (Logic.varify (DatatypeProp.make_ind [descr] sorts)) []);
wenzelm@22691
   138
    val rvs = rev (Thm.fold_terms Term.add_vars thm' []);
berghofe@13725
   139
    val ivs1 = map Var (filter_out (fn (_, T) =>
berghofe@13725
   140
      tname_of (body_type T) mem ["set", "bool"]) ivs);
haftmann@17521
   141
    val ivs2 = map (fn (ixn, _) => Var (ixn, valOf (AList.lookup (op =) rvs ixn))) ivs;
berghofe@13467
   142
skalberg@15574
   143
    val prf = foldr forall_intr_prf
skalberg@15574
   144
     (foldr (fn ((f, p), prf) =>
berghofe@13467
   145
        (case head_of (strip_abs_body f) of
berghofe@13467
   146
           Free (s, T) =>
wenzelm@19806
   147
             let val T' = Logic.varifyT T
skalberg@15531
   148
             in Abst (s, SOME T', Proofterm.prf_abstract_over
skalberg@15531
   149
               (Var ((s, 0), T')) (AbsP ("H", SOME p, prf)))
berghofe@13467
   150
             end
skalberg@15531
   151
         | _ => AbsP ("H", SOME p, prf)))
skalberg@15574
   152
           (Proofterm.proof_combP
skalberg@15574
   153
             (prf_of thm', map PBound (length prems - 1 downto 0))) (rec_fns ~~ prems_of thm)) ivs2;
berghofe@13467
   154
skalberg@15574
   155
    val r' = if null is then r else Logic.varify (foldr (uncurry lambda)
skalberg@15574
   156
      r (map Logic.unvarify ivs1 @ filter_out is_unit
skalberg@15574
   157
          (map (head_of o strip_abs_body) rec_fns)));
berghofe@13467
   158
berghofe@13467
   159
  in Extraction.add_realizers_i [(ind_name, (vs, r', prf))] thy' end;
berghofe@13467
   160
berghofe@13467
   161
haftmann@24699
   162
fun make_casedists sorts ({index, descr, case_name, case_rewrites, exhaustion, ...} : datatype_info) thy =
berghofe@13467
   163
  let
wenzelm@19806
   164
    val cert = cterm_of thy;
berghofe@13467
   165
    val rT = TFree ("'P", HOLogic.typeS);
berghofe@13467
   166
    val rT' = TVar (("'P", 0), HOLogic.typeS);
berghofe@13467
   167
berghofe@13467
   168
    fun make_casedist_prem T (cname, cargs) =
berghofe@13467
   169
      let
berghofe@13467
   170
        val Ts = map (typ_of_dtyp descr sorts) cargs;
wenzelm@20071
   171
        val frees = Name.variant_list ["P", "y"] (DatatypeProp.make_tnames Ts) ~~ Ts;
berghofe@13467
   172
        val free_ts = map Free frees;
berghofe@13467
   173
        val r = Free ("r" ^ NameSpace.base cname, Ts ---> rT)
berghofe@13467
   174
      in (r, list_all_free (frees, Logic.mk_implies (HOLogic.mk_Trueprop
berghofe@13467
   175
        (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))),
berghofe@13467
   176
          HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
berghofe@13467
   177
            list_comb (r, free_ts)))))
berghofe@13467
   178
      end;
berghofe@13467
   179
haftmann@17521
   180
    val SOME (_, _, constrs) = AList.lookup (op =) descr index;
skalberg@15570
   181
    val T = List.nth (get_rec_types descr sorts, index);
berghofe@13467
   182
    val (rs, prems) = split_list (map (make_casedist_prem T) constrs);
berghofe@13467
   183
    val r = Const (case_name, map fastype_of rs ---> T --> rT);
berghofe@13467
   184
wenzelm@19806
   185
    val y = Var (("y", 0), Logic.legacy_varifyT T);
berghofe@13467
   186
    val y' = Free ("y", T);
berghofe@13467
   187
wenzelm@17959
   188
    val thm = OldGoals.prove_goalw_cterm [] (cert (Logic.list_implies (prems,
berghofe@13467
   189
      HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
berghofe@13467
   190
        list_comb (r, rs @ [y'])))))
berghofe@13467
   191
      (fn prems =>
berghofe@13467
   192
         [rtac (cterm_instantiate [(cert y, cert y')] exhaustion) 1,
berghofe@13467
   193
          ALLGOALS (EVERY'
berghofe@13467
   194
            [asm_simp_tac (HOL_basic_ss addsimps case_rewrites),
berghofe@13467
   195
             resolve_tac prems, asm_simp_tac HOL_basic_ss])]);
berghofe@13467
   196
wenzelm@21646
   197
    val exh_name = Thm.get_name exhaustion;
haftmann@18358
   198
    val (thm', thy') = thy
wenzelm@24712
   199
      |> Sign.absolute_path
wenzelm@26481
   200
      |> PureThy.store_thm (exh_name ^ "_P_correctness", thm)
wenzelm@24712
   201
      ||> Sign.restore_naming thy;
berghofe@13467
   202
berghofe@13725
   203
    val P = Var (("P", 0), rT' --> HOLogic.boolT);
berghofe@13467
   204
    val prf = forall_intr_prf (y, forall_intr_prf (P,
skalberg@15574
   205
      foldr (fn ((p, r), prf) =>
wenzelm@19806
   206
        forall_intr_prf (Logic.legacy_varify r, AbsP ("H", SOME (Logic.varify p),
skalberg@15574
   207
          prf))) (Proofterm.proof_combP (prf_of thm',
skalberg@15574
   208
            map PBound (length prems - 1 downto 0))) (prems ~~ rs)));
wenzelm@19806
   209
    val r' = Logic.legacy_varify (Abs ("y", Logic.legacy_varifyT T,
berghofe@13725
   210
      list_abs (map dest_Free rs, list_comb (r,
berghofe@13725
   211
        map Bound ((length rs - 1 downto 0) @ [length rs])))));
berghofe@13467
   212
berghofe@13467
   213
  in Extraction.add_realizers_i
berghofe@13467
   214
    [(exh_name, (["P"], r', prf)),
berghofe@13725
   215
     (exh_name, ([], Extraction.nullt, prf_of exhaustion))] thy'
berghofe@13467
   216
  end;
berghofe@13467
   217
haftmann@24699
   218
fun add_dt_realizers names thy =
wenzelm@25223
   219
  if ! Proofterm.proofs < 2 then thy
haftmann@24699
   220
  else let
haftmann@24699
   221
    val _ = message "Adding realizers for induction and case analysis ..."
haftmann@24699
   222
    val infos = map (DatatypePackage.the_datatype thy) names;
haftmann@24699
   223
    val info :: _ = infos;
haftmann@24699
   224
  in
haftmann@24699
   225
    thy
haftmann@24699
   226
    |> fold_rev (make_ind (#sorts info) info) (subsets 0 (length (#descr info) - 1))
haftmann@24699
   227
    |> fold_rev (make_casedists (#sorts info)) infos
haftmann@24699
   228
  end;
berghofe@13467
   229
wenzelm@24711
   230
val setup = DatatypePackage.interpretation add_dt_realizers;
berghofe@13467
   231
berghofe@13467
   232
end;