--- a/src/Pure/Syntax/ast.ML Fri Oct 11 15:17:37 2024 +0200
+++ b/src/Pure/Syntax/ast.ML Sat Oct 12 14:16:15 2024 +0200
@@ -161,33 +161,30 @@
(* match *)
-fun match ast pat =
- let
- exception NO_MATCH;
+local exception NO_MATCH in
- fun mtch (Constant a) (Constant b) env =
- if a = b then env else raise NO_MATCH
- | mtch (Variable a) (Constant b) env =
- if a = b then env else raise NO_MATCH
- | mtch ast (Variable x) env = Symtab.update (x, ast) env
- | mtch (Appl asts) (Appl pats) env = mtch_lst asts pats env
- | mtch _ _ _ = raise NO_MATCH
- and mtch_lst (ast :: asts) (pat :: pats) env =
- mtch_lst asts pats (mtch ast pat env)
- | mtch_lst [] [] env = env
- | mtch_lst _ _ _ = raise NO_MATCH;
+fun match obj pat =
+ let
+ fun match1 (Constant a) (Constant b) env = if a = b then env else raise NO_MATCH
+ | match1 (Variable a) (Constant b) env = if a = b then env else raise NO_MATCH
+ | match1 ast (Variable x) env = Symtab.update (x, ast) env
+ | match1 (Appl asts) (Appl pats) env = match2 asts pats env
+ | match1 _ _ _ = raise NO_MATCH
+ and match2 (ast :: asts) (pat :: pats) env = match1 ast pat env |> match2 asts pats
+ | match2 [] [] env = env
+ | match2 _ _ _ = raise NO_MATCH;
val (head, args) =
- (case (ast, pat) of
+ (case (obj, pat) of
(Appl asts, Appl pats) =>
let val a = length asts and p = length pats in
if a > p then (Appl (take p asts), drop p asts)
- else (ast, [])
+ else (obj, [])
end
- | _ => (ast, []));
- in
- SOME (Symtab.build (mtch head pat), args) handle NO_MATCH => NONE
- end;
+ | _ => (obj, []));
+ in SOME (Symtab.build (match1 head pat), args) handle NO_MATCH => NONE end;
+
+end;
(* normalize *)
@@ -195,9 +192,23 @@
val trace = Config.declare_bool ("syntax_ast_trace", \<^here>) (K false);
val stats = Config.declare_bool ("syntax_ast_stats", \<^here>) (K false);
+local
+
+fun subst _ (ast as Constant _) = ast
+ | subst env (Variable x) = the (Symtab.lookup env x)
+ | subst env (Appl asts) = Appl (map (subst env) asts);
+
+fun head_name (Constant a) = SOME a
+ | head_name (Variable a) = SOME a
+ | head_name (Appl (Constant a :: _)) = SOME a
+ | head_name (Appl (Variable a :: _)) = SOME a
+ | head_name _ = NONE;
+
fun message head body =
Pretty.string_of (Pretty.block [Pretty.str head, Pretty.brk 1, body]);
+in
+
(*the normalizer works yoyo-like: top-down, bottom-up, top-down, ...*)
fun normalize ctxt get_rules pre_ast =
let
@@ -208,61 +219,53 @@
val failed_matches = Unsynchronized.ref 0;
val changes = Unsynchronized.ref 0;
- fun subst _ (ast as Constant _) = ast
- | subst env (Variable x) = the (Symtab.lookup env x)
- | subst env (Appl asts) = Appl (map (subst env) asts);
-
- fun try_rules ((lhs, rhs) :: pats) ast =
+ fun rewrite1 ((lhs, rhs) :: pats) ast =
(case match ast lhs of
SOME (env, args) =>
(Unsynchronized.inc changes; SOME (mk_appl (subst env rhs) args))
- | NONE => (Unsynchronized.inc failed_matches; try_rules pats ast))
- | try_rules [] _ = NONE;
- val try_headless_rules = try_rules (get_rules "");
+ | NONE => (Unsynchronized.inc failed_matches; rewrite1 pats ast))
+ | rewrite1 [] _ = NONE;
- fun try ast a =
- (case try_rules (get_rules a) ast of
- NONE => try_headless_rules ast
- | some => some);
+ fun rewrite2 (SOME a) ast =
+ (case rewrite1 (get_rules a) ast of
+ NONE => rewrite2 NONE ast
+ | some => some)
+ | rewrite2 NONE ast = rewrite1 (get_rules "") ast;
- fun rewrite (ast as Constant a) = try ast a
- | rewrite (ast as Variable a) = try ast a
- | rewrite (ast as Appl (Constant a :: _)) = try ast a
- | rewrite (ast as Appl (Variable a :: _)) = try ast a
- | rewrite ast = try_headless_rules ast;
+ fun rewrite ast = rewrite2 (head_name ast) ast;
fun rewrote old_ast new_ast =
if trace then tracing (message "rewrote:" (pretty_rule (old_ast, new_ast)))
else ();
- fun norm_root ast =
+ fun norm1 ast =
(case rewrite ast of
- SOME new_ast => (rewrote ast new_ast; norm_root new_ast)
+ SOME new_ast => (rewrote ast new_ast; norm1 new_ast)
| NONE => ast);
- fun norm ast =
- (case norm_root ast of
+ fun norm2 ast =
+ (case norm1 ast of
Appl sub_asts =>
let
val old_changes = ! changes;
- val new_ast = Appl (map norm sub_asts);
+ val new_ast = Appl (map norm2 sub_asts);
in
- if old_changes = ! changes then new_ast else norm_root new_ast
+ if old_changes = ! changes then new_ast else norm1 new_ast
end
| atomic_ast => atomic_ast);
- fun normal ast =
+ fun norm ast =
let
val old_changes = ! changes;
- val new_ast = norm ast;
+ val new_ast = norm2 ast;
in
Unsynchronized.inc passes;
- if old_changes = ! changes then new_ast else normal new_ast
+ if old_changes = ! changes then new_ast else norm new_ast
end;
val _ = if trace then tracing (message "pre:" (pretty_ast pre_ast)) else ();
- val post_ast = normal pre_ast;
+ val post_ast = norm pre_ast;
val _ =
if trace orelse stats then
tracing (message "post:" (pretty_ast post_ast) ^ "\nnormalize: " ^
@@ -273,3 +276,5 @@
in post_ast end;
end;
+
+end;