src/HOL/Tools/datatype_realizer.ML
author berghofe
Wed Nov 13 15:28:41 2002 +0100 (2002-11-13)
changeset 13708 a3a410782c95
parent 13656 58bb243dbafb
child 13725 12404b452034
permissions -rw-r--r--
prove_goal' -> Goal.simple_prove_goal_cterm
     1 (*  Title:      HOL/Tools/datatype_realizer.ML
     2     ID:         $Id$
     3     Author:     Stefan Berghofer, TU Muenchen
     4     License:    GPL (GNU GENERAL PUBLIC LICENSE)
     5 
     6 Porgram extraction from proofs involving datatypes:
     7 Realizers for induction and case analysis
     8 *)
     9 
    10 signature DATATYPE_REALIZER =
    11 sig
    12   val add_dt_realizers: (string * sort) list ->
    13     DatatypeAux.datatype_info list -> theory -> theory
    14 end;
    15 
    16 structure DatatypeRealizer : DATATYPE_REALIZER =
    17 struct
    18 
    19 open DatatypeAux;
    20 
    21 fun subsets i j = if i <= j then
    22        let val is = subsets (i+1) j
    23        in map (fn ks => i::ks) is @ is end
    24      else [[]];
    25 
    26 fun forall_intr_prf (t, prf) =
    27   let val (a, T) = (case t of Var ((a, _), T) => (a, T) | Free p => p)
    28   in Abst (a, Some T, Proofterm.prf_abstract_over t prf) end;
    29 
    30 fun prf_of thm =
    31   let val {sign, prop, der = (_, prf), ...} = rep_thm thm
    32   in Reconstruct.reconstruct_proof sign prop prf end;
    33 
    34 fun prf_subst_vars inst =
    35   Proofterm.map_proof_terms (subst_vars ([], inst)) I;
    36 
    37 fun is_unit t = snd (strip_type (fastype_of t)) = HOLogic.unitT;
    38 
    39 fun mk_realizes T = Const ("realizes", T --> HOLogic.boolT --> HOLogic.boolT);
    40 
    41 fun make_ind sorts ({descr, rec_names, rec_rewrites, induction, ...} : datatype_info) (is, thy) =
    42   let
    43     val sg = sign_of thy;
    44     val recTs = get_rec_types descr sorts;
    45     val pnames = if length descr = 1 then ["P"]
    46       else map (fn i => "P" ^ string_of_int i) (1 upto length descr);
    47 
    48     val rec_result_Ts = map (fn ((i, _), P) =>
    49       if i mem is then TFree ("'" ^ P, HOLogic.typeS) else HOLogic.unitT)
    50         (descr ~~ pnames);
    51 
    52     fun make_pred i T U r x =
    53       if i mem is then
    54         Free (nth_elem (i, pnames), T --> U --> HOLogic.boolT) $ r $ x
    55       else Free (nth_elem (i, pnames), U --> HOLogic.boolT) $ x;
    56 
    57     fun mk_all i s T t =
    58       if i mem is then list_all_free ([(s, T)], t) else t;
    59 
    60     val (prems, rec_fns) = split_list (flat (snd (foldl_map
    61       (fn (j, ((i, (_, _, constrs)), T)) => foldl_map (fn (j, (cname, cargs)) =>
    62         let
    63           val Ts = map (typ_of_dtyp descr sorts) cargs;
    64           val tnames = variantlist (DatatypeProp.make_tnames Ts, pnames);
    65           val recs = filter (is_rec_type o fst o fst) (cargs ~~ tnames ~~ Ts);
    66           val frees = tnames ~~ Ts;
    67 
    68           fun mk_prems vs [] = 
    69                 let
    70                   val rT = nth_elem (i, rec_result_Ts);
    71                   val vs' = filter_out is_unit vs;
    72                   val f = mk_Free "f" (map fastype_of vs' ---> rT) j;
    73                   val f' = Pattern.eta_contract (list_abs_free
    74                     (map dest_Free vs, if i mem is then list_comb (f, vs')
    75                       else HOLogic.unit));
    76                 in (HOLogic.mk_Trueprop (make_pred i rT T (list_comb (f, vs'))
    77                   (list_comb (Const (cname, Ts ---> T), map Free frees))), f')
    78                 end
    79             | mk_prems vs (((dt, s), T) :: ds) = 
    80                 let
    81                   val k = body_index dt;
    82                   val (Us, U) = strip_type T;
    83                   val i = length Us;
    84                   val rT = nth_elem (k, rec_result_Ts);
    85                   val r = Free ("r" ^ s, Us ---> rT);
    86                   val (p, f) = mk_prems (vs @ [r]) ds
    87                 in (mk_all k ("r" ^ s) (Us ---> rT) (Logic.mk_implies
    88                   (list_all (map (pair "x") Us, HOLogic.mk_Trueprop
    89                     (make_pred k rT U (app_bnds r i)
    90                       (app_bnds (Free (s, T)) i))), p)), f)
    91                 end
    92 
    93         in (j + 1,
    94           apfst (curry list_all_free frees) (mk_prems (map Free frees) recs))
    95         end) (j, constrs)) (1, descr ~~ recTs))));
    96  
    97     fun mk_proj j [] t = t
    98       | mk_proj j (i :: is) t = if null is then t else
    99           if j = i then HOLogic.mk_fst t
   100           else mk_proj j is (HOLogic.mk_snd t);
   101 
   102     val tnames = DatatypeProp.make_tnames recTs;
   103     val fTs = map fastype_of rec_fns;
   104     val ps = map (fn ((((i, _), T), U), s) => Abs ("x", T, make_pred i U T
   105       (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Bound 0) (Bound 0)))
   106         (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names);
   107     val r = if null is then Extraction.nullt else
   108       foldr1 HOLogic.mk_prod (mapfilter (fn (((((i, _), T), U), s), tname) =>
   109         if i mem is then Some
   110           (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Free (tname, T))
   111         else None) (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names ~~ tnames));
   112     val concl = HOLogic.mk_Trueprop (foldr1 (HOLogic.mk_binop "op &")
   113       (map (fn ((((i, _), T), U), tname) =>
   114         make_pred i U T (mk_proj i is r) (Free (tname, T)))
   115           (descr ~~ recTs ~~ rec_result_Ts ~~ tnames)));
   116     val cert = cterm_of sg;
   117     val inst = map (pairself cert) (map head_of (HOLogic.dest_conj
   118       (HOLogic.dest_Trueprop (concl_of induction))) ~~ ps);
   119 
   120     val thm = simple_prove_goal_cterm (cert (Logic.list_implies (prems, concl)))
   121       (fn prems =>
   122          [rewrite_goals_tac (map mk_meta_eq [fst_conv, snd_conv]),
   123           rtac (cterm_instantiate inst induction) 1,
   124           ALLGOALS ObjectLogic.atomize_tac,
   125           rewrite_goals_tac (o_def :: map mk_meta_eq rec_rewrites),
   126           REPEAT ((resolve_tac prems THEN_ALL_NEW (fn i =>
   127             REPEAT (etac allE i) THEN atac i)) 1)]);
   128 
   129     val {path, ...} = Sign.rep_sg sg;
   130     val ind_name = Thm.name_of_thm induction;
   131     val vs = map (fn i => nth_elem (i, pnames)) is;
   132     val (thy', thm') = thy
   133       |> Theory.absolute_path
   134       |> PureThy.store_thm
   135         ((space_implode "_" (ind_name :: vs @ ["correctness"]), thm), [])
   136       |>> Theory.add_path (NameSpace.pack (if_none path []));
   137 
   138     val inst = map (fn ((((i, _), s), T), U) => ((s, 0), if i mem is then
   139         Abs ("r", U, Abs ("x", T, mk_realizes U $ Bound 1 $
   140           (Var ((s, 0), T --> HOLogic.boolT) $ Bound 0)))
   141       else Abs ("x", T, mk_realizes Extraction.nullT $ Extraction.nullt $
   142         (Var ((s, 0), T --> HOLogic.boolT) $
   143           Bound 0)))) (descr ~~ pnames ~~ map Type.varifyT recTs ~~
   144             map Type.varifyT rec_result_Ts);
   145 
   146     val ivs = map Var (Drule.vars_of_terms
   147       [Logic.varify (DatatypeProp.make_ind [descr] sorts)]);
   148 
   149     val prf = foldr forall_intr_prf (ivs,
   150       prf_subst_vars inst (foldr (fn ((f, p), prf) =>
   151         (case head_of (strip_abs_body f) of
   152            Free (s, T) =>
   153              let val T' = Type.varifyT T
   154              in Abst (s, Some T', Proofterm.prf_abstract_over
   155                (Var ((s, 0), T')) (AbsP ("H", Some p, prf)))
   156              end
   157          | _ => AbsP ("H", Some p, prf)))
   158            (rec_fns ~~ prems_of thm, Proofterm.proof_combP
   159              (prf_of thm', map PBound (length prems - 1 downto 0)))));
   160 
   161     val r' = if null is then r else Logic.varify (foldr (uncurry lambda)
   162       (map Logic.unvarify ivs @ filter_out is_unit
   163         (map (head_of o strip_abs_body) rec_fns), r));
   164 
   165   in Extraction.add_realizers_i [(ind_name, (vs, r', prf))] thy' end;
   166 
   167 
   168 fun make_casedists sorts ({index, descr, case_name, case_rewrites, exhaustion, ...} : datatype_info, thy) =
   169   let
   170     val sg = sign_of thy;
   171     val sorts = map (rpair HOLogic.typeS) (distinct (flat (map
   172       (fn (_, (_, ds, _)) => mapfilter (try dest_DtTFree) ds) descr)));
   173     val cert = cterm_of sg;
   174     val rT = TFree ("'P", HOLogic.typeS);
   175     val rT' = TVar (("'P", 0), HOLogic.typeS);
   176 
   177     fun make_casedist_prem T (cname, cargs) =
   178       let
   179         val Ts = map (typ_of_dtyp descr sorts) cargs;
   180         val frees = variantlist
   181           (DatatypeProp.make_tnames Ts, ["P", "y"]) ~~ Ts;
   182         val free_ts = map Free frees;
   183         val r = Free ("r" ^ NameSpace.base cname, Ts ---> rT)
   184       in (r, list_all_free (frees, Logic.mk_implies (HOLogic.mk_Trueprop
   185         (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))),
   186           HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
   187             list_comb (r, free_ts)))))
   188       end;
   189 
   190     val Some (_, _, constrs) = assoc (descr, index);
   191     val T = nth_elem (index, get_rec_types descr sorts);
   192     val (rs, prems) = split_list (map (make_casedist_prem T) constrs);
   193     val r = Const (case_name, map fastype_of rs ---> T --> rT);
   194 
   195     val y = Var (("y", 0), Type.varifyT T);
   196     val y' = Free ("y", T);
   197 
   198     val thm = prove_goalw_cterm [] (cert (Logic.list_implies (prems,
   199       HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
   200         list_comb (r, rs @ [y'])))))
   201       (fn prems =>
   202          [rtac (cterm_instantiate [(cert y, cert y')] exhaustion) 1,
   203           ALLGOALS (EVERY'
   204             [asm_simp_tac (HOL_basic_ss addsimps case_rewrites),
   205              resolve_tac prems, asm_simp_tac HOL_basic_ss])]);
   206 
   207     val {path, ...} = Sign.rep_sg sg;
   208     val exh_name = Thm.name_of_thm exhaustion;
   209     val (thy', thm') = thy
   210       |> Theory.absolute_path
   211       |> PureThy.store_thm ((exh_name ^ "_P_correctness", thm), [])
   212       |>> Theory.add_path (NameSpace.pack (if_none path []));
   213 
   214     val P = Var (("P", 0), HOLogic.boolT);
   215     val prf = forall_intr_prf (y, forall_intr_prf (P,
   216       prf_subst_vars [(("P", 0), Abs ("r", rT',
   217         mk_realizes rT' $ Bound 0 $ P))] (foldr (fn ((p, r), prf) =>
   218           forall_intr_prf (Logic.varify r, AbsP ("H", Some (Logic.varify p),
   219             prf))) (prems ~~ rs, Proofterm.proof_combP (prf_of thm',
   220               map PBound (length prems - 1 downto 0))))));
   221     val r' = Logic.varify (Abs ("y", Type.varifyT T,
   222       Abs ("P", HOLogic.boolT, list_abs (map dest_Free rs, list_comb (r,
   223         map Bound ((length rs - 1 downto 0) @ [length rs + 1]))))));
   224 
   225     val prf' = forall_intr_prf (y, forall_intr_prf (P, prf_subst_vars
   226       [(("P", 0), mk_realizes Extraction.nullT $ Extraction.nullt $ P)]
   227         (prf_of exhaustion)));
   228 
   229   in Extraction.add_realizers_i
   230     [(exh_name, (["P"], r', prf)),
   231      (exh_name, ([], Extraction.nullt, prf'))] thy'
   232   end;
   233 
   234 
   235 fun add_dt_realizers sorts infos thy = if !proofs < 2 then thy else
   236   (message "Adding realizers for induction and case analysis ..."; thy
   237    |> curry (foldr (make_ind sorts (hd infos)))
   238      (subsets 0 (length (#descr (hd infos)) - 1))
   239    |> curry (foldr (make_casedists sorts)) infos);
   240 
   241 end;