src/HOL/Tools/BNF/bnf_fp_rec_sugar_transfer.ML
changeset 59275 77cd4992edcd
child 59276 d207455817e8
equal deleted inserted replaced
59274:67afe7e6a516 59275:77cd4992edcd
       
     1 (*  Title:      HOL/Tools/BNF/bnf_fp_rec_sugar_transfer.ML
       
     2     Author:     Martin Desharnais, TU Muenchen
       
     3     Copyright   2014
       
     4 
       
     5 Parametricity of primitively (co)recursive functions.
       
     6 *)
       
     7 
       
     8 (* DO NOT FORGET TO DOCUMENT THIS NEW PLUGIN!!! *)
       
     9 
       
    10 signature BNF_FP_REC_SUGAR_TRANSFER =
       
    11 sig
       
    12 
       
    13 val primrec_transfer_pluginN : string
       
    14 val primcorec_transfer_pluginN : string
       
    15 
       
    16 val primrec_transfer_interpretation:
       
    17   BNF_FP_Rec_Sugar_Util.fp_rec_sugar -> Proof.context -> Proof.context
       
    18 val primcorec_transfer_interpretation:
       
    19   BNF_FP_Rec_Sugar_Util.fp_rec_sugar -> Proof.context -> Proof.context
       
    20 
       
    21 end;
       
    22 
       
    23 structure BNF_FP_Rec_Sugar_Transfer : BNF_FP_REC_SUGAR_TRANSFER =
       
    24 struct
       
    25 
       
    26 open BNF_Def
       
    27 open BNF_FP_Def_Sugar
       
    28 open BNF_FP_Rec_Sugar_Util
       
    29 open BNF_FP_Util
       
    30 open Ctr_Sugar_Tactics
       
    31 open Ctr_Sugar_Util
       
    32 
       
    33 val primrec_transfer_pluginN = Plugin_Name.declare_setup @{binding primrec_transfer};
       
    34 val primcorec_transfer_pluginN = Plugin_Name.declare_setup @{binding primcorec_transfer};
       
    35 
       
    36 fun mk_primrec_transfer_tac ctxt def =
       
    37   Ctr_Sugar_Tactics.unfold_thms_tac ctxt [def] THEN
       
    38   HEADGOAL (Transfer.transfer_prover_tac ctxt);
       
    39 
       
    40 fun mk_primcorec_transfer_tac apply_transfer ctxt f_def corec_def type_definitions
       
    41   dtor_corec_transfers rel_pre_defs disc_eq_cases cases case_distribs case_congs =
       
    42   let
       
    43     fun instantiate_with_lambda thm =
       
    44       let
       
    45         val prop = Thm.prop_of thm;
       
    46         val @{const Trueprop} $
       
    47           (Const (@{const_name HOL.eq}, _) $
       
    48             (Var (_, fT) $ _) $ _) = prop;
       
    49         val T = range_type fT;
       
    50         val idx = Term.maxidx_of_term prop + 1;
       
    51         val bool_expr = Var (("x", idx), HOLogic.boolT);
       
    52         val then_expr = Var (("t", idx), T);
       
    53         val else_expr = Var (("e", idx), T);
       
    54         val lambda = Term.lambda bool_expr (mk_If bool_expr then_expr else_expr);
       
    55       in
       
    56         cterm_instantiate_pos [SOME (certify ctxt lambda)] thm
       
    57       end;
       
    58 
       
    59     val transfer_rules =
       
    60       @{thm Abs_transfer[OF
       
    61         BNF_Composition.type_definition_id_bnf_UNIV
       
    62         BNF_Composition.type_definition_id_bnf_UNIV]} ::
       
    63       map (fn thm => @{thm Abs_transfer} OF [thm, thm]) type_definitions @
       
    64       map (Local_Defs.unfold ctxt rel_pre_defs) dtor_corec_transfers;
       
    65     val add_transfer_rule = Thm.attribute_declaration Transfer.transfer_add
       
    66     val ctxt' = Context.proof_map (fold add_transfer_rule transfer_rules) ctxt
       
    67 
       
    68     val case_distribs = map instantiate_with_lambda case_distribs;
       
    69     val simps = case_distribs @ disc_eq_cases @ cases @ @{thms if_True if_False};
       
    70     val simp_ctxt = put_simpset (simpset_of (ss_only simps ctxt)) ctxt';
       
    71   in
       
    72     unfold_thms_tac ctxt ([f_def, corec_def] @ @{thms split_beta if_conn}) THEN
       
    73     HEADGOAL (simp_tac (fold Simplifier.add_cong case_congs simp_ctxt)) THEN
       
    74     (if apply_transfer then HEADGOAL (Transfer.transfer_prover_tac ctxt') else all_tac)
       
    75   end;
       
    76 
       
    77 fun massage_simple_notes base =
       
    78   filter_out (null o #2)
       
    79   #> map (fn (thmN, thms, f_attrs) =>
       
    80     ((Binding.qualify true base (Binding.name thmN), []),
       
    81      map_index (fn (i, thm) => ([thm], f_attrs i)) thms));
       
    82 
       
    83 fun fp_sugar_of_bnf ctxt = fp_sugar_of ctxt o (fn Type (s, _) => s) o T_of_bnf;
       
    84 
       
    85 val cat_somes = map the o filter is_some
       
    86 fun maybe_apply z = the_default z oo Option.map
       
    87 
       
    88 fun bnf_depth_first_traverse ctxt f T z =
       
    89   case T of
       
    90     Type (s, innerTs) =>
       
    91     (case bnf_of ctxt s of
       
    92       NONE => z
       
    93     | SOME bnf => let val z' = f bnf z in
       
    94         fold (bnf_depth_first_traverse ctxt f) innerTs z'
       
    95       end)
       
    96   | _ => z
       
    97 
       
    98 fun if_all_bnfs ctxt Ts f g =
       
    99   let
       
   100     val bnfs = cat_somes (map (fn T =>
       
   101       case T of Type (s, _) => BNF_Def.bnf_of ctxt s | _ => NONE) Ts);
       
   102   in
       
   103     if length bnfs = length Ts then f bnfs else g
       
   104   end;
       
   105 
       
   106 fun mk_goal lthy f =
       
   107   let
       
   108     val skematicTs = Term.add_tvarsT (fastype_of f) [];
       
   109 
       
   110     val ((As, Bs), names_lthy) = lthy
       
   111       |> Ctr_Sugar_Util.mk_TFrees' (map snd skematicTs)
       
   112       ||>> Ctr_Sugar_Util.mk_TFrees' (map snd skematicTs);
       
   113 
       
   114     val (Rs, names_lthy) =
       
   115       Ctr_Sugar_Util.mk_Frees "R" (map2 BNF_Util.mk_pred2T As Bs) names_lthy;
       
   116 
       
   117     val fA = Term.subst_TVars (map fst skematicTs ~~ As) f;
       
   118     val fB = Term.subst_TVars (map fst skematicTs ~~ Bs) f;
       
   119   in
       
   120     (BNF_FP_Def_Sugar.mk_parametricity_goal lthy Rs fA fB, names_lthy)
       
   121   end;
       
   122 
       
   123 fun prove_parametricity_if_bnf prove {transfers, fun_names, funs, fun_defs, fpTs} lthy =
       
   124   fold_index (fn (n, (((transfer, f_names), f), def)) => fn lthy =>
       
   125       if not transfer then lthy
       
   126       else
       
   127         if_all_bnfs lthy fpTs
       
   128           (fn bnfs => fn () => prove n bnfs f_names f def lthy)
       
   129           (fn () => let val _ = error "Function is not parametric." in lthy end) ())
       
   130     (transfers ~~ fun_names ~~ funs ~~ fun_defs) lthy;
       
   131 
       
   132 fun prim_co_rec_transfer_interpretation prove =
       
   133   prove_parametricity_if_bnf (fn n => fn bnfs => fn f_name => fn f => fn def => fn lthy =>
       
   134     case try (prove n bnfs f def) lthy of
       
   135       NONE => error "Failed to prove parametricity."
       
   136     | SOME thm =>
       
   137       let
       
   138         val notes =
       
   139           [("transfer", [thm], K @{attributes [transfer_rule]})]
       
   140           |> massage_simple_notes f_name;
       
   141       in
       
   142         snd (Local_Theory.notes notes lthy)
       
   143       end);
       
   144 
       
   145 val primrec_transfer_interpretation = prim_co_rec_transfer_interpretation
       
   146   (fn n => fn bnfs => fn f => fn def => fn lthy =>
       
   147      let
       
   148        val (goal, names_lthy) = mk_goal lthy f;
       
   149      in
       
   150        Goal.prove lthy [] [] goal (fn {context = ctxt, prems = _} =>
       
   151          mk_primrec_transfer_tac ctxt def)
       
   152        |> singleton (Proof_Context.export names_lthy lthy)
       
   153        |> Thm.close_derivation
       
   154      end);
       
   155 
       
   156 val primcorec_transfer_interpretation = prim_co_rec_transfer_interpretation
       
   157   (fn n => fn bnfs => fn f => fn def => fn lthy =>
       
   158      let
       
   159        val fp_sugars = map (the o fp_sugar_of_bnf lthy) bnfs;
       
   160        val (goal, names_lthy) = mk_goal lthy f;
       
   161        val (disc_eq_cases, case_thms, case_distribs, case_congs) =
       
   162          bnf_depth_first_traverse lthy (fn bnf => fn xs =>
       
   163            let
       
   164              fun add_thms (xs, ys, zs, ws) (fp_sugar : fp_sugar) =
       
   165                let
       
   166                  val ctr_sugar = #ctr_sugar (#fp_ctr_sugar fp_sugar);
       
   167                  val xs' = #disc_eq_cases ctr_sugar;
       
   168                  val ys' = #case_thms ctr_sugar;
       
   169                  val zs' = #case_distribs ctr_sugar;
       
   170                  val w = #case_cong ctr_sugar;
       
   171                  val union' = union Thm.eq_thm;
       
   172                  val insert' = insert Thm.eq_thm;
       
   173                in
       
   174                  (union' xs' xs, union' ys' ys, union' zs' zs, insert' w ws)
       
   175                end;
       
   176            in
       
   177              maybe_apply xs (add_thms xs) (fp_sugar_of_bnf lthy bnf)
       
   178            end) (fastype_of f) ([], [], [], []);
       
   179      in
       
   180        Goal.prove lthy [] [] goal (fn {context = ctxt, prems = _} =>
       
   181          mk_primcorec_transfer_tac true ctxt def
       
   182          (#co_rec_def (#fp_co_induct_sugar (nth fp_sugars n)))
       
   183          (map (#type_definition o #absT_info) fp_sugars)
       
   184          (flat (map (#xtor_co_rec_transfers o #fp_res) fp_sugars))
       
   185          (map (rel_def_of_bnf o #pre_bnf) fp_sugars)
       
   186          disc_eq_cases case_thms case_distribs case_congs)
       
   187        |> singleton (Proof_Context.export names_lthy lthy)
       
   188        |> Thm.close_derivation
       
   189      end);
       
   190 
       
   191 end