src/HOL/Codatatype/Tools/bnf_fp_util.ML
author blanchet
Sat Sep 08 21:04:26 2012 +0200 (2012-09-08)
changeset 49218 d01a5c918298
parent 49207 4634c217b77b
child 49222 cbe8c859817c
permissions -rw-r--r--
renamed xxxBNF to pre_xxx
     1 (*  Title:      HOL/Codatatype/Tools/bnf_fp_util.ML
     2     Author:     Dmitriy Traytel, TU Muenchen
     3     Copyright   2012
     4 
     5 Shared library for the datatype and the codatatype construction.
     6 *)
     7 
     8 signature BNF_FP_UTIL =
     9 sig
    10   val time: Timer.real_timer -> string -> Timer.real_timer
    11 
    12   val IITN: string
    13   val LevN: string
    14   val algN: string
    15   val behN: string
    16   val bisN: string
    17   val carTN: string
    18   val coN: string
    19   val coinductN: string
    20   val coiterN: string
    21   val unf_coiter_uniqueN: string
    22   val corecN: string
    23   val exhaustN: string
    24   val fldN: string
    25   val fld_unf_coiterN: string
    26   val fld_exhaustN: string
    27   val fld_induct2N: string
    28   val fld_inductN: string
    29   val fld_injectN: string
    30   val fld_iterN: string
    31   val fld_recN: string
    32   val fld_unfN: string
    33   val hsetN: string
    34   val hset_recN: string
    35   val inductN: string
    36   val injectN: string
    37   val isNodeN: string
    38   val iterN: string
    39   val fld_iter_uniqueN: string
    40   val lsbisN: string
    41   val map_simpsN: string
    42   val map_uniqueN: string
    43   val min_algN: string
    44   val morN: string
    45   val nchotomyN: string
    46   val pred_coinductN: string
    47   val pred_coinduct_uptoN: string
    48   val recN: string
    49   val rel_coinductN: string
    50   val rel_coinduct_uptoN: string
    51   val rvN: string
    52   val set_inclN: string
    53   val set_set_inclN: string
    54   val strTN: string
    55   val str_initN: string
    56   val sum_bdN: string
    57   val sum_bdTN: string
    58   val unfN: string
    59   val unf_coinductN: string
    60   val unf_coinduct_uptoN: string
    61   val unf_coiterN: string
    62   val unf_corecN: string
    63   val unf_exhaustN: string
    64   val unf_fldN: string
    65   val unf_injectN: string
    66   val uniqueN: string
    67   val uptoN: string
    68 
    69   val mk_exhaustN: string -> string
    70   val mk_injectN: string -> string
    71   val mk_nchotomyN: string -> string
    72   val mk_set_simpsN: int -> string
    73   val mk_set_minimalN: int -> string
    74   val mk_set_inductN: int -> string
    75 
    76   val typedef: bool -> binding option -> binding * (string * sort) list * mixfix -> term ->
    77     (binding * binding) option -> tactic -> local_theory -> (string * Typedef.info) * local_theory
    78 
    79   val split_conj_thm: thm -> thm list
    80   val split_conj_prems: int -> thm -> thm
    81 
    82   val Inl_const: typ -> typ -> term
    83   val Inr_const: typ -> typ -> term
    84 
    85   val mk_Inl: term -> typ -> term
    86   val mk_Inr: term -> typ -> term
    87   val mk_InN: typ list -> term -> int -> term
    88   val mk_sum_case: term -> term -> term
    89   val mk_sum_caseN: term list -> term
    90 
    91   val dest_sumTN: int -> typ -> typ list
    92   val dest_tupleT: int -> typ -> typ list
    93 
    94   val mk_Field: term -> term
    95   val mk_union: term * term -> term
    96 
    97   val mk_sumEN: int -> thm
    98   val mk_sum_casesN: int -> int -> thm
    99 
   100   val mk_tactics: 'a -> 'a -> 'a -> 'a list -> 'a -> 'a -> 'a list -> 'a -> 'a -> 'a list
   101 
   102   val fixpoint: ('a * 'a -> bool) -> ('a list -> 'a list) -> 'a list -> 'a list
   103 
   104   val fp_bnf: (mixfix list -> (string * sort) list option -> binding list ->
   105     typ list * typ list list -> BNF_Def.BNF list -> local_theory -> 'a) ->
   106     binding list -> mixfix list -> (string * sort) list -> ((string * sort) * typ) list ->
   107     local_theory -> thm list * 'a
   108   val fp_bnf_cmd: (mixfix list -> (string * sort) list option -> binding list ->
   109     typ list * typ list list -> BNF_Def.BNF list -> local_theory -> 'a) ->
   110     binding list * (string list * string list) -> local_theory -> 'a
   111 end;
   112 
   113 structure BNF_FP_Util : BNF_FP_UTIL =
   114 struct
   115 
   116 open BNF_Comp
   117 open BNF_Def
   118 open BNF_Util
   119 
   120 val timing = true;
   121 fun time timer msg = (if timing
   122   then warning (msg ^ ": " ^ ATP_Util.string_from_time (Timer.checkRealTimer timer))
   123   else (); Timer.startRealTimer ());
   124 
   125 (*TODO: is this really different from Typedef.add_typedef_global?*)
   126 fun typedef def opt_name typ set opt_morphs tac lthy =
   127   let
   128     val ((name, info), (lthy, lthy_old)) =
   129       lthy
   130       |> Typedef.add_typedef def opt_name typ set opt_morphs tac
   131       ||> `Local_Theory.restore;
   132     val phi = Proof_Context.export_morphism lthy_old lthy;
   133   in
   134     ((name, Typedef.transform_info phi info), lthy)
   135   end;
   136 
   137 val pre_N = "pre_"
   138 val raw_N = "raw_"
   139 
   140 val coN = "co"
   141 val algN = "alg"
   142 val IITN = "IITN"
   143 val iterN = "iter"
   144 val coiterN = coN ^ iterN
   145 val uniqueN = "_unique"
   146 val fldN = "fld"
   147 val unfN = "unf"
   148 val fld_iterN = fldN ^ "_" ^ iterN
   149 val unf_coiterN = unfN ^ "_" ^ coiterN
   150 val fld_iter_uniqueN = fld_iterN ^ uniqueN
   151 val unf_coiter_uniqueN = unf_coiterN ^ uniqueN
   152 val fld_unf_coiterN = fldN ^ "_" ^ unf_coiterN
   153 val map_simpsN = mapN ^ "_simps"
   154 val map_uniqueN = mapN ^ uniqueN
   155 val min_algN = "min_alg"
   156 val morN = "mor"
   157 val bisN = "bis"
   158 val lsbisN = "lsbis"
   159 val sum_bdTN = "sbdT"
   160 val sum_bdN = "sbd"
   161 val carTN = "carT"
   162 val strTN = "strT"
   163 val isNodeN = "isNode"
   164 val LevN = "Lev"
   165 val rvN = "recover"
   166 val behN = "beh"
   167 fun mk_set_simpsN i = mk_setN i ^ "_simps"
   168 fun mk_set_minimalN i = mk_setN i ^ "_minimal"
   169 fun mk_set_inductN i = mk_setN i ^ "_induct"
   170 
   171 val str_initN = "str_init"
   172 val recN = "rec"
   173 val corecN = coN ^ recN
   174 val fld_recN = fldN ^ "_" ^ recN
   175 val unf_corecN = unfN ^ "_" ^ corecN
   176 
   177 val fld_unfN = fldN ^ "_" ^ unfN
   178 val unf_fldN = unfN ^ "_" ^ fldN
   179 val nchotomyN = "nchotomy"
   180 fun mk_nchotomyN s = s ^ "_" ^ nchotomyN
   181 val injectN = "inject"
   182 fun mk_injectN s = s ^ "_" ^ injectN
   183 val exhaustN = "exhaust"
   184 fun mk_exhaustN s = s ^ "_" ^ exhaustN
   185 val fld_injectN = mk_injectN fldN
   186 val fld_exhaustN = mk_exhaustN fldN
   187 val unf_injectN = mk_injectN unfN
   188 val unf_exhaustN = mk_exhaustN unfN
   189 val inductN = "induct"
   190 val coinductN = coN ^ inductN
   191 val fld_inductN = fldN ^ "_" ^ inductN
   192 val fld_induct2N = fld_inductN ^ "2"
   193 val unf_coinductN = unfN ^ "_" ^ coinductN
   194 val rel_coinductN = relN ^ "_" ^ coinductN
   195 val pred_coinductN = predN ^ "_" ^ coinductN
   196 val uptoN = "upto"
   197 val unf_coinduct_uptoN = unf_coinductN ^ "_" ^ uptoN
   198 val rel_coinduct_uptoN = rel_coinductN ^ "_" ^ uptoN
   199 val pred_coinduct_uptoN = pred_coinductN ^ "_" ^ uptoN
   200 val hsetN = "Hset"
   201 val hset_recN = hsetN ^ "_rec"
   202 val set_inclN = "set_incl"
   203 val set_set_inclN = "set_set_incl"
   204 
   205 fun Inl_const LT RT = Const (@{const_name Inl}, LT --> mk_sumT (LT, RT));
   206 fun mk_Inl t RT = Inl_const (fastype_of t) RT $ t;
   207 
   208 fun Inr_const LT RT = Const (@{const_name Inr}, RT --> mk_sumT (LT, RT));
   209 fun mk_Inr t LT = Inr_const LT (fastype_of t) $ t;
   210 
   211 fun mk_InN [_] t 1 = t
   212   | mk_InN (_ :: Ts) t 1 = mk_Inl t (mk_sumTN Ts)
   213   | mk_InN (LT :: Ts) t m = mk_Inr (mk_InN Ts t (m - 1)) LT
   214   | mk_InN Ts t _ = raise (TYPE ("mk_InN", Ts, [t]));
   215 
   216 fun mk_sum_case f g =
   217   let
   218     val fT = fastype_of f;
   219     val gT = fastype_of g;
   220   in
   221     Const (@{const_name sum_case},
   222       fT --> gT --> mk_sumT (domain_type fT, domain_type gT) --> range_type fT) $ f $ g
   223   end;
   224 
   225 fun mk_sum_caseN [f] = f
   226   | mk_sum_caseN (f :: fs) = mk_sum_case f (mk_sum_caseN fs);
   227 
   228 fun dest_sumTN 1 T = [T]
   229   | dest_sumTN n (Type (@{type_name sum}, [T, T'])) = T :: dest_sumTN (n - 1) T';
   230 
   231 (* TODO: move something like this to "HOLogic"? *)
   232 fun dest_tupleT 0 @{typ unit} = []
   233   | dest_tupleT 1 T = [T]
   234   | dest_tupleT n (Type (@{type_name prod}, [T, T'])) = T :: dest_tupleT (n - 1) T';
   235 
   236 fun mk_Field r =
   237   let val T = fst (dest_relT (fastype_of r));
   238   in Const (@{const_name Field}, mk_relT (T, T) --> HOLogic.mk_setT T) $ r end;
   239 
   240 val mk_union = HOLogic.mk_binop @{const_name sup};
   241 
   242 (*dangerous; use with monotonic, converging functions only!*)
   243 fun fixpoint eq f X = if subset eq (f X, X) then X else fixpoint eq f (f X);
   244 
   245 (* stolen from "~~/src/HOL/Tools/Datatype/datatype_aux.ML" *)
   246 fun split_conj_thm th =
   247   ((th RS conjunct1) :: split_conj_thm (th RS conjunct2)) handle THM _ => [th];
   248 
   249 fun split_conj_prems limit th =
   250   let
   251     fun split n i th =
   252       if i = n then th else split n (i + 1) (conjI RSN (i, th)) handle THM _ => th;
   253   in split limit 1 th end;
   254 
   255 local
   256   fun mk_sumEN' 1 = @{thm obj_sum_step}
   257     | mk_sumEN' n = mk_sumEN' (n - 1) RSN (2, @{thm obj_sum_step});
   258 in
   259   fun mk_sumEN 1 = @{thm obj_sum_base}
   260     | mk_sumEN 2 = @{thm sumE}
   261     | mk_sumEN n = (mk_sumEN' (n - 2) RSN (2, @{thm obj_sumE})) OF replicate n (impI RS allI);
   262 end;
   263 
   264 fun mk_sum_casesN 1 1 = @{thm refl}
   265   | mk_sum_casesN _ 1 = @{thm sum.cases(1)}
   266   | mk_sum_casesN 2 2 = @{thm sum.cases(2)}
   267   | mk_sum_casesN n m = trans OF [@{thm sum_case_step(2)}, mk_sum_casesN (n - 1) (m - 1)];
   268 
   269 fun mk_tactics mid mcomp mcong snat bdco bdinf sbd inbd wpull =
   270   [mid, mcomp, mcong] @ snat @ [bdco, bdinf] @ sbd @ [inbd, wpull];
   271 
   272 (* FIXME: because of "@ lhss", the output could contain type variables that are not in the input;
   273    also, "fp_sort" should put the "resBs" first and in the order in which they appear *)
   274 fun fp_sort lhss NONE Ass = Library.sort (Term_Ord.typ_ord o pairself TFree)
   275     (subtract (op =) lhss (fold (fold (insert (op =))) Ass [])) @ lhss
   276   | fp_sort lhss (SOME resBs) Ass =
   277     (subtract (op =) lhss (filter (fn T => exists (fn Ts => member (op =) Ts T) Ass) resBs)) @ lhss;
   278 
   279 fun mk_fp_bnf timer construct resBs bs sort lhss bnfs deadss livess unfold lthy =
   280   let
   281     val name = fold_rev (fn b => fn s => Binding.name_of b ^ s) bs "";
   282     fun qualify i bind =
   283       let val namei = if i > 0 then name ^ string_of_int i else name;
   284       in
   285         if member (op =) (#2 (Binding.dest bind)) (namei, true) then bind
   286         else Binding.prefix_name namei bind
   287       end;
   288 
   289     val Ass = map (map dest_TFree) livess;
   290     val resDs = (case resBs of NONE => [] | SOME Ts => fold (subtract (op =)) Ass Ts);
   291     val Ds = fold (fold Term.add_tfreesT) deadss [];
   292 
   293     val _ = (case Library.inter (op =) Ds lhss of [] => ()
   294       | A :: _ => error ("Nonadmissible type recursion (cannot take fixed point of dead type \
   295         \variable " ^ quote (Syntax.string_of_typ lthy (TFree A)) ^ ")"));
   296 
   297     val timer = time (timer "Construction of BNFs");
   298 
   299     val ((kill_poss, _), (bnfs', (unfold', lthy'))) =
   300       normalize_bnfs qualify Ass Ds sort bnfs unfold lthy;
   301 
   302     val Dss = map3 (append oo map o nth) livess kill_poss deadss;
   303 
   304     val ((bnfs'', deadss), lthy'') =
   305       fold_map3 (seal_bnf unfold') (map (Binding.prefix_name pre_N) bs) Dss bnfs' lthy'
   306       |>> split_list;
   307 
   308     val pre_map_defs = map map_def_of_bnf bnfs'';
   309 
   310     val timer = time (timer "Normalization & sealing of BNFs");
   311 
   312     val res = construct resBs bs (map TFree resDs, deadss) bnfs'' lthy'';
   313 
   314     val timer = time (timer "FP construction in total");
   315   in
   316     (pre_map_defs, res)
   317   end;
   318 
   319 fun fp_bnf construct bs mixfixes resBs eqs lthy =
   320   let
   321     val timer = time (Timer.startRealTimer ());
   322     val (lhss, rhss) = split_list eqs;
   323     val sort = fp_sort lhss (SOME resBs);
   324     val ((bnfs, (Dss, Ass)), (unfold, lthy')) = apfst (apsnd split_list o split_list)
   325       (fold_map2 (fn b => bnf_of_typ Smart_Inline (Binding.prefix_name raw_N b) I sort) bs rhss
   326         (empty_unfold, lthy));
   327   in
   328     mk_fp_bnf timer (construct mixfixes) (SOME resBs) bs sort lhss bnfs Dss Ass unfold lthy'
   329   end;
   330 
   331 fun fp_bnf_cmd construct (bs, (raw_lhss, raw_bnfs)) lthy =
   332   let
   333     val timer = time (Timer.startRealTimer ());
   334     val lhss = map (dest_TFree o Syntax.read_typ lthy) raw_lhss;
   335     val sort = fp_sort lhss NONE;
   336     val ((bnfs, (Dss, Ass)), (unfold, lthy')) = apfst (apsnd split_list o split_list)
   337       (fold_map2 (fn b => fn rawT =>
   338         (bnf_of_typ Smart_Inline (Binding.prefix_name raw_N b) I sort (Syntax.read_typ lthy rawT)))
   339       bs raw_bnfs (empty_unfold, lthy));
   340   in
   341     snd (mk_fp_bnf timer (construct (map (K NoSyn) bs)) NONE bs sort lhss bnfs Dss Ass unfold lthy')
   342   end;
   343 
   344 end;