001/**
002 * Licensed to the Apache Software Foundation (ASF) under one or more
003 * contributor license agreements.  See the NOTICE file distributed with
004 * this work for additional information regarding copyright ownership.
005 * The ASF licenses this file to You under the Apache License, Version 2.0
006 * (the "License"); you may not use this file except in compliance with
007 * the License.  You may obtain a copy of the License at
008 *
009 *      http://www.apache.org/licenses/LICENSE-2.0
010 *
011 * Unless required by applicable law or agreed to in writing, software
012 * distributed under the License is distributed on an "AS IS" BASIS,
013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
014 * See the License for the specific language governing permissions and
015 * limitations under the License.
016 */
017
018package org.apache.activemq.transport.nio;
019
020import java.io.DataOutputStream;
021import java.io.EOFException;
022import java.io.IOException;
023import java.net.Socket;
024import java.net.URI;
025import java.net.UnknownHostException;
026import java.nio.ByteBuffer;
027import java.util.concurrent.atomic.AtomicInteger;
028
029import javax.net.SocketFactory;
030import javax.net.ssl.SSLContext;
031import javax.net.ssl.SSLEngine;
032import javax.net.ssl.SSLEngineResult;
033
034import org.apache.activemq.thread.TaskRunnerFactory;
035import org.apache.activemq.util.IOExceptionSupport;
036import org.apache.activemq.util.ServiceStopper;
037import org.apache.activemq.wireformat.WireFormat;
038
039/**
040 * This transport initializes the SSLEngine and reads the first command before
041 * handing off to the detected transport.
042 *
043 */
044public class AutoInitNioSSLTransport extends NIOSSLTransport {
045
046    public AutoInitNioSSLTransport(WireFormat wireFormat, SocketFactory socketFactory, URI remoteLocation, URI localLocation) throws UnknownHostException, IOException {
047        super(wireFormat, socketFactory, remoteLocation, localLocation);
048    }
049
050    public AutoInitNioSSLTransport(WireFormat wireFormat, Socket socket) throws IOException {
051        super(wireFormat, socket, null, null, null);
052    }
053
054    @Override
055    public void setSslContext(SSLContext sslContext) {
056        this.sslContext = sslContext;
057    }
058
059    public ByteBuffer getInputBuffer() {
060        return this.inputBuffer;
061    }
062
063    @Override
064    protected void initializeStreams() throws IOException {
065        NIOOutputStream outputStream = null;
066        try {
067            channel = socket.getChannel();
068            channel.configureBlocking(false);
069
070            if (sslContext == null) {
071                sslContext = SSLContext.getDefault();
072            }
073
074            String remoteHost = null;
075            int remotePort = -1;
076
077            try {
078                URI remoteAddress = new URI(this.getRemoteAddress());
079                remoteHost = remoteAddress.getHost();
080                remotePort = remoteAddress.getPort();
081            } catch (Exception e) {
082            }
083
084            // initialize engine, the initial sslSession we get will need to be
085            // updated once the ssl handshake process is completed.
086            if (remoteHost != null && remotePort != -1) {
087                sslEngine = sslContext.createSSLEngine(remoteHost, remotePort);
088            } else {
089                sslEngine = sslContext.createSSLEngine();
090            }
091
092            sslEngine.setUseClientMode(false);
093            if (enabledCipherSuites != null) {
094                sslEngine.setEnabledCipherSuites(enabledCipherSuites);
095            }
096
097            if (enabledProtocols != null) {
098                sslEngine.setEnabledProtocols(enabledProtocols);
099            }
100
101            if (wantClientAuth) {
102                sslEngine.setWantClientAuth(wantClientAuth);
103            }
104
105            if (needClientAuth) {
106                sslEngine.setNeedClientAuth(needClientAuth);
107            }
108
109            sslSession = sslEngine.getSession();
110
111            inputBuffer = ByteBuffer.allocate(sslSession.getPacketBufferSize());
112            inputBuffer.clear();
113
114            outputStream = new NIOOutputStream(channel);
115            outputStream.setEngine(sslEngine);
116            this.dataOut = new DataOutputStream(outputStream);
117            this.buffOut = outputStream;
118            sslEngine.beginHandshake();
119            handshakeStatus = sslEngine.getHandshakeStatus();
120            doHandshake();
121
122        } catch (Exception e) {
123            try {
124                if(outputStream != null) {
125                    outputStream.close();
126                }
127                super.closeStreams();
128            } catch (Exception ex) {}
129            throw new IOException(e);
130        }
131    }
132
133    @Override
134    protected void doOpenWireInit() throws Exception {
135
136    }
137
138
139    @Override
140    protected void finishHandshake() throws Exception {
141        if (handshakeInProgress) {
142            handshakeInProgress = false;
143            nextFrameSize = -1;
144
145            // Once handshake completes we need to ask for the now real sslSession
146            // otherwise the session would return 'SSL_NULL_WITH_NULL_NULL' for the
147            // cipher suite.
148            sslSession = sslEngine.getSession();
149
150        }
151    }
152
153    public SSLEngine getSslSession() {
154        return this.sslEngine;
155    }
156
157    private volatile byte[] readData;
158
159    private final AtomicInteger readSize = new AtomicInteger();
160
161    public byte[] getReadData() {
162        return readData != null ? readData : new byte[0];
163    }
164
165    public AtomicInteger getReadSize() {
166        return readSize;
167    }
168
169    @Override
170    public void serviceRead() {
171        try {
172            if (handshakeInProgress) {
173                doHandshake();
174            }
175
176            ByteBuffer plain = ByteBuffer.allocate(sslSession.getApplicationBufferSize());
177            plain.position(plain.limit());
178
179            while (true) {
180                if (!plain.hasRemaining()) {
181                    int readCount = secureRead(plain);
182
183                    // channel is closed, cleanup
184                    if (readCount == -1) {
185                        onException(new EOFException());
186                        break;
187                    }
188
189                    receiveCounter += readCount;
190                    readSize.addAndGet(readCount);
191                }
192
193                if (status == SSLEngineResult.Status.OK && handshakeStatus != SSLEngineResult.HandshakeStatus.NEED_UNWRAP) {
194                    processCommand(plain);
195                    //we have received enough bytes to detect the protocol
196                    if (receiveCounter >= 8) {
197                        break;
198                    }
199                }
200            }
201        } catch (IOException e) {
202            onException(e);
203        } catch (Throwable e) {
204            onException(IOExceptionSupport.create(e));
205        }
206    }
207
208    @Override
209    protected void processCommand(ByteBuffer plain) throws Exception {
210        ByteBuffer newBuffer = ByteBuffer.allocate(receiveCounter);
211        if (readData != null) {
212            newBuffer.put(readData);
213        }
214        newBuffer.put(plain);
215        newBuffer.flip();
216        readData = newBuffer.array();
217    }
218
219
220    @Override
221    public void doStart() throws Exception {
222        taskRunnerFactory = new TaskRunnerFactory("ActiveMQ NIOSSLTransport Task");
223        // no need to init as we can delay that until demand (eg in doHandshake)
224        connect();
225    }
226
227
228    @Override
229    protected void doStop(ServiceStopper stopper) throws Exception {
230        if (taskRunnerFactory != null) {
231            taskRunnerFactory.shutdownNow();
232            taskRunnerFactory = null;
233        }
234    }
235
236
237}