001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017 018package org.apache.activemq.transport.nio; 019 020import java.io.DataOutputStream; 021import java.io.EOFException; 022import java.io.IOException; 023import java.net.Socket; 024import java.net.URI; 025import java.net.UnknownHostException; 026import java.nio.ByteBuffer; 027import java.util.concurrent.atomic.AtomicInteger; 028 029import javax.net.SocketFactory; 030import javax.net.ssl.SSLContext; 031import javax.net.ssl.SSLEngine; 032import javax.net.ssl.SSLEngineResult; 033 034import org.apache.activemq.thread.TaskRunnerFactory; 035import org.apache.activemq.util.IOExceptionSupport; 036import org.apache.activemq.util.ServiceStopper; 037import org.apache.activemq.wireformat.WireFormat; 038 039/** 040 * This transport initializes the SSLEngine and reads the first command before 041 * handing off to the detected transport. 042 * 043 */ 044public class AutoInitNioSSLTransport extends NIOSSLTransport { 045 046 public AutoInitNioSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException { 047 super(wireFormat, socketFactory, remoteLocation, localLocation); 048 } 049 050 public AutoInitNioSSLTransport(WireFormat wireFormat, Socket socket) throws IOException { 051 super(wireFormat, socket, null, null, null); 052 } 053 054 @Override 055 public void setSslContext(SSLContext sslContext) { 056 this.sslContext = sslContext; 057 } 058 059 public ByteBuffer getInputBuffer() { 060 return this.inputBuffer; 061 } 062 063 @Override 064 protected void initializeStreams() throws IOException { 065 NIOOutputStream outputStream = null; 066 try { 067 channel = socket.getChannel(); 068 channel.configureBlocking(false); 069 070 if (sslContext == null) { 071 sslContext = SSLContext.getDefault(); 072 } 073 074 String remoteHost = null; 075 int remotePort = -1; 076 077 try { 078 URI remoteAddress = new URI(this.getRemoteAddress()); 079 remoteHost = remoteAddress.getHost(); 080 remotePort = remoteAddress.getPort(); 081 } catch (Exception e) { 082 } 083 084 // initialize engine, the initial sslSession we get will need to be 085 // updated once the ssl handshake process is completed. 086 if (remoteHost != null && remotePort != -1) { 087 sslEngine = sslContext.createSSLEngine(remoteHost, remotePort); 088 } else { 089 sslEngine = sslContext.createSSLEngine(); 090 } 091 092 sslEngine.setUseClientMode(false); 093 if (enabledCipherSuites != null) { 094 sslEngine.setEnabledCipherSuites(enabledCipherSuites); 095 } 096 097 if (enabledProtocols != null) { 098 sslEngine.setEnabledProtocols(enabledProtocols); 099 } 100 101 if (wantClientAuth) { 102 sslEngine.setWantClientAuth(wantClientAuth); 103 } 104 105 if (needClientAuth) { 106 sslEngine.setNeedClientAuth(needClientAuth); 107 } 108 109 sslSession = sslEngine.getSession(); 110 111 inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize()); 112 inputBuffer.clear(); 113 114 outputStream = new NIOOutputStream(channel); 115 outputStream.setEngine(sslEngine); 116 this.dataOut = new DataOutputStream(outputStream); 117 this.buffOut = outputStream; 118 sslEngine.beginHandshake(); 119 handshakeStatus = sslEngine.getHandshakeStatus(); 120 doHandshake(); 121 122 } catch (Exception e) { 123 try { 124 if(outputStream != null) { 125 outputStream.close(); 126 } 127 super.closeStreams(); 128 } catch (Exception ex) {} 129 throw new IOException(e); 130 } 131 } 132 133 @Override 134 protected void doOpenWireInit() throws Exception { 135 136 } 137 138 139 @Override 140 protected void finishHandshake() throws Exception { 141 if (handshakeInProgress) { 142 handshakeInProgress = false; 143 nextFrameSize = -1; 144 145 // Once handshake completes we need to ask for the now real sslSession 146 // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the 147 // cipher suite. 148 sslSession = sslEngine.getSession(); 149 150 } 151 } 152 153 public SSLEngine getSslSession() { 154 return this.sslEngine; 155 } 156 157 private volatile byte[] readData; 158 159 private final AtomicInteger readSize = new AtomicInteger(); 160 161 public byte[] getReadData() { 162 return readData != null ? readData : new byte[0]; 163 } 164 165 public AtomicInteger getReadSize() { 166 return readSize; 167 } 168 169 @Override 170 public void serviceRead() { 171 try { 172 if (handshakeInProgress) { 173 doHandshake(); 174 } 175 176 ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize()); 177 plain.position(plain.limit()); 178 179 while (true) { 180 if (!plain.hasRemaining()) { 181 int readCount = secureRead(plain); 182 183 // channel is closed, cleanup 184 if (readCount == -1) { 185 onException(new EOFException()); 186 break; 187 } 188 189 receiveCounter += readCount; 190 readSize.addAndGet(readCount); 191 } 192 193 if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) { 194 processCommand(plain); 195 //we have received enough bytes to detect the protocol 196 if (receiveCounter >= 8) { 197 break; 198 } 199 } 200 } 201 } catch (IOException e) { 202 onException(e); 203 } catch (Throwable e) { 204 onException(IOExceptionSupport.create(e)); 205 } 206 } 207 208 @Override 209 protected void processCommand(ByteBuffer plain) throws Exception { 210 ByteBuffer newBuffer = ByteBuffer.allocate(receiveCounter); 211 if (readData != null) { 212 newBuffer.put(readData); 213 } 214 newBuffer.put(plain); 215 newBuffer.flip(); 216 readData = newBuffer.array(); 217 } 218 219 220 @Override 221 public void doStart() throws Exception { 222 taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task"); 223 // no need to init as we can delay that until demand (eg in doHandshake) 224 connect(); 225 } 226 227 228 @Override 229 protected void doStop(ServiceStopper stopper) throws Exception { 230 if (taskRunnerFactory != null) { 231 taskRunnerFactory.shutdownNow(); 232 taskRunnerFactory = null; 233 } 234 } 235 236 237}