src/Pure/Syntax/syn_ext.ML
author wenzelm
Fri Dec 13 17:30:28 1996 +0100 (1996-12-13)
changeset 2382 e7c2bce815ba
parent 2364 821f44a0abba
child 2694 b98365c6e869
permissions -rw-r--r--
added fix_tr', syn_ext_trfunsT;
changed syn_ext_trfuns (fix_tr');
     1 (*  Title:      Pure/Syntax/syn_ext.ML
     2     ID:         $Id$
     3     Author:     Markus Wenzel and Carsten Clasohm, TU Muenchen
     4 
     5 Syntax extension (internal interface).
     6 *)
     7 
     8 signature SYN_EXT0 =
     9   sig
    10   val typeT: typ
    11   val constrainC: string
    12   end;
    13 
    14 signature SYN_EXT =
    15   sig
    16   include SYN_EXT0
    17   val logic: string
    18   val args: string
    19   val cargs: string
    20   val any: string
    21   val sprop: string
    22   val typ_to_nonterm: typ -> string
    23   datatype xsymb =
    24     Delim of string |
    25     Argument of string * int |
    26     Space of string |
    27     Bg of int | Brk of int | En
    28   datatype xprod = XProd of string * xsymb list * string * int
    29   val max_pri: int
    30   val chain_pri: int
    31   val delims_of: xprod list -> string list
    32   datatype mfix = Mfix of string * typ * string * int list * int
    33   datatype syn_ext =
    34     SynExt of {
    35       logtypes: string list,
    36       xprods: xprod list,
    37       consts: string list,
    38       parse_ast_translation: (string * (Ast.ast list -> Ast.ast)) list,
    39       parse_rules: (Ast.ast * Ast.ast) list,
    40       parse_translation: (string * (term list -> term)) list,
    41       print_translation: (string * (typ -> term list -> term)) list,
    42       print_rules: (Ast.ast * Ast.ast) list,
    43       print_ast_translation: (string * (Ast.ast list -> Ast.ast)) list}
    44   val mk_syn_ext: bool -> string list -> mfix list ->
    45     string list -> (string * (Ast.ast list -> Ast.ast)) list *
    46     (string * (term list -> term)) list *
    47     (string * (typ -> term list -> term)) list * (string * (Ast.ast list -> Ast.ast)) list
    48     -> (Ast.ast * Ast.ast) list * (Ast.ast * Ast.ast) list -> syn_ext
    49   val syn_ext: string list -> mfix list -> string list ->
    50     (string * (Ast.ast list -> Ast.ast)) list * (string * (term list -> term)) list *
    51     (string * (typ -> term list -> term)) list * (string * (Ast.ast list -> Ast.ast)) list
    52     -> (Ast.ast * Ast.ast) list * (Ast.ast * Ast.ast) list -> syn_ext
    53   val syn_ext_logtypes: string list -> syn_ext
    54   val syn_ext_const_names: string list -> string list -> syn_ext
    55   val syn_ext_rules: string list -> (Ast.ast * Ast.ast) list * (Ast.ast * Ast.ast) list -> syn_ext
    56   val fix_tr': (term list -> term) -> typ -> term list -> term
    57   val syn_ext_trfuns: string list ->
    58     (string * (Ast.ast list -> Ast.ast)) list * (string * (term list -> term)) list *
    59     (string * (term list -> term)) list * (string * (Ast.ast list -> Ast.ast)) list
    60     -> syn_ext
    61   val syn_ext_trfunsT: string list -> (string * (typ -> term list -> term)) list -> syn_ext
    62   val pure_ext: syn_ext
    63   end;
    64 
    65 structure SynExt : SYN_EXT =
    66 struct
    67 
    68 open Lexicon Ast;
    69 
    70 (** misc definitions **)
    71 
    72 (* syntactic categories *)
    73 
    74 val logic = "logic";
    75 val logicT = Type (logic, []);
    76 
    77 val args = "args";
    78 val cargs = "cargs";
    79 
    80 val typeT = Type ("type", []);
    81 
    82 val sprop = "#prop";
    83 val spropT = Type (sprop, []);
    84 
    85 val any = "any";
    86 val anyT = Type (any, []);
    87 
    88 
    89 (* constants *)
    90 
    91 val constrainC = "_constrain";
    92 
    93 
    94 
    95 (** datatype xprod **)
    96 
    97 (*Delim s: delimiter s
    98   Argument (s, p): nonterminal s requiring priority >= p, or valued token
    99   Space s: some white space for printing
   100   Bg, Brk, En: blocks and breaks for pretty printing*)
   101 
   102 datatype xsymb =
   103   Delim of string |
   104   Argument of string * int |
   105   Space of string |
   106   Bg of int | Brk of int | En;
   107 
   108 
   109 (*XProd (lhs, syms, c, p):
   110     lhs: name of nonterminal on the lhs of the production
   111     syms: list of symbols on the rhs of the production
   112     c: head of parse tree
   113     p: priority of this production*)
   114 
   115 datatype xprod = XProd of string * xsymb list * string * int;
   116 
   117 val max_pri = 1000;   (*maximum legal priority*)
   118 val chain_pri = ~1;   (*dummy for chain productions*)
   119 
   120 
   121 (* delims_of *)
   122 
   123 fun delims_of xprods =
   124   let
   125     fun del_of (Delim s) = Some s
   126       | del_of _ = None;
   127 
   128     fun dels_of (XProd (_, xsymbs, _, _)) =
   129       mapfilter del_of xsymbs;
   130   in
   131     distinct (flat (map dels_of xprods))
   132   end;
   133 
   134 
   135 
   136 (** datatype mfix **)
   137 
   138 (*Mfix (sy, ty, c, ps, p):
   139     sy: rhs of production as symbolic string
   140     ty: type description of production
   141     c: head of parse tree
   142     ps: priorities of arguments in sy
   143     p: priority of production*)
   144 
   145 datatype mfix = Mfix of string * typ * string * int list * int;
   146 
   147 
   148 (* typ_to_nonterm *)
   149 
   150 fun typ_to_nt _ (Type (c, _)) = c
   151   | typ_to_nt default _ = default;
   152 
   153 (*get nonterminal for rhs*)
   154 val typ_to_nonterm = typ_to_nt any;
   155 
   156 (*get nonterminal for lhs*)
   157 val typ_to_nonterm1 = typ_to_nt logic;
   158 
   159 
   160 (* mfix_to_xprod *)
   161 
   162 fun mfix_to_xprod convert logtypes (Mfix (sy, typ, const, pris, pri)) =
   163   let
   164     fun err msg =
   165       (writeln ("Error in mixfix annotation " ^ quote sy ^ " for " ^ quote const);
   166         error msg);
   167     fun post_err () = error ("The error(s) above occurred in mixfix annotation " ^
   168       quote sy ^ " for " ^ quote const);
   169 
   170     fun check_pri p =
   171       if p >= 0 andalso p <= max_pri then ()
   172       else err ("precedence out of range: " ^ string_of_int p);
   173 
   174     fun blocks_ok [] 0 = true
   175       | blocks_ok [] _ = false
   176       | blocks_ok (Bg _ :: syms) n = blocks_ok syms (n + 1)
   177       | blocks_ok (En :: _) 0 = false
   178       | blocks_ok (En :: syms) n = blocks_ok syms (n - 1)
   179       | blocks_ok (_ :: syms) n = blocks_ok syms n;
   180 
   181     fun check_blocks syms =
   182       if blocks_ok syms 0 then ()
   183       else err "unbalanced block parentheses";
   184 
   185 
   186     local
   187       fun is_meta c = c mem ["(", ")", "/", "_"];
   188 
   189       fun scan_delim_char ("'" :: c :: cs) =
   190             if is_blank c then raise LEXICAL_ERROR else (c, cs)
   191         | scan_delim_char ["'"] = err "trailing escape character"
   192         | scan_delim_char (chs as c :: cs) =
   193             if is_blank c orelse is_meta c then raise LEXICAL_ERROR else (c, cs)
   194         | scan_delim_char [] = raise LEXICAL_ERROR;
   195 
   196       val scan_sym =
   197         $$ "_" >> K (Argument ("", 0)) ||
   198         $$ "(" -- scan_int >> (Bg o #2) ||
   199         $$ ")" >> K En ||
   200         $$ "/" -- $$ "/" >> K (Brk ~1) ||
   201         $$ "/" -- scan_any is_blank >> (Brk o length o #2) ||
   202         scan_any1 is_blank >> (Space o implode) ||
   203         repeat1 scan_delim_char >> (Delim o implode);
   204 
   205       val scan_symb =
   206         scan_sym >> Some ||
   207         $$ "'" -- scan_one is_blank >> K None;
   208     in
   209       val scan_symbs = mapfilter I o #1 o repeat scan_symb;
   210     end;
   211 
   212 
   213     val cons_fst = apfst o cons;
   214 
   215     fun add_args [] ty [] = ([], typ_to_nonterm1 ty)
   216       | add_args [] _ _ = err "too many precedences"
   217       | add_args (Argument _ :: syms) (Type ("fun", [ty, tys])) [] =
   218           cons_fst (Argument (typ_to_nonterm ty, 0)) (add_args syms tys [])
   219       | add_args (Argument _ :: syms) (Type ("fun", [ty, tys])) (p :: ps) =
   220           cons_fst (Argument (typ_to_nonterm ty, p)) (add_args syms tys ps)
   221       | add_args (Argument _ :: _) _ _ =
   222           err "more arguments than in corresponding type"
   223       | add_args (sym :: syms) ty ps = cons_fst sym (add_args syms ty ps);
   224 
   225 
   226     fun is_arg (Argument _) = true
   227       | is_arg _ = false;
   228 
   229     fun is_term (Delim _) = true
   230       | is_term (Argument (s, _)) = is_terminal s
   231       | is_term _ = false;
   232 
   233     fun rem_pri (Argument (s, _)) = Argument (s, chain_pri)
   234       | rem_pri sym = sym;
   235 
   236     fun is_delim (Delim _) = true
   237       | is_delim _ = false;
   238 
   239     (*replace logical types on rhs by "logic"*)
   240     fun unify_logtypes copy_prod (a as (Argument (s, p))) =
   241           if s mem logtypes then Argument (logic, p)
   242           else a
   243       | unify_logtypes _ a = a;
   244 
   245 
   246     val sy_chars =
   247       SymbolFont.read_charnames (explode sy) handle ERROR => post_err ();
   248     val raw_symbs = scan_symbs sy_chars;
   249     val (symbs, lhs) = add_args raw_symbs typ pris;
   250     val copy_prod =
   251       lhs mem ["prop", "logic"]
   252         andalso const <> ""
   253         andalso not (null symbs)
   254         andalso not (exists is_delim symbs);
   255     val lhs' =
   256       if convert andalso not copy_prod then
   257        (if lhs mem logtypes then logic
   258         else if lhs = "prop" then sprop else lhs)
   259       else lhs;
   260     val symbs' = map (unify_logtypes copy_prod) symbs;
   261     val xprod = XProd (lhs', symbs', const, pri);
   262   in
   263     seq check_pri pris;
   264     check_pri pri;
   265     check_blocks symbs';
   266 
   267     if is_terminal lhs' then err ("illegal lhs: " ^ lhs')
   268     else if const <> "" then xprod
   269     else if length (filter is_arg symbs') <> 1 then
   270       err "copy production must have exactly one argument"
   271     else if exists is_term symbs' then xprod
   272     else XProd (lhs', map rem_pri symbs', "", chain_pri)
   273   end;
   274 
   275 
   276 (** datatype syn_ext **)
   277 
   278 datatype syn_ext =
   279   SynExt of {
   280     logtypes: string list,
   281     xprods: xprod list,
   282     consts: string list,
   283     parse_ast_translation: (string * (Ast.ast list -> Ast.ast)) list,
   284     parse_rules: (Ast.ast * Ast.ast) list,
   285     parse_translation: (string * (term list -> term)) list,
   286     print_translation: (string * (typ -> term list -> term)) list,
   287     print_rules: (Ast.ast * Ast.ast) list,
   288     print_ast_translation: (string * (Ast.ast list -> Ast.ast)) list};
   289 
   290 
   291 (* syn_ext *)
   292 
   293 fun mk_syn_ext convert logtypes mfixes consts trfuns rules =
   294   let
   295     val (parse_ast_translation, parse_translation, print_translation,
   296       print_ast_translation) = trfuns;
   297     val (parse_rules, print_rules) = rules;
   298     val logtypes' = logtypes \ "prop";
   299 
   300     val mfix_consts = distinct (map (fn (Mfix (_, _, c, _, _)) => c) mfixes);
   301     val xprods = map (mfix_to_xprod convert logtypes') mfixes;
   302   in
   303     SynExt {
   304       logtypes = logtypes',
   305       xprods = xprods,
   306       consts = filter is_xid (consts union mfix_consts),
   307       parse_ast_translation = parse_ast_translation,
   308       parse_rules = parse_rules,
   309       parse_translation = parse_translation,
   310       print_translation = print_translation,
   311       print_rules = print_rules,
   312       print_ast_translation = print_ast_translation}
   313   end;
   314 
   315 
   316 val syn_ext = mk_syn_ext true;
   317 
   318 fun syn_ext_logtypes logtypes =
   319   syn_ext logtypes [] [] ([], [], [], []) ([], []);
   320 
   321 fun syn_ext_const_names logtypes cs =
   322   syn_ext logtypes [] cs ([], [], [], []) ([], []);
   323 
   324 fun syn_ext_rules logtypes rules =
   325   syn_ext logtypes [] [] ([], [], [], []) rules;
   326 
   327 fun fix_tr' f _ args = f args;
   328 
   329 fun syn_ext_trfuns logtypes (atrs, trs, tr's, atr's) =
   330   syn_ext logtypes [] [] (atrs, trs, map (apsnd fix_tr') tr's, atr's) ([], []);
   331 
   332 fun syn_ext_trfunsT logtypes tr's =
   333   syn_ext logtypes [] [] ([], [], tr's, []) ([], []);
   334 
   335 
   336 (* pure_ext *)
   337 
   338 val pure_ext = mk_syn_ext false []
   339   [Mfix ("_", spropT --> propT, "", [0], 0),
   340    Mfix ("_", logicT --> anyT, "", [0], 0),
   341    Mfix ("_", spropT --> anyT, "", [0], 0),
   342    Mfix ("'(_')", logicT --> logicT, "", [0], max_pri),
   343    Mfix ("'(_')", spropT --> spropT, "", [0], max_pri),
   344    Mfix ("_::_",  [logicT, typeT] ---> logicT, "_constrain", [4, 0], 3),
   345    Mfix ("_::_",  [spropT, typeT] ---> spropT, "_constrain", [4, 0], 3)]
   346   []
   347   ([], [], [], [])
   348   ([], []);
   349 
   350 end;