diff graal/com.oracle.truffle.api/src/com/oracle/truffle/api/nodes/NodeUtil.java @ 20131:4b12d5355811

Truffle: do not use iterators for visitors.
author Christian Humer <christian.humer@gmail.com>
date Thu, 02 Apr 2015 01:27:27 +0200
parents 8dc73c226c63
children 286aef83a9a7
line wrap: on
line diff
--- a/graal/com.oracle.truffle.api/src/com/oracle/truffle/api/nodes/NodeUtil.java	Thu Apr 02 01:26:31 2015 +0200
+++ b/graal/com.oracle.truffle.api/src/com/oracle/truffle/api/nodes/NodeUtil.java	Thu Apr 02 01:27:27 2015 +0200
@@ -331,6 +331,44 @@
         return true;
     }
 
+    static boolean forEachChildRecursive(Node parent, NodeVisitor visitor) {
+        NodeClass parentNodeClass = parent.getNodeClass();
+
+        for (NodeFieldAccessor field : parentNodeClass.getChildFields()) {
+            if (!visitChild((Node) field.getObject(parent), visitor)) {
+                return false;
+            }
+        }
+
+        for (NodeFieldAccessor field : parentNodeClass.getChildrenFields()) {
+            Object arrayObject = field.getObject(parent);
+            if (arrayObject == null) {
+                continue;
+            }
+            Object[] array = (Object[]) arrayObject;
+            for (int i = 0; i < array.length; i++) {
+                if (!visitChild((Node) array[i], visitor)) {
+                    return false;
+                }
+            }
+        }
+
+        return true;
+    }
+
+    private static boolean visitChild(Node child, NodeVisitor visitor) {
+        if (child == null) {
+            return true;
+        }
+        if (!visitor.visit(child)) {
+            return false;
+        }
+        if (!forEachChildRecursive(child, visitor)) {
+            return false;
+        }
+        return true;
+    }
+
     /** Returns all declared fields in the class hierarchy. */
     static Field[] getAllFields(Class<? extends Object> clazz) {
         Field[] declaredFields = clazz.getDeclaredFields();
@@ -460,29 +498,24 @@
     }
 
     public static int countNodes(Node root) {
-        Iterator<Node> nodeIterator = makeRecursiveIterator(root);
-        int count = 0;
-        while (nodeIterator.hasNext()) {
-            nodeIterator.next();
-            count++;
-        }
-        return count;
+        return countNodes(root, NodeCountFilter.NO_FILTER);
     }
 
     public static int countNodes(Node root, NodeCountFilter filter) {
-        Iterator<Node> nodeIterator = makeRecursiveIterator(root);
-        int count = 0;
-        while (nodeIterator.hasNext()) {
-            Node node = nodeIterator.next();
-            if (node != null && filter.isCounted(node)) {
-                count++;
-            }
-        }
-        return count;
+        NodeCounter counter = new NodeCounter(filter);
+        root.accept(counter);
+        return counter.count;
     }
 
     public interface NodeCountFilter {
 
+        NodeCountFilter NO_FILTER = new NodeCountFilter() {
+
+            public boolean isCounted(Node node) {
+                return true;
+            }
+        };
+
         boolean isCounted(Node node);
 
     }
@@ -801,4 +834,22 @@
         }
         return true;
     }
+
+    private static final class NodeCounter implements NodeVisitor {
+
+        public int count;
+        private final NodeCountFilter filter;
+
+        public NodeCounter(NodeCountFilter filter) {
+            this.filter = filter;
+        }
+
+        public boolean visit(Node node) {
+            if (filter.isCounted(node)) {
+                count++;
+            }
+            return true;
+        }
+
+    }
 }