comparison graal/com.oracle.truffle.api.dsl.test/src/com/oracle/truffle/api/dsl/test/MergeSpecializationsTest.java @ 19709:678a3de139ad

Add test for concurrent DSL node specializations.
author Benoit Daloze <benoit.daloze@jku.at>
date Thu, 05 Mar 2015 14:19:23 +0100
parents
children e8d2f3f95dcd
comparison
equal deleted inserted replaced
19708:6755624bf03d 19709:678a3de139ad
1 /*
2 * Copyright (c) 2015, Oracle and/or its affiliates. All rights reserved.
3 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
4 *
5 * This code is free software; you can redistribute it and/or modify it
6 * under the terms of the GNU General Public License version 2 only, as
7 * published by the Free Software Foundation.
8 *
9 * This code is distributed in the hope that it will be useful, but WITHOUT
10 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
11 * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
12 * version 2 for more details (a copy is included in the LICENSE file that
13 * accompanied this code).
14 *
15 * You should have received a copy of the GNU General Public License version
16 * 2 along with this work; if not, write to the Free Software Foundation,
17 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
18 *
19 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
20 * or visit www.oracle.com if you need additional information or have any
21 * questions.
22 */
23 package com.oracle.truffle.api.dsl.test;
24
25 import static com.oracle.truffle.api.dsl.test.TestHelper.*;
26 import static org.junit.Assert.*;
27
28 import java.util.*;
29 import java.util.concurrent.*;
30
31 import org.junit.*;
32
33 import com.oracle.truffle.api.dsl.*;
34 import com.oracle.truffle.api.dsl.internal.*;
35 import com.oracle.truffle.api.dsl.test.MergeSpecializationsTestFactory.TestCachedNodeFactory;
36 import com.oracle.truffle.api.dsl.test.MergeSpecializationsTestFactory.TestNodeFactory;
37 import com.oracle.truffle.api.dsl.test.TypeSystemTest.TestRootNode;
38 import com.oracle.truffle.api.dsl.test.TypeSystemTest.ValueNode;
39 import com.oracle.truffle.api.nodes.*;
40
41 public class MergeSpecializationsTest {
42
43 private static final int THREADS = 8;
44
45 @NodeChild
46 @SuppressWarnings("unused")
47 abstract static class TestNode extends ValueNode {
48
49 @Specialization
50 int s1(int a) {
51 return 1;
52 }
53
54 @Specialization
55 int s2(long a) {
56 return 2;
57 }
58
59 @Specialization
60 int s3(double a) {
61 return 3;
62 }
63 }
64
65 @NodeChild
66 @SuppressWarnings("unused")
67 abstract static class TestCachedNode extends ValueNode {
68
69 @Specialization(guards = "a == cachedA", limit = "3")
70 int s1(int a, @Cached("a") int cachedA) {
71 return 1;
72 }
73
74 @Specialization
75 int s2(long a) {
76 return 2;
77 }
78
79 @Specialization
80 int s3(double a) {
81 return 3;
82 }
83 }
84
85 @Test
86 public void testMultithreadedMergeInOrder() {
87 multithreadedMerge(TestNodeFactory.getInstance(), new Executions(1, 1L << 32, 1.0), 1, 2, 3);
88 }
89
90 @Test
91 public void testMultithreadedMergeReverse() {
92 multithreadedMerge(TestNodeFactory.getInstance(), new Executions(1.0, 1L << 32, 1), 3, 2, 1);
93 }
94
95 @Ignore
96 @Test
97 public void testMultithreadedMergeCachedInOrder() {
98 multithreadedMerge(TestCachedNodeFactory.getInstance(), new Executions(1, 1L << 32, 1.0), 1, 2, 3);
99 }
100
101 @Ignore
102 @Test
103 public void testMultithreadedMergeCachedTwoEntries() {
104 multithreadedMerge(TestCachedNodeFactory.getInstance(), new Executions(1, 2, 1.0), 1, 1, 3);
105 }
106
107 @Ignore
108 @Test
109 public void testMultithreadedMergeCachedThreeEntries() {
110 multithreadedMerge(TestCachedNodeFactory.getInstance(), new Executions(1, 2, 3), 1, 1, 1);
111 }
112
113 private static <T extends ValueNode> void multithreadedMerge(NodeFactory<T> factory, final Executions executions, int... order) {
114 assertEquals(3, order.length);
115 final TestRootNode<T> node = createRoot(factory);
116
117 final CountDownLatch threadsStarted = new CountDownLatch(THREADS);
118
119 final CountDownLatch beforeFirst = new CountDownLatch(1);
120 final CountDownLatch executedFirst = new CountDownLatch(THREADS);
121
122 final CountDownLatch beforeSecond = new CountDownLatch(1);
123 final CountDownLatch executedSecond = new CountDownLatch(THREADS);
124
125 final CountDownLatch beforeThird = new CountDownLatch(1);
126 final CountDownLatch executedThird = new CountDownLatch(THREADS);
127
128 Thread[] threads = new Thread[THREADS];
129 for (int i = 0; i < threads.length; i++) {
130 threads[i] = new Thread(new Runnable() {
131 public void run() {
132 threadsStarted.countDown();
133
134 MergeSpecializationsTest.await(beforeFirst);
135 executeWith(node, executions.firstValue);
136 executedFirst.countDown();
137
138 MergeSpecializationsTest.await(beforeSecond);
139 executeWith(node, executions.secondValue);
140 executedSecond.countDown();
141
142 MergeSpecializationsTest.await(beforeThird);
143 executeWith(node, executions.thirdValue);
144 executedThird.countDown();
145 }
146 });
147 threads[i].start();
148 }
149
150 final SpecializedNode gen = (SpecializedNode) node.getNode();
151
152 final SpecializationNode start0 = gen.getSpecializationNode();
153 assertEquals("UninitializedNode_", start0.getClass().getSimpleName());
154
155 await(threadsStarted);
156 beforeFirst.countDown();
157 await(executedFirst);
158
159 final SpecializationNode start1 = gen.getSpecializationNode();
160 assertEquals("S" + order[0] + "Node_", start1.getClass().getSimpleName());
161 assertEquals("UninitializedNode_", nthChild(1, start1).getClass().getSimpleName());
162
163 beforeSecond.countDown();
164 await(executedSecond);
165
166 final SpecializationNode start2 = gen.getSpecializationNode();
167 Arrays.sort(order, 0, 2);
168 assertEquals("PolymorphicNode_", start2.getClass().getSimpleName());
169 assertEquals("S" + order[0] + "Node_", nthChild(1, start2).getClass().getSimpleName());
170 assertEquals("S" + order[1] + "Node_", nthChild(2, start2).getClass().getSimpleName());
171 assertEquals("UninitializedNode_", nthChild(3, start2).getClass().getSimpleName());
172
173 beforeThird.countDown();
174 await(executedThird);
175
176 final SpecializationNode start3 = gen.getSpecializationNode();
177 Arrays.sort(order);
178 assertEquals("PolymorphicNode_", start3.getClass().getSimpleName());
179 assertEquals("S" + order[0] + "Node_", nthChild(1, start3).getClass().getSimpleName());
180 assertEquals("S" + order[1] + "Node_", nthChild(2, start3).getClass().getSimpleName());
181 assertEquals("S" + order[2] + "Node_", nthChild(3, start3).getClass().getSimpleName());
182 assertEquals("UninitializedNode_", nthChild(4, start3).getClass().getSimpleName());
183
184 for (Thread thread : threads) {
185 try {
186 thread.join();
187 } catch (InterruptedException e) {
188 fail("interrupted");
189 }
190 }
191 }
192
193 private static class Executions {
194 public final Object firstValue;
195 public final Object secondValue;
196 public final Object thirdValue;
197
198 public Executions(Object firstValue, Object secondValue, Object thirdValue) {
199 this.firstValue = firstValue;
200 this.secondValue = secondValue;
201 this.thirdValue = thirdValue;
202 }
203 }
204
205 private static void await(final CountDownLatch latch) {
206 try {
207 latch.await();
208 } catch (InterruptedException e) {
209 fail("interrupted");
210 }
211 }
212
213 private static Node firstChild(Node node) {
214 return node.getChildren().iterator().next();
215 }
216
217 private static Node nthChild(int n, Node node) {
218 if (n == 0) {
219 return node;
220 } else {
221 return nthChild(n - 1, firstChild(node));
222 }
223 }
224 }