001/*
002 * Copyright (c) 2013, 2014, Oracle and/or its affiliates. All rights reserved.
003 * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
004 *
005 * This code is free software; you can redistribute it and/or modify it
006 * under the terms of the GNU General Public License version 2 only, as
007 * published by the Free Software Foundation.
008 *
009 * This code is distributed in the hope that it will be useful, but WITHOUT
010 * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
011 * FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
012 * version 2 for more details (a copy is included in the LICENSE file that
013 * accompanied this code).
014 *
015 * You should have received a copy of the GNU General Public License version
016 * 2 along with this work; if not, write to the Free Software Foundation,
017 * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
018 *
019 * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
020 * or visit www.oracle.com if you need additional information or have any
021 * questions.
022 */
023package com.oracle.graal.hotspot.test;
024
025import static com.oracle.graal.graphbuilderconf.IntrinsicContext.CompilationContext.*;
026
027import java.io.*;
028import java.lang.reflect.*;
029import java.security.*;
030
031import javax.crypto.*;
032
033import jdk.internal.jvmci.code.*;
034import jdk.internal.jvmci.hotspot.*;
035import jdk.internal.jvmci.meta.*;
036
037import org.junit.*;
038
039import com.oracle.graal.graphbuilderconf.*;
040import com.oracle.graal.graphbuilderconf.GraphBuilderConfiguration.Plugins;
041import com.oracle.graal.hotspot.meta.*;
042import com.oracle.graal.java.*;
043import com.oracle.graal.nodes.*;
044import com.oracle.graal.nodes.StructuredGraph.AllowAssumptions;
045import com.oracle.graal.phases.*;
046
047/**
048 * Tests the intrinsification of certain crypto methods.
049 */
050public class HotSpotCryptoSubstitutionTest extends HotSpotGraalCompilerTest {
051
052    @Override
053    protected InstalledCode addMethod(ResolvedJavaMethod method, CompilationResult compResult) {
054        HotSpotResolvedJavaMethod hsMethod = (HotSpotResolvedJavaMethod) method;
055        HotSpotNmethod installedCode = new HotSpotNmethod(hsMethod, compResult.getName(), true);
056        HotSpotCompiledNmethod compiledNmethod = new HotSpotCompiledNmethod(hsMethod, compResult);
057        int result = runtime().getCompilerToVM().installCode(compiledNmethod, installedCode, null);
058        HotSpotVMConfig config = runtime().getConfig();
059        Assert.assertEquals("Error installing method " + method + ": " + config.getCodeInstallResultDescription(result), result, config.codeInstallResultOk);
060
061        // HotSpotRuntime hsRuntime = (HotSpotRuntime) getCodeCache();
062        // TTY.println(hsMethod.toString());
063        // TTY.println(hsRuntime.disassemble(installedCode));
064        return installedCode;
065    }
066
067    SecretKey aesKey;
068    SecretKey desKey;
069    byte[] input;
070    ByteArrayOutputStream aesExpected = new ByteArrayOutputStream();
071    ByteArrayOutputStream desExpected = new ByteArrayOutputStream();
072
073    public HotSpotCryptoSubstitutionTest() throws Exception {
074        byte[] seed = {0x4, 0x7, 0x1, 0x1};
075        SecureRandom random = new SecureRandom(seed);
076        KeyGenerator aesKeyGen = KeyGenerator.getInstance("AES");
077        KeyGenerator desKeyGen = KeyGenerator.getInstance("DESede");
078        aesKeyGen.init(128, random);
079        desKeyGen.init(168, random);
080        aesKey = aesKeyGen.generateKey();
081        desKey = desKeyGen.generateKey();
082        input = readClassfile16(getClass());
083
084        aesExpected.write(runEncryptDecrypt(aesKey, "AES/CBC/NoPadding"));
085        aesExpected.write(runEncryptDecrypt(aesKey, "AES/CBC/PKCS5Padding"));
086
087        desExpected.write(runEncryptDecrypt(desKey, "DESede/CBC/NoPadding"));
088        desExpected.write(runEncryptDecrypt(desKey, "DESede/CBC/PKCS5Padding"));
089    }
090
091    @Test
092    public void testAESCryptIntrinsics() throws Exception {
093        if (compileAndInstall("com.sun.crypto.provider.AESCrypt", "encryptBlock", "decryptBlock")) {
094            ByteArrayOutputStream actual = new ByteArrayOutputStream();
095            actual.write(runEncryptDecrypt(aesKey, "AES/CBC/NoPadding"));
096            actual.write(runEncryptDecrypt(aesKey, "AES/CBC/PKCS5Padding"));
097            Assert.assertArrayEquals(aesExpected.toByteArray(), actual.toByteArray());
098        }
099    }
100
101    @Test
102    public void testCipherBlockChainingIntrinsics() throws Exception {
103        if (compileAndInstall("com.sun.crypto.provider.CipherBlockChaining", "encrypt", "decrypt")) {
104            ByteArrayOutputStream actual = new ByteArrayOutputStream();
105            actual.write(runEncryptDecrypt(aesKey, "AES/CBC/NoPadding"));
106            actual.write(runEncryptDecrypt(aesKey, "AES/CBC/PKCS5Padding"));
107            Assert.assertArrayEquals(aesExpected.toByteArray(), actual.toByteArray());
108
109            actual.reset();
110            actual.write(runEncryptDecrypt(desKey, "DESede/CBC/NoPadding"));
111            actual.write(runEncryptDecrypt(desKey, "DESede/CBC/PKCS5Padding"));
112            Assert.assertArrayEquals(desExpected.toByteArray(), actual.toByteArray());
113        }
114    }
115
116    /**
117     * Compiles and installs the substitution for some specified methods. Once installed, the next
118     * execution of the methods will use the newly installed code.
119     *
120     * @param className the name of the class for which substitutions are available
121     * @param methodNames the names of the substituted methods
122     * @return true if at least one substitution was compiled and installed
123     */
124    private boolean compileAndInstall(String className, String... methodNames) {
125        boolean atLeastOneCompiled = false;
126        for (String methodName : methodNames) {
127            Method method = lookup(className, methodName);
128            if (method != null) {
129                ResolvedJavaMethod installedCodeOwner = getMetaAccess().lookupJavaMethod(method);
130                StructuredGraph subst = getReplacements().getSubstitution(installedCodeOwner, 0);
131                ResolvedJavaMethod substMethod = subst == null ? null : subst.method();
132                if (substMethod != null) {
133                    StructuredGraph graph = new StructuredGraph(substMethod, AllowAssumptions.YES);
134                    Plugins plugins = new Plugins(((HotSpotProviders) getProviders()).getGraphBuilderPlugins());
135                    GraphBuilderConfiguration config = GraphBuilderConfiguration.getSnippetDefault(plugins);
136                    IntrinsicContext initialReplacementContext = new IntrinsicContext(installedCodeOwner, substMethod, ROOT_COMPILATION);
137                    new GraphBuilderPhase.Instance(getMetaAccess(), getProviders().getStampProvider(), getConstantReflection(), config, OptimisticOptimizations.NONE, initialReplacementContext).apply(graph);
138                    Assert.assertNotNull(getCode(installedCodeOwner, graph, true));
139                    atLeastOneCompiled = true;
140                } else {
141                    Assert.assertFalse(runtime().getConfig().useAESIntrinsics);
142                }
143            }
144        }
145        return atLeastOneCompiled;
146    }
147
148    private static Method lookup(String className, String methodName) {
149        Class<?> c;
150        try {
151            c = Class.forName(className);
152            for (Method m : c.getDeclaredMethods()) {
153                if (m.getName().equals(methodName)) {
154                    return m;
155                }
156            }
157            // If the expected security provider exists, the specific method should also exist
158            throw new NoSuchMethodError(className + "." + methodName);
159        } catch (ClassNotFoundException e) {
160            // It's ok to not find the class - a different security provider
161            // may have been installed
162            return null;
163        }
164    }
165
166    AlgorithmParameters algorithmParameters;
167
168    private byte[] encrypt(byte[] indata, SecretKey key, String algorithm) throws Exception {
169
170        byte[] result = indata;
171
172        Cipher c = Cipher.getInstance(algorithm);
173        c.init(Cipher.ENCRYPT_MODE, key);
174        algorithmParameters = c.getParameters();
175
176        byte[] r1 = c.update(result);
177        byte[] r2 = c.doFinal();
178
179        result = new byte[r1.length + r2.length];
180        System.arraycopy(r1, 0, result, 0, r1.length);
181        System.arraycopy(r2, 0, result, r1.length, r2.length);
182
183        return result;
184    }
185
186    private byte[] decrypt(byte[] indata, SecretKey key, String algorithm) throws Exception {
187
188        byte[] result = indata;
189
190        Cipher c = Cipher.getInstance(algorithm);
191        c.init(Cipher.DECRYPT_MODE, key, algorithmParameters);
192
193        byte[] r1 = c.update(result);
194        byte[] r2 = c.doFinal();
195
196        result = new byte[r1.length + r2.length];
197        System.arraycopy(r1, 0, result, 0, r1.length);
198        System.arraycopy(r2, 0, result, r1.length, r2.length);
199        return result;
200    }
201
202    private static byte[] readClassfile16(Class<? extends HotSpotCryptoSubstitutionTest> c) throws IOException {
203        String classFilePath = "/" + c.getName().replace('.', '/') + ".class";
204        InputStream stream = c.getResourceAsStream(classFilePath);
205        int bytesToRead = stream.available();
206        bytesToRead -= bytesToRead % 16;
207        byte[] classFile = new byte[bytesToRead];
208        new DataInputStream(stream).readFully(classFile);
209        return classFile;
210    }
211
212    public byte[] runEncryptDecrypt(SecretKey key, String algorithm) throws Exception {
213        byte[] indata = input.clone();
214        byte[] cipher = encrypt(indata, key, algorithm);
215        byte[] plain = decrypt(cipher, key, algorithm);
216        Assert.assertArrayEquals(indata, plain);
217        return plain;
218    }
219}