src/Pure/Syntax/ast.ML
author wenzelm
Sat Dec 30 16:08:00 2006 +0100 (2006-12-30)
changeset 21962 279b129498b6
parent 19486 e04e20b1253a
child 29565 3f8b24fcfbd6
permissions -rw-r--r--
removed conditional combinator;
     1 (*  Title:      Pure/Syntax/ast.ML
     2     ID:         $Id$
     3     Author:     Markus Wenzel, TU Muenchen
     4 
     5 Abstract syntax trees, translation rules, matching and normalization of asts.
     6 *)
     7 
     8 signature AST0 =
     9 sig
    10   datatype ast =
    11     Constant of string |
    12     Variable of string |
    13     Appl of ast list
    14   exception AST of string * ast list
    15 end;
    16 
    17 signature AST1 =
    18 sig
    19   include AST0
    20   val mk_appl: ast -> ast list -> ast
    21   val str_of_ast: ast -> string
    22   val pretty_ast: ast -> Pretty.T
    23   val pretty_rule: ast * ast -> Pretty.T
    24   val pprint_ast: ast -> pprint_args -> unit
    25   val fold_ast: string -> ast list -> ast
    26   val fold_ast_p: string -> ast list * ast -> ast
    27   val unfold_ast: string -> ast -> ast list
    28   val unfold_ast_p: string -> ast -> ast list * ast
    29   val trace_ast: bool ref
    30   val stat_ast: bool ref
    31 end;
    32 
    33 signature AST =
    34 sig
    35   include AST1
    36   val head_of_rule: ast * ast -> string
    37   val rule_error: ast * ast -> string option
    38   val normalize: bool -> bool -> (string -> (ast * ast) list) -> ast -> ast
    39   val normalize_ast: (string -> (ast * ast) list) -> ast -> ast
    40 end;
    41 
    42 structure Ast : AST =
    43 struct
    44 
    45 (** abstract syntax trees **)
    46 
    47 (*asts come in two flavours:
    48    - ordinary asts representing terms and typs: Variables are (often) treated
    49      like Constants;
    50    - patterns used as lhs and rhs in rules: Variables are placeholders for
    51      proper asts*)
    52 
    53 datatype ast =
    54   Constant of string |    (*"not", "_abs", "fun"*)
    55   Variable of string |    (*x, ?x, 'a, ?'a*)
    56   Appl of ast list;       (*(f x y z), ("fun" 'a 'b), ("_abs" x t)*)
    57 
    58 
    59 (*the list of subasts of an Appl node has to contain at least 2 elements, i.e.
    60   there are no empty asts or nullary applications; use mk_appl for convenience*)
    61 
    62 fun mk_appl f [] = f
    63   | mk_appl f args = Appl (f :: args);
    64 
    65 
    66 (*exception for system errors involving asts*)
    67 
    68 exception AST of string * ast list;
    69 
    70 
    71 
    72 (** print asts in a LISP-like style **)
    73 
    74 fun str_of_ast (Constant a) = quote a
    75   | str_of_ast (Variable x) = x
    76   | str_of_ast (Appl asts) = "(" ^ (space_implode " " (map str_of_ast asts)) ^ ")";
    77 
    78 fun pretty_ast (Constant a) = Pretty.quote (Pretty.str a)
    79   | pretty_ast (Variable x) = Pretty.str x
    80   | pretty_ast (Appl asts) =
    81       Pretty.enclose "(" ")" (Pretty.breaks (map pretty_ast asts));
    82 
    83 val pprint_ast = Pretty.pprint o pretty_ast;
    84 
    85 fun pretty_rule (lhs, rhs) =
    86   Pretty.block [pretty_ast lhs, Pretty.str "  ->", Pretty.brk 2, pretty_ast rhs];
    87 
    88 
    89 (* head_of_ast, head_of_rule *)
    90 
    91 fun head_of_ast (Constant a) = a
    92   | head_of_ast (Appl (Constant a :: _)) = a
    93   | head_of_ast _ = "";
    94 
    95 fun head_of_rule (lhs, _) = head_of_ast lhs;
    96 
    97 
    98 
    99 (** check translation rules **)
   100 
   101 fun rule_error (rule as (lhs, rhs)) =
   102   let
   103     fun add_vars (Constant _) = I
   104       | add_vars (Variable x) = cons x
   105       | add_vars (Appl asts) = fold add_vars asts;
   106 
   107     val lvars = add_vars lhs [];
   108     val rvars = add_vars rhs [];
   109   in
   110     if has_duplicates (op =) lvars then SOME "duplicate vars in lhs"
   111     else if not (rvars subset lvars) then SOME "rhs contains extra variables"
   112     else NONE
   113   end;
   114 
   115 
   116 
   117 (** ast translation utilities **)
   118 
   119 (* fold asts *)
   120 
   121 fun fold_ast _ [] = raise Match
   122   | fold_ast _ [y] = y
   123   | fold_ast c (x :: xs) = Appl [Constant c, x, fold_ast c xs];
   124 
   125 fun fold_ast_p c = uncurry (fold_rev (fn x => fn xs => Appl [Constant c, x, xs]));
   126 
   127 
   128 (* unfold asts *)
   129 
   130 fun unfold_ast c (y as Appl [Constant c', x, xs]) =
   131       if c = c' then x :: unfold_ast c xs else [y]
   132   | unfold_ast _ y = [y];
   133 
   134 fun unfold_ast_p c (y as Appl [Constant c', x, xs]) =
   135       if c = c' then apfst (cons x) (unfold_ast_p c xs)
   136       else ([], y)
   137   | unfold_ast_p _ y = ([], y);
   138 
   139 
   140 
   141 (** normalization of asts **)
   142 
   143 (* match *)
   144 
   145 fun match ast pat =
   146   let
   147     exception NO_MATCH;
   148 
   149     fun mtch (Constant a) (Constant b) env =
   150           if a = b then env else raise NO_MATCH
   151       | mtch (Variable a) (Constant b) env =
   152           if a = b then env else raise NO_MATCH
   153       | mtch ast (Variable x) env = Symtab.update (x, ast) env
   154       | mtch (Appl asts) (Appl pats) env = mtch_lst asts pats env
   155       | mtch _ _ _ = raise NO_MATCH
   156     and mtch_lst (ast :: asts) (pat :: pats) env =
   157           mtch_lst asts pats (mtch ast pat env)
   158       | mtch_lst [] [] env = env
   159       | mtch_lst _ _ _ = raise NO_MATCH;
   160 
   161     val (head, args) =
   162       (case (ast, pat) of
   163         (Appl asts, Appl pats) =>
   164           let val a = length asts and p = length pats in
   165             if a > p then (Appl (Library.take (p, asts)), Library.drop (p, asts))
   166             else (ast, [])
   167           end
   168       | _ => (ast, []));
   169   in
   170     SOME (mtch head pat Symtab.empty, args) handle NO_MATCH => NONE
   171   end;
   172 
   173 
   174 (* normalize *)
   175 
   176 (*the normalizer works yoyo-like: top-down, bottom-up, top-down, ...*)
   177 
   178 fun normalize trace stat get_rules pre_ast =
   179   let
   180     val passes = ref 0;
   181     val failed_matches = ref 0;
   182     val changes = ref 0;
   183 
   184     fun subst _ (ast as Constant _) = ast
   185       | subst env (Variable x) = the (Symtab.lookup env x)
   186       | subst env (Appl asts) = Appl (map (subst env) asts);
   187 
   188     fun try_rules ((lhs, rhs) :: pats) ast =
   189           (case match ast lhs of
   190             SOME (env, args) =>
   191               (inc changes; SOME (mk_appl (subst env rhs) args))
   192           | NONE => (inc failed_matches; try_rules pats ast))
   193       | try_rules [] _ = NONE;
   194     val try_headless_rules = try_rules (get_rules "");
   195 
   196     fun try ast a =
   197       (case try_rules (get_rules a) ast of
   198         NONE => try_headless_rules ast
   199       | some => some);
   200 
   201     fun rewrite (ast as Constant a) = try ast a
   202       | rewrite (ast as Variable a) = try ast a
   203       | rewrite (ast as Appl (Constant a :: _)) = try ast a
   204       | rewrite (ast as Appl (Variable a :: _)) = try ast a
   205       | rewrite ast = try_headless_rules ast;
   206 
   207     fun rewrote old_ast new_ast =
   208       if trace then
   209         tracing ("rewrote: " ^ str_of_ast old_ast ^ "  ->  " ^ str_of_ast new_ast)
   210       else ();
   211 
   212     fun norm_root ast =
   213       (case rewrite ast of
   214         SOME new_ast => (rewrote ast new_ast; norm_root new_ast)
   215       | NONE => ast);
   216 
   217     fun norm ast =
   218       (case norm_root ast of
   219         Appl sub_asts =>
   220           let
   221             val old_changes = ! changes;
   222             val new_ast = Appl (map norm sub_asts);
   223           in
   224             if old_changes = ! changes then new_ast else norm_root new_ast
   225           end
   226       | atomic_ast => atomic_ast);
   227 
   228     fun normal ast =
   229       let
   230         val old_changes = ! changes;
   231         val new_ast = norm ast;
   232       in
   233         inc passes;
   234         if old_changes = ! changes then new_ast else normal new_ast
   235       end;
   236 
   237 
   238     val _ = if trace then tracing ("pre: " ^ str_of_ast pre_ast) else ();
   239     val post_ast = normal pre_ast;
   240     val _ =
   241       if trace orelse stat then
   242         tracing ("post: " ^ str_of_ast post_ast ^ "\nnormalize: " ^
   243           string_of_int (! passes) ^ " passes, " ^
   244           string_of_int (! changes) ^ " changes, " ^
   245           string_of_int (! failed_matches) ^ " matches failed")
   246       else ();
   247   in post_ast end;
   248 
   249 
   250 (* normalize_ast *)
   251 
   252 val trace_ast = ref false;
   253 val stat_ast = ref false;
   254 
   255 fun normalize_ast get_rules ast =
   256   normalize (! trace_ast) (! stat_ast) get_rules ast;
   257 
   258 end;