src/Pure/Syntax/ast.ML
author haftmann
Tue Nov 24 17:28:25 2009 +0100 (2009-11-24)
changeset 33955 fff6f11b1f09
parent 33038 8f9594c31de4
child 33957 e9afca2118d4
permissions -rw-r--r--
curried take/drop
     1 (*  Title:      Pure/Syntax/ast.ML
     2     Author:     Markus Wenzel, TU Muenchen
     3 
     4 Abstract syntax trees, translation rules, matching and normalization of asts.
     5 *)
     6 
     7 signature AST0 =
     8 sig
     9   datatype ast =
    10     Constant of string |
    11     Variable of string |
    12     Appl of ast list
    13   exception AST of string * ast list
    14 end;
    15 
    16 signature AST1 =
    17 sig
    18   include AST0
    19   val mk_appl: ast -> ast list -> ast
    20   val str_of_ast: ast -> string
    21   val pretty_ast: ast -> Pretty.T
    22   val pretty_rule: ast * ast -> Pretty.T
    23   val fold_ast: string -> ast list -> ast
    24   val fold_ast_p: string -> ast list * ast -> ast
    25   val unfold_ast: string -> ast -> ast list
    26   val unfold_ast_p: string -> ast -> ast list * ast
    27   val trace_ast: bool Unsynchronized.ref
    28   val stat_ast: bool Unsynchronized.ref
    29 end;
    30 
    31 signature AST =
    32 sig
    33   include AST1
    34   val head_of_rule: ast * ast -> string
    35   val rule_error: ast * ast -> string option
    36   val normalize: bool -> bool -> (string -> (ast * ast) list) -> ast -> ast
    37   val normalize_ast: (string -> (ast * ast) list) -> ast -> ast
    38 end;
    39 
    40 structure Ast : AST =
    41 struct
    42 
    43 (** abstract syntax trees **)
    44 
    45 (*asts come in two flavours:
    46    - ordinary asts representing terms and typs: Variables are (often) treated
    47      like Constants;
    48    - patterns used as lhs and rhs in rules: Variables are placeholders for
    49      proper asts*)
    50 
    51 datatype ast =
    52   Constant of string |    (*"not", "_abs", "fun"*)
    53   Variable of string |    (*x, ?x, 'a, ?'a*)
    54   Appl of ast list;       (*(f x y z), ("fun" 'a 'b), ("_abs" x t)*)
    55 
    56 
    57 (*the list of subasts of an Appl node has to contain at least 2 elements, i.e.
    58   there are no empty asts or nullary applications; use mk_appl for convenience*)
    59 
    60 fun mk_appl f [] = f
    61   | mk_appl f args = Appl (f :: args);
    62 
    63 
    64 (*exception for system errors involving asts*)
    65 
    66 exception AST of string * ast list;
    67 
    68 
    69 
    70 (** print asts in a LISP-like style **)
    71 
    72 fun str_of_ast (Constant a) = quote a
    73   | str_of_ast (Variable x) = x
    74   | str_of_ast (Appl asts) = "(" ^ (space_implode " " (map str_of_ast asts)) ^ ")";
    75 
    76 fun pretty_ast (Constant a) = Pretty.quote (Pretty.str a)
    77   | pretty_ast (Variable x) = Pretty.str x
    78   | pretty_ast (Appl asts) =
    79       Pretty.enclose "(" ")" (Pretty.breaks (map pretty_ast asts));
    80 
    81 fun pretty_rule (lhs, rhs) =
    82   Pretty.block [pretty_ast lhs, Pretty.str "  ->", Pretty.brk 2, pretty_ast rhs];
    83 
    84 
    85 (* head_of_ast, head_of_rule *)
    86 
    87 fun head_of_ast (Constant a) = a
    88   | head_of_ast (Appl (Constant a :: _)) = a
    89   | head_of_ast _ = "";
    90 
    91 fun head_of_rule (lhs, _) = head_of_ast lhs;
    92 
    93 
    94 
    95 (** check translation rules **)
    96 
    97 fun rule_error (lhs, rhs) =
    98   let
    99     fun add_vars (Constant _) = I
   100       | add_vars (Variable x) = cons x
   101       | add_vars (Appl asts) = fold add_vars asts;
   102 
   103     val lvars = add_vars lhs [];
   104     val rvars = add_vars rhs [];
   105   in
   106     if has_duplicates (op =) lvars then SOME "duplicate vars in lhs"
   107     else if not (subset (op =) (rvars, lvars)) then SOME "rhs contains extra variables"
   108     else NONE
   109   end;
   110 
   111 
   112 
   113 (** ast translation utilities **)
   114 
   115 (* fold asts *)
   116 
   117 fun fold_ast _ [] = raise Match
   118   | fold_ast _ [y] = y
   119   | fold_ast c (x :: xs) = Appl [Constant c, x, fold_ast c xs];
   120 
   121 fun fold_ast_p c = uncurry (fold_rev (fn x => fn xs => Appl [Constant c, x, xs]));
   122 
   123 
   124 (* unfold asts *)
   125 
   126 fun unfold_ast c (y as Appl [Constant c', x, xs]) =
   127       if c = c' then x :: unfold_ast c xs else [y]
   128   | unfold_ast _ y = [y];
   129 
   130 fun unfold_ast_p c (y as Appl [Constant c', x, xs]) =
   131       if c = c' then apfst (cons x) (unfold_ast_p c xs)
   132       else ([], y)
   133   | unfold_ast_p _ y = ([], y);
   134 
   135 
   136 
   137 (** normalization of asts **)
   138 
   139 (* match *)
   140 
   141 fun match ast pat =
   142   let
   143     exception NO_MATCH;
   144 
   145     fun mtch (Constant a) (Constant b) env =
   146           if a = b then env else raise NO_MATCH
   147       | mtch (Variable a) (Constant b) env =
   148           if a = b then env else raise NO_MATCH
   149       | mtch ast (Variable x) env = Symtab.update (x, ast) env
   150       | mtch (Appl asts) (Appl pats) env = mtch_lst asts pats env
   151       | mtch _ _ _ = raise NO_MATCH
   152     and mtch_lst (ast :: asts) (pat :: pats) env =
   153           mtch_lst asts pats (mtch ast pat env)
   154       | mtch_lst [] [] env = env
   155       | mtch_lst _ _ _ = raise NO_MATCH;
   156 
   157     val (head, args) =
   158       (case (ast, pat) of
   159         (Appl asts, Appl pats) =>
   160           let val a = length asts and p = length pats in
   161             if a > p then (Appl ((uncurry take) (p, asts)), (uncurry drop) (p, asts))
   162             else (ast, [])
   163           end
   164       | _ => (ast, []));
   165   in
   166     SOME (mtch head pat Symtab.empty, args) handle NO_MATCH => NONE
   167   end;
   168 
   169 
   170 (* normalize *)
   171 
   172 (*the normalizer works yoyo-like: top-down, bottom-up, top-down, ...*)
   173 
   174 fun normalize trace stat get_rules pre_ast =
   175   let
   176     val passes = Unsynchronized.ref 0;
   177     val failed_matches = Unsynchronized.ref 0;
   178     val changes = Unsynchronized.ref 0;
   179 
   180     fun subst _ (ast as Constant _) = ast
   181       | subst env (Variable x) = the (Symtab.lookup env x)
   182       | subst env (Appl asts) = Appl (map (subst env) asts);
   183 
   184     fun try_rules ((lhs, rhs) :: pats) ast =
   185           (case match ast lhs of
   186             SOME (env, args) =>
   187               (Unsynchronized.inc changes; SOME (mk_appl (subst env rhs) args))
   188           | NONE => (Unsynchronized.inc failed_matches; try_rules pats ast))
   189       | try_rules [] _ = NONE;
   190     val try_headless_rules = try_rules (get_rules "");
   191 
   192     fun try ast a =
   193       (case try_rules (get_rules a) ast of
   194         NONE => try_headless_rules ast
   195       | some => some);
   196 
   197     fun rewrite (ast as Constant a) = try ast a
   198       | rewrite (ast as Variable a) = try ast a
   199       | rewrite (ast as Appl (Constant a :: _)) = try ast a
   200       | rewrite (ast as Appl (Variable a :: _)) = try ast a
   201       | rewrite ast = try_headless_rules ast;
   202 
   203     fun rewrote old_ast new_ast =
   204       if trace then
   205         tracing ("rewrote: " ^ str_of_ast old_ast ^ "  ->  " ^ str_of_ast new_ast)
   206       else ();
   207 
   208     fun norm_root ast =
   209       (case rewrite ast of
   210         SOME new_ast => (rewrote ast new_ast; norm_root new_ast)
   211       | NONE => ast);
   212 
   213     fun norm ast =
   214       (case norm_root ast of
   215         Appl sub_asts =>
   216           let
   217             val old_changes = ! changes;
   218             val new_ast = Appl (map norm sub_asts);
   219           in
   220             if old_changes = ! changes then new_ast else norm_root new_ast
   221           end
   222       | atomic_ast => atomic_ast);
   223 
   224     fun normal ast =
   225       let
   226         val old_changes = ! changes;
   227         val new_ast = norm ast;
   228       in
   229         Unsynchronized.inc passes;
   230         if old_changes = ! changes then new_ast else normal new_ast
   231       end;
   232 
   233 
   234     val _ = if trace then tracing ("pre: " ^ str_of_ast pre_ast) else ();
   235     val post_ast = normal pre_ast;
   236     val _ =
   237       if trace orelse stat then
   238         tracing ("post: " ^ str_of_ast post_ast ^ "\nnormalize: " ^
   239           string_of_int (! passes) ^ " passes, " ^
   240           string_of_int (! changes) ^ " changes, " ^
   241           string_of_int (! failed_matches) ^ " matches failed")
   242       else ();
   243   in post_ast end;
   244 
   245 
   246 (* normalize_ast *)
   247 
   248 val trace_ast = Unsynchronized.ref false;
   249 val stat_ast = Unsynchronized.ref false;
   250 
   251 fun normalize_ast get_rules ast =
   252   normalize (! trace_ast) (! stat_ast) get_rules ast;
   253 
   254 end;