src/HOLCF/Tools/Domain/domain_extender.ML
author haftmann
Tue Nov 24 17:28:25 2009 +0100 (2009-11-24)
changeset 33955 fff6f11b1f09
parent 33798 46cbbcbd4e68
child 33957 e9afca2118d4
permissions -rw-r--r--
curried take/drop
     1 (*  Title:      HOLCF/Tools/Domain/domain_extender.ML
     2     Author:     David von Oheimb
     3 
     4 Theory extender for domain command, including theory syntax.
     5 *)
     6 
     7 signature DOMAIN_EXTENDER =
     8 sig
     9   val add_domain_cmd:
    10       string ->
    11       ((string * string option) list * binding * mixfix *
    12        (binding * (bool * binding option * string) list * mixfix) list) list
    13       -> theory -> theory
    14 
    15   val add_domain:
    16       string ->
    17       ((string * string option) list * binding * mixfix *
    18        (binding * (bool * binding option * typ) list * mixfix) list) list
    19       -> theory -> theory
    20 
    21   val add_new_domain_cmd:
    22       string ->
    23       ((string * string option) list * binding * mixfix *
    24        (binding * (bool * binding option * string) list * mixfix) list) list
    25       -> theory -> theory
    26 
    27   val add_new_domain:
    28       string ->
    29       ((string * string option) list * binding * mixfix *
    30        (binding * (bool * binding option * typ) list * mixfix) list) list
    31       -> theory -> theory
    32 end;
    33 
    34 structure Domain_Extender :> DOMAIN_EXTENDER =
    35 struct
    36 
    37 open Domain_Library;
    38 
    39 (* ----- general testing and preprocessing of constructor list -------------- *)
    40 fun check_and_sort_domain
    41     (definitional : bool)
    42     (dtnvs : (string * typ list) list)
    43     (cons'' : (binding * (bool * binding option * typ) list * mixfix) list list)
    44     (thy : theory)
    45     : ((string * typ list) *
    46        (binding * (bool * binding option * typ) list * mixfix) list) list =
    47   let
    48     val defaultS = Sign.defaultS thy;
    49 
    50     val test_dupl_typs =
    51       case duplicates (op =) (map fst dtnvs) of 
    52         [] => false | dups => error ("Duplicate types: " ^ commas_quote dups);
    53 
    54     val all_cons = map (Binding.name_of o first) (flat cons'');
    55     val test_dupl_cons =
    56       case duplicates (op =) all_cons of 
    57         [] => false | dups => error ("Duplicate constructors: " 
    58                                       ^ commas_quote dups);
    59     val all_sels =
    60       (map Binding.name_of o map_filter second o maps second) (flat cons'');
    61     val test_dupl_sels =
    62       case duplicates (op =) all_sels of
    63         [] => false | dups => error("Duplicate selectors: "^commas_quote dups);
    64 
    65     fun test_dupl_tvars s =
    66       case duplicates (op =) (map(fst o dest_TFree)s) of
    67         [] => false | dups => error("Duplicate type arguments: " 
    68                                     ^commas_quote dups);
    69     val test_dupl_tvars' = exists test_dupl_tvars (map snd dtnvs);
    70 
    71     (* test for free type variables, illegal sort constraints on rhs,
    72        non-pcpo-types and invalid use of recursive type;
    73        replace sorts in type variables on rhs *)
    74     fun analyse_equation ((dname,typevars),cons') = 
    75       let
    76         val tvars = map dest_TFree typevars;
    77         val distinct_typevars = map TFree tvars;
    78         fun rm_sorts (TFree(s,_)) = TFree(s,[])
    79           | rm_sorts (Type(s,ts)) = Type(s,remove_sorts ts)
    80           | rm_sorts (TVar(s,_))  = TVar(s,[])
    81         and remove_sorts l = map rm_sorts l;
    82         val indirect_ok = ["*","Cfun.->","Ssum.++","Sprod.**","Up.u"]
    83         fun analyse indirect (TFree(v,s))  =
    84             (case AList.lookup (op =) tvars v of 
    85                NONE => error ("Free type variable " ^ quote v ^ " on rhs.")
    86              | SOME sort => if eq_set (op =) (s, defaultS) orelse
    87                                eq_set (op =) (s, sort)
    88                             then TFree(v,sort)
    89                             else error ("Inconsistent sort constraint" ^
    90                                         " for type variable " ^ quote v))
    91           | analyse indirect (t as Type(s,typl)) =
    92             (case AList.lookup (op =) dtnvs s of
    93                NONE =>
    94                  if definitional orelse s mem indirect_ok
    95                  then Type(s,map (analyse false) typl)
    96                  else Type(s,map (analyse true) typl)
    97              | SOME typevars =>
    98                  if indirect 
    99                  then error ("Indirect recursion of type " ^ 
   100                              quote (string_of_typ thy t))
   101                  else if dname <> s orelse
   102                          (** BUG OR FEATURE?:
   103                              mutual recursion may use different arguments **)
   104                          remove_sorts typevars = remove_sorts typl 
   105                  then Type(s,map (analyse true) typl)
   106                  else error ("Direct recursion of type " ^ 
   107                              quote (string_of_typ thy t) ^ 
   108                              " with different arguments"))
   109           | analyse indirect (TVar _) = Imposs "extender:analyse";
   110         fun check_pcpo lazy T =
   111             let val ok = if lazy then cpo_type else pcpo_type
   112             in if ok thy T then T
   113                else error ("Constructor argument type is not of sort pcpo: " ^
   114                            string_of_typ thy T)
   115             end;
   116         fun analyse_arg (lazy, sel, T) =
   117             (lazy, sel, check_pcpo lazy (analyse false T));
   118         fun analyse_con (b, args, mx) = (b, map analyse_arg args, mx);
   119       in ((dname,distinct_typevars), map analyse_con cons') end; 
   120   in ListPair.map analyse_equation (dtnvs,cons'')
   121   end; (* let *)
   122 
   123 (* ----- calls for building new thy and thms -------------------------------- *)
   124 
   125 fun gen_add_domain
   126     (prep_typ : theory -> 'a -> typ)
   127     (comp_dnam : string)
   128     (eqs''' : ((string * string option) list * binding * mixfix *
   129                (binding * (bool * binding option * 'a) list * mixfix) list) list)
   130     (thy''' : theory) =
   131   let
   132     fun readS (SOME s) = Syntax.read_sort_global thy''' s
   133       | readS NONE = Sign.defaultS thy''';
   134     fun readTFree (a, s) = TFree (a, readS s);
   135 
   136     val dtnvs = map (fn (vs,dname:binding,mx,_) => 
   137                         (dname, map readTFree vs, mx)) eqs''';
   138     val cons''' = map (fn (_,_,_,cons) => cons) eqs''';
   139     fun thy_type  (dname,tvars,mx) = (dname, length tvars, mx);
   140     fun thy_arity (dname,tvars,mx) =
   141         (Sign.full_name thy''' dname, map (snd o dest_TFree) tvars, pcpoS);
   142     val thy'' =
   143       thy'''
   144       |> Sign.add_types (map thy_type dtnvs)
   145       |> fold (AxClass.axiomatize_arity o thy_arity) dtnvs;
   146     val cons'' =
   147       map (map (upd_second (map (upd_third (prep_typ thy''))))) cons''';
   148     val dtnvs' =
   149       map (fn (dname,vs,mx) => (Sign.full_name thy''' dname,vs)) dtnvs;
   150     val eqs' : ((string * typ list) *
   151         (binding * (bool * binding option * typ) list * mixfix) list) list =
   152       check_and_sort_domain false dtnvs' cons'' thy'';
   153     val thy' = thy'' |> Domain_Syntax.add_syntax false comp_dnam eqs';
   154     val dts  = map (Type o fst) eqs';
   155     val new_dts = map (fn ((s,Ts),_) => (s, map (fst o dest_TFree) Ts)) eqs';
   156     fun strip ss = (uncurry drop) (find_index (fn s => s = "'") ss + 1, ss);
   157     fun typid (Type  (id,_)) =
   158         let val c = hd (Symbol.explode (Long_Name.base_name id))
   159         in if Symbol.is_letter c then c else "t" end
   160       | typid (TFree (id,_)   ) = hd (strip (tl (Symbol.explode id)))
   161       | typid (TVar ((id,_),_)) = hd (tl (Symbol.explode id));
   162     fun one_con (con,args,mx) =
   163         ((Syntax.const_name mx (Binding.name_of con)),
   164          ListPair.map (fn ((lazy,sel,tp),vn) =>
   165            mk_arg ((lazy, DatatypeAux.dtyp_of_typ new_dts tp),
   166                    Option.map Binding.name_of sel,vn))
   167                       (args,(mk_var_names(map (typid o third) args)))
   168         ) : cons;
   169     val eqs : eq list =
   170         map (fn (dtnvs,cons') => (dtnvs, map one_con cons')) eqs';
   171     val thy = thy' |> Domain_Axioms.add_axioms false comp_dnam eqs;
   172     val ((rewss, take_rews), theorems_thy) =
   173         thy
   174           |> fold_map (fn eq => Domain_Theorems.theorems (eq, eqs)) eqs
   175           ||>> Domain_Theorems.comp_theorems (comp_dnam, eqs);
   176   in
   177     theorems_thy
   178       |> Sign.add_path (Long_Name.base_name comp_dnam)
   179       |> PureThy.add_thmss
   180            [((Binding.name "rews", flat rewss @ take_rews), [])]
   181       |> snd
   182       |> Sign.parent_path
   183   end;
   184 
   185 fun gen_add_new_domain
   186     (prep_typ : theory -> 'a -> typ)
   187     (comp_dnam : string)
   188     (eqs''' : ((string * string option) list * binding * mixfix *
   189                (binding * (bool * binding option * 'a) list * mixfix) list) list)
   190     (thy''' : theory) =
   191   let
   192     fun readS (SOME s) = Syntax.read_sort_global thy''' s
   193       | readS NONE = Sign.defaultS thy''';
   194     fun readTFree (a, s) = TFree (a, readS s);
   195 
   196     val dtnvs = map (fn (vs,dname:binding,mx,_) => 
   197                         (dname, map readTFree vs, mx)) eqs''';
   198     val cons''' = map (fn (_,_,_,cons) => cons) eqs''';
   199     fun thy_type  (dname,tvars,mx) = (dname, length tvars, mx);
   200     fun thy_arity (dname,tvars,mx) =
   201       (Sign.full_name thy''' dname, map (snd o dest_TFree) tvars, @{sort rep});
   202 
   203     (* this theory is used just for parsing and error checking *)
   204     val tmp_thy = thy'''
   205       |> Theory.copy
   206       |> Sign.add_types (map thy_type dtnvs)
   207       |> fold (AxClass.axiomatize_arity o thy_arity) dtnvs;
   208 
   209     val cons'' : (binding * (bool * binding option * typ) list * mixfix) list list =
   210       map (map (upd_second (map (upd_third (prep_typ tmp_thy))))) cons''';
   211     val dtnvs' : (string * typ list) list =
   212       map (fn (dname,vs,mx) => (Sign.full_name thy''' dname,vs)) dtnvs;
   213     val eqs' : ((string * typ list) *
   214         (binding * (bool * binding option * typ) list * mixfix) list) list =
   215       check_and_sort_domain true dtnvs' cons'' tmp_thy;
   216 
   217     fun mk_arg_typ (lazy, dest_opt, T) = if lazy then mk_uT T else T;
   218     fun mk_con_typ (bind, args, mx) =
   219         if null args then oneT else foldr1 mk_sprodT (map mk_arg_typ args);
   220     fun mk_eq_typ (_, cons) = foldr1 mk_ssumT (map mk_con_typ cons);
   221     
   222     val thy'' = thy''' |>
   223       Domain_Isomorphism.domain_isomorphism
   224         (map (fn ((vs, dname, mx, _), eq) =>
   225                  (map fst vs, dname, mx, mk_eq_typ eq))
   226              (eqs''' ~~ eqs'))
   227 
   228     val thy' = thy'' |> Domain_Syntax.add_syntax true comp_dnam eqs';
   229     val dts  = map (Type o fst) eqs';
   230     val new_dts = map (fn ((s,Ts),_) => (s, map (fst o dest_TFree) Ts)) eqs';
   231     fun strip ss = (uncurry drop) (find_index (fn s => s = "'") ss + 1, ss);
   232     fun typid (Type  (id,_)) =
   233         let val c = hd (Symbol.explode (Long_Name.base_name id))
   234         in if Symbol.is_letter c then c else "t" end
   235       | typid (TFree (id,_)   ) = hd (strip (tl (Symbol.explode id)))
   236       | typid (TVar ((id,_),_)) = hd (tl (Symbol.explode id));
   237     fun one_con (con,args,mx) =
   238         ((Syntax.const_name mx (Binding.name_of con)),
   239          ListPair.map (fn ((lazy,sel,tp),vn) =>
   240            mk_arg ((lazy, DatatypeAux.dtyp_of_typ new_dts tp),
   241                    Option.map Binding.name_of sel,vn))
   242                       (args,(mk_var_names(map (typid o third) args)))
   243         ) : cons;
   244     val eqs : eq list =
   245         map (fn (dtnvs,cons') => (dtnvs, map one_con cons')) eqs';
   246     val thy = thy' |> Domain_Axioms.add_axioms true comp_dnam eqs;
   247     val ((rewss, take_rews), theorems_thy) =
   248         thy
   249           |> fold_map (fn eq => Domain_Theorems.theorems (eq, eqs)) eqs
   250           ||>> Domain_Theorems.comp_theorems (comp_dnam, eqs);
   251   in
   252     theorems_thy
   253       |> Sign.add_path (Long_Name.base_name comp_dnam)
   254       |> PureThy.add_thmss
   255            [((Binding.name "rews", flat rewss @ take_rews), [])]
   256       |> snd
   257       |> Sign.parent_path
   258   end;
   259 
   260 val add_domain = gen_add_domain Sign.certify_typ;
   261 val add_domain_cmd = gen_add_domain Syntax.read_typ_global;
   262 
   263 val add_new_domain = gen_add_new_domain Sign.certify_typ;
   264 val add_new_domain_cmd = gen_add_new_domain Syntax.read_typ_global;
   265 
   266 
   267 (** outer syntax **)
   268 
   269 local structure P = OuterParse and K = OuterKeyword in
   270 
   271 val _ = OuterKeyword.keyword "lazy";
   272 
   273 val dest_decl : (bool * binding option * string) parser =
   274   P.$$$ "(" |-- Scan.optional (P.$$$ "lazy" >> K true) false --
   275     (P.binding >> SOME) -- (P.$$$ "::" |-- P.typ)  --| P.$$$ ")" >> P.triple1
   276     || P.$$$ "(" |-- P.$$$ "lazy" |-- P.typ --| P.$$$ ")"
   277     >> (fn t => (true,NONE,t))
   278     || P.typ >> (fn t => (false,NONE,t));
   279 
   280 val cons_decl =
   281   P.binding -- Scan.repeat dest_decl -- P.opt_mixfix;
   282 
   283 val type_var' : (string * string option) parser =
   284   (P.type_ident -- Scan.option (P.$$$ "::" |-- P.!!! P.sort));
   285 
   286 val type_args' : (string * string option) list parser =
   287   type_var' >> single
   288   || P.$$$ "(" |-- P.!!! (P.list1 type_var' --| P.$$$ ")")
   289   || Scan.succeed [];
   290 
   291 val domain_decl =
   292   (type_args' -- P.binding -- P.opt_infix) --
   293     (P.$$$ "=" |-- P.enum1 "|" cons_decl);
   294 
   295 val domains_decl =
   296   Scan.option (P.$$$ "(" |-- P.name --| P.$$$ ")") --
   297     P.and_list1 domain_decl;
   298 
   299 fun mk_domain
   300     (definitional : bool)
   301     (opt_name : string option,
   302      doms : ((((string * string option) list * binding) * mixfix) *
   303              ((binding * (bool * binding option * string) list) * mixfix) list) list ) =
   304   let
   305     val names = map (fn (((_, t), _), _) => Binding.name_of t) doms;
   306     val specs : ((string * string option) list * binding * mixfix *
   307                  (binding * (bool * binding option * string) list * mixfix) list) list =
   308         map (fn (((vs, t), mx), cons) =>
   309                 (vs, t, mx, map (fn ((c, ds), mx) => (c, ds, mx)) cons)) doms;
   310     val comp_dnam =
   311         case opt_name of NONE => space_implode "_" names | SOME s => s;
   312   in
   313     if definitional 
   314     then add_new_domain_cmd comp_dnam specs
   315     else add_domain_cmd comp_dnam specs
   316   end;
   317 
   318 val _ =
   319   OuterSyntax.command "domain" "define recursive domains (HOLCF)"
   320     K.thy_decl (domains_decl >> (Toplevel.theory o mk_domain false));
   321 
   322 val _ =
   323   OuterSyntax.command "new_domain" "define recursive domains (HOLCF)"
   324     K.thy_decl (domains_decl >> (Toplevel.theory o mk_domain true));
   325 
   326 end;
   327 
   328 end;