src/HOLCF/Tools/Domain/domain_extender.ML
author huffman
Thu Oct 14 13:46:27 2010 -0700 (2010-10-14)
changeset 40017 575d3aa1f3c5
parent 40016 2eff1cbc1ccb
child 40019 05cda34d36e7
permissions -rw-r--r--
include iso_info as part of constr_info type
     1 (*  Title:      HOLCF/Tools/Domain/domain_extender.ML
     2     Author:     David von Oheimb
     3     Author:     Brian Huffman
     4 
     5 Theory extender for domain command, including theory syntax.
     6 *)
     7 
     8 signature DOMAIN_EXTENDER =
     9 sig
    10   val add_domain_cmd:
    11       binding ->
    12       ((string * string option) list * binding * mixfix *
    13        (binding * (bool * binding option * string) list * mixfix) list) list
    14       -> theory -> theory
    15 
    16   val add_domain:
    17       binding ->
    18       ((string * string option) list * binding * mixfix *
    19        (binding * (bool * binding option * typ) list * mixfix) list) list
    20       -> theory -> theory
    21 
    22   val add_new_domain_cmd:
    23       binding ->
    24       ((string * string option) list * binding * mixfix *
    25        (binding * (bool * binding option * string) list * mixfix) list) list
    26       -> theory -> theory
    27 
    28   val add_new_domain:
    29       binding ->
    30       ((string * string option) list * binding * mixfix *
    31        (binding * (bool * binding option * typ) list * mixfix) list) list
    32       -> theory -> theory
    33 end;
    34 
    35 structure Domain_Extender :> DOMAIN_EXTENDER =
    36 struct
    37 
    38 open Domain_Library;
    39 
    40 (* ----- general testing and preprocessing of constructor list -------------- *)
    41 fun check_and_sort_domain
    42     (arg_sort : bool -> sort)
    43     (dtnvs : (string * typ list) list)
    44     (cons'' : (binding * (bool * binding option * typ) list * mixfix) list list)
    45     (thy : theory)
    46     : ((string * typ list) *
    47        (binding * (bool * binding option * typ) list * mixfix) list) list =
    48   let
    49     val defaultS = Sign.defaultS thy;
    50 
    51     val test_dupl_typs =
    52       case duplicates (op =) (map fst dtnvs) of 
    53         [] => false | dups => error ("Duplicate types: " ^ commas_quote dups);
    54 
    55     val all_cons = map (Binding.name_of o first) (flat cons'');
    56     val test_dupl_cons =
    57       case duplicates (op =) all_cons of 
    58         [] => false | dups => error ("Duplicate constructors: " 
    59                                       ^ commas_quote dups);
    60     val all_sels =
    61       (map Binding.name_of o map_filter second o maps second) (flat cons'');
    62     val test_dupl_sels =
    63       case duplicates (op =) all_sels of
    64         [] => false | dups => error("Duplicate selectors: "^commas_quote dups);
    65 
    66     fun test_dupl_tvars s =
    67       case duplicates (op =) (map(fst o dest_TFree)s) of
    68         [] => false | dups => error("Duplicate type arguments: " 
    69                                     ^commas_quote dups);
    70     val test_dupl_tvars' = exists test_dupl_tvars (map snd dtnvs);
    71 
    72     (* test for free type variables, illegal sort constraints on rhs,
    73        non-pcpo-types and invalid use of recursive type;
    74        replace sorts in type variables on rhs *)
    75     fun analyse_equation ((dname,typevars),cons') = 
    76       let
    77         val tvars = map dest_TFree typevars;
    78         val distinct_typevars = map TFree tvars;
    79         fun rm_sorts (TFree(s,_)) = TFree(s,[])
    80           | rm_sorts (Type(s,ts)) = Type(s,remove_sorts ts)
    81           | rm_sorts (TVar(s,_))  = TVar(s,[])
    82         and remove_sorts l = map rm_sorts l;
    83         fun analyse indirect (TFree(v,s))  =
    84             (case AList.lookup (op =) tvars v of 
    85                NONE => error ("Free type variable " ^ quote v ^ " on rhs.")
    86              | SOME sort => if eq_set (op =) (s, defaultS) orelse
    87                                eq_set (op =) (s, sort)
    88                             then TFree(v,sort)
    89                             else error ("Inconsistent sort constraint" ^
    90                                         " for type variable " ^ quote v))
    91           | analyse indirect (t as Type(s,typl)) =
    92             (case AList.lookup (op =) dtnvs s of
    93                NONE => Type (s, map (analyse false) typl)
    94              | SOME typevars =>
    95                  if indirect 
    96                  then error ("Indirect recursion of type " ^ 
    97                              quote (string_of_typ thy t))
    98                  else if dname <> s orelse
    99                          (** BUG OR FEATURE?:
   100                              mutual recursion may use different arguments **)
   101                          remove_sorts typevars = remove_sorts typl 
   102                  then Type(s,map (analyse true) typl)
   103                  else error ("Direct recursion of type " ^ 
   104                              quote (string_of_typ thy t) ^ 
   105                              " with different arguments"))
   106           | analyse indirect (TVar _) = Imposs "extender:analyse";
   107         fun check_pcpo lazy T =
   108             let val sort = arg_sort lazy in
   109               if Sign.of_sort thy (T, sort) then T
   110               else error ("Constructor argument type is not of sort " ^
   111                           Syntax.string_of_sort_global thy sort ^ ": " ^
   112                           string_of_typ thy T)
   113             end;
   114         fun analyse_arg (lazy, sel, T) =
   115             (lazy, sel, check_pcpo lazy (analyse false T));
   116         fun analyse_con (b, args, mx) = (b, map analyse_arg args, mx);
   117       in ((dname,distinct_typevars), map analyse_con cons') end; 
   118   in ListPair.map analyse_equation (dtnvs,cons'')
   119   end; (* let *)
   120 
   121 (* ----- calls for building new thy and thms -------------------------------- *)
   122 
   123 type info =
   124      Domain_Take_Proofs.iso_info list * Domain_Take_Proofs.take_induct_info;
   125 
   126 fun gen_add_domain
   127     (prep_typ : theory -> 'a -> typ)
   128     (add_isos : (binding * mixfix * (typ * typ)) list -> theory -> info * theory)
   129     (arg_sort : bool -> sort)
   130     (comp_dbind : binding)
   131     (eqs''' : ((string * string option) list * binding * mixfix *
   132                (binding * (bool * binding option * 'a) list * mixfix) list) list)
   133     (thy : theory) =
   134   let
   135     val dtnvs : (binding * typ list * mixfix) list =
   136       let
   137         fun readS (SOME s) = Syntax.read_sort_global thy s
   138           | readS NONE = Sign.defaultS thy;
   139         fun readTFree (a, s) = TFree (a, readS s);
   140       in
   141         map (fn (vs,dname:binding,mx,_) =>
   142                 (dname, map readTFree vs, mx)) eqs'''
   143       end;
   144 
   145     fun thy_type  (dname,tvars,mx) = (dname, length tvars, mx);
   146     fun thy_arity (dname,tvars,mx) =
   147       (Sign.full_name thy dname, map (snd o dest_TFree) tvars, arg_sort false);
   148 
   149     (* this theory is used just for parsing and error checking *)
   150     val tmp_thy = thy
   151       |> Theory.copy
   152       |> Sign.add_types (map thy_type dtnvs)
   153       |> fold (AxClass.axiomatize_arity o thy_arity) dtnvs;
   154 
   155     val dbinds : binding list =
   156         map (fn (_,dbind,_,_) => dbind) eqs''';
   157     val cons''' :
   158         (binding * (bool * binding option * 'a) list * mixfix) list list =
   159         map (fn (_,_,_,cons) => cons) eqs''';
   160     val cons'' :
   161         (binding * (bool * binding option * typ) list * mixfix) list list =
   162         map (map (upd_second (map (upd_third (prep_typ tmp_thy))))) cons''';
   163     val dtnvs' : (string * typ list) list =
   164         map (fn (dname,vs,mx) => (Sign.full_name thy dname,vs)) dtnvs;
   165     val eqs' : ((string * typ list) *
   166         (binding * (bool * binding option * typ) list * mixfix) list) list =
   167         check_and_sort_domain arg_sort dtnvs' cons'' tmp_thy;
   168     val dts : typ list = map (Type o fst) eqs';
   169 
   170     fun mk_arg_typ (lazy, dest_opt, T) = if lazy then mk_uT T else T;
   171     fun mk_con_typ (bind, args, mx) =
   172         if null args then oneT else foldr1 mk_sprodT (map mk_arg_typ args);
   173     fun mk_eq_typ (_, cons) = foldr1 mk_ssumT (map mk_con_typ cons);
   174     val repTs : typ list = map mk_eq_typ eqs';
   175 
   176     val iso_spec : (binding * mixfix * (typ * typ)) list =
   177         map (fn ((dbind, _, mx), eq) => (dbind, mx, eq))
   178           (dtnvs ~~ (dts ~~ repTs));
   179 
   180     val ((iso_infos, take_info), thy) = add_isos iso_spec thy;
   181 
   182     val new_dts : (string * string list) list =
   183         map (fn ((s,Ts),_) => (s, map (fst o dest_TFree) Ts)) eqs';
   184     fun one_con (con,args,mx) : cons =
   185         (Binding.name_of con,  (* FIXME preverse binding (!?) *)
   186          ListPair.map (fn ((lazy,sel,tp),vn) =>
   187            mk_arg ((lazy, Datatype_Aux.dtyp_of_typ new_dts tp), vn))
   188                       (args, Datatype_Prop.make_tnames (map third args)));
   189     val eqs : eq list =
   190         map (fn (dtnvs,cons') => (dtnvs, map one_con cons')) eqs';
   191 
   192     val (constr_infos, thy) =
   193         thy
   194           |> fold_map (fn ((dbind, (_,cs)), info) =>
   195                 Domain_Constructors.add_domain_constructors dbind cs info)
   196              (dbinds ~~ eqs' ~~ iso_infos);
   197 
   198     val (take_rews, theorems_thy) =
   199         thy
   200           |> Domain_Theorems.comp_theorems (comp_dbind, eqs)
   201               (dbinds ~~ map snd eqs') take_info constr_infos;
   202   in
   203     theorems_thy
   204   end;
   205 
   206 fun define_isos (spec : (binding * mixfix * (typ * typ)) list) =
   207   let
   208     fun prep (dbind, mx, (lhsT, rhsT)) =
   209       let val (dname, vs) = dest_Type lhsT;
   210       in (map (fst o dest_TFree) vs, dbind, mx, rhsT, NONE) end;
   211   in
   212     Domain_Isomorphism.domain_isomorphism (map prep spec)
   213   end;
   214 
   215 fun pcpo_arg lazy = if lazy then @{sort cpo} else @{sort pcpo};
   216 fun rep_arg lazy = @{sort bifinite};
   217 
   218 val add_domain =
   219     gen_add_domain Sign.certify_typ Domain_Axioms.add_axioms pcpo_arg;
   220 
   221 val add_new_domain =
   222     gen_add_domain Sign.certify_typ define_isos rep_arg;
   223 
   224 val add_domain_cmd =
   225     gen_add_domain Syntax.read_typ_global Domain_Axioms.add_axioms pcpo_arg;
   226 
   227 val add_new_domain_cmd =
   228     gen_add_domain Syntax.read_typ_global define_isos rep_arg;
   229 
   230 
   231 (** outer syntax **)
   232 
   233 val _ = Keyword.keyword "lazy";
   234 
   235 val dest_decl : (bool * binding option * string) parser =
   236   Parse.$$$ "(" |-- Scan.optional (Parse.$$$ "lazy" >> K true) false --
   237     (Parse.binding >> SOME) -- (Parse.$$$ "::" |-- Parse.typ)  --| Parse.$$$ ")" >> Parse.triple1
   238     || Parse.$$$ "(" |-- Parse.$$$ "lazy" |-- Parse.typ --| Parse.$$$ ")"
   239     >> (fn t => (true,NONE,t))
   240     || Parse.typ >> (fn t => (false,NONE,t));
   241 
   242 val cons_decl =
   243   Parse.binding -- Scan.repeat dest_decl -- Parse.opt_mixfix;
   244 
   245 val domain_decl =
   246   (Parse.type_args_constrained -- Parse.binding -- Parse.opt_mixfix) --
   247     (Parse.$$$ "=" |-- Parse.enum1 "|" cons_decl);
   248 
   249 val domains_decl =
   250   Scan.option (Parse.$$$ "(" |-- Parse.binding --| Parse.$$$ ")") --
   251     Parse.and_list1 domain_decl;
   252 
   253 fun mk_domain
   254     (definitional : bool)
   255     (opt_name : binding option,
   256      doms : ((((string * string option) list * binding) * mixfix) *
   257              ((binding * (bool * binding option * string) list) * mixfix) list) list ) =
   258   let
   259     val names = map (fn (((_, t), _), _) => Binding.name_of t) doms;
   260     val specs : ((string * string option) list * binding * mixfix *
   261                  (binding * (bool * binding option * string) list * mixfix) list) list =
   262         map (fn (((vs, t), mx), cons) =>
   263                 (vs, t, mx, map (fn ((c, ds), mx) => (c, ds, mx)) cons)) doms;
   264     val comp_dbind =
   265         case opt_name of NONE => Binding.name (space_implode "_" names)
   266                        | SOME s => s;
   267   in
   268     if definitional 
   269     then add_new_domain_cmd comp_dbind specs
   270     else add_domain_cmd comp_dbind specs
   271   end;
   272 
   273 val _ =
   274   Outer_Syntax.command "domain" "define recursive domains (HOLCF)"
   275     Keyword.thy_decl (domains_decl >> (Toplevel.theory o mk_domain false));
   276 
   277 val _ =
   278   Outer_Syntax.command "new_domain" "define recursive domains (HOLCF)"
   279     Keyword.thy_decl (domains_decl >> (Toplevel.theory o mk_domain true));
   280 
   281 end;