src/HOL/Tools/datatype_realizer.ML
changeset 58275 280ede57a6a9
parent 58274 4a84e94e58a2
child 58277 0dcd3a623a6e
     1.1 --- /dev/null	Thu Jan 01 00:00:00 1970 +0000
     1.2 +++ b/src/HOL/Tools/datatype_realizer.ML	Tue Sep 09 20:51:36 2014 +0200
     1.3 @@ -0,0 +1,250 @@
     1.4 +(*  Title:      HOL/Tools/datatype_realizer.ML
     1.5 +    Author:     Stefan Berghofer, TU Muenchen
     1.6 +
     1.7 +Program extraction from proofs involving datatypes:
     1.8 +realizers for induction and case analysis.
     1.9 +*)
    1.10 +
    1.11 +signature DATATYPE_REALIZER =
    1.12 +sig
    1.13 +  val realizer_plugin: string
    1.14 +  val add_dt_realizers: Old_Datatype_Aux.config -> string list -> theory -> theory
    1.15 +end;
    1.16 +
    1.17 +structure Datatype_Realizer : DATATYPE_REALIZER =
    1.18 +struct
    1.19 +
    1.20 +val realizer_plugin = "realizer";
    1.21 +
    1.22 +fun subsets i j =
    1.23 +  if i <= j then
    1.24 +    let val is = subsets (i+1) j
    1.25 +    in map (fn ks => i::ks) is @ is end
    1.26 +  else [[]];
    1.27 +
    1.28 +fun is_unit t = body_type (fastype_of t) = HOLogic.unitT;
    1.29 +
    1.30 +fun tname_of (Type (s, _)) = s
    1.31 +  | tname_of _ = "";
    1.32 +
    1.33 +fun make_ind ({descr, rec_names, rec_rewrites, induct, ...} : Old_Datatype_Aux.info) is thy =
    1.34 +  let
    1.35 +    val ctxt = Proof_Context.init_global thy;
    1.36 +    val cert = cterm_of thy;
    1.37 +
    1.38 +    val recTs = Old_Datatype_Aux.get_rec_types descr;
    1.39 +    val pnames =
    1.40 +      if length descr = 1 then ["P"]
    1.41 +      else map (fn i => "P" ^ string_of_int i) (1 upto length descr);
    1.42 +
    1.43 +    val rec_result_Ts = map (fn ((i, _), P) =>
    1.44 +        if member (op =) is i then TFree ("'" ^ P, @{sort type}) else HOLogic.unitT)
    1.45 +      (descr ~~ pnames);
    1.46 +
    1.47 +    fun make_pred i T U r x =
    1.48 +      if member (op =) is i then
    1.49 +        Free (nth pnames i, T --> U --> HOLogic.boolT) $ r $ x
    1.50 +      else Free (nth pnames i, U --> HOLogic.boolT) $ x;
    1.51 +
    1.52 +    fun mk_all i s T t =
    1.53 +      if member (op =) is i then Logic.all (Free (s, T)) t else t;
    1.54 +
    1.55 +    val (prems, rec_fns) = split_list (flat (fst (fold_map
    1.56 +      (fn ((i, (_, _, constrs)), T) => fold_map (fn (cname, cargs) => fn j =>
    1.57 +        let
    1.58 +          val Ts = map (Old_Datatype_Aux.typ_of_dtyp descr) cargs;
    1.59 +          val tnames = Name.variant_list pnames (Old_Datatype_Prop.make_tnames Ts);
    1.60 +          val recs = filter (Old_Datatype_Aux.is_rec_type o fst o fst) (cargs ~~ tnames ~~ Ts);
    1.61 +          val frees = tnames ~~ Ts;
    1.62 +
    1.63 +          fun mk_prems vs [] =
    1.64 +                let
    1.65 +                  val rT = nth (rec_result_Ts) i;
    1.66 +                  val vs' = filter_out is_unit vs;
    1.67 +                  val f = Old_Datatype_Aux.mk_Free "f" (map fastype_of vs' ---> rT) j;
    1.68 +                  val f' =
    1.69 +                    Envir.eta_contract (fold_rev (absfree o dest_Free) vs
    1.70 +                      (if member (op =) is i then list_comb (f, vs') else HOLogic.unit));
    1.71 +                in
    1.72 +                  (HOLogic.mk_Trueprop (make_pred i rT T (list_comb (f, vs'))
    1.73 +                    (list_comb (Const (cname, Ts ---> T), map Free frees))), f')
    1.74 +                end
    1.75 +            | mk_prems vs (((dt, s), T) :: ds) =
    1.76 +                let
    1.77 +                  val k = Old_Datatype_Aux.body_index dt;
    1.78 +                  val (Us, U) = strip_type T;
    1.79 +                  val i = length Us;
    1.80 +                  val rT = nth (rec_result_Ts) k;
    1.81 +                  val r = Free ("r" ^ s, Us ---> rT);
    1.82 +                  val (p, f) = mk_prems (vs @ [r]) ds;
    1.83 +                in
    1.84 +                  (mk_all k ("r" ^ s) (Us ---> rT) (Logic.mk_implies
    1.85 +                    (Logic.list_all (map (pair "x") Us, HOLogic.mk_Trueprop
    1.86 +                      (make_pred k rT U (Old_Datatype_Aux.app_bnds r i)
    1.87 +                        (Old_Datatype_Aux.app_bnds (Free (s, T)) i))), p)), f)
    1.88 +                end;
    1.89 +        in (apfst (fold_rev (Logic.all o Free) frees) (mk_prems (map Free frees) recs), j + 1) end)
    1.90 +          constrs) (descr ~~ recTs) 1)));
    1.91 +
    1.92 +    fun mk_proj _ [] t = t
    1.93 +      | mk_proj j (i :: is) t =
    1.94 +          if null is then t
    1.95 +          else if (j: int) = i then HOLogic.mk_fst t
    1.96 +          else mk_proj j is (HOLogic.mk_snd t);
    1.97 +
    1.98 +    val tnames = Old_Datatype_Prop.make_tnames recTs;
    1.99 +    val fTs = map fastype_of rec_fns;
   1.100 +    val ps = map (fn ((((i, _), T), U), s) => Abs ("x", T, make_pred i U T
   1.101 +      (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Bound 0) (Bound 0)))
   1.102 +        (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names);
   1.103 +    val r =
   1.104 +      if null is then Extraction.nullt
   1.105 +      else
   1.106 +        foldr1 HOLogic.mk_prod (map_filter (fn (((((i, _), T), U), s), tname) =>
   1.107 +          if member (op =) is i then SOME
   1.108 +            (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Free (tname, T))
   1.109 +          else NONE) (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names ~~ tnames));
   1.110 +    val concl =
   1.111 +      HOLogic.mk_Trueprop (foldr1 (HOLogic.mk_binop @{const_name HOL.conj})
   1.112 +        (map (fn ((((i, _), T), U), tname) =>
   1.113 +          make_pred i U T (mk_proj i is r) (Free (tname, T)))
   1.114 +            (descr ~~ recTs ~~ rec_result_Ts ~~ tnames)));
   1.115 +    val inst = map (pairself cert) (map head_of (HOLogic.dest_conj
   1.116 +      (HOLogic.dest_Trueprop (concl_of induct))) ~~ ps);
   1.117 +
   1.118 +    val thm =
   1.119 +      Goal.prove_internal ctxt (map cert prems) (cert concl)
   1.120 +        (fn prems =>
   1.121 +           EVERY [
   1.122 +            rewrite_goals_tac ctxt (map mk_meta_eq [@{thm fst_conv}, @{thm snd_conv}]),
   1.123 +            rtac (cterm_instantiate inst induct) 1,
   1.124 +            ALLGOALS (Object_Logic.atomize_prems_tac ctxt),
   1.125 +            rewrite_goals_tac ctxt (@{thm o_def} :: map mk_meta_eq rec_rewrites),
   1.126 +            REPEAT ((resolve_tac prems THEN_ALL_NEW (fn i =>
   1.127 +              REPEAT (etac allE i) THEN atac i)) 1)])
   1.128 +      |> Drule.export_without_context;
   1.129 +
   1.130 +    val ind_name = Thm.derivation_name induct;
   1.131 +    val vs = map (nth pnames) is;
   1.132 +    val (thm', thy') = thy
   1.133 +      |> Sign.root_path
   1.134 +      |> Global_Theory.store_thm
   1.135 +        (Binding.qualified_name (space_implode "_" (ind_name :: vs @ ["correctness"])), thm)
   1.136 +      ||> Sign.restore_naming thy;
   1.137 +
   1.138 +    val ivs = rev (Term.add_vars (Logic.varify_global (Old_Datatype_Prop.make_ind [descr])) []);
   1.139 +    val rvs = rev (Thm.fold_terms Term.add_vars thm' []);
   1.140 +    val ivs1 = map Var (filter_out (fn (_, T) => @{type_name bool} = tname_of (body_type T)) ivs);
   1.141 +    val ivs2 = map (fn (ixn, _) => Var (ixn, the (AList.lookup (op =) rvs ixn))) ivs;
   1.142 +
   1.143 +    val prf =
   1.144 +      Extraction.abs_corr_shyps thy' induct vs ivs2
   1.145 +        (fold_rev (fn (f, p) => fn prf =>
   1.146 +            (case head_of (strip_abs_body f) of
   1.147 +              Free (s, T) =>
   1.148 +                let val T' = Logic.varifyT_global T in
   1.149 +                  Abst (s, SOME T', Proofterm.prf_abstract_over
   1.150 +                    (Var ((s, 0), T')) (AbsP ("H", SOME p, prf)))
   1.151 +                end
   1.152 +            | _ => AbsP ("H", SOME p, prf)))
   1.153 +          (rec_fns ~~ prems_of thm)
   1.154 +          (Proofterm.proof_combP
   1.155 +            (Reconstruct.proof_of thm', map PBound (length prems - 1 downto 0))));
   1.156 +
   1.157 +    val r' =
   1.158 +      if null is then r
   1.159 +      else
   1.160 +        Logic.varify_global (fold_rev lambda
   1.161 +          (map Logic.unvarify_global ivs1 @ filter_out is_unit
   1.162 +              (map (head_of o strip_abs_body) rec_fns)) r);
   1.163 +
   1.164 +  in Extraction.add_realizers_i [(ind_name, (vs, r', prf))] thy' end
   1.165 +  (* Nested new-style datatypes are not supported (unless they are registered via
   1.166 +     "datatype_compat"). *)
   1.167 +  handle Old_Datatype_Aux.Datatype => thy;
   1.168 +
   1.169 +fun make_casedists ({index, descr, case_name, case_rewrites, exhaust, ...} : Old_Datatype_Aux.info) thy =
   1.170 +  let
   1.171 +    val ctxt = Proof_Context.init_global thy;
   1.172 +    val cert = cterm_of thy;
   1.173 +    val rT = TFree ("'P", @{sort type});
   1.174 +    val rT' = TVar (("'P", 0), @{sort type});
   1.175 +
   1.176 +    fun make_casedist_prem T (cname, cargs) =
   1.177 +      let
   1.178 +        val Ts = map (Old_Datatype_Aux.typ_of_dtyp descr) cargs;
   1.179 +        val frees = Name.variant_list ["P", "y"] (Old_Datatype_Prop.make_tnames Ts) ~~ Ts;
   1.180 +        val free_ts = map Free frees;
   1.181 +        val r = Free ("r" ^ Long_Name.base_name cname, Ts ---> rT)
   1.182 +      in
   1.183 +        (r, fold_rev Logic.all free_ts
   1.184 +          (Logic.mk_implies (HOLogic.mk_Trueprop
   1.185 +            (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))),
   1.186 +              HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
   1.187 +                list_comb (r, free_ts)))))
   1.188 +      end;
   1.189 +
   1.190 +    val SOME (_, _, constrs) = AList.lookup (op =) descr index;
   1.191 +    val T = nth (Old_Datatype_Aux.get_rec_types descr) index;
   1.192 +    val (rs, prems) = split_list (map (make_casedist_prem T) constrs);
   1.193 +    val r = Const (case_name, map fastype_of rs ---> T --> rT);
   1.194 +
   1.195 +    val y = Var (("y", 0), Logic.varifyT_global T);
   1.196 +    val y' = Free ("y", T);
   1.197 +
   1.198 +    val thm =
   1.199 +      Goal.prove_internal ctxt (map cert prems)
   1.200 +        (cert (HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $ list_comb (r, rs @ [y']))))
   1.201 +        (fn prems =>
   1.202 +           EVERY [
   1.203 +            rtac (cterm_instantiate [(cert y, cert y')] exhaust) 1,
   1.204 +            ALLGOALS (EVERY'
   1.205 +              [asm_simp_tac (put_simpset HOL_basic_ss ctxt addsimps case_rewrites),
   1.206 +               resolve_tac prems, asm_simp_tac (put_simpset HOL_basic_ss ctxt)])])
   1.207 +      |> Drule.export_without_context;
   1.208 +
   1.209 +    val exh_name = Thm.derivation_name exhaust;
   1.210 +    val (thm', thy') = thy
   1.211 +      |> Sign.root_path
   1.212 +      |> Global_Theory.store_thm (Binding.qualified_name (exh_name ^ "_P_correctness"), thm)
   1.213 +      ||> Sign.restore_naming thy;
   1.214 +
   1.215 +    val P = Var (("P", 0), rT' --> HOLogic.boolT);
   1.216 +    val prf =
   1.217 +      Extraction.abs_corr_shyps thy' exhaust ["P"] [y, P]
   1.218 +        (fold_rev (fn (p, r) => fn prf =>
   1.219 +            Proofterm.forall_intr_proof' (Logic.varify_global r)
   1.220 +              (AbsP ("H", SOME (Logic.varify_global p), prf)))
   1.221 +          (prems ~~ rs)
   1.222 +          (Proofterm.proof_combP
   1.223 +            (Reconstruct.proof_of thm', map PBound (length prems - 1 downto 0))));
   1.224 +    val prf' =
   1.225 +      Extraction.abs_corr_shyps thy' exhaust []
   1.226 +        (map Var (Term.add_vars (prop_of exhaust) [])) (Reconstruct.proof_of exhaust);
   1.227 +    val r' =
   1.228 +      Logic.varify_global (Abs ("y", T,
   1.229 +        (fold_rev (Term.abs o dest_Free) rs
   1.230 +          (list_comb (r, map Bound ((length rs - 1 downto 0) @ [length rs]))))));
   1.231 +  in
   1.232 +    Extraction.add_realizers_i
   1.233 +      [(exh_name, (["P"], r', prf)),
   1.234 +       (exh_name, ([], Extraction.nullt, prf'))] thy'
   1.235 +  end;
   1.236 +
   1.237 +fun add_dt_realizers config names thy =
   1.238 +  if not (Proofterm.proofs_enabled ()) then thy
   1.239 +  else
   1.240 +    let
   1.241 +      val _ = Old_Datatype_Aux.message config "Adding realizers for induction and case analysis ...";
   1.242 +      val infos = map (BNF_LFP_Compat.the_info thy BNF_LFP_Compat.Unfold_Nesting) names;
   1.243 +      val info :: _ = infos;
   1.244 +    in
   1.245 +      thy
   1.246 +      |> fold_rev (make_ind info) (subsets 0 (length (#descr info) - 1))
   1.247 +      |> fold_rev make_casedists infos
   1.248 +    end;
   1.249 +
   1.250 +val _ = Theory.setup (BNF_LFP_Compat.interpretation realizer_plugin BNF_LFP_Compat.Unfold_Nesting
   1.251 +  add_dt_realizers);
   1.252 +
   1.253 +end;