src/HOL/Tools/datatype_realizer.ML
changeset 31664 ee3c9e31e029
parent 31663 5eb82f064630
parent 31653 b013d4340a32
child 31673 6cc4c63cc990
child 31706 1db0c8f235fb
equal deleted inserted replaced
31663:5eb82f064630 31664:ee3c9e31e029
     1 (*  Title:      HOL/Tools/datatype_realizer.ML
       
     2     Author:     Stefan Berghofer, TU Muenchen
       
     3 
       
     4 Porgram extraction from proofs involving datatypes:
       
     5 Realizers for induction and case analysis
       
     6 *)
       
     7 
       
     8 signature DATATYPE_REALIZER =
       
     9 sig
       
    10   val add_dt_realizers: string list -> theory -> theory
       
    11   val setup: theory -> theory
       
    12 end;
       
    13 
       
    14 structure DatatypeRealizer : DATATYPE_REALIZER =
       
    15 struct
       
    16 
       
    17 open DatatypeAux;
       
    18 
       
    19 fun subsets i j = if i <= j then
       
    20        let val is = subsets (i+1) j
       
    21        in map (fn ks => i::ks) is @ is end
       
    22      else [[]];
       
    23 
       
    24 fun forall_intr_prf (t, prf) =
       
    25   let val (a, T) = (case t of Var ((a, _), T) => (a, T) | Free p => p)
       
    26   in Abst (a, SOME T, Proofterm.prf_abstract_over t prf) end;
       
    27 
       
    28 fun prf_of thm =
       
    29   Reconstruct.reconstruct_proof (Thm.theory_of_thm thm) (Thm.prop_of thm) (Thm.proof_of thm);
       
    30 
       
    31 fun prf_subst_vars inst =
       
    32   Proofterm.map_proof_terms (subst_vars ([], inst)) I;
       
    33 
       
    34 fun is_unit t = snd (strip_type (fastype_of t)) = HOLogic.unitT;
       
    35 
       
    36 fun tname_of (Type (s, _)) = s
       
    37   | tname_of _ = "";
       
    38 
       
    39 fun mk_realizes T = Const ("realizes", T --> HOLogic.boolT --> HOLogic.boolT);
       
    40 
       
    41 fun make_ind sorts ({descr, rec_names, rec_rewrites, induction, ...} : datatype_info) is thy =
       
    42   let
       
    43     val recTs = get_rec_types descr sorts;
       
    44     val pnames = if length descr = 1 then ["P"]
       
    45       else map (fn i => "P" ^ string_of_int i) (1 upto length descr);
       
    46 
       
    47     val rec_result_Ts = map (fn ((i, _), P) =>
       
    48       if i mem is then TFree ("'" ^ P, HOLogic.typeS) else HOLogic.unitT)
       
    49         (descr ~~ pnames);
       
    50 
       
    51     fun make_pred i T U r x =
       
    52       if i mem is then
       
    53         Free (List.nth (pnames, i), T --> U --> HOLogic.boolT) $ r $ x
       
    54       else Free (List.nth (pnames, i), U --> HOLogic.boolT) $ x;
       
    55 
       
    56     fun mk_all i s T t =
       
    57       if i mem is then list_all_free ([(s, T)], t) else t;
       
    58 
       
    59     val (prems, rec_fns) = split_list (flat (fst (fold_map
       
    60       (fn ((i, (_, _, constrs)), T) => fold_map (fn (cname, cargs) => fn j =>
       
    61         let
       
    62           val Ts = map (typ_of_dtyp descr sorts) cargs;
       
    63           val tnames = Name.variant_list pnames (DatatypeProp.make_tnames Ts);
       
    64           val recs = filter (is_rec_type o fst o fst) (cargs ~~ tnames ~~ Ts);
       
    65           val frees = tnames ~~ Ts;
       
    66 
       
    67           fun mk_prems vs [] = 
       
    68                 let
       
    69                   val rT = nth (rec_result_Ts) i;
       
    70                   val vs' = filter_out is_unit vs;
       
    71                   val f = mk_Free "f" (map fastype_of vs' ---> rT) j;
       
    72                   val f' = Envir.eta_contract (list_abs_free
       
    73                     (map dest_Free vs, if i mem is then list_comb (f, vs')
       
    74                       else HOLogic.unit));
       
    75                 in (HOLogic.mk_Trueprop (make_pred i rT T (list_comb (f, vs'))
       
    76                   (list_comb (Const (cname, Ts ---> T), map Free frees))), f')
       
    77                 end
       
    78             | mk_prems vs (((dt, s), T) :: ds) = 
       
    79                 let
       
    80                   val k = body_index dt;
       
    81                   val (Us, U) = strip_type T;
       
    82                   val i = length Us;
       
    83                   val rT = nth (rec_result_Ts) k;
       
    84                   val r = Free ("r" ^ s, Us ---> rT);
       
    85                   val (p, f) = mk_prems (vs @ [r]) ds
       
    86                 in (mk_all k ("r" ^ s) (Us ---> rT) (Logic.mk_implies
       
    87                   (list_all (map (pair "x") Us, HOLogic.mk_Trueprop
       
    88                     (make_pred k rT U (app_bnds r i)
       
    89                       (app_bnds (Free (s, T)) i))), p)), f)
       
    90                 end
       
    91 
       
    92         in (apfst (curry list_all_free frees) (mk_prems (map Free frees) recs), j + 1) end)
       
    93           constrs) (descr ~~ recTs) 1)));
       
    94  
       
    95     fun mk_proj j [] t = t
       
    96       | mk_proj j (i :: is) t = if null is then t else
       
    97           if (j: int) = i then HOLogic.mk_fst t
       
    98           else mk_proj j is (HOLogic.mk_snd t);
       
    99 
       
   100     val tnames = DatatypeProp.make_tnames recTs;
       
   101     val fTs = map fastype_of rec_fns;
       
   102     val ps = map (fn ((((i, _), T), U), s) => Abs ("x", T, make_pred i U T
       
   103       (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Bound 0) (Bound 0)))
       
   104         (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names);
       
   105     val r = if null is then Extraction.nullt else
       
   106       foldr1 HOLogic.mk_prod (List.mapPartial (fn (((((i, _), T), U), s), tname) =>
       
   107         if i mem is then SOME
       
   108           (list_comb (Const (s, fTs ---> T --> U), rec_fns) $ Free (tname, T))
       
   109         else NONE) (descr ~~ recTs ~~ rec_result_Ts ~~ rec_names ~~ tnames));
       
   110     val concl = HOLogic.mk_Trueprop (foldr1 (HOLogic.mk_binop "op &")
       
   111       (map (fn ((((i, _), T), U), tname) =>
       
   112         make_pred i U T (mk_proj i is r) (Free (tname, T)))
       
   113           (descr ~~ recTs ~~ rec_result_Ts ~~ tnames)));
       
   114     val cert = cterm_of thy;
       
   115     val inst = map (pairself cert) (map head_of (HOLogic.dest_conj
       
   116       (HOLogic.dest_Trueprop (concl_of induction))) ~~ ps);
       
   117 
       
   118     val thm = OldGoals.simple_prove_goal_cterm (cert (Logic.list_implies (prems, concl)))
       
   119       (fn prems =>
       
   120          [rewrite_goals_tac (map mk_meta_eq [fst_conv, snd_conv]),
       
   121           rtac (cterm_instantiate inst induction) 1,
       
   122           ALLGOALS ObjectLogic.atomize_prems_tac,
       
   123           rewrite_goals_tac (@{thm o_def} :: map mk_meta_eq rec_rewrites),
       
   124           REPEAT ((resolve_tac prems THEN_ALL_NEW (fn i =>
       
   125             REPEAT (etac allE i) THEN atac i)) 1)]);
       
   126 
       
   127     val ind_name = Thm.get_name induction;
       
   128     val vs = map (fn i => List.nth (pnames, i)) is;
       
   129     val (thm', thy') = thy
       
   130       |> Sign.root_path
       
   131       |> PureThy.store_thm
       
   132         (Binding.qualified_name (space_implode "_" (ind_name :: vs @ ["correctness"])), thm)
       
   133       ||> Sign.restore_naming thy;
       
   134 
       
   135     val ivs = rev (Term.add_vars (Logic.varify (DatatypeProp.make_ind [descr] sorts)) []);
       
   136     val rvs = rev (Thm.fold_terms Term.add_vars thm' []);
       
   137     val ivs1 = map Var (filter_out (fn (_, T) =>
       
   138       tname_of (body_type T) mem ["set", "bool"]) ivs);
       
   139     val ivs2 = map (fn (ixn, _) => Var (ixn, valOf (AList.lookup (op =) rvs ixn))) ivs;
       
   140 
       
   141     val prf = List.foldr forall_intr_prf
       
   142      (List.foldr (fn ((f, p), prf) =>
       
   143         (case head_of (strip_abs_body f) of
       
   144            Free (s, T) =>
       
   145              let val T' = Logic.varifyT T
       
   146              in Abst (s, SOME T', Proofterm.prf_abstract_over
       
   147                (Var ((s, 0), T')) (AbsP ("H", SOME p, prf)))
       
   148              end
       
   149          | _ => AbsP ("H", SOME p, prf)))
       
   150            (Proofterm.proof_combP
       
   151              (prf_of thm', map PBound (length prems - 1 downto 0))) (rec_fns ~~ prems_of thm)) ivs2;
       
   152 
       
   153     val r' = if null is then r else Logic.varify (List.foldr (uncurry lambda)
       
   154       r (map Logic.unvarify ivs1 @ filter_out is_unit
       
   155           (map (head_of o strip_abs_body) rec_fns)));
       
   156 
       
   157   in Extraction.add_realizers_i [(ind_name, (vs, r', prf))] thy' end;
       
   158 
       
   159 
       
   160 fun make_casedists sorts ({index, descr, case_name, case_rewrites, exhaustion, ...} : datatype_info) thy =
       
   161   let
       
   162     val cert = cterm_of thy;
       
   163     val rT = TFree ("'P", HOLogic.typeS);
       
   164     val rT' = TVar (("'P", 0), HOLogic.typeS);
       
   165 
       
   166     fun make_casedist_prem T (cname, cargs) =
       
   167       let
       
   168         val Ts = map (typ_of_dtyp descr sorts) cargs;
       
   169         val frees = Name.variant_list ["P", "y"] (DatatypeProp.make_tnames Ts) ~~ Ts;
       
   170         val free_ts = map Free frees;
       
   171         val r = Free ("r" ^ Long_Name.base_name cname, Ts ---> rT)
       
   172       in (r, list_all_free (frees, Logic.mk_implies (HOLogic.mk_Trueprop
       
   173         (HOLogic.mk_eq (Free ("y", T), list_comb (Const (cname, Ts ---> T), free_ts))),
       
   174           HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
       
   175             list_comb (r, free_ts)))))
       
   176       end;
       
   177 
       
   178     val SOME (_, _, constrs) = AList.lookup (op =) descr index;
       
   179     val T = List.nth (get_rec_types descr sorts, index);
       
   180     val (rs, prems) = split_list (map (make_casedist_prem T) constrs);
       
   181     val r = Const (case_name, map fastype_of rs ---> T --> rT);
       
   182 
       
   183     val y = Var (("y", 0), Logic.legacy_varifyT T);
       
   184     val y' = Free ("y", T);
       
   185 
       
   186     val thm = OldGoals.prove_goalw_cterm [] (cert (Logic.list_implies (prems,
       
   187       HOLogic.mk_Trueprop (Free ("P", rT --> HOLogic.boolT) $
       
   188         list_comb (r, rs @ [y'])))))
       
   189       (fn prems =>
       
   190          [rtac (cterm_instantiate [(cert y, cert y')] exhaustion) 1,
       
   191           ALLGOALS (EVERY'
       
   192             [asm_simp_tac (HOL_basic_ss addsimps case_rewrites),
       
   193              resolve_tac prems, asm_simp_tac HOL_basic_ss])]);
       
   194 
       
   195     val exh_name = Thm.get_name exhaustion;
       
   196     val (thm', thy') = thy
       
   197       |> Sign.root_path
       
   198       |> PureThy.store_thm (Binding.qualified_name (exh_name ^ "_P_correctness"), thm)
       
   199       ||> Sign.restore_naming thy;
       
   200 
       
   201     val P = Var (("P", 0), rT' --> HOLogic.boolT);
       
   202     val prf = forall_intr_prf (y, forall_intr_prf (P,
       
   203       List.foldr (fn ((p, r), prf) =>
       
   204         forall_intr_prf (Logic.legacy_varify r, AbsP ("H", SOME (Logic.varify p),
       
   205           prf))) (Proofterm.proof_combP (prf_of thm',
       
   206             map PBound (length prems - 1 downto 0))) (prems ~~ rs)));
       
   207     val r' = Logic.legacy_varify (Abs ("y", Logic.legacy_varifyT T,
       
   208       list_abs (map dest_Free rs, list_comb (r,
       
   209         map Bound ((length rs - 1 downto 0) @ [length rs])))));
       
   210 
       
   211   in Extraction.add_realizers_i
       
   212     [(exh_name, (["P"], r', prf)),
       
   213      (exh_name, ([], Extraction.nullt, prf_of exhaustion))] thy'
       
   214   end;
       
   215 
       
   216 fun add_dt_realizers names thy =
       
   217   if ! Proofterm.proofs < 2 then thy
       
   218   else let
       
   219     val _ = message "Adding realizers for induction and case analysis ..."
       
   220     val infos = map (DatatypePackage.the_datatype thy) names;
       
   221     val info :: _ = infos;
       
   222   in
       
   223     thy
       
   224     |> fold_rev (make_ind (#sorts info) info) (subsets 0 (length (#descr info) - 1))
       
   225     |> fold_rev (make_casedists (#sorts info)) infos
       
   226   end;
       
   227 
       
   228 val setup = DatatypePackage.interpretation add_dt_realizers;
       
   229 
       
   230 end;