src/Pure/Syntax/ast.ML
author wenzelm
Sun, 16 Jul 2000 21:00:10 +0200
changeset 9372 7834e56e2277
parent 8997 da290d99d8b2
child 10913 57eb8c1d6f88
permissions -rw-r--r--
AST translation rules no longer require constant head on LHS;

(*  Title:      Pure/Syntax/ast.ML
    ID:         $Id$
    Author:     Markus Wenzel, TU Muenchen

Abstract syntax trees, translation rules, matching and normalization of asts.
*)

signature AST0 =
  sig
  datatype ast =
    Constant of string |
    Variable of string |
    Appl of ast list
  exception AST of string * ast list
  end;

signature AST1 =
  sig
  include AST0
  val mk_appl: ast -> ast list -> ast
  val str_of_ast: ast -> string
  val pretty_ast: ast -> Pretty.T
  val pretty_rule: ast * ast -> Pretty.T
  val pprint_ast: ast -> pprint_args -> unit
  val trace_ast: bool ref
  val stat_ast: bool ref
  end;

signature AST =
  sig
  include AST1
  val head_of_rule: ast * ast -> string
  val rule_error: ast * ast -> string option
  val fold_ast: string -> ast list -> ast
  val fold_ast_p: string -> ast list * ast -> ast
  val unfold_ast: string -> ast -> ast list
  val unfold_ast_p: string -> ast -> ast list * ast
  val normalize: bool -> bool -> (string -> (ast * ast) list) -> ast -> ast
  val normalize_ast: (string -> (ast * ast) list) -> ast -> ast
  end;

structure Ast : AST =
struct

(** abstract syntax trees **)

(*asts come in two flavours:
   - ordinary asts representing terms and typs: Variables are (often) treated
     like Constants;
   - patterns used as lhs and rhs in rules: Variables are placeholders for
     proper asts*)

datatype ast =
  Constant of string |    (*"not", "_abs", "fun"*)
  Variable of string |    (*x, ?x, 'a, ?'a*)
  Appl of ast list;       (*(f x y z), ("fun" 'a 'b), ("_abs" x t)*)


(*the list of subasts of an Appl node has to contain at least 2 elements, i.e.
  there are no empty asts or nullary applications; use mk_appl for convenience*)

fun mk_appl f [] = f
  | mk_appl f args = Appl (f :: args);


(*exception for system errors involving asts*)

exception AST of string * ast list;



(** print asts in a LISP-like style **)

(* str_of_ast *)

fun str_of_ast (Constant a) = quote a
  | str_of_ast (Variable x) = x
  | str_of_ast (Appl asts) = "(" ^ (space_implode " " (map str_of_ast asts)) ^ ")";


(* pretty_ast *)

fun pretty_ast (Constant a) = Pretty.str (quote a)
  | pretty_ast (Variable x) = Pretty.str x
  | pretty_ast (Appl asts) =
      Pretty.enclose "(" ")" (Pretty.breaks (map pretty_ast asts));


(* pprint_ast *)

val pprint_ast = Pretty.pprint o pretty_ast;


(* pretty_rule *)

fun pretty_rule (lhs, rhs) =
  Pretty.block [pretty_ast lhs, Pretty.str "  ->", Pretty.brk 2, pretty_ast rhs];


(* head_of_ast, head_of_rule *)

fun head_of_ast (Constant a) = a
  | head_of_ast (Appl (Constant a :: _)) = a
  | head_of_ast _ = "";

fun head_of_rule (lhs, _) = head_of_ast lhs;



(** check translation rules **)

(*a wellformed rule (lhs, rhs): (ast * ast) obeys the following conditions:
   - the lhs has unique vars,
   - vars of rhs is subset of vars of lhs*)

fun rule_error (rule as (lhs, rhs)) =
  let
    fun vars_of (Constant _) = []
      | vars_of (Variable x) = [x]
      | vars_of (Appl asts) = flat (map vars_of asts);

    fun unique (x :: xs) = not (x mem xs) andalso unique xs
      | unique [] = true;

    val lvars = vars_of lhs;
    val rvars = vars_of rhs;
  in
    if not (unique lvars) then Some "duplicate vars in lhs"
    else if not (rvars subset lvars) then Some "rhs contains extra variables"
    else None
  end;



(** ast translation utilities **)

(* fold asts *)

fun fold_ast _ [] = raise Match
  | fold_ast _ [y] = y
  | fold_ast c (x :: xs) = Appl [Constant c, x, fold_ast c xs];

fun fold_ast_p c = foldr (fn (x, xs) => Appl [Constant c, x, xs]);


(* unfold asts *)

fun unfold_ast c (y as Appl [Constant c', x, xs]) =
      if c = c' then x :: (unfold_ast c xs) else [y]
  | unfold_ast _ y = [y];

fun unfold_ast_p c (y as Appl [Constant c', x, xs]) =
      if c = c' then apfst (cons x) (unfold_ast_p c xs)
      else ([], y)
  | unfold_ast_p _ y = ([], y);


(** normalization of asts **)

(* tracing options *)

val trace_ast = ref false;
val stat_ast = ref false;


(* match *)

fun match ast pat =
  let
    exception NO_MATCH;

    fun mtch (Constant a) (Constant b) env =
          if a = b then env else raise NO_MATCH
      | mtch (Variable a) (Constant b) env =
          if a = b then env else raise NO_MATCH
      | mtch ast (Variable x) env = Symtab.update ((x, ast), env)
      | mtch (Appl asts) (Appl pats) env = mtch_lst asts pats env
      | mtch _ _ _ = raise NO_MATCH
    and mtch_lst (ast :: asts) (pat :: pats) env =
          mtch_lst asts pats (mtch ast pat env)
      | mtch_lst [] [] env = env
      | mtch_lst _ _ _ = raise NO_MATCH;

    val (head, args) =
      (case (ast, pat) of
        (Appl asts, Appl pats) =>
          let val a = length asts and p = length pats in
            if a > p then (Appl (take (p, asts)), drop (p, asts))
            else (ast, [])
          end
      | _ => (ast, []));
  in
    Some (mtch head pat Symtab.empty, args) handle NO_MATCH => None
  end;


(* normalize *)

(*the normalizer works yoyo-like: top-down, bottom-up, top-down, ...*)

fun normalize trace stat get_rules pre_ast =
  let
    val passes = ref 0;
    val lookups = ref 0;
    val failed_matches = ref 0;
    val changes = ref 0;

    fun subst _ (ast as Constant _) = ast
      | subst env (Variable x) = the (Symtab.lookup (env, x))
      | subst env (Appl asts) = Appl (map (subst env) asts);

    fun try_rules ast ((lhs, rhs) :: pats) =
          (case match ast lhs of
            Some (env, args) =>
              (inc changes; Some (mk_appl (subst env rhs) args))
          | None => (inc failed_matches; try_rules ast pats))
      | try_rules _ [] = None;

    fun try ast a = (inc lookups; try_rules ast (get_rules a));

    fun rewrite (ast as Constant a) = try ast a
      | rewrite (ast as Variable a) = try ast a
      | rewrite (ast as Appl (Constant a :: _)) = try ast a
      | rewrite (ast as Appl (Variable a :: _)) = try ast a
      | rewrite ast = try ast "";

    fun rewrote old_ast new_ast =
      if trace then
        writeln ("rewrote: " ^ str_of_ast old_ast ^ "  ->  " ^ str_of_ast new_ast)
      else ();

    fun norm_root ast =
      (case rewrite ast of
        Some new_ast => (rewrote ast new_ast; norm_root new_ast)
      | None => ast);

    fun norm ast =
      (case norm_root ast of
        Appl sub_asts =>
          let
            val old_changes = ! changes;
            val new_ast = Appl (map norm sub_asts);
          in
            if old_changes = ! changes then new_ast else norm_root new_ast
          end
      | atomic_ast => atomic_ast);

    fun normal ast =
      let
        val old_changes = ! changes;
        val new_ast = norm ast;
      in
        inc passes;
        if old_changes = ! changes then new_ast else normal new_ast
      end;


    val _ = if trace then writeln ("pre: " ^ str_of_ast pre_ast) else ();

    val post_ast = normal pre_ast;
  in
    if trace orelse stat then
      writeln ("post: " ^ str_of_ast post_ast ^ "\nnormalize: " ^
        string_of_int (! passes) ^ " passes, " ^
        string_of_int (! lookups) ^ " lookups, " ^
        string_of_int (! changes) ^ " changes, " ^
        string_of_int (! failed_matches) ^ " matches failed")
    else ();
    post_ast
  end;


(* normalize_ast *)

fun normalize_ast get_rules ast =
  normalize (! trace_ast) (! stat_ast) get_rules ast;

end;