# HG changeset patch # User Benoit Daloze # Date 1425561563 -3600 # Node ID 678a3de139adbf091fd52e394c2a71557f60090b # Parent 6755624bf03d269ef1e1755d6dc4d4d4a01e8e3b Add test for concurrent DSL node specializations. diff -r 6755624bf03d -r 678a3de139ad graal/com.oracle.truffle.api.dsl.test/src/com/oracle/truffle/api/dsl/test/MergeSpecializationsTest.java --- /dev/null Thu Jan 01 00:00:00 1970 +0000 +++ b/graal/com.oracle.truffle.api.dsl.test/src/com/oracle/truffle/api/dsl/test/MergeSpecializationsTest.java Thu Mar 05 14:19:23 2015 +0100 @@ -0,0 +1,224 @@ +/* + * Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ +package com.oracle.truffle.api.dsl.test; + +import static com.oracle.truffle.api.dsl.test.TestHelper.*; +import static org.junit.Assert.*; + +import java.util.*; +import java.util.concurrent.*; + +import org.junit.*; + +import com.oracle.truffle.api.dsl.*; +import com.oracle.truffle.api.dsl.internal.*; +import com.oracle.truffle.api.dsl.test.MergeSpecializationsTestFactory.TestCachedNodeFactory; +import com.oracle.truffle.api.dsl.test.MergeSpecializationsTestFactory.TestNodeFactory; +import com.oracle.truffle.api.dsl.test.TypeSystemTest.TestRootNode; +import com.oracle.truffle.api.dsl.test.TypeSystemTest.ValueNode; +import com.oracle.truffle.api.nodes.*; + +public class MergeSpecializationsTest { + + private static final int THREADS = 8; + + @NodeChild + @SuppressWarnings("unused") + abstract static class TestNode extends ValueNode { + + @Specialization + int s1(int a) { + return 1; + } + + @Specialization + int s2(long a) { + return 2; + } + + @Specialization + int s3(double a) { + return 3; + } + } + + @NodeChild + @SuppressWarnings("unused") + abstract static class TestCachedNode extends ValueNode { + + @Specialization(guards = "a == cachedA", limit = "3") + int s1(int a, @Cached("a") int cachedA) { + return 1; + } + + @Specialization + int s2(long a) { + return 2; + } + + @Specialization + int s3(double a) { + return 3; + } + } + + @Test + public void testMultithreadedMergeInOrder() { + multithreadedMerge(TestNodeFactory.getInstance(), new Executions(1, 1L << 32, 1.0), 1, 2, 3); + } + + @Test + public void testMultithreadedMergeReverse() { + multithreadedMerge(TestNodeFactory.getInstance(), new Executions(1.0, 1L << 32, 1), 3, 2, 1); + } + + @Ignore + @Test + public void testMultithreadedMergeCachedInOrder() { + multithreadedMerge(TestCachedNodeFactory.getInstance(), new Executions(1, 1L << 32, 1.0), 1, 2, 3); + } + + @Ignore + @Test + public void testMultithreadedMergeCachedTwoEntries() { + multithreadedMerge(TestCachedNodeFactory.getInstance(), new Executions(1, 2, 1.0), 1, 1, 3); + } + + @Ignore + @Test + public void testMultithreadedMergeCachedThreeEntries() { + multithreadedMerge(TestCachedNodeFactory.getInstance(), new Executions(1, 2, 3), 1, 1, 1); + } + + private static void multithreadedMerge(NodeFactory factory, final Executions executions, int... order) { + assertEquals(3, order.length); + final TestRootNode node = createRoot(factory); + + final CountDownLatch threadsStarted = new CountDownLatch(THREADS); + + final CountDownLatch beforeFirst = new CountDownLatch(1); + final CountDownLatch executedFirst = new CountDownLatch(THREADS); + + final CountDownLatch beforeSecond = new CountDownLatch(1); + final CountDownLatch executedSecond = new CountDownLatch(THREADS); + + final CountDownLatch beforeThird = new CountDownLatch(1); + final CountDownLatch executedThird = new CountDownLatch(THREADS); + + Thread[] threads = new Thread[THREADS]; + for (int i = 0; i < threads.length; i++) { + threads[i] = new Thread(new Runnable() { + public void run() { + threadsStarted.countDown(); + + MergeSpecializationsTest.await(beforeFirst); + executeWith(node, executions.firstValue); + executedFirst.countDown(); + + MergeSpecializationsTest.await(beforeSecond); + executeWith(node, executions.secondValue); + executedSecond.countDown(); + + MergeSpecializationsTest.await(beforeThird); + executeWith(node, executions.thirdValue); + executedThird.countDown(); + } + }); + threads[i].start(); + } + + final SpecializedNode gen = (SpecializedNode) node.getNode(); + + final SpecializationNode start0 = gen.getSpecializationNode(); + assertEquals("UninitializedNode_", start0.getClass().getSimpleName()); + + await(threadsStarted); + beforeFirst.countDown(); + await(executedFirst); + + final SpecializationNode start1 = gen.getSpecializationNode(); + assertEquals("S" + order[0] + "Node_", start1.getClass().getSimpleName()); + assertEquals("UninitializedNode_", nthChild(1, start1).getClass().getSimpleName()); + + beforeSecond.countDown(); + await(executedSecond); + + final SpecializationNode start2 = gen.getSpecializationNode(); + Arrays.sort(order, 0, 2); + assertEquals("PolymorphicNode_", start2.getClass().getSimpleName()); + assertEquals("S" + order[0] + "Node_", nthChild(1, start2).getClass().getSimpleName()); + assertEquals("S" + order[1] + "Node_", nthChild(2, start2).getClass().getSimpleName()); + assertEquals("UninitializedNode_", nthChild(3, start2).getClass().getSimpleName()); + + beforeThird.countDown(); + await(executedThird); + + final SpecializationNode start3 = gen.getSpecializationNode(); + Arrays.sort(order); + assertEquals("PolymorphicNode_", start3.getClass().getSimpleName()); + assertEquals("S" + order[0] + "Node_", nthChild(1, start3).getClass().getSimpleName()); + assertEquals("S" + order[1] + "Node_", nthChild(2, start3).getClass().getSimpleName()); + assertEquals("S" + order[2] + "Node_", nthChild(3, start3).getClass().getSimpleName()); + assertEquals("UninitializedNode_", nthChild(4, start3).getClass().getSimpleName()); + + for (Thread thread : threads) { + try { + thread.join(); + } catch (InterruptedException e) { + fail("interrupted"); + } + } + } + + private static class Executions { + public final Object firstValue; + public final Object secondValue; + public final Object thirdValue; + + public Executions(Object firstValue, Object secondValue, Object thirdValue) { + this.firstValue = firstValue; + this.secondValue = secondValue; + this.thirdValue = thirdValue; + } + } + + private static void await(final CountDownLatch latch) { + try { + latch.await(); + } catch (InterruptedException e) { + fail("interrupted"); + } + } + + private static Node firstChild(Node node) { + return node.getChildren().iterator().next(); + } + + private static Node nthChild(int n, Node node) { + if (n == 0) { + return node; + } else { + return nthChild(n - 1, firstChild(node)); + } + } +}