src/HOL/Codatatype/Tools/bnf_wrap.ML
changeset 49280 52413dc96326
parent 49278 718e4ad1517e
child 49281 3d87f4fd0d50
equal deleted inserted replaced
49279:2fcfc11374ed 49280:52413dc96326
     9 sig
     9 sig
    10   val no_binder: binding
    10   val no_binder: binding
    11   val mk_half_pairss: 'a list -> ('a * 'a) list list
    11   val mk_half_pairss: 'a list -> ('a * 'a) list list
    12   val mk_ctr: typ list -> term -> term
    12   val mk_ctr: typ list -> term -> term
    13   val wrap_datatype: ({prems: thm list, context: Proof.context} -> tactic) list list ->
    13   val wrap_datatype: ({prems: thm list, context: Proof.context} -> tactic) list list ->
    14     ((bool * term list) * term) * (binding list * binding list list) -> local_theory ->
    14     ((bool * term list) * term) *
       
    15       (binding list * (binding list list * (binding * term) list list)) -> local_theory ->
    15     (term list list * thm list * thm list list) * local_theory
    16     (term list list * thm list * thm list list) * local_theory
    16   val parse_wrap_options: bool parser
    17   val parse_wrap_options: bool parser
    17 end;
    18 end;
    18 
    19 
    19 structure BNF_Wrap : BNF_WRAP =
    20 structure BNF_Wrap : BNF_WRAP =
    54   | mk_half_pairss' indent (y :: ys) =
    55   | mk_half_pairss' indent (y :: ys) =
    55     indent @ fold_rev (cons o single o pair y) ys (mk_half_pairss' ([] :: indent) ys);
    56     indent @ fold_rev (cons o single o pair y) ys (mk_half_pairss' ([] :: indent) ys);
    56 
    57 
    57 fun mk_half_pairss ys = mk_half_pairss' [[]] ys;
    58 fun mk_half_pairss ys = mk_half_pairss' [[]] ys;
    58 
    59 
    59 (* TODO: provide a way to have a different default value, e.g. "tl Nil = Nil" *)
    60 fun mk_undefined T = Const (@{const_name undefined}, T);
    60 fun mk_undef T Ts = Const (@{const_name undefined}, Ts ---> T);
       
    61 
    61 
    62 fun mk_ctr Ts ctr =
    62 fun mk_ctr Ts ctr =
    63   let val Type (_, Ts0) = body_type (fastype_of ctr) in
    63   let val Type (_, Ts0) = body_type (fastype_of ctr) in
    64     Term.subst_atomic_types (Ts0 ~~ Ts) ctr
    64     Term.subst_atomic_types (Ts0 ~~ Ts) ctr
    65   end;
    65   end;
    66 
    66 
    67 fun eta_expand_case_arg xs f_xs = fold_rev Term.lambda xs f_xs;
    67 fun eta_expand_case_arg xs f_xs = fold_rev Term.lambda xs f_xs;
    68 
    68 
    69 fun name_of_ctr c =
    69 fun name_of_ctr c =
    70   case head_of c of
    70   (case head_of c of
    71     Const (s, _) => s
    71     Const (s, _) => s
    72   | Free (s, _) => s
    72   | Free (s, _) => s
    73   | _ => error "Cannot extract name of constructor";
    73   | _ => error "Cannot extract name of constructor");
    74 
    74 
    75 fun prepare_wrap_datatype prep_term (((no_dests, raw_ctrs), raw_case),
    75 fun prepare_wrap_datatype prep_term (((no_dests, raw_ctrs), raw_case),
    76     (raw_disc_binders, raw_sel_binderss)) no_defs_lthy =
    76     (raw_disc_binders, (raw_sel_binderss, raw_sel_defaultss))) no_defs_lthy =
    77   let
    77   let
    78     (* TODO: sanity checks on arguments *)
    78     (* TODO: sanity checks on arguments *)
    79     (* TODO: attributes (simp, case_names, etc.) *)
    79     (* TODO: attributes (simp, case_names, etc.) *)
    80     (* TODO: case syntax *)
    80     (* TODO: case syntax *)
    81     (* TODO: integration with function package ("size") *)
    81     (* TODO: integration with function package ("size") *)
    82 
    82 
       
    83     val n = length raw_ctrs;
       
    84     val ks = 1 upto n;
       
    85 
       
    86     val _ = if n > 0 then () else error "No constructors specified";
       
    87 
    83     val ctrs0 = map (prep_term no_defs_lthy) raw_ctrs;
    88     val ctrs0 = map (prep_term no_defs_lthy) raw_ctrs;
    84     val case0 = prep_term no_defs_lthy raw_case;
    89     val case0 = prep_term no_defs_lthy raw_case;
    85 
    90     val sel_defaultss =
    86     val n = length ctrs0;
    91       pad_list [] n (map (map (apsnd (prep_term no_defs_lthy))) raw_sel_defaultss);
    87     val ks = 1 upto n;
       
    88 
       
    89     val _ = if n > 0 then () else error "No constructors specified";
       
    90 
    92 
    91     val Type (fpT_name, As0) = body_type (fastype_of (hd ctrs0));
    93     val Type (fpT_name, As0) = body_type (fastype_of (hd ctrs0));
    92     val b = Binding.qualified_name fpT_name;
    94     val b = Binding.qualified_name fpT_name;
    93 
    95 
    94     val (As, B) =
    96     val (As, B) =
   192 
   194 
   193           fun disc_spec b exist_xs_v_eq_ctr = mk_Trueprop_eq (disc_free b $ v, exist_xs_v_eq_ctr);
   195           fun disc_spec b exist_xs_v_eq_ctr = mk_Trueprop_eq (disc_free b $ v, exist_xs_v_eq_ctr);
   194 
   196 
   195           fun alternate_disc k = Term.lambda v (alternate_disc_lhs (K o disc_free) (3 - k));
   197           fun alternate_disc k = Term.lambda v (alternate_disc_lhs (K o disc_free) (3 - k));
   196 
   198 
   197           fun mk_sel_case_args proto_sels T =
   199           fun mk_sel_case_args b proto_sels T =
   198             map2 (fn Ts => fn i =>
   200             map2 (fn Ts => fn k =>
   199               case AList.lookup (op =) proto_sels i of
   201               (case AList.lookup (op =) proto_sels k of
   200                 NONE => mk_undef T Ts
   202                 NONE =>
   201               | SOME (xs, x) => fold_rev Term.lambda xs x) ctr_Tss ks;
   203                 let val def_T = Ts ---> T in
       
   204                   (case AList.lookup Binding.eq_name (rev (nth sel_defaultss (k - 1))) b of
       
   205                     NONE => mk_undefined def_T
       
   206                   | SOME t => fold_rev (fn T => Term.lambda (Free (Name.uu, T))) Ts
       
   207                       (Term.subst_atomic_types [(fastype_of t, T)] t))
       
   208                 end
       
   209               | SOME (xs, x) => fold_rev Term.lambda xs x)) ctr_Tss ks;
   202 
   210 
   203           fun sel_spec b proto_sels =
   211           fun sel_spec b proto_sels =
   204             let
   212             let
   205               val _ =
   213               val _ =
   206                 (case duplicates (op =) (map fst proto_sels) of
   214                 (case duplicates (op =) (map fst proto_sels) of
   214                 | T :: T' :: _ => error ("Inconsistent range type for selector " ^
   222                 | T :: T' :: _ => error ("Inconsistent range type for selector " ^
   215                     quote (Binding.name_of b) ^ ": " ^ quote (Syntax.string_of_typ no_defs_lthy T) ^
   223                     quote (Binding.name_of b) ^ ": " ^ quote (Syntax.string_of_typ no_defs_lthy T) ^
   216                     " vs. " ^ quote (Syntax.string_of_typ no_defs_lthy T')));
   224                     " vs. " ^ quote (Syntax.string_of_typ no_defs_lthy T')));
   217             in
   225             in
   218               mk_Trueprop_eq (Free (Binding.name_of b, fpT --> T) $ v,
   226               mk_Trueprop_eq (Free (Binding.name_of b, fpT --> T) $ v,
   219                 Term.list_comb (mk_case As T, mk_sel_case_args proto_sels T) $ v)
   227                 Term.list_comb (mk_case As T, mk_sel_case_args b proto_sels T) $ v)
   220             end;
   228             end;
   221 
   229 
   222           val proto_selss = map3 (fn k => fn xs => map (fn x => (k, (xs, x)))) ks xss xss;
   230           val proto_selss = map3 (fn k => fn xs => map (fn x => (k, (xs, x)))) ks xss xss;
   223           val sel_bundles = AList.group Binding.eq_name (flat sel_binderss ~~ flat proto_selss);
   231           val sel_bundles = AList.group Binding.eq_name (flat sel_binderss ~~ flat proto_selss);
   224           val sel_binders = map fst sel_bundles;
   232           val sel_binders = map fst sel_bundles;
   535     (goalss, after_qed, lthy')
   543     (goalss, after_qed, lthy')
   536   end;
   544   end;
   537 
   545 
   538 fun wrap_datatype tacss = (fn (goalss, after_qed, lthy) =>
   546 fun wrap_datatype tacss = (fn (goalss, after_qed, lthy) =>
   539   map2 (map2 (Skip_Proof.prove lthy [] [])) goalss tacss
   547   map2 (map2 (Skip_Proof.prove lthy [] [])) goalss tacss
   540   |> (fn thms => after_qed thms lthy)) oo
   548   |> (fn thms => after_qed thms lthy)) oo prepare_wrap_datatype (K I);
   541   prepare_wrap_datatype (K I) (* FIXME? (singleton o Type_Infer_Context.infer_types) *)
   549 
   542 
   550 fun parse_bracket_list parser = @{keyword "["} |-- Parse.list parser --|  @{keyword "]"};
   543 val parse_bindings = @{keyword "["} |-- Parse.list Parse.binding --| @{keyword "]"};
   551 
   544 val parse_bindingss = @{keyword "["} |-- Parse.list parse_bindings --| @{keyword "]"};
   552 val parse_bindings = parse_bracket_list Parse.binding;
       
   553 val parse_bindingss = parse_bracket_list parse_bindings;
       
   554 
       
   555 val parse_bound_term = (Parse.binding --| @{keyword ":"}) -- Parse.term;
       
   556 val parse_bound_terms = parse_bracket_list parse_bound_term;
       
   557 val parse_bound_termss = parse_bracket_list parse_bound_terms;
   545 
   558 
   546 val wrap_datatype_cmd = (fn (goalss, after_qed, lthy) =>
   559 val wrap_datatype_cmd = (fn (goalss, after_qed, lthy) =>
   547   Proof.theorem NONE (snd oo after_qed) (map (map (rpair [])) goalss) lthy) oo
   560   Proof.theorem NONE (snd oo after_qed) (map (map (rpair [])) goalss) lthy) oo
   548   prepare_wrap_datatype Syntax.read_term;
   561   prepare_wrap_datatype Syntax.read_term;
   549 
   562 
   551   Scan.optional (@{keyword "("} |-- (@{keyword "no_dests"} >> K true) --| @{keyword ")"}) false;
   564   Scan.optional (@{keyword "("} |-- (@{keyword "no_dests"} >> K true) --| @{keyword ")"}) false;
   552 
   565 
   553 val _ =
   566 val _ =
   554   Outer_Syntax.local_theory_to_proof @{command_spec "wrap_data"} "wraps an existing datatype"
   567   Outer_Syntax.local_theory_to_proof @{command_spec "wrap_data"} "wraps an existing datatype"
   555     ((parse_wrap_options -- (@{keyword "["} |-- Parse.list Parse.term --| @{keyword "]"}) --
   568     ((parse_wrap_options -- (@{keyword "["} |-- Parse.list Parse.term --| @{keyword "]"}) --
   556       Parse.term -- Scan.optional (parse_bindings -- Scan.optional parse_bindingss []) ([], []))
   569       Parse.term -- Scan.optional (parse_bindings -- Scan.optional (parse_bindingss --
       
   570         Scan.optional parse_bound_termss []) ([], [])) ([], ([], [])))
   557      >> wrap_datatype_cmd);
   571      >> wrap_datatype_cmd);
   558 
   572 
   559 end;
   573 end;