src/HOL/Tools/datatype_realizer.ML
author wenzelm
Fri Mar 06 23:56:43 2015 +0100 (2015-03-06)
changeset 59642 929984c529d3
parent 59621 291934bac95e
child 60752 b48830b670a1
permissions -rw-r--r--
clarified context;
     1 (*  Title:      HOL/Tools/datatype_realizer.ML
     2     Author:     Stefan Berghofer, TU Muenchen
     3 
     4 Program extraction from proofs involving datatypes:
     5 realizers for induction and case analysis.
     6 *)
     7 
     8 signature DATATYPE_REALIZER =
     9 sig
    10   val add_dt_realizers: Old_Datatype_Aux.config -> string list -> theory -> theory
    11 end;
    12 
    13 structure Datatype_Realizer : DATATYPE_REALIZER =
    14 struct
    15 
    16 fun subsets i j =
    17   if i <= j then
    18     let val is = subsets (i+1) j
    19     in map (fn ks => i::ks) is @ is end
    20   else [[]];
    21 
    22 fun is_unit t = body_type (fastype_of t) = HOLogic.unitT;
    23 
    24 fun tname_of (Type (s, _)) = s
    25   | tname_of _ = "";
    26 
    27 fun make_ind ({descr, rec_names, rec_rewrites, induct, ...} : Old_Datatype_Aux.info) is thy =
    28   let
    29     val ctxt = Proof_Context.init_global thy;
    30 
    31     val recTs = Old_Datatype_Aux.get_rec_types descr;
    32     val pnames =
    33       if length descr = 1 then ["P"]
    34       else map (fn i => "P" ^ string_of_int i) (1 upto length descr);
    35 
    36     val rec_result_Ts = map (fn ((i, _), P) =>
    37         if member (op =) is i then TFree ("'" ^ P, @{sort type}) else HOLogic.unitT)
    38       (descr ~~ pnames);
    39 
    40     fun make_pred i T U r x =
    41       if member (op =) is i then
    42         Free (nth pnames i, T --> U --> HOLogic.boolT) $ r $ x
    43       else Free (nth pnames i, U --> HOLogic.boolT) $ x;
    44 
    45     fun mk_all i s T t =
    46       if member (op =) is i then Logic.all (Free (s, T)) t else t;
    47 
    48     val (prems, rec_fns) = split_list (flat (fst (fold_map
    49       (fn ((i, (_, _, constrs)), T) => fold_map (fn (cname, cargs) => fn j =>
    50         let
    51           val Ts = map (Old_Datatype_Aux.typ_of_dtyp descr) cargs;
    52           val tnames = Name.variant_list pnames (Old_Datatype_Prop.make_tnames Ts);
    53           val recs = filter (Old_Datatype_Aux.is_rec_type o fst o fst) (cargs ~~ tnames ~~ Ts);
    54           val frees = tnames ~~ Ts;
    55 
    56           fun mk_prems vs [] =
    57                 let
    58                   val rT = nth (rec_result_Ts) i;
    59                   val vs' = filter_out is_unit vs;
    60                   val f = Old_Datatype_Aux.mk_Free "f" (map fastype_of vs' ---> rT) j;
    61                   val f' =
    62                     Envir.eta_contract (fold_rev (absfree o dest_Free) vs
    63                       (if member (op =) is i then list_comb (f, vs') else HOLogic.unit));
    64                 in
    65                   (HOLogic.mk_Trueprop (make_pred i rT T (list_comb (f, vs'))
    66                     (list_comb (Const (cname, Ts ---> T), map Free frees))), f')
    67                 end
    68             | mk_prems vs (((dt, s), T) :: ds) =
    69                 let
    70                   val k = Old_Datatype_Aux.body_index dt;
    71                   val (Us, U) = strip_type T;
    72                   val i = length Us;
    73                   val rT = nth (rec_result_Ts) k;
    74                   val r = Free ("r" ^ s, Us ---> rT);
    75                   val (p, f) = mk_prems (vs @ [r]) ds;
    76                 in
    77                   (mk_all k ("r" ^ s) (Us ---> rT) (Logic.mk_implies
    78                     (Logic.list_all (map (pair "x") Us, HOLogic.mk_Trueprop
    79                       (make_pred k rT U (Old_Datatype_Aux.app_bnds r i)
    80                         (Old_Datatype_Aux.app_bnds (Free (s, T)) i))), p)), f)
    81                 end;
    82         in (apfst (fold_rev (Logic.all o Free) frees) (mk_prems (map Free frees) recs), j + 1) end)
    83           constrs) (descr ~~ recTs) 1)));
    84 
    85     fun mk_proj _ [] t = t
    86       | mk_proj j (i :: is) t =
    87           if null is then t
    88           else if (j: int) = i then HOLogic.mk_fst t
    89           else mk_proj j is (HOLogic.mk_snd t);
    90 
    91     val tnames = Old_Datatype_Prop.make_tnames recTs;
    92     val fTs = map fastype_of rec_fns;
    93     val ps = map (fn ((((i, _), T), U), s) => Abs ("x", T, make_pred i U T
    94       (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Bound 0) (Bound 0)))
    95         (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names);
    96     val r =
    97       if null is then Extraction.nullt
    98       else
    99         foldr1 HOLogic.mk_prod (map_filter (fn (((((i, _), T), U), s), tname) =>
   100           if member (op =) is i then SOME
   101             (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Free (tname, T))
   102           else NONE) (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names ~~ tnames));
   103     val concl =
   104       HOLogic.mk_Trueprop (foldr1 (HOLogic.mk_binop @{const_name HOL.conj})
   105         (map (fn ((((i, _), T), U), tname) =>
   106           make_pred i U T (mk_proj i is r) (Free (tname, T)))
   107             (descr ~~ recTs ~~ rec_result_Ts ~~ tnames)));
   108     val inst = map (apply2 (Thm.cterm_of ctxt)) (map head_of (HOLogic.dest_conj
   109       (HOLogic.dest_Trueprop (Thm.concl_of induct))) ~~ ps);
   110 
   111     val thm =
   112       Goal.prove_internal ctxt (map (Thm.cterm_of ctxt) prems) (Thm.cterm_of ctxt concl)
   113         (fn prems =>
   114            EVERY [
   115             rewrite_goals_tac ctxt (map mk_meta_eq [@{thm fst_conv}, @{thm snd_conv}]),
   116             rtac (cterm_instantiate inst induct) 1,
   117             ALLGOALS (Object_Logic.atomize_prems_tac ctxt),
   118             rewrite_goals_tac ctxt (@{thm o_def} :: map mk_meta_eq rec_rewrites),
   119             REPEAT ((resolve_tac ctxt prems THEN_ALL_NEW (fn i =>
   120               REPEAT (etac allE i) THEN assume_tac ctxt i)) 1)])
   121       |> Drule.export_without_context;
   122 
   123     val ind_name = Thm.derivation_name induct;
   124     val vs = map (nth pnames) is;
   125     val (thm', thy') = thy
   126       |> Sign.root_path
   127       |> Global_Theory.store_thm
   128         (Binding.qualified_name (space_implode "_" (ind_name :: vs @ ["correctness"])), thm)
   129       ||> Sign.restore_naming thy;
   130 
   131     val ivs = rev (Term.add_vars (Logic.varify_global (Old_Datatype_Prop.make_ind [descr])) []);
   132     val rvs = rev (Thm.fold_terms Term.add_vars thm' []);
   133     val ivs1 = map Var (filter_out (fn (_, T) => @{type_name bool} = tname_of (body_type T)) ivs);
   134     val ivs2 = map (fn (ixn, _) => Var (ixn, the (AList.lookup (op =) rvs ixn))) ivs;
   135 
   136     val prf =
   137       Extraction.abs_corr_shyps thy' induct vs ivs2
   138         (fold_rev (fn (f, p) => fn prf =>
   139             (case head_of (strip_abs_body f) of
   140               Free (s, T) =>
   141                 let val T' = Logic.varifyT_global T in
   142                   Abst (s, SOME T', Proofterm.prf_abstract_over
   143                     (Var ((s, 0), T')) (AbsP ("H", SOME p, prf)))
   144                 end
   145             | _ => AbsP ("H", SOME p, prf)))
   146           (rec_fns ~~ Thm.prems_of thm)
   147           (Proofterm.proof_combP
   148             (Reconstruct.proof_of thm', map PBound (length prems - 1 downto 0))));
   149 
   150     val r' =
   151       if null is then r
   152       else
   153         Logic.varify_global (fold_rev lambda
   154           (map Logic.unvarify_global ivs1 @ filter_out is_unit
   155               (map (head_of o strip_abs_body) rec_fns)) r);
   156 
   157   in Extraction.add_realizers_i [(ind_name, (vs, r', prf))] thy' end;
   158 
   159 fun make_casedists ({index, descr, case_name, case_rewrites, exhaust, ...} : Old_Datatype_Aux.info) thy =
   160   let
   161     val ctxt = Proof_Context.init_global thy;
   162     val rT = TFree ("'P", @{sort type});
   163     val rT' = TVar (("'P", 0), @{sort type});
   164 
   165     fun make_casedist_prem T (cname, cargs) =
   166       let
   167         val Ts = map (Old_Datatype_Aux.typ_of_dtyp descr) cargs;
   168         val frees = Name.variant_list ["P", "y"] (Old_Datatype_Prop.make_tnames Ts) ~~ Ts;
   169         val free_ts = map Free frees;
   170         val r = Free ("r" ^ Long_Name.base_name cname, Ts ---> rT)
   171       in
   172         (r, fold_rev Logic.all free_ts
   173           (Logic.mk_implies (HOLogic.mk_Trueprop
   174             (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))),
   175               HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
   176                 list_comb (r, free_ts)))))
   177       end;
   178 
   179     val SOME (_, _, constrs) = AList.lookup (op =) descr index;
   180     val T = nth (Old_Datatype_Aux.get_rec_types descr) index;
   181     val (rs, prems) = split_list (map (make_casedist_prem T) constrs);
   182     val r = Const (case_name, map fastype_of rs ---> T --> rT);
   183 
   184     val y = Var (("y", 0), Logic.varifyT_global T);
   185     val y' = Free ("y", T);
   186 
   187     val thm =
   188       Goal.prove_internal ctxt (map (Thm.cterm_of ctxt) prems)
   189         (Thm.cterm_of ctxt
   190           (HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $ list_comb (r, rs @ [y']))))
   191         (fn prems =>
   192            EVERY [
   193             rtac (cterm_instantiate [apply2 (Thm.cterm_of ctxt) (y, y')] exhaust) 1,
   194             ALLGOALS (EVERY'
   195               [asm_simp_tac (put_simpset HOL_basic_ss ctxt addsimps case_rewrites),
   196                resolve_tac ctxt prems, asm_simp_tac (put_simpset HOL_basic_ss ctxt)])])
   197       |> Drule.export_without_context;
   198 
   199     val exh_name = Thm.derivation_name exhaust;
   200     val (thm', thy') = thy
   201       |> Sign.root_path
   202       |> Global_Theory.store_thm (Binding.qualified_name (exh_name ^ "_P_correctness"), thm)
   203       ||> Sign.restore_naming thy;
   204 
   205     val P = Var (("P", 0), rT' --> HOLogic.boolT);
   206     val prf =
   207       Extraction.abs_corr_shyps thy' exhaust ["P"] [y, P]
   208         (fold_rev (fn (p, r) => fn prf =>
   209             Proofterm.forall_intr_proof' (Logic.varify_global r)
   210               (AbsP ("H", SOME (Logic.varify_global p), prf)))
   211           (prems ~~ rs)
   212           (Proofterm.proof_combP
   213             (Reconstruct.proof_of thm', map PBound (length prems - 1 downto 0))));
   214     val prf' =
   215       Extraction.abs_corr_shyps thy' exhaust []
   216         (map Var (Term.add_vars (Thm.prop_of exhaust) [])) (Reconstruct.proof_of exhaust);
   217     val r' =
   218       Logic.varify_global (Abs ("y", T,
   219         (fold_rev (Term.abs o dest_Free) rs
   220           (list_comb (r, map Bound ((length rs - 1 downto 0) @ [length rs]))))));
   221   in
   222     Extraction.add_realizers_i
   223       [(exh_name, (["P"], r', prf)),
   224        (exh_name, ([], Extraction.nullt, prf'))] thy'
   225   end;
   226 
   227 fun add_dt_realizers config names thy =
   228   if not (Proofterm.proofs_enabled ()) then thy
   229   else
   230     let
   231       val _ = Old_Datatype_Aux.message config "Adding realizers for induction and case analysis ...";
   232       val infos = map (BNF_LFP_Compat.the_info thy []) names;
   233       val info :: _ = infos;
   234     in
   235       thy
   236       |> fold_rev (perhaps o try o make_ind info) (subsets 0 (length (#descr info) - 1))
   237       |> fold_rev (perhaps o try o make_casedists) infos
   238     end;
   239 
   240 val _ = Theory.setup (BNF_LFP_Compat.interpretation @{plugin extraction} [] add_dt_realizers);
   241 
   242 end;