src/HOL/Tools/Datatype/datatype_realizer.ML
author wenzelm
Sun Mar 07 12:19:47 2010 +0100 (2010-03-07)
changeset 35625 9c818cab0dd0
parent 35364 b8c62d60195c
child 35845 e5980f0ad025
permissions -rw-r--r--
modernized structure Object_Logic;
     1 (*  Title:      HOL/Tools/Datatype/datatype_realizer.ML
     2     Author:     Stefan Berghofer, TU Muenchen
     3 
     4 Program extraction from proofs involving datatypes:
     5 realizers for induction and case analysis.
     6 *)
     7 
     8 signature DATATYPE_REALIZER =
     9 sig
    10   val add_dt_realizers: Datatype.config -> string list -> theory -> theory
    11   val setup: theory -> theory
    12 end;
    13 
    14 structure Datatype_Realizer : DATATYPE_REALIZER =
    15 struct
    16 
    17 open Datatype_Aux;
    18 
    19 fun subsets i j =
    20   if i <= j then
    21     let val is = subsets (i+1) j
    22     in map (fn ks => i::ks) is @ is end
    23   else [[]];
    24 
    25 fun forall_intr_prf (t, prf) =
    26   let val (a, T) = (case t of Var ((a, _), T) => (a, T) | Free p => p)
    27   in Abst (a, SOME T, Proofterm.prf_abstract_over t prf) end;
    28 
    29 fun prf_of thm =
    30   Reconstruct.reconstruct_proof (Thm.theory_of_thm thm) (Thm.prop_of thm) (Thm.proof_of thm);
    31 
    32 fun is_unit t = snd (strip_type (fastype_of t)) = HOLogic.unitT;
    33 
    34 fun tname_of (Type (s, _)) = s
    35   | tname_of _ = "";
    36 
    37 fun make_ind sorts ({descr, rec_names, rec_rewrites, induct, ...} : info) is thy =
    38   let
    39     val recTs = get_rec_types descr sorts;
    40     val pnames = if length descr = 1 then ["P"]
    41       else map (fn i => "P" ^ string_of_int i) (1 upto length descr);
    42 
    43     val rec_result_Ts = map (fn ((i, _), P) =>
    44       if i mem is then TFree ("'" ^ P, HOLogic.typeS) else HOLogic.unitT)
    45         (descr ~~ pnames);
    46 
    47     fun make_pred i T U r x =
    48       if i mem is then
    49         Free (List.nth (pnames, i), T --> U --> HOLogic.boolT) $ r $ x
    50       else Free (List.nth (pnames, i), U --> HOLogic.boolT) $ x;
    51 
    52     fun mk_all i s T t =
    53       if i mem is then list_all_free ([(s, T)], t) else t;
    54 
    55     val (prems, rec_fns) = split_list (flat (fst (fold_map
    56       (fn ((i, (_, _, constrs)), T) => fold_map (fn (cname, cargs) => fn j =>
    57         let
    58           val Ts = map (typ_of_dtyp descr sorts) cargs;
    59           val tnames = Name.variant_list pnames (Datatype_Prop.make_tnames Ts);
    60           val recs = filter (is_rec_type o fst o fst) (cargs ~~ tnames ~~ Ts);
    61           val frees = tnames ~~ Ts;
    62 
    63           fun mk_prems vs [] = 
    64                 let
    65                   val rT = nth (rec_result_Ts) i;
    66                   val vs' = filter_out is_unit vs;
    67                   val f = mk_Free "f" (map fastype_of vs' ---> rT) j;
    68                   val f' = Envir.eta_contract (list_abs_free
    69                     (map dest_Free vs, if i mem is then list_comb (f, vs')
    70                       else HOLogic.unit));
    71                 in (HOLogic.mk_Trueprop (make_pred i rT T (list_comb (f, vs'))
    72                   (list_comb (Const (cname, Ts ---> T), map Free frees))), f')
    73                 end
    74             | mk_prems vs (((dt, s), T) :: ds) = 
    75                 let
    76                   val k = body_index dt;
    77                   val (Us, U) = strip_type T;
    78                   val i = length Us;
    79                   val rT = nth (rec_result_Ts) k;
    80                   val r = Free ("r" ^ s, Us ---> rT);
    81                   val (p, f) = mk_prems (vs @ [r]) ds
    82                 in (mk_all k ("r" ^ s) (Us ---> rT) (Logic.mk_implies
    83                   (list_all (map (pair "x") Us, HOLogic.mk_Trueprop
    84                     (make_pred k rT U (app_bnds r i)
    85                       (app_bnds (Free (s, T)) i))), p)), f)
    86                 end
    87 
    88         in (apfst (curry list_all_free frees) (mk_prems (map Free frees) recs), j + 1) end)
    89           constrs) (descr ~~ recTs) 1)));
    90  
    91     fun mk_proj j [] t = t
    92       | mk_proj j (i :: is) t = if null is then t else
    93           if (j: int) = i then HOLogic.mk_fst t
    94           else mk_proj j is (HOLogic.mk_snd t);
    95 
    96     val tnames = Datatype_Prop.make_tnames recTs;
    97     val fTs = map fastype_of rec_fns;
    98     val ps = map (fn ((((i, _), T), U), s) => Abs ("x", T, make_pred i U T
    99       (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Bound 0) (Bound 0)))
   100         (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names);
   101     val r = if null is then Extraction.nullt else
   102       foldr1 HOLogic.mk_prod (map_filter (fn (((((i, _), T), U), s), tname) =>
   103         if i mem is then SOME
   104           (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Free (tname, T))
   105         else NONE) (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names ~~ tnames));
   106     val concl = HOLogic.mk_Trueprop (foldr1 (HOLogic.mk_binop @{const_name "op &"})
   107       (map (fn ((((i, _), T), U), tname) =>
   108         make_pred i U T (mk_proj i is r) (Free (tname, T)))
   109           (descr ~~ recTs ~~ rec_result_Ts ~~ tnames)));
   110     val cert = cterm_of thy;
   111     val inst = map (pairself cert) (map head_of (HOLogic.dest_conj
   112       (HOLogic.dest_Trueprop (concl_of induct))) ~~ ps);
   113 
   114     val thm = OldGoals.simple_prove_goal_cterm (cert (Logic.list_implies (prems, concl)))
   115       (fn prems =>
   116          [rewrite_goals_tac (map mk_meta_eq [fst_conv, snd_conv]),
   117           rtac (cterm_instantiate inst induct) 1,
   118           ALLGOALS Object_Logic.atomize_prems_tac,
   119           rewrite_goals_tac (@{thm o_def} :: map mk_meta_eq rec_rewrites),
   120           REPEAT ((resolve_tac prems THEN_ALL_NEW (fn i =>
   121             REPEAT (etac allE i) THEN atac i)) 1)]);
   122 
   123     val ind_name = Thm.get_name induct;
   124     val vs = map (fn i => List.nth (pnames, i)) is;
   125     val (thm', thy') = thy
   126       |> Sign.root_path
   127       |> PureThy.store_thm
   128         (Binding.qualified_name (space_implode "_" (ind_name :: vs @ ["correctness"])), thm)
   129       ||> Sign.restore_naming thy;
   130 
   131     val ivs = rev (Term.add_vars (Logic.varify (Datatype_Prop.make_ind [descr] sorts)) []);
   132     val rvs = rev (Thm.fold_terms Term.add_vars thm' []);
   133     val ivs1 = map Var (filter_out (fn (_, T) =>  (* FIXME set (!??) *)
   134       tname_of (body_type T) mem [@{type_abbrev set}, @{type_name bool}]) ivs);
   135     val ivs2 = map (fn (ixn, _) => Var (ixn, the (AList.lookup (op =) rvs ixn))) ivs;
   136 
   137     val prf =
   138       List.foldr forall_intr_prf
   139         (fold_rev (fn (f, p) => fn prf =>
   140           (case head_of (strip_abs_body f) of
   141              Free (s, T) =>
   142                let val T' = Logic.varifyT T
   143                in Abst (s, SOME T', Proofterm.prf_abstract_over
   144                  (Var ((s, 0), T')) (AbsP ("H", SOME p, prf)))
   145                end
   146           | _ => AbsP ("H", SOME p, prf)))
   147             (rec_fns ~~ prems_of thm)
   148             (Proofterm.proof_combP (prf_of thm', map PBound (length prems - 1 downto 0)))) ivs2;
   149 
   150     val r' =
   151       if null is then r
   152       else Logic.varify (fold_rev lambda
   153         (map Logic.unvarify ivs1 @ filter_out is_unit
   154             (map (head_of o strip_abs_body) rec_fns)) r);
   155 
   156   in Extraction.add_realizers_i [(ind_name, (vs, r', prf))] thy' end;
   157 
   158 
   159 fun make_casedists sorts ({index, descr, case_name, case_rewrites, exhaust, ...} : info) thy =
   160   let
   161     val cert = cterm_of thy;
   162     val rT = TFree ("'P", HOLogic.typeS);
   163     val rT' = TVar (("'P", 0), HOLogic.typeS);
   164 
   165     fun make_casedist_prem T (cname, cargs) =
   166       let
   167         val Ts = map (typ_of_dtyp descr sorts) cargs;
   168         val frees = Name.variant_list ["P", "y"] (Datatype_Prop.make_tnames Ts) ~~ Ts;
   169         val free_ts = map Free frees;
   170         val r = Free ("r" ^ Long_Name.base_name cname, Ts ---> rT)
   171       in (r, list_all_free (frees, Logic.mk_implies (HOLogic.mk_Trueprop
   172         (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))),
   173           HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
   174             list_comb (r, free_ts)))))
   175       end;
   176 
   177     val SOME (_, _, constrs) = AList.lookup (op =) descr index;
   178     val T = List.nth (get_rec_types descr sorts, index);
   179     val (rs, prems) = split_list (map (make_casedist_prem T) constrs);
   180     val r = Const (case_name, map fastype_of rs ---> T --> rT);
   181 
   182     val y = Var (("y", 0), Logic.varifyT T);
   183     val y' = Free ("y", T);
   184 
   185     val thm = OldGoals.prove_goalw_cterm [] (cert (Logic.list_implies (prems,
   186       HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
   187         list_comb (r, rs @ [y'])))))
   188       (fn prems =>
   189          [rtac (cterm_instantiate [(cert y, cert y')] exhaust) 1,
   190           ALLGOALS (EVERY'
   191             [asm_simp_tac (HOL_basic_ss addsimps case_rewrites),
   192              resolve_tac prems, asm_simp_tac HOL_basic_ss])]);
   193 
   194     val exh_name = Thm.get_name exhaust;
   195     val (thm', thy') = thy
   196       |> Sign.root_path
   197       |> PureThy.store_thm (Binding.qualified_name (exh_name ^ "_P_correctness"), thm)
   198       ||> Sign.restore_naming thy;
   199 
   200     val P = Var (("P", 0), rT' --> HOLogic.boolT);
   201     val prf = forall_intr_prf (y, forall_intr_prf (P,
   202       fold_rev (fn (p, r) => fn prf =>
   203           forall_intr_prf (Logic.varify r, AbsP ("H", SOME (Logic.varify p), prf)))
   204         (prems ~~ rs)
   205         (Proofterm.proof_combP (prf_of thm', map PBound (length prems - 1 downto 0)))));
   206     val r' = Logic.varify (Abs ("y", T,
   207       list_abs (map dest_Free rs, list_comb (r,
   208         map Bound ((length rs - 1 downto 0) @ [length rs])))));
   209 
   210   in Extraction.add_realizers_i
   211     [(exh_name, (["P"], r', prf)),
   212      (exh_name, ([], Extraction.nullt, prf_of exhaust))] thy'
   213   end;
   214 
   215 fun add_dt_realizers config names thy =
   216   if ! Proofterm.proofs < 2 then thy
   217   else let
   218     val _ = message config "Adding realizers for induction and case analysis ..."
   219     val infos = map (Datatype.the_info thy) names;
   220     val info :: _ = infos;
   221   in
   222     thy
   223     |> fold_rev (make_ind (#sorts info) info) (subsets 0 (length (#descr info) - 1))
   224     |> fold_rev (make_casedists (#sorts info)) infos
   225   end;
   226 
   227 val setup = Datatype.interpretation add_dt_realizers;
   228 
   229 end;