src/HOL/Tools/datatype_realizer.ML
author wenzelm
Sat Mar 29 19:14:03 2008 +0100 (2008-03-29)
changeset 26481 92e901171cc8
parent 26359 6d437bde2f1d
child 26626 c6231d64d264
permissions -rw-r--r--
simplified PureThy.store_thm;
     1 (*  Title:      HOL/Tools/datatype_realizer.ML
     2     ID:         $Id$
     3     Author:     Stefan Berghofer, TU Muenchen
     4 
     5 Porgram extraction from proofs involving datatypes:
     6 Realizers for induction and case analysis
     7 *)
     8 
     9 signature DATATYPE_REALIZER =
    10 sig
    11   val add_dt_realizers: string list -> theory -> theory
    12   val setup: theory -> theory
    13 end;
    14 
    15 structure DatatypeRealizer : DATATYPE_REALIZER =
    16 struct
    17 
    18 open DatatypeAux;
    19 
    20 fun subsets i j = if i <= j then
    21        let val is = subsets (i+1) j
    22        in map (fn ks => i::ks) is @ is end
    23      else [[]];
    24 
    25 fun forall_intr_prf (t, prf) =
    26   let val (a, T) = (case t of Var ((a, _), T) => (a, T) | Free p => p)
    27   in Abst (a, SOME T, Proofterm.prf_abstract_over t prf) end;
    28 
    29 fun prf_of thm =
    30   let val {thy, prop, der = (_, prf), ...} = rep_thm thm
    31   in Reconstruct.reconstruct_proof thy prop prf end;
    32 
    33 fun prf_subst_vars inst =
    34   Proofterm.map_proof_terms (subst_vars ([], inst)) I;
    35 
    36 fun is_unit t = snd (strip_type (fastype_of t)) = HOLogic.unitT;
    37 
    38 fun tname_of (Type (s, _)) = s
    39   | tname_of _ = "";
    40 
    41 fun mk_realizes T = Const ("realizes", T --> HOLogic.boolT --> HOLogic.boolT);
    42 
    43 fun make_ind sorts ({descr, rec_names, rec_rewrites, induction, ...} : datatype_info) is thy =
    44   let
    45     val recTs = get_rec_types descr sorts;
    46     val pnames = if length descr = 1 then ["P"]
    47       else map (fn i => "P" ^ string_of_int i) (1 upto length descr);
    48 
    49     val rec_result_Ts = map (fn ((i, _), P) =>
    50       if i mem is then TFree ("'" ^ P, HOLogic.typeS) else HOLogic.unitT)
    51         (descr ~~ pnames);
    52 
    53     fun make_pred i T U r x =
    54       if i mem is then
    55         Free (List.nth (pnames, i), T --> U --> HOLogic.boolT) $ r $ x
    56       else Free (List.nth (pnames, i), U --> HOLogic.boolT) $ x;
    57 
    58     fun mk_all i s T t =
    59       if i mem is then list_all_free ([(s, T)], t) else t;
    60 
    61     val (prems, rec_fns) = split_list (List.concat (snd (foldl_map
    62       (fn (j, ((i, (_, _, constrs)), T)) => foldl_map (fn (j, (cname, cargs)) =>
    63         let
    64           val Ts = map (typ_of_dtyp descr sorts) cargs;
    65           val tnames = Name.variant_list pnames (DatatypeProp.make_tnames Ts);
    66           val recs = List.filter (is_rec_type o fst o fst) (cargs ~~ tnames ~~ Ts);
    67           val frees = tnames ~~ Ts;
    68 
    69           fun mk_prems vs [] = 
    70                 let
    71                   val rT = List.nth (rec_result_Ts, i);
    72                   val vs' = filter_out is_unit vs;
    73                   val f = mk_Free "f" (map fastype_of vs' ---> rT) j;
    74                   val f' = Envir.eta_contract (list_abs_free
    75                     (map dest_Free vs, if i mem is then list_comb (f, vs')
    76                       else HOLogic.unit));
    77                 in (HOLogic.mk_Trueprop (make_pred i rT T (list_comb (f, vs'))
    78                   (list_comb (Const (cname, Ts ---> T), map Free frees))), f')
    79                 end
    80             | mk_prems vs (((dt, s), T) :: ds) = 
    81                 let
    82                   val k = body_index dt;
    83                   val (Us, U) = strip_type T;
    84                   val i = length Us;
    85                   val rT = List.nth (rec_result_Ts, k);
    86                   val r = Free ("r" ^ s, Us ---> rT);
    87                   val (p, f) = mk_prems (vs @ [r]) ds
    88                 in (mk_all k ("r" ^ s) (Us ---> rT) (Logic.mk_implies
    89                   (list_all (map (pair "x") Us, HOLogic.mk_Trueprop
    90                     (make_pred k rT U (app_bnds r i)
    91                       (app_bnds (Free (s, T)) i))), p)), f)
    92                 end
    93 
    94         in (j + 1,
    95           apfst (curry list_all_free frees) (mk_prems (map Free frees) recs))
    96         end) (j, constrs)) (1, descr ~~ recTs))));
    97  
    98     fun mk_proj j [] t = t
    99       | mk_proj j (i :: is) t = if null is then t else
   100           if (j: int) = i then HOLogic.mk_fst t
   101           else mk_proj j is (HOLogic.mk_snd t);
   102 
   103     val tnames = DatatypeProp.make_tnames recTs;
   104     val fTs = map fastype_of rec_fns;
   105     val ps = map (fn ((((i, _), T), U), s) => Abs ("x", T, make_pred i U T
   106       (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Bound 0) (Bound 0)))
   107         (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names);
   108     val r = if null is then Extraction.nullt else
   109       foldr1 HOLogic.mk_prod (List.mapPartial (fn (((((i, _), T), U), s), tname) =>
   110         if i mem is then SOME
   111           (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Free (tname, T))
   112         else NONE) (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names ~~ tnames));
   113     val concl = HOLogic.mk_Trueprop (foldr1 (HOLogic.mk_binop "op &")
   114       (map (fn ((((i, _), T), U), tname) =>
   115         make_pred i U T (mk_proj i is r) (Free (tname, T)))
   116           (descr ~~ recTs ~~ rec_result_Ts ~~ tnames)));
   117     val cert = cterm_of thy;
   118     val inst = map (pairself cert) (map head_of (HOLogic.dest_conj
   119       (HOLogic.dest_Trueprop (concl_of induction))) ~~ ps);
   120 
   121     val thm = OldGoals.simple_prove_goal_cterm (cert (Logic.list_implies (prems, concl)))
   122       (fn prems =>
   123          [rewrite_goals_tac (map mk_meta_eq [fst_conv, snd_conv]),
   124           rtac (cterm_instantiate inst induction) 1,
   125           ALLGOALS ObjectLogic.atomize_prems_tac,
   126           rewrite_goals_tac (@{thm o_def} :: map mk_meta_eq rec_rewrites),
   127           REPEAT ((resolve_tac prems THEN_ALL_NEW (fn i =>
   128             REPEAT (etac allE i) THEN atac i)) 1)]);
   129 
   130     val ind_name = Thm.get_name induction;
   131     val vs = map (fn i => List.nth (pnames, i)) is;
   132     val (thm', thy') = thy
   133       |> Sign.absolute_path
   134       |> PureThy.store_thm (space_implode "_" (ind_name :: vs @ ["correctness"]), thm)
   135       ||> Sign.restore_naming thy;
   136 
   137     val ivs = rev (Term.add_vars (Logic.varify (DatatypeProp.make_ind [descr] sorts)) []);
   138     val rvs = rev (Thm.fold_terms Term.add_vars thm' []);
   139     val ivs1 = map Var (filter_out (fn (_, T) =>
   140       tname_of (body_type T) mem ["set", "bool"]) ivs);
   141     val ivs2 = map (fn (ixn, _) => Var (ixn, valOf (AList.lookup (op =) rvs ixn))) ivs;
   142 
   143     val prf = foldr forall_intr_prf
   144      (foldr (fn ((f, p), prf) =>
   145         (case head_of (strip_abs_body f) of
   146            Free (s, T) =>
   147              let val T' = Logic.varifyT T
   148              in Abst (s, SOME T', Proofterm.prf_abstract_over
   149                (Var ((s, 0), T')) (AbsP ("H", SOME p, prf)))
   150              end
   151          | _ => AbsP ("H", SOME p, prf)))
   152            (Proofterm.proof_combP
   153              (prf_of thm', map PBound (length prems - 1 downto 0))) (rec_fns ~~ prems_of thm)) ivs2;
   154 
   155     val r' = if null is then r else Logic.varify (foldr (uncurry lambda)
   156       r (map Logic.unvarify ivs1 @ filter_out is_unit
   157           (map (head_of o strip_abs_body) rec_fns)));
   158 
   159   in Extraction.add_realizers_i [(ind_name, (vs, r', prf))] thy' end;
   160 
   161 
   162 fun make_casedists sorts ({index, descr, case_name, case_rewrites, exhaustion, ...} : datatype_info) thy =
   163   let
   164     val cert = cterm_of thy;
   165     val rT = TFree ("'P", HOLogic.typeS);
   166     val rT' = TVar (("'P", 0), HOLogic.typeS);
   167 
   168     fun make_casedist_prem T (cname, cargs) =
   169       let
   170         val Ts = map (typ_of_dtyp descr sorts) cargs;
   171         val frees = Name.variant_list ["P", "y"] (DatatypeProp.make_tnames Ts) ~~ Ts;
   172         val free_ts = map Free frees;
   173         val r = Free ("r" ^ NameSpace.base cname, Ts ---> rT)
   174       in (r, list_all_free (frees, Logic.mk_implies (HOLogic.mk_Trueprop
   175         (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))),
   176           HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
   177             list_comb (r, free_ts)))))
   178       end;
   179 
   180     val SOME (_, _, constrs) = AList.lookup (op =) descr index;
   181     val T = List.nth (get_rec_types descr sorts, index);
   182     val (rs, prems) = split_list (map (make_casedist_prem T) constrs);
   183     val r = Const (case_name, map fastype_of rs ---> T --> rT);
   184 
   185     val y = Var (("y", 0), Logic.legacy_varifyT T);
   186     val y' = Free ("y", T);
   187 
   188     val thm = OldGoals.prove_goalw_cterm [] (cert (Logic.list_implies (prems,
   189       HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
   190         list_comb (r, rs @ [y'])))))
   191       (fn prems =>
   192          [rtac (cterm_instantiate [(cert y, cert y')] exhaustion) 1,
   193           ALLGOALS (EVERY'
   194             [asm_simp_tac (HOL_basic_ss addsimps case_rewrites),
   195              resolve_tac prems, asm_simp_tac HOL_basic_ss])]);
   196 
   197     val exh_name = Thm.get_name exhaustion;
   198     val (thm', thy') = thy
   199       |> Sign.absolute_path
   200       |> PureThy.store_thm (exh_name ^ "_P_correctness", thm)
   201       ||> Sign.restore_naming thy;
   202 
   203     val P = Var (("P", 0), rT' --> HOLogic.boolT);
   204     val prf = forall_intr_prf (y, forall_intr_prf (P,
   205       foldr (fn ((p, r), prf) =>
   206         forall_intr_prf (Logic.legacy_varify r, AbsP ("H", SOME (Logic.varify p),
   207           prf))) (Proofterm.proof_combP (prf_of thm',
   208             map PBound (length prems - 1 downto 0))) (prems ~~ rs)));
   209     val r' = Logic.legacy_varify (Abs ("y", Logic.legacy_varifyT T,
   210       list_abs (map dest_Free rs, list_comb (r,
   211         map Bound ((length rs - 1 downto 0) @ [length rs])))));
   212 
   213   in Extraction.add_realizers_i
   214     [(exh_name, (["P"], r', prf)),
   215      (exh_name, ([], Extraction.nullt, prf_of exhaustion))] thy'
   216   end;
   217 
   218 fun add_dt_realizers names thy =
   219   if ! Proofterm.proofs < 2 then thy
   220   else let
   221     val _ = message "Adding realizers for induction and case analysis ..."
   222     val infos = map (DatatypePackage.the_datatype thy) names;
   223     val info :: _ = infos;
   224   in
   225     thy
   226     |> fold_rev (make_ind (#sorts info) info) (subsets 0 (length (#descr info) - 1))
   227     |> fold_rev (make_casedists (#sorts info)) infos
   228   end;
   229 
   230 val setup = DatatypePackage.interpretation add_dt_realizers;
   231 
   232 end;