001/** 002 * Licensed to the Apache Software Foundation (ASF) under one or more 003 * contributor license agreements. See the NOTICE file distributed with 004 * this work for additional information regarding copyright ownership. 005 * The ASF licenses this file to You under the Apache License, Version 2.0 006 * (the "License"); you may not use this file except in compliance with 007 * the License. You may obtain a copy of the License at 008 * 009 * http://www.apache.org/licenses/LICENSE-2.0 010 * 011 * Unless required by applicable law or agreed to in writing, software 012 * distributed under the License is distributed on an "AS IS" BASIS, 013 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 014 * See the License for the specific language governing permissions and 015 * limitations under the License. 016 */ 017package org.apache.activemq.transport.auto; 018 019import java.io.IOException; 020import java.io.InputStream; 021import java.net.Socket; 022import java.net.URI; 023import java.net.URISyntaxException; 024import java.nio.ByteBuffer; 025import java.util.HashMap; 026import java.util.Map; 027import java.util.Set; 028import java.util.concurrent.ConcurrentHashMap; 029import java.util.concurrent.ConcurrentMap; 030import java.util.concurrent.ExecutorService; 031import java.util.concurrent.Executors; 032import java.util.concurrent.Future; 033import java.util.concurrent.LinkedBlockingQueue; 034import java.util.concurrent.ThreadPoolExecutor; 035import java.util.concurrent.TimeUnit; 036import java.util.concurrent.TimeoutException; 037import java.util.concurrent.atomic.AtomicInteger; 038 039import javax.net.ServerSocketFactory; 040 041import org.apache.activemq.broker.BrokerService; 042import org.apache.activemq.broker.BrokerServiceAware; 043import org.apache.activemq.openwire.OpenWireFormatFactory; 044import org.apache.activemq.transport.InactivityIOException; 045import org.apache.activemq.transport.Transport; 046import org.apache.activemq.transport.TransportFactory; 047import org.apache.activemq.transport.TransportServer; 048import org.apache.activemq.transport.protocol.AmqpProtocolVerifier; 049import org.apache.activemq.transport.protocol.MqttProtocolVerifier; 050import org.apache.activemq.transport.protocol.OpenWireProtocolVerifier; 051import org.apache.activemq.transport.protocol.ProtocolVerifier; 052import org.apache.activemq.transport.protocol.StompProtocolVerifier; 053import org.apache.activemq.transport.tcp.TcpTransport; 054import org.apache.activemq.transport.tcp.TcpTransport.InitBuffer; 055import org.apache.activemq.transport.tcp.TcpTransportFactory; 056import org.apache.activemq.transport.tcp.TcpTransportServer; 057import org.apache.activemq.util.FactoryFinder; 058import org.apache.activemq.util.IOExceptionSupport; 059import org.apache.activemq.util.IntrospectionSupport; 060import org.apache.activemq.util.ServiceStopper; 061import org.apache.activemq.wireformat.WireFormat; 062import org.apache.activemq.wireformat.WireFormatFactory; 063import org.slf4j.Logger; 064import org.slf4j.LoggerFactory; 065 066/** 067 * A TCP based implementation of {@link TransportServer} 068 */ 069public class AutoTcpTransportServer extends TcpTransportServer { 070 071 private static final Logger LOG = LoggerFactory.getLogger(AutoTcpTransportServer.class); 072 073 protected Map<String, Map<String, Object>> wireFormatOptions; 074 protected Map<String, Object> autoTransportOptions; 075 protected Set<String> enabledProtocols; 076 protected final Map<String, ProtocolVerifier> protocolVerifiers = new ConcurrentHashMap<String, ProtocolVerifier>(); 077 078 protected BrokerService brokerService; 079 080 protected int maxConnectionThreadPoolSize = Integer.MAX_VALUE; 081 protected int protocolDetectionTimeOut = 30000; 082 083 private static final FactoryFinder TRANSPORT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/transport/"); 084 private final ConcurrentMap<String, TransportFactory> transportFactories = new ConcurrentHashMap<String, TransportFactory>(); 085 086 private static final FactoryFinder WIREFORMAT_FACTORY_FINDER = new FactoryFinder("META-INF/services/org/apache/activemq/wireformat/"); 087 088 public WireFormatFactory findWireFormatFactory(String scheme, Map<String, Map<String, Object>> options) throws IOException { 089 WireFormatFactory wff = null; 090 try { 091 wff = (WireFormatFactory)WIREFORMAT_FACTORY_FINDER.newInstance(scheme); 092 if (options != null) { 093 final Map<String, Object> wfOptions = new HashMap<>(); 094 if (options.get(AutoTransportUtils.ALL) != null) { 095 wfOptions.putAll(options.get(AutoTransportUtils.ALL)); 096 } 097 if (options.get(scheme) != null) { 098 wfOptions.putAll(options.get(scheme)); 099 } 100 IntrospectionSupport.setProperties(wff, wfOptions); 101 } 102 if (wff instanceof OpenWireFormatFactory) { 103 protocolVerifiers.put(AutoTransportUtils.OPENWIRE, new OpenWireProtocolVerifier((OpenWireFormatFactory) wff)); 104 } 105 return wff; 106 } catch (Throwable e) { 107 throw IOExceptionSupport.create("Could not create wire format factory for: " + scheme + ", reason: " + e, e); 108 } 109 } 110 111 public TransportFactory findTransportFactory(String scheme, Map<String, ?> options) throws IOException { 112 scheme = append(scheme, "nio"); 113 scheme = append(scheme, "ssl"); 114 115 if (scheme.isEmpty()) { 116 scheme = "tcp"; 117 } 118 119 TransportFactory tf = transportFactories.get(scheme); 120 if (tf == null) { 121 // Try to load if from a META-INF property. 122 try { 123 tf = (TransportFactory)TRANSPORT_FACTORY_FINDER.newInstance(scheme); 124 if (options != null) { 125 IntrospectionSupport.setProperties(tf, options); 126 } 127 transportFactories.put(scheme, tf); 128 } catch (Throwable e) { 129 throw IOExceptionSupport.create("Transport scheme NOT recognized: [" + scheme + "]", e); 130 } 131 } 132 return tf; 133 } 134 135 protected String append(String currentScheme, String scheme) { 136 if (this.getBindLocation().getScheme().contains(scheme)) { 137 if (!currentScheme.isEmpty()) { 138 currentScheme += "+"; 139 } 140 currentScheme += scheme; 141 } 142 return currentScheme; 143 } 144 145 /** 146 * @param transportFactory 147 * @param location 148 * @param serverSocketFactory 149 * @throws IOException 150 * @throws URISyntaxException 151 */ 152 public AutoTcpTransportServer(TcpTransportFactory transportFactory, 153 URI location, ServerSocketFactory serverSocketFactory, BrokerService brokerService, 154 Set<String> enabledProtocols) 155 throws IOException, URISyntaxException { 156 super(transportFactory, location, serverSocketFactory); 157 158 //Use an executor service here to handle new connections. Setting the max number 159 //of threads to the maximum number of connections the thread count isn't unbounded 160 service = new ThreadPoolExecutor(maxConnectionThreadPoolSize, 161 maxConnectionThreadPoolSize, 162 30L, TimeUnit.SECONDS, 163 new LinkedBlockingQueue<Runnable>()); 164 //allow the thread pool to shrink if the max number of threads isn't needed 165 service.allowCoreThreadTimeOut(true); 166 167 this.brokerService = brokerService; 168 this.enabledProtocols = enabledProtocols; 169 initProtocolVerifiers(); 170 } 171 172 public int getMaxConnectionThreadPoolSize() { 173 return maxConnectionThreadPoolSize; 174 } 175 176 public void setMaxConnectionThreadPoolSize(int maxConnectionThreadPoolSize) { 177 this.maxConnectionThreadPoolSize = maxConnectionThreadPoolSize; 178 service.setCorePoolSize(maxConnectionThreadPoolSize); 179 service.setMaximumPoolSize(maxConnectionThreadPoolSize); 180 } 181 182 public void setProtocolDetectionTimeOut(int protocolDetectionTimeOut) { 183 this.protocolDetectionTimeOut = protocolDetectionTimeOut; 184 } 185 186 @Override 187 public void setWireFormatFactory(WireFormatFactory factory) { 188 super.setWireFormatFactory(factory); 189 initOpenWireProtocolVerifier(); 190 } 191 192 protected void initProtocolVerifiers() { 193 initOpenWireProtocolVerifier(); 194 195 if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.AMQP)) { 196 protocolVerifiers.put(AutoTransportUtils.AMQP, new AmqpProtocolVerifier()); 197 } 198 if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.STOMP)) { 199 protocolVerifiers.put(AutoTransportUtils.STOMP, new StompProtocolVerifier()); 200 } 201 if (isAllProtocols()|| enabledProtocols.contains(AutoTransportUtils.MQTT)) { 202 protocolVerifiers.put(AutoTransportUtils.MQTT, new MqttProtocolVerifier()); 203 } 204 } 205 206 protected void initOpenWireProtocolVerifier() { 207 if (isAllProtocols() || enabledProtocols.contains(AutoTransportUtils.OPENWIRE)) { 208 OpenWireProtocolVerifier owpv; 209 if (wireFormatFactory instanceof OpenWireFormatFactory) { 210 owpv = new OpenWireProtocolVerifier((OpenWireFormatFactory) wireFormatFactory); 211 } else { 212 owpv = new OpenWireProtocolVerifier(new OpenWireFormatFactory()); 213 } 214 protocolVerifiers.put(AutoTransportUtils.OPENWIRE, owpv); 215 } 216 } 217 218 protected boolean isAllProtocols() { 219 return enabledProtocols == null || enabledProtocols.isEmpty(); 220 } 221 222 223 protected final ThreadPoolExecutor service; 224 225 226 /** 227 * This holds the initial buffer that has been read to detect the protocol. 228 */ 229 public InitBuffer initBuffer; 230 231 @Override 232 protected void handleSocket(final Socket socket) { 233 final AutoTcpTransportServer server = this; 234 //This needs to be done in a new thread because 235 //the socket might be waiting on the client to send bytes 236 //doHandleSocket can't complete until the protocol can be detected 237 service.submit(new Runnable() { 238 @Override 239 public void run() { 240 server.doHandleSocket(socket); 241 } 242 }); 243 } 244 245 @Override 246 protected TransportInfo configureTransport(final TcpTransportServer server, final Socket socket) throws Exception { 247 final InputStream is = socket.getInputStream(); 248 ExecutorService executor = Executors.newSingleThreadExecutor(); 249 250 final AtomicInteger readBytes = new AtomicInteger(0); 251 final ByteBuffer data = ByteBuffer.allocate(8); 252 // We need to peak at the first 8 bytes of the buffer to detect the protocol 253 Future<?> future = executor.submit(new Runnable() { 254 @Override 255 public void run() { 256 try { 257 do { 258 int read = is.read(); 259 if (read == -1) { 260 throw new IOException("Connection failed, stream is closed."); 261 } 262 data.put((byte) read); 263 readBytes.incrementAndGet(); 264 } while (readBytes.get() < 8); 265 } catch (Exception e) { 266 throw new IllegalStateException(e); 267 } 268 } 269 }); 270 271 waitForProtocolDetectionFinish(future, readBytes); 272 data.flip(); 273 ProtocolInfo protocolInfo = detectProtocol(data.array()); 274 275 initBuffer = new InitBuffer(readBytes.get(), ByteBuffer.allocate(readBytes.get())); 276 initBuffer.buffer.put(data.array()); 277 278 if (protocolInfo.detectedTransportFactory instanceof BrokerServiceAware) { 279 ((BrokerServiceAware) protocolInfo.detectedTransportFactory).setBrokerService(brokerService); 280 } 281 282 WireFormat format = protocolInfo.detectedWireFormatFactory.createWireFormat(); 283 Transport transport = createTransport(socket, format, protocolInfo.detectedTransportFactory); 284 285 return new TransportInfo(format, transport, protocolInfo.detectedTransportFactory); 286 } 287 288 protected void waitForProtocolDetectionFinish(final Future<?> future, final AtomicInteger readBytes) throws Exception { 289 try { 290 //Wait for protocolDetectionTimeOut if defined 291 if (protocolDetectionTimeOut > 0) { 292 future.get(protocolDetectionTimeOut, TimeUnit.MILLISECONDS); 293 } else { 294 future.get(); 295 } 296 } catch (TimeoutException e) { 297 throw new InactivityIOException("Client timed out before wire format could be detected. " + 298 " 8 bytes are required to detect the protocol but only: " + readBytes.get() + " byte(s) were sent."); 299 } 300 } 301 302 @Override 303 protected TcpTransport createTransport(Socket socket, WireFormat format) throws IOException { 304 return new TcpTransport(format, socket, this.initBuffer); 305 } 306 307 /** 308 * @param socket 309 * @param format 310 * @param detectedTransportFactory 311 * @return 312 */ 313 protected TcpTransport createTransport(Socket socket, WireFormat format, 314 TcpTransportFactory detectedTransportFactory) throws IOException { 315 return createTransport(socket, format); 316 } 317 318 public void setWireFormatOptions(Map<String, Map<String, Object>> wireFormatOptions) { 319 this.wireFormatOptions = wireFormatOptions; 320 } 321 322 public void setEnabledProtocols(Set<String> enabledProtocols) { 323 this.enabledProtocols = enabledProtocols; 324 } 325 326 public void setAutoTransportOptions(Map<String, Object> autoTransportOptions) { 327 this.autoTransportOptions = autoTransportOptions; 328 if (autoTransportOptions.get("protocols") != null) { 329 this.enabledProtocols = AutoTransportUtils.parseProtocols((String) autoTransportOptions.get("protocols")); 330 } 331 } 332 @Override 333 protected void doStop(ServiceStopper stopper) throws Exception { 334 if (service != null) { 335 service.shutdown(); 336 } 337 super.doStop(stopper); 338 } 339 340 protected ProtocolInfo detectProtocol(byte[] buffer) throws IOException { 341 TcpTransportFactory detectedTransportFactory = transportFactory; 342 WireFormatFactory detectedWireFormatFactory = wireFormatFactory; 343 344 boolean found = false; 345 for (String scheme : protocolVerifiers.keySet()) { 346 if (protocolVerifiers.get(scheme).isProtocol(buffer)) { 347 LOG.debug("Detected protocol " + scheme); 348 detectedWireFormatFactory = findWireFormatFactory(scheme, wireFormatOptions); 349 350 if (scheme.equals("default")) { 351 scheme = ""; 352 } 353 354 detectedTransportFactory = (TcpTransportFactory) findTransportFactory(scheme, transportOptions); 355 found = true; 356 break; 357 } 358 } 359 360 if (!found) { 361 throw new IllegalStateException("Could not detect the wire format"); 362 } 363 364 return new ProtocolInfo(detectedTransportFactory, detectedWireFormatFactory); 365 366 } 367 368 protected class ProtocolInfo { 369 public final TcpTransportFactory detectedTransportFactory; 370 public final WireFormatFactory detectedWireFormatFactory; 371 372 public ProtocolInfo(TcpTransportFactory detectedTransportFactory, 373 WireFormatFactory detectedWireFormatFactory) { 374 super(); 375 this.detectedTransportFactory = detectedTransportFactory; 376 this.detectedWireFormatFactory = detectedWireFormatFactory; 377 } 378 } 379 380}