src/HOL/Tools/datatype_realizer.ML
author blanchet
Tue Sep 09 20:51:36 2014 +0200 (2014-09-09)
changeset 58277 0dcd3a623a6e
parent 58275 280ede57a6a9
child 58354 04ac60da613e
permissions -rw-r--r--
made realizer more robust in the face of nesting through functions
     1 (*  Title:      HOL/Tools/datatype_realizer.ML
     2     Author:     Stefan Berghofer, TU Muenchen
     3 
     4 Program extraction from proofs involving datatypes:
     5 realizers for induction and case analysis.
     6 *)
     7 
     8 signature DATATYPE_REALIZER =
     9 sig
    10   val extraction_plugin: string
    11   val add_dt_realizers: Old_Datatype_Aux.config -> string list -> theory -> theory
    12 end;
    13 
    14 structure Datatype_Realizer : DATATYPE_REALIZER =
    15 struct
    16 
    17 val extraction_plugin = "extraction";
    18 
    19 fun subsets i j =
    20   if i <= j then
    21     let val is = subsets (i+1) j
    22     in map (fn ks => i::ks) is @ is end
    23   else [[]];
    24 
    25 fun is_unit t = body_type (fastype_of t) = HOLogic.unitT;
    26 
    27 fun tname_of (Type (s, _)) = s
    28   | tname_of _ = "";
    29 
    30 fun make_ind ({descr, rec_names, rec_rewrites, induct, ...} : Old_Datatype_Aux.info) is thy =
    31   let
    32     val ctxt = Proof_Context.init_global thy;
    33     val cert = cterm_of thy;
    34 
    35     val recTs = Old_Datatype_Aux.get_rec_types descr;
    36     val pnames =
    37       if length descr = 1 then ["P"]
    38       else map (fn i => "P" ^ string_of_int i) (1 upto length descr);
    39 
    40     val rec_result_Ts = map (fn ((i, _), P) =>
    41         if member (op =) is i then TFree ("'" ^ P, @{sort type}) else HOLogic.unitT)
    42       (descr ~~ pnames);
    43 
    44     fun make_pred i T U r x =
    45       if member (op =) is i then
    46         Free (nth pnames i, T --> U --> HOLogic.boolT) $ r $ x
    47       else Free (nth pnames i, U --> HOLogic.boolT) $ x;
    48 
    49     fun mk_all i s T t =
    50       if member (op =) is i then Logic.all (Free (s, T)) t else t;
    51 
    52     val (prems, rec_fns) = split_list (flat (fst (fold_map
    53       (fn ((i, (_, _, constrs)), T) => fold_map (fn (cname, cargs) => fn j =>
    54         let
    55           val Ts = map (Old_Datatype_Aux.typ_of_dtyp descr) cargs;
    56           val tnames = Name.variant_list pnames (Old_Datatype_Prop.make_tnames Ts);
    57           val recs = filter (Old_Datatype_Aux.is_rec_type o fst o fst) (cargs ~~ tnames ~~ Ts);
    58           val frees = tnames ~~ Ts;
    59 
    60           fun mk_prems vs [] =
    61                 let
    62                   val rT = nth (rec_result_Ts) i;
    63                   val vs' = filter_out is_unit vs;
    64                   val f = Old_Datatype_Aux.mk_Free "f" (map fastype_of vs' ---> rT) j;
    65                   val f' =
    66                     Envir.eta_contract (fold_rev (absfree o dest_Free) vs
    67                       (if member (op =) is i then list_comb (f, vs') else HOLogic.unit));
    68                 in
    69                   (HOLogic.mk_Trueprop (make_pred i rT T (list_comb (f, vs'))
    70                     (list_comb (Const (cname, Ts ---> T), map Free frees))), f')
    71                 end
    72             | mk_prems vs (((dt, s), T) :: ds) =
    73                 let
    74                   val k = Old_Datatype_Aux.body_index dt;
    75                   val (Us, U) = strip_type T;
    76                   val i = length Us;
    77                   val rT = nth (rec_result_Ts) k;
    78                   val r = Free ("r" ^ s, Us ---> rT);
    79                   val (p, f) = mk_prems (vs @ [r]) ds;
    80                 in
    81                   (mk_all k ("r" ^ s) (Us ---> rT) (Logic.mk_implies
    82                     (Logic.list_all (map (pair "x") Us, HOLogic.mk_Trueprop
    83                       (make_pred k rT U (Old_Datatype_Aux.app_bnds r i)
    84                         (Old_Datatype_Aux.app_bnds (Free (s, T)) i))), p)), f)
    85                 end;
    86         in (apfst (fold_rev (Logic.all o Free) frees) (mk_prems (map Free frees) recs), j + 1) end)
    87           constrs) (descr ~~ recTs) 1)));
    88 
    89     fun mk_proj _ [] t = t
    90       | mk_proj j (i :: is) t =
    91           if null is then t
    92           else if (j: int) = i then HOLogic.mk_fst t
    93           else mk_proj j is (HOLogic.mk_snd t);
    94 
    95     val tnames = Old_Datatype_Prop.make_tnames recTs;
    96     val fTs = map fastype_of rec_fns;
    97     val ps = map (fn ((((i, _), T), U), s) => Abs ("x", T, make_pred i U T
    98       (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Bound 0) (Bound 0)))
    99         (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names);
   100     val r =
   101       if null is then Extraction.nullt
   102       else
   103         foldr1 HOLogic.mk_prod (map_filter (fn (((((i, _), T), U), s), tname) =>
   104           if member (op =) is i then SOME
   105             (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Free (tname, T))
   106           else NONE) (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names ~~ tnames));
   107     val concl =
   108       HOLogic.mk_Trueprop (foldr1 (HOLogic.mk_binop @{const_name HOL.conj})
   109         (map (fn ((((i, _), T), U), tname) =>
   110           make_pred i U T (mk_proj i is r) (Free (tname, T)))
   111             (descr ~~ recTs ~~ rec_result_Ts ~~ tnames)));
   112     val inst = map (pairself cert) (map head_of (HOLogic.dest_conj
   113       (HOLogic.dest_Trueprop (concl_of induct))) ~~ ps);
   114 
   115     val thm =
   116       Goal.prove_internal ctxt (map cert prems) (cert concl)
   117         (fn prems =>
   118            EVERY [
   119             rewrite_goals_tac ctxt (map mk_meta_eq [@{thm fst_conv}, @{thm snd_conv}]),
   120             rtac (cterm_instantiate inst induct) 1,
   121             ALLGOALS (Object_Logic.atomize_prems_tac ctxt),
   122             rewrite_goals_tac ctxt (@{thm o_def} :: map mk_meta_eq rec_rewrites),
   123             REPEAT ((resolve_tac prems THEN_ALL_NEW (fn i =>
   124               REPEAT (etac allE i) THEN atac i)) 1)])
   125       |> Drule.export_without_context;
   126 
   127     val ind_name = Thm.derivation_name induct;
   128     val vs = map (nth pnames) is;
   129     val (thm', thy') = thy
   130       |> Sign.root_path
   131       |> Global_Theory.store_thm
   132         (Binding.qualified_name (space_implode "_" (ind_name :: vs @ ["correctness"])), thm)
   133       ||> Sign.restore_naming thy;
   134 
   135     val ivs = rev (Term.add_vars (Logic.varify_global (Old_Datatype_Prop.make_ind [descr])) []);
   136     val rvs = rev (Thm.fold_terms Term.add_vars thm' []);
   137     val ivs1 = map Var (filter_out (fn (_, T) => @{type_name bool} = tname_of (body_type T)) ivs);
   138     val ivs2 = map (fn (ixn, _) => Var (ixn, the (AList.lookup (op =) rvs ixn))) ivs;
   139 
   140     val prf =
   141       Extraction.abs_corr_shyps thy' induct vs ivs2
   142         (fold_rev (fn (f, p) => fn prf =>
   143             (case head_of (strip_abs_body f) of
   144               Free (s, T) =>
   145                 let val T' = Logic.varifyT_global T in
   146                   Abst (s, SOME T', Proofterm.prf_abstract_over
   147                     (Var ((s, 0), T')) (AbsP ("H", SOME p, prf)))
   148                 end
   149             | _ => AbsP ("H", SOME p, prf)))
   150           (rec_fns ~~ prems_of thm)
   151           (Proofterm.proof_combP
   152             (Reconstruct.proof_of thm', map PBound (length prems - 1 downto 0))));
   153 
   154     val r' =
   155       if null is then r
   156       else
   157         Logic.varify_global (fold_rev lambda
   158           (map Logic.unvarify_global ivs1 @ filter_out is_unit
   159               (map (head_of o strip_abs_body) rec_fns)) r);
   160 
   161   in Extraction.add_realizers_i [(ind_name, (vs, r', prf))] thy' end;
   162 
   163 fun make_casedists ({index, descr, case_name, case_rewrites, exhaust, ...} : Old_Datatype_Aux.info) thy =
   164   let
   165     val ctxt = Proof_Context.init_global thy;
   166     val cert = cterm_of thy;
   167     val rT = TFree ("'P", @{sort type});
   168     val rT' = TVar (("'P", 0), @{sort type});
   169 
   170     fun make_casedist_prem T (cname, cargs) =
   171       let
   172         val Ts = map (Old_Datatype_Aux.typ_of_dtyp descr) cargs;
   173         val frees = Name.variant_list ["P", "y"] (Old_Datatype_Prop.make_tnames Ts) ~~ Ts;
   174         val free_ts = map Free frees;
   175         val r = Free ("r" ^ Long_Name.base_name cname, Ts ---> rT)
   176       in
   177         (r, fold_rev Logic.all free_ts
   178           (Logic.mk_implies (HOLogic.mk_Trueprop
   179             (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))),
   180               HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
   181                 list_comb (r, free_ts)))))
   182       end;
   183 
   184     val SOME (_, _, constrs) = AList.lookup (op =) descr index;
   185     val T = nth (Old_Datatype_Aux.get_rec_types descr) index;
   186     val (rs, prems) = split_list (map (make_casedist_prem T) constrs);
   187     val r = Const (case_name, map fastype_of rs ---> T --> rT);
   188 
   189     val y = Var (("y", 0), Logic.varifyT_global T);
   190     val y' = Free ("y", T);
   191 
   192     val thm =
   193       Goal.prove_internal ctxt (map cert prems)
   194         (cert (HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $ list_comb (r, rs @ [y']))))
   195         (fn prems =>
   196            EVERY [
   197             rtac (cterm_instantiate [(cert y, cert y')] exhaust) 1,
   198             ALLGOALS (EVERY'
   199               [asm_simp_tac (put_simpset HOL_basic_ss ctxt addsimps case_rewrites),
   200                resolve_tac prems, asm_simp_tac (put_simpset HOL_basic_ss ctxt)])])
   201       |> Drule.export_without_context;
   202 
   203     val exh_name = Thm.derivation_name exhaust;
   204     val (thm', thy') = thy
   205       |> Sign.root_path
   206       |> Global_Theory.store_thm (Binding.qualified_name (exh_name ^ "_P_correctness"), thm)
   207       ||> Sign.restore_naming thy;
   208 
   209     val P = Var (("P", 0), rT' --> HOLogic.boolT);
   210     val prf =
   211       Extraction.abs_corr_shyps thy' exhaust ["P"] [y, P]
   212         (fold_rev (fn (p, r) => fn prf =>
   213             Proofterm.forall_intr_proof' (Logic.varify_global r)
   214               (AbsP ("H", SOME (Logic.varify_global p), prf)))
   215           (prems ~~ rs)
   216           (Proofterm.proof_combP
   217             (Reconstruct.proof_of thm', map PBound (length prems - 1 downto 0))));
   218     val prf' =
   219       Extraction.abs_corr_shyps thy' exhaust []
   220         (map Var (Term.add_vars (prop_of exhaust) [])) (Reconstruct.proof_of exhaust);
   221     val r' =
   222       Logic.varify_global (Abs ("y", T,
   223         (fold_rev (Term.abs o dest_Free) rs
   224           (list_comb (r, map Bound ((length rs - 1 downto 0) @ [length rs]))))));
   225   in
   226     Extraction.add_realizers_i
   227       [(exh_name, (["P"], r', prf)),
   228        (exh_name, ([], Extraction.nullt, prf'))] thy'
   229   end;
   230 
   231 fun add_dt_realizers config names thy =
   232   if not (Proofterm.proofs_enabled ()) then thy
   233   else
   234     let
   235       val _ = Old_Datatype_Aux.message config "Adding realizers for induction and case analysis ...";
   236       val infos = map (BNF_LFP_Compat.the_info thy BNF_LFP_Compat.Unfold_Nesting) names;
   237       val info :: _ = infos;
   238     in
   239       thy
   240       |> fold_rev (perhaps o try o make_ind info) (subsets 0 (length (#descr info) - 1))
   241       |> fold_rev (perhaps o try o make_casedists) infos
   242     end;
   243 
   244 val _ = Theory.setup (BNF_LFP_Compat.interpretation extraction_plugin BNF_LFP_Compat.Unfold_Nesting
   245   add_dt_realizers);
   246 
   247 end;