src/HOL/Codatatype/Tools/bnf_fp_sugar.ML
author blanchet
Tue Sep 04 16:17:22 2012 +0200 (2012-09-04)
changeset 49126 1bbd7a37fc29
parent 49125 5fc5211cf104
child 49127 f7326a0d7f19
permissions -rw-r--r--
implemented "mk_inject_tac"
     1 (*  Title:      HOL/Codatatype/Tools/bnf_fp_sugar.ML
     2     Author:     Jasmin Blanchette, TU Muenchen
     3     Copyright   2012
     4 
     5 Sugar for constructing LFPs and GFPs.
     6 *)
     7 
     8 signature BNF_FP_SUGAR =
     9 sig
    10 end;
    11 
    12 structure BNF_FP_Sugar : BNF_FP_SUGAR =
    13 struct
    14 
    15 open BNF_Util
    16 open BNF_Wrap
    17 open BNF_FP_Util
    18 open BNF_LFP
    19 open BNF_GFP
    20 open BNF_FP_Sugar_Tactics
    21 
    22 fun cannot_merge_types () = error "Mutually recursive types must have the same type parameters";
    23 
    24 fun merge_type_arg_constrained ctxt (T, c) (T', c') =
    25   if T = T' then
    26     (case (c, c') of
    27       (_, NONE) => (T, c)
    28     | (NONE, _) => (T, c')
    29     | _ =>
    30       if c = c' then
    31         (T, c)
    32       else
    33         error ("Inconsistent sort constraints for type variable " ^
    34           quote (Syntax.string_of_typ ctxt T)))
    35   else
    36     cannot_merge_types ();
    37 
    38 fun merge_type_args_constrained ctxt (cAs, cAs') =
    39   if length cAs = length cAs' then map2 (merge_type_arg_constrained ctxt) cAs cAs'
    40   else cannot_merge_types ();
    41 
    42 fun type_args_constrained_of (((cAs, _), _), _) = cAs;
    43 val type_args_of = map fst o type_args_constrained_of;
    44 fun type_name_of (((_, b), _), _) = b;
    45 fun mixfix_of_typ ((_, mx), _) = mx;
    46 fun ctr_specs_of (_, ctr_specs) = ctr_specs;
    47 
    48 fun disc_of (((disc, _), _), _) = disc;
    49 fun ctr_of (((_, ctr), _), _) = ctr;
    50 fun args_of ((_, args), _) = args;
    51 fun mixfix_of_ctr (_, mx) = mx;
    52 
    53 val lfp_info = bnf_lfp;
    54 val gfp_info = bnf_gfp;
    55 
    56 fun prepare_data prepare_typ construct specs fake_lthy lthy =
    57   let
    58     val constrained_As =
    59       map (map (apfst (prepare_typ fake_lthy)) o type_args_constrained_of) specs
    60       |> Library.foldr1 (merge_type_args_constrained lthy);
    61     val As = map fst constrained_As;
    62 
    63     val _ = (case duplicates (op =) As of [] => ()
    64       | T :: _ => error ("Duplicate type parameter " ^ quote (Syntax.string_of_typ lthy T)));
    65 
    66     (* TODO: check that no type variables occur in the rhss that's not in the lhss *)
    67     (* TODO: use sort constraints on type args *)
    68 
    69     val N = length specs;
    70 
    71     fun mk_T b =
    72       Type (fst (Term.dest_Type (Proof_Context.read_type_name fake_lthy true (Binding.name_of b))),
    73         As);
    74 
    75     val bs = map type_name_of specs;
    76     val Ts = map mk_T bs;
    77 
    78     val mixfixes = map mixfix_of_typ specs;
    79 
    80     val _ = (case duplicates Binding.eq_name bs of [] => ()
    81       | b :: _ => error ("Duplicate type name declaration " ^ quote (Binding.name_of b)));
    82 
    83     val ctr_specss = map ctr_specs_of specs;
    84 
    85     val disc_namess = map (map disc_of) ctr_specss;
    86     val ctr_namess = map (map ctr_of) ctr_specss;
    87     val ctr_argsss = map (map args_of) ctr_specss;
    88     val ctr_mixfixess = map (map mixfix_of_ctr) ctr_specss;
    89 
    90     val sel_namesss = map (map (map fst)) ctr_argsss;
    91     val ctr_Tsss = map (map (map (prepare_typ fake_lthy o snd))) ctr_argsss;
    92 
    93     val (Bs, C) =
    94       lthy
    95       |> fold (fold (fn s => Variable.declare_typ (TFree (s, dummyS))) o type_args_of) specs
    96       |> mk_TFrees N
    97       ||> the_single o fst o mk_TFrees 1;
    98 
    99     fun freeze_rec (T as Type (s, Ts')) =
   100         (case find_index (curry (op =) T) Ts of
   101           ~1 => Type (s, map freeze_rec Ts')
   102         | i => nth Bs i)
   103       | freeze_rec T = T;
   104 
   105     val ctr_TsssBs = map (map (map freeze_rec)) ctr_Tsss;
   106     val sum_prod_TsBs = map (mk_sumTN o map HOLogic.mk_tupleT) ctr_TsssBs;
   107 
   108     val eqs = map dest_TFree Bs ~~ sum_prod_TsBs;
   109 
   110     val ((raw_unfs, raw_flds, unf_flds, fld_unfs, fld_injects), lthy') =
   111       fp_bnf construct bs eqs lthy;
   112 
   113     fun mk_unf_or_fld get_foldedT Ts t =
   114       let val Type (_, Ts0) = get_foldedT (fastype_of t) in
   115         Term.subst_atomic_types (Ts0 ~~ Ts) t
   116       end;
   117 
   118     val mk_unf = mk_unf_or_fld domain_type;
   119     val mk_fld = mk_unf_or_fld range_type;
   120 
   121     val unfs = map (mk_unf As) raw_unfs;
   122     val flds = map (mk_fld As) raw_flds;
   123 
   124     fun wrap_type (((((((((T, fld), unf), fld_unf), unf_fld), fld_inject), ctr_names), ctr_Tss),
   125         disc_names), sel_namess) no_defs_lthy =
   126       let
   127         val n = length ctr_names;
   128         val ks = 1 upto n;
   129         val ms = map length ctr_Tss;
   130 
   131         val unf_T = domain_type (fastype_of fld);
   132 
   133         val prod_Ts = map HOLogic.mk_tupleT ctr_Tss;
   134 
   135         val (((u, v), xss), _) =
   136           lthy
   137           |> yield_singleton (mk_Frees "u") unf_T
   138           ||>> yield_singleton (mk_Frees "v") T
   139           ||>> mk_Freess "x" ctr_Tss;
   140 
   141         val rhss =
   142           map2 (fn k => fn xs =>
   143             fold_rev Term.lambda xs (fld $ mk_InN prod_Ts (HOLogic.mk_tuple xs) k)) ks xss;
   144 
   145         val ((raw_ctrs, raw_ctr_defs), (lthy', lthy)) = no_defs_lthy
   146           |> apfst split_list o fold_map2 (fn b => fn rhs =>
   147                Local_Theory.define ((b, NoSyn), ((Thm.def_binding b, []), rhs)) #>> apsnd snd)
   148              ctr_names rhss
   149           ||> `Local_Theory.restore;
   150 
   151         val raw_caseof =
   152           Const (@{const_name undefined}, map (fn Ts => Ts ---> C) ctr_Tss ---> T --> C);
   153 
   154         (*transforms defined frees into consts (and more)*)
   155         val phi = Proof_Context.export_morphism lthy lthy';
   156 
   157         val ctr_defs = map (Morphism.thm phi) raw_ctr_defs;
   158         val ctrs = map (Morphism.term phi) raw_ctrs;
   159         val caseof = Morphism.term phi raw_caseof;
   160 
   161         val fld_iff_unf_thm =
   162           let
   163             val goal =
   164               fold_rev Logic.all [u, v]
   165                 (mk_Trueprop_eq (HOLogic.mk_eq (v, fld $ u), HOLogic.mk_eq (unf $ v, u)));
   166           in
   167             Skip_Proof.prove lthy [] [] goal (fn {context = ctxt, ...} =>
   168               mk_fld_iff_unf_tac ctxt (map (SOME o certifyT lthy) [unf_T, T]) (certify lthy fld)
   169                 (certify lthy unf) fld_unf unf_fld)
   170             |> Thm.close_derivation
   171           end;
   172 
   173         val sumEN_thm = mk_sumEN n;
   174         val sumEN_thm' =
   175           let val cTs = map (SOME o certifyT lthy) prod_Ts in
   176             Local_Defs.unfold lthy @{thms all_unit_eq} (Drule.instantiate' cTs [] sumEN_thm)
   177           end;
   178 
   179         fun exhaust_tac {context = ctxt, ...} =
   180           mk_exhaust_tac ctxt n ms ctr_defs fld_iff_unf_thm sumEN_thm';
   181 
   182         val inject_tacss =
   183           map2 (fn 0 => K []
   184                  | _ => fn ctr_def => [fn {context = ctxt, ...} =>
   185                      mk_inject_tac ctxt ctr_def fld_inject])
   186             ms ctr_defs;
   187 
   188         (*###*)
   189         fun cheat_tac {context = ctxt, ...} = Skip_Proof.cheat_tac (Proof_Context.theory_of ctxt);
   190 
   191         val half_distinct_tacss = map (map (K cheat_tac)) (mk_half_pairss ks);
   192 
   193         val case_tacs = map (K cheat_tac) ks;
   194 
   195         val tacss = [exhaust_tac] :: inject_tacss @ half_distinct_tacss @ [case_tacs];
   196       in
   197         wrap_data tacss ((ctrs, caseof), (disc_names, sel_namess)) lthy'
   198       end;
   199   in
   200     lthy'
   201     |> fold wrap_type (Ts ~~ flds ~~ unfs ~~ fld_unfs ~~ unf_flds ~~ fld_injects ~~ ctr_namess ~~
   202       ctr_Tsss ~~ disc_namess ~~ sel_namesss)
   203   end;
   204 
   205 fun data_cmd info specs lthy =
   206   let
   207     val fake_lthy =
   208       Proof_Context.theory_of lthy
   209       |> Theory.copy
   210       |> Sign.add_types_global (map (fn spec =>
   211         (type_name_of spec, length (type_args_constrained_of spec), mixfix_of_typ spec)) specs)
   212       |> Proof_Context.init_global
   213   in
   214     prepare_data Syntax.read_typ info specs fake_lthy lthy
   215   end;
   216 
   217 val parse_opt_binding_colon = Scan.optional (Parse.binding --| Parse.$$$ ":") no_name
   218 
   219 val parse_ctr_arg =
   220   Parse.$$$ "(" |-- parse_opt_binding_colon -- Parse.typ --| Parse.$$$ ")" ||
   221   (Parse.typ >> pair no_name);
   222 
   223 val parse_single_spec =
   224   Parse.type_args_constrained -- Parse.binding -- Parse.opt_mixfix --
   225   (@{keyword "="} |-- Parse.enum1 "|" (parse_opt_binding_colon -- Parse.binding --
   226     Scan.repeat parse_ctr_arg -- Parse.opt_mixfix));
   227 
   228 val _ =
   229   Outer_Syntax.local_theory @{command_spec "data"} "define BNF-based inductive datatypes"
   230     (Parse.and_list1 parse_single_spec >> data_cmd lfp_info);
   231 
   232 val _ =
   233   Outer_Syntax.local_theory @{command_spec "codata"} "define BNF-based coinductive datatypes"
   234     (Parse.and_list1 parse_single_spec >> data_cmd gfp_info);
   235 
   236 end;