Add HTTP tunnel support

Add server and client components to support tunneling of binary TCP
protocols over HTTP. Primarily designed to support Java's remote
debug protocol (JDWP).

See gh-3087
pull/3077/merge
Phillip Webb 10 years ago
parent c27b63b354
commit 2123b267aa

@ -0,0 +1,216 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.client;
import java.io.Closeable;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.channels.WritableByteChannel;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.boot.developertools.tunnel.payload.HttpTunnelPayload;
import org.springframework.boot.developertools.tunnel.payload.HttpTunnelPayloadForwarder;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.Assert;
/**
* {@link TunnelConnection} implementation that uses HTTP to transfer data.
*
* @author Phillip Webb
* @author Rob Winch
* @since 1.3.0
* @see TunnelClient
* @see org.springframework.boot.developertools.tunnel.server.HttpTunnelServer
*/
public class HttpTunnelConnection implements TunnelConnection {
private static Log logger = LogFactory.getLog(HttpTunnelConnection.class);
private final URI uri;
private final ClientHttpRequestFactory requestFactory;
private final Executor executor;
/**
* Create a new {@link HttpTunnelConnection} instance.
* @param url the URL to connect to
* @param requestFactory the HTTP request factory
*/
public HttpTunnelConnection(String url, ClientHttpRequestFactory requestFactory) {
this(url, requestFactory, null);
}
/**
* Create a new {@link HttpTunnelConnection} instance.
* @param url the URL to connect to
* @param requestFactory the HTTP request factory
* @param executor the executor used to handle connections
*/
protected HttpTunnelConnection(String url, ClientHttpRequestFactory requestFactory,
Executor executor) {
Assert.hasLength(url, "URL must not be empty");
Assert.notNull(requestFactory, "RequestFactory must not be null");
try {
this.uri = new URL(url).toURI();
}
catch (URISyntaxException ex) {
throw new IllegalArgumentException("Malformed URL '" + url + "'");
}
catch (MalformedURLException ex) {
throw new IllegalArgumentException("Malformed URL '" + url + "'");
}
this.requestFactory = requestFactory;
this.executor = (executor == null ? Executors
.newCachedThreadPool(new TunnelThreadFactory()) : executor);
}
@Override
public TunnelChannel open(WritableByteChannel incomingChannel, Closeable closeable)
throws Exception {
logger.trace("Opening HTTP tunnel to " + this.uri);
return new TunnelChannel(incomingChannel, closeable);
}
protected final ClientHttpRequest createRequest(boolean hasPayload)
throws IOException {
HttpMethod method = (hasPayload ? HttpMethod.POST : HttpMethod.GET);
return this.requestFactory.createRequest(this.uri, method);
}
/**
* A {@link WritableByteChannel} used to transfer traffic.
*/
protected class TunnelChannel implements WritableByteChannel {
private final HttpTunnelPayloadForwarder forwarder;
private final Closeable closeable;
private boolean open = true;
private AtomicLong requestSeq = new AtomicLong();
public TunnelChannel(WritableByteChannel incomingChannel, Closeable closeable) {
this.forwarder = new HttpTunnelPayloadForwarder(incomingChannel);
this.closeable = closeable;
openNewConnection(null);
}
@Override
public boolean isOpen() {
return this.open;
}
@Override
public void close() throws IOException {
if (this.open) {
this.open = false;
this.closeable.close();
}
}
@Override
public int write(ByteBuffer src) throws IOException {
int size = src.remaining();
if (size > 0) {
openNewConnection(new HttpTunnelPayload(
this.requestSeq.incrementAndGet(), src));
}
return size;
}
private synchronized void openNewConnection(final HttpTunnelPayload payload) {
HttpTunnelConnection.this.executor.execute(new Runnable() {
@Override
public void run() {
try {
sendAndReceive(payload);
}
catch (IOException ex) {
logger.trace("Unexpected connection error", ex);
closeQuitely();
}
}
private void closeQuitely() {
try {
close();
}
catch (IOException ex) {
}
}
});
}
private void sendAndReceive(HttpTunnelPayload payload) throws IOException {
ClientHttpRequest request = createRequest(payload != null);
if (payload != null) {
payload.logIncoming();
payload.assignTo(request);
}
handleResponse(request.execute());
}
private void handleResponse(ClientHttpResponse response) throws IOException {
if (response.getStatusCode() == HttpStatus.GONE) {
close();
return;
}
if (response.getStatusCode() == HttpStatus.OK) {
HttpTunnelPayload payload = HttpTunnelPayload.get(response);
if (payload != null) {
this.forwarder.forward(payload);
}
}
if (response.getStatusCode() != HttpStatus.TOO_MANY_REQUESTS) {
openNewConnection(null);
}
}
}
/**
* {@link ThreadFactory} used to create the tunnel thread.
*/
private static class TunnelThreadFactory implements ThreadFactory {
@Override
public Thread newThread(Runnable runnable) {
Thread thread = new Thread(runnable, "HTTP Tunnel Connection");
thread.setDaemon(true);
return thread;
}
}
}

@ -0,0 +1,207 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.client;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.nio.ByteBuffer;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.channels.WritableByteChannel;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.SmartInitializingSingleton;
import org.springframework.util.Assert;
/**
* The client side component of a socket tunnel. Starts a {@link ServerSocket} of the
* specified port for local clients to connect to.
*
* @author Phillip Webb
* @since 1.3.0
*/
public class TunnelClient implements SmartInitializingSingleton {
private static final int BUFFER_SIZE = 1024 * 100;
private static final Log logger = LogFactory.getLog(TunnelClient.class);
private final int listenPort;
private final TunnelConnection tunnelConnection;
private TunnelClientListeners listeners = new TunnelClientListeners();
private ServerThread serverThread;
public TunnelClient(int listenPort, TunnelConnection tunnelConnection) {
Assert.isTrue(listenPort > 0, "ListenPort must be positive");
Assert.notNull(tunnelConnection, "TunnelConnection must not be null");
this.listenPort = listenPort;
this.tunnelConnection = tunnelConnection;
}
@Override
public void afterSingletonsInstantiated() {
if (this.serverThread == null) {
try {
start();
}
catch (IOException ex) {
throw new IllegalStateException(ex);
}
}
}
/**
* Start the client and accept incoming connections on the port.
* @throws IOException
*/
public synchronized void start() throws IOException {
Assert.state(this.serverThread == null, "Server already started");
ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
serverSocketChannel.socket().bind(new InetSocketAddress(this.listenPort));
logger.trace("Listening for TCP traffic to tunnel on port " + this.listenPort);
this.serverThread = new ServerThread(serverSocketChannel);
this.serverThread.start();
}
/**
* Stop the client, disconnecting any servers.
* @throws IOException
*/
public synchronized void stop() throws IOException {
if (this.serverThread != null) {
logger.trace("Closing tunnel client on port " + this.listenPort);
this.serverThread.close();
try {
this.serverThread.join(2000);
}
catch (InterruptedException ex) {
}
this.serverThread = null;
}
}
protected final ServerThread getServerThread() {
return this.serverThread;
}
public void addListener(TunnelClientListener listener) {
this.listeners.addListener(listener);
}
public void removeListener(TunnelClientListener listener) {
this.listeners.removeListener(listener);
}
/**
* The main server thread.
*/
protected class ServerThread extends Thread {
private final ServerSocketChannel serverSocketChannel;
private boolean acceptConnections = true;
public ServerThread(ServerSocketChannel serverSocketChannel) {
this.serverSocketChannel = serverSocketChannel;
setName("Tunnel Server");
setDaemon(true);
}
public void close() throws IOException {
this.serverSocketChannel.close();
this.acceptConnections = false;
interrupt();
}
@Override
public void run() {
try {
while (this.acceptConnections) {
SocketChannel socket = this.serverSocketChannel.accept();
try {
handleConnection(socket);
}
finally {
socket.close();
}
}
}
catch (Exception ex) {
logger.trace("Unexpected exception from tunnel client", ex);
}
}
private void handleConnection(SocketChannel socketChannel) throws Exception {
Closeable closeable = new SocketCloseable(socketChannel);
WritableByteChannel outputChannel = TunnelClient.this.tunnelConnection.open(
socketChannel, closeable);
TunnelClient.this.listeners.fireOpenEvent(socketChannel);
try {
logger.trace("Accepted connection to tunnel client from "
+ socketChannel.socket().getRemoteSocketAddress());
while (true) {
ByteBuffer buffer = ByteBuffer.allocate(BUFFER_SIZE);
int amountRead = socketChannel.read(buffer);
if (amountRead == -1) {
outputChannel.close();
return;
}
if (amountRead > 0) {
buffer.flip();
outputChannel.write(buffer);
}
}
}
finally {
outputChannel.close();
}
}
protected void stopAcceptingConnections() {
this.acceptConnections = false;
}
}
/**
* {@link Closeable} used to close a {@link SocketChannel} and fire an event.
*/
private class SocketCloseable implements Closeable {
private final SocketChannel socketChannel;
private boolean closed = false;
public SocketCloseable(SocketChannel socketChannel) {
this.socketChannel = socketChannel;
}
@Override
public void close() throws IOException {
if (!this.closed) {
this.socketChannel.close();
TunnelClient.this.listeners.fireCloseEvent(this.socketChannel);
this.closed = true;
}
}
}
}

@ -0,0 +1,41 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.client;
import java.nio.channels.SocketChannel;
/**
* Listener that can be used to receive {@link TunnelClient} events.
*
* @author Phillip Webb
* @since 1.3.0
*/
public interface TunnelClientListener {
/**
* Called when a socket channel is opened.
* @param socket the socket channel
*/
void onOpen(SocketChannel socket);
/**
* Called when a socket channel is closed.
* @param socket the socket channel
*/
void onClose(SocketChannel socket);
}

@ -0,0 +1,56 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.client;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.List;
import org.springframework.util.Assert;
/**
* A collection of {@link TunnelClientListener}.
*
* @author Phillip Webb
*/
class TunnelClientListeners {
private final List<TunnelClientListener> listeners = new ArrayList<TunnelClientListener>();
public void addListener(TunnelClientListener listener) {
Assert.notNull(listener, "Listener must not be null");
this.listeners.add(listener);
}
public void removeListener(TunnelClientListener listener) {
Assert.notNull(listener, "Listener must not be null");
this.listeners.remove(listener);
}
public void fireOpenEvent(SocketChannel socket) {
for (TunnelClientListener listener : this.listeners) {
listener.onOpen(socket);
}
}
public void fireCloseEvent(SocketChannel socket) {
for (TunnelClientListener listener : this.listeners) {
listener.onClose(socket);
}
}
}

@ -0,0 +1,42 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.client;
import java.io.Closeable;
import java.nio.channels.WritableByteChannel;
/**
* Interface used to manage socket tunnel connections.
*
* @author Phillip Webb
* @since 1.3.0
*/
public interface TunnelConnection {
/**
* Open the tunnel connection.
* @param incomingChannel A {@link WritableByteChannel} that should be used to write
* any incoming data received from the remote server.
* @param closeable
* @return A {@link WritableByteChannel} that should be used to send any outgoing data
* destined for the remote server
* @throws Exception
*/
WritableByteChannel open(WritableByteChannel incomingChannel, Closeable closeable)
throws Exception;
}

@ -0,0 +1,21 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Client side TCP tunnel support.
*/
package org.springframework.boot.developertools.tunnel.client;

@ -0,0 +1,23 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Provides support for tunneling TCP traffic over HTTP. Tunneling is primarily designed
* for the Java Debug Wire Protocol (JDWP) and as such only expects a single connection
* and isn't particularly worried about resource usage.
*/
package org.springframework.boot.developertools.tunnel;

@ -0,0 +1,185 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.payload;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpInputMessage;
import org.springframework.http.HttpOutputMessage;
import org.springframework.http.MediaType;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
* Encapsulates a payload data sent via a HTTP tunnel.
*
* @author Phillip Webb
* @since 1.3.0
*/
public class HttpTunnelPayload {
private static final String SEQ_HEADER = "x-seq";
private static final int BUFFER_SIZE = 1024 * 100;
final protected static char[] HEX_CHARS = "0123456789ABCDEF".toCharArray();
private static final Log logger = LogFactory.getLog(HttpTunnelPayload.class);
private final long sequence;
private final ByteBuffer data;
/**
* Create a new {@link HttpTunnelPayload} instance.
* @param sequence the sequence number of the payload
* @param data the payload data
*/
public HttpTunnelPayload(long sequence, ByteBuffer data) {
Assert.isTrue(sequence > 0, "Sequence must be positive");
Assert.notNull(data, "Data must not be null");
this.sequence = sequence;
this.data = data;
}
/**
* Return the sequence number of the payload.
* @return the sequence
*/
public long getSequence() {
return this.sequence;
}
/**
* Assign this payload to the given {@link HttpOutputMessage}.
* @param message the message to assign this payload to
* @throws IOException
*/
public void assignTo(HttpOutputMessage message) throws IOException {
Assert.notNull(message, "Message must not be null");
HttpHeaders headers = message.getHeaders();
headers.setContentLength(this.data.remaining());
headers.add(SEQ_HEADER, Long.toString(getSequence()));
headers.setContentType(MediaType.APPLICATION_OCTET_STREAM);
WritableByteChannel body = Channels.newChannel(message.getBody());
while (this.data.hasRemaining()) {
body.write(this.data);
}
body.close();
}
/**
* Write the content of this payload to the given target channel.
* @param channel the channel to write to
* @throws IOException
*/
public void writeTo(WritableByteChannel channel) throws IOException {
Assert.notNull(channel, "Channel must not be null");
while (this.data.hasRemaining()) {
channel.write(this.data);
}
}
/**
* Return the {@link HttpTunnelPayload} for the given message or {@code null} if there
* is no payload.
* @param message the HTTP message
* @return the payload or {@code null}
* @throws IOException
*/
public static HttpTunnelPayload get(HttpInputMessage message) throws IOException {
long length = message.getHeaders().getContentLength();
if (length <= 0) {
return null;
}
String seqHeader = message.getHeaders().getFirst(SEQ_HEADER);
Assert.state(StringUtils.hasLength(seqHeader), "Missing sequence header");
ReadableByteChannel body = Channels.newChannel(message.getBody());
ByteBuffer payload = ByteBuffer.allocate((int) length);
while (payload.hasRemaining()) {
body.read(payload);
}
body.close();
payload.flip();
return new HttpTunnelPayload(Long.valueOf(seqHeader), payload);
}
/**
* Return the payload data for the given source {@link ReadableByteChannel} or null if
* the channel timed out whilst reading.
* @param channel the source channel
* @return payload data or {@code null}
* @throws IOException
*/
public static ByteBuffer getPayloadData(ReadableByteChannel channel)
throws IOException {
ByteBuffer buffer = ByteBuffer.allocate(BUFFER_SIZE);
try {
int amountRead = channel.read(buffer);
Assert.state(amountRead != -1, "Target server connection closed");
buffer.flip();
return buffer;
}
catch (InterruptedIOException ex) {
return null;
}
}
/**
* Log incoming payload information at trace level to aid diagnostics.
*/
public void logIncoming() {
log("< ");
}
/**
* Log incoming payload information at trace level to aid diagnostics.
*/
public void logOutgoing() {
log("> ");
}
private void log(String prefix) {
if (logger.isTraceEnabled()) {
logger.trace(prefix + toHexString());
}
}
/**
* Return the payload as a hexadecimal string.
* @return the payload as a hex string
*/
public String toHexString() {
byte[] bytes = this.data.array();
char[] hex = new char[this.data.remaining() * 2];
for (int i = this.data.position(); i < this.data.remaining(); i++) {
int b = bytes[i] & 0xFF;
hex[i * 2] = HEX_CHARS[b >>> 4];
hex[i * 2 + 1] = HEX_CHARS[b & 0x0F];
}
return new String(hex);
}
}

@ -0,0 +1,69 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.payload;
import java.io.IOException;
import java.nio.channels.WritableByteChannel;
import java.util.HashMap;
import java.util.Map;
import org.springframework.util.Assert;
/**
* Utility class that forwards {@link HttpTunnelPayload} instances to a destination
* channel, respecting sequence order.
*
* @author Phillip Webb
* @since 1.3.0
*/
public class HttpTunnelPayloadForwarder {
private static final int MAXIMUM_QUEUE_SIZE = 100;
private final WritableByteChannel targetChannel;
private long lastRequestSeq = 0;
private final Map<Long, HttpTunnelPayload> queue = new HashMap<Long, HttpTunnelPayload>();
/**
* Create a new {@link HttpTunnelPayloadForwarder} instance.
* @param targetChannel the target channel
*/
public HttpTunnelPayloadForwarder(WritableByteChannel targetChannel) {
Assert.notNull(targetChannel, "TargetChannel must not be null");
this.targetChannel = targetChannel;
}
public synchronized void forward(HttpTunnelPayload payload) throws IOException {
long seq = payload.getSequence();
if (this.lastRequestSeq != seq - 1) {
Assert.state(this.queue.size() < MAXIMUM_QUEUE_SIZE,
"Too many messages queued");
this.queue.put(seq, payload);
return;
}
payload.logOutgoing();
payload.writeTo(this.targetChannel);
this.lastRequestSeq = seq;
HttpTunnelPayload queuedItem = this.queue.get(seq + 1);
if (queuedItem != null) {
forward(queuedItem);
}
}
}

@ -0,0 +1,21 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Classes to deal with payloads sent over a HTTP tunnel.
*/
package org.springframework.boot.developertools.tunnel.payload;

@ -0,0 +1,486 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import java.io.IOException;
import java.net.ConnectException;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.Iterator;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.boot.developertools.tunnel.payload.HttpTunnelPayload;
import org.springframework.boot.developertools.tunnel.payload.HttpTunnelPayloadForwarder;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.ServerHttpAsyncRequestControl;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.Assert;
/**
* A server that can be used to tunnel TCP traffic over HTTP. Similar in design to the <a
* href="http://xmpp.org/extensions/xep-0124.html">Bidirectional-streams Over Synchronous
* HTTP (BOSH)</a> XMPP extension protocol, the server uses long polling with HTTP
* requests held open until a response is available. A typical traffic pattern would be as
* follows:
*
* <pre>
* [ CLIENT ] [ SERVER ]
* | (a) Initial empty request |
* |------------------------------}|
* | (b) Data I |
* --}|------------------------------}|---}
* | Response I (a) |
* {--|<------------------------------|{---
* | |
* | (c) Data II |
* --}|------------------------------}|---}
* | Response II (b) |
* {--|{------------------------------|{---
* . .
* . .
* </pre>
*
* Each incoming request is held open to be used to carry the next available response. The
* server will hold at most two connections open at any given time.
* <p>
* Requests should be made using HTTP GET or POST (depending if there is a payload), with
* any payload contained in the body. The following response codes can be returned from
* the server:
* <p>
* <table>
* <tr>
* <th>Status</th>
* <th>Meaning</th>
* </tr>
* <tr>
* <td>200 (OK)</td>
* <td>Data payload response.</td>
* </tr>
* <tr>
* <td>204 (No Content)</td>
* <td>The long poll has timed out and the client should start a new request.</td>
* </tr>
* <tr>
* <td>429 (Too many requests)</td>
* <td>There are already enough connections open, this one can be dropped.</td>
* </tr>
* <tr>
* <td>410 (Gone)</td>
* <td>The target server has disconnected.</td>
* </tr>
* </table>
* <p>
* Requests and responses that contain payloads include a {@code x-seq} header that
* contains a running sequence number (used to ensure data is applied in the correct
* order). The first request containing a payload should have a {@code x-seq} value of
* {@code 1}.
*
* @author Phillip Webb
* @since 1.3.0
* @see org.springframework.boot.developertools.tunnel.client.HttpTunnelConnection
*/
public class HttpTunnelServer {
private static final int SECONDS = 1000;
private static final int DEFAULT_LONG_POLL_TIMEOUT = 10 * SECONDS;
private static final long DEFAULT_DISCONNECT_TIMEOUT = 30 * SECONDS;
private static final MediaType DISCONNECT_MEDIA_TYPE = new MediaType("application",
"x-disconnect");
private static final Log logger = LogFactory.getLog(HttpTunnelServer.class);
private final TargetServerConnection serverConnection;
private int longPollTimeout = DEFAULT_LONG_POLL_TIMEOUT;
private long disconnectTimeout = DEFAULT_DISCONNECT_TIMEOUT;
private volatile ServerThread serverThread;
/**
* Creates a new {@link HttpTunnelServer} instance.
* @param serverConnection the connection to the target server
*/
public HttpTunnelServer(TargetServerConnection serverConnection) {
Assert.notNull(serverConnection, "ServerConnection must not be null");
this.serverConnection = serverConnection;
}
/**
* Handle an incoming HTTP connection.
* @param request the HTTP request
* @param response the HTTP response
* @throws IOException
*/
public void handle(ServerHttpRequest request, ServerHttpResponse response)
throws IOException {
handle(new HttpConnection(request, response));
}
/**
* Handle an incoming HTTP connection.
* @param httpConnection the HTTP connection
* @throws IOException
*/
protected void handle(HttpConnection httpConnection) throws IOException {
try {
getServerThread().handleIncomingHttp(httpConnection);
httpConnection.waitForResponse();
}
catch (ConnectException ex) {
httpConnection.respond(HttpStatus.GONE);
}
}
/**
* Returns the active server thread, creating and starting it if necessary.
* @return the {@code ServerThread} (never {@code null})
* @throws IOException
*/
protected ServerThread getServerThread() throws IOException {
synchronized (this) {
if (this.serverThread == null) {
ByteChannel channel = this.serverConnection.open(this.longPollTimeout);
this.serverThread = new ServerThread(channel);
this.serverThread.start();
}
return this.serverThread;
}
}
/**
* Called when the server thread exits.
*/
void clearServerThread() {
synchronized (this) {
this.serverThread = null;
}
}
/**
* Set the long poll timeout for the server.
* @param longPollTimeout the long poll timeout in milliseconds
*/
public void setLongPollTimeout(int longPollTimeout) {
Assert.isTrue(longPollTimeout > 0, "LongPollTimeout must be a positive value");
this.longPollTimeout = longPollTimeout;
}
/**
* Set the maximum amount of time to wait for a client before closing the connection.
* @param disconnectTimeout the disconnect timeout in milliseconds
*/
public void setDisconnectTimeout(long disconnectTimeout) {
Assert.isTrue(disconnectTimeout > 0, "DisconnectTimeout must be a positive value");
this.disconnectTimeout = disconnectTimeout;
}
/**
* The main server thread used to transfer tunnel traffic.
*/
protected class ServerThread extends Thread {
private final ByteChannel targetServer;
private final Deque<HttpConnection> httpConnections;
private final HttpTunnelPayloadForwarder payloadForwarder;
private boolean closed;
private AtomicLong responseSeq = new AtomicLong();
private long lastHttpRequestTime;
public ServerThread(ByteChannel targetServer) {
Assert.notNull(targetServer, "TargetServer must not be null");
this.targetServer = targetServer;
this.httpConnections = new ArrayDeque<HttpConnection>(2);
this.payloadForwarder = new HttpTunnelPayloadForwarder(targetServer);
}
@Override
public void run() {
try {
try {
readAndForwardTargetServerData();
}
catch (Exception ex) {
logger.trace("Unexpected exception from tunnel server", ex);
}
}
finally {
this.closed = true;
closeHttpConnections();
closeTargetServer();
HttpTunnelServer.this.clearServerThread();
}
}
private void readAndForwardTargetServerData() throws IOException {
while (this.targetServer.isOpen()) {
closeStaleHttpConnections();
ByteBuffer data = HttpTunnelPayload.getPayloadData(this.targetServer);
synchronized (this.httpConnections) {
if (data != null) {
HttpTunnelPayload payload = new HttpTunnelPayload(
this.responseSeq.incrementAndGet(), data);
payload.logIncoming();
HttpConnection connection = getOrWaitForHttpConnection();
connection.respond(payload);
}
}
}
}
private HttpConnection getOrWaitForHttpConnection() {
synchronized (this.httpConnections) {
HttpConnection httpConnection = this.httpConnections.pollFirst();
while (httpConnection == null) {
try {
this.httpConnections.wait(HttpTunnelServer.this.longPollTimeout);
}
catch (InterruptedException ex) {
closeHttpConnections();
}
httpConnection = this.httpConnections.pollFirst();
}
return httpConnection;
}
}
private void closeStaleHttpConnections() throws IOException {
checkNotDisconnected();
synchronized (this.httpConnections) {
Iterator<HttpConnection> iterator = this.httpConnections.iterator();
while (iterator.hasNext()) {
HttpConnection httpConnection = iterator.next();
if (httpConnection.isOlderThan(HttpTunnelServer.this.longPollTimeout)) {
httpConnection.respond(HttpStatus.NO_CONTENT);
iterator.remove();
}
}
}
}
private void checkNotDisconnected() {
long timeout = HttpTunnelServer.this.disconnectTimeout;
long duration = System.currentTimeMillis() - this.lastHttpRequestTime;
Assert.state(duration < timeout, "Disconnect timeout");
}
private void closeHttpConnections() {
synchronized (this.httpConnections) {
while (!this.httpConnections.isEmpty()) {
try {
this.httpConnections.removeFirst().respond(HttpStatus.GONE);
}
catch (Exception ex) {
logger.trace("Unable to close remote HTTP connection");
}
}
}
}
private void closeTargetServer() {
try {
this.targetServer.close();
}
catch (IOException ex) {
logger.trace("Unable to target server connection");
}
}
/**
* Handle an incoming {@link HttpConnection}.
* @param httpConnection the connection to handle.
* @throws IOException
*/
public void handleIncomingHttp(HttpConnection httpConnection) throws IOException {
if (this.closed) {
httpConnection.respond(HttpStatus.GONE);
}
synchronized (this.httpConnections) {
while (this.httpConnections.size() > 1) {
this.httpConnections.removeFirst().respond(
HttpStatus.TOO_MANY_REQUESTS);
}
this.lastHttpRequestTime = System.currentTimeMillis();
this.httpConnections.addLast(httpConnection);
this.httpConnections.notify();
}
forwardToTargetServer(httpConnection);
}
private void forwardToTargetServer(HttpConnection httpConnection)
throws IOException {
if (httpConnection.isDisconnectRequest()) {
this.targetServer.close();
interrupt();
}
ServerHttpRequest request = httpConnection.getRequest();
HttpTunnelPayload payload = HttpTunnelPayload.get(request);
if (payload != null) {
this.payloadForwarder.forward(payload);
}
}
}
/**
* Encapsulates a HTTP request/response pair.
*/
protected static class HttpConnection {
private final long createTime;
private final ServerHttpRequest request;
private final ServerHttpResponse response;
private ServerHttpAsyncRequestControl async;
private volatile boolean complete = false;
public HttpConnection(ServerHttpRequest request, ServerHttpResponse response) {
this.createTime = System.currentTimeMillis();
this.request = request;
this.response = response;
this.async = startAsync();
}
/**
* Start asynchronous support or if unavailble return {@code null} to cause
* {@link #waitForResponse()} to block.
* @return the async request control
*/
protected ServerHttpAsyncRequestControl startAsync() {
try {
// Try to use async to save blocking
ServerHttpAsyncRequestControl async = this.request
.getAsyncRequestControl(this.response);
async.start();
return async;
}
catch (Exception ex) {
return null;
}
}
/**
* Return the underlying request.
* @return the request
*/
public final ServerHttpRequest getRequest() {
return this.request;
}
/**
* Return the underlying response.
* @return the response
*/
protected final ServerHttpResponse getResponse() {
return this.response;
}
/**
* Determine if a connection is older than the specified time.
* @param time the time to check
* @return {@code true} if the request is older than the time
*/
public boolean isOlderThan(int time) {
long runningTime = System.currentTimeMillis() - this.createTime;
return (runningTime > time);
}
/**
* Cause the request to block or use asynchronous methods to wait until a response
* is available.
*/
public void waitForResponse() {
if (this.async == null) {
while (!this.complete) {
try {
synchronized (this) {
wait(1000);
}
}
catch (InterruptedException ex) {
}
}
}
}
/**
* Detect if the request is actually a signal to disconnect.
* @return if the request is a signal to disconnect
*/
public boolean isDisconnectRequest() {
return DISCONNECT_MEDIA_TYPE.equals(this.request.getHeaders()
.getContentType());
}
/**
* Send a HTTP status response.
* @param status the status to send
* @throws IOException
*/
public void respond(HttpStatus status) throws IOException {
Assert.notNull(status, "Status must not be null");
this.response.setStatusCode(status);
complete();
}
/**
* Send a payload response.
* @param payload the payload to send
* @throws IOException
*/
public void respond(HttpTunnelPayload payload) throws IOException {
Assert.notNull(payload, "Payload must not be null");
this.response.setStatusCode(HttpStatus.OK);
payload.assignTo(this.response);
complete();
}
/**
* Called when a request is complete.
*/
protected void complete() {
if (this.async != null) {
this.async.complete();
}
else {
synchronized (this) {
this.complete = true;
notifyAll();
}
}
}
}
}

@ -0,0 +1,51 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import java.io.IOException;
import org.springframework.boot.developertools.remote.server.Handler;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.Assert;
/**
* Adapts a {@link HttpTunnelServer} to a {@link Handler}.
*
* @author Phillip Webb
* @since 1.3.0
*/
public class HttpTunnelServerHandler implements Handler {
private HttpTunnelServer server;
/**
* Create a new {@link HttpTunnelServerHandler} instance.
* @param server the server to adapt
*/
public HttpTunnelServerHandler(HttpTunnelServer server) {
Assert.notNull(server, "Server must not be null");
this.server = server;
}
@Override
public void handle(ServerHttpRequest request, ServerHttpResponse response)
throws IOException {
this.server.handle(request, response);
}
}

@ -0,0 +1,34 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
/**
* Strategy interface to provide access to a port (which may change if an existing
* connection is closed).
*
* @author Phillip Webb
* @since 1.3.0
*/
public interface PortProvider {
/**
* Return the port number
* @return the port number
*/
int getPort();
}

@ -0,0 +1,61 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.boot.lang.UsesUnsafeJava;
import org.springframework.util.Assert;
/**
* {@link PortProvider} that provides the port being used by the Java remote debugging.
*
* @author Phillip Webb
*/
public class RemoteDebugPortProvider implements PortProvider {
private static final String JDWP_ADDRESS_PROPERTY = "sun.jdwp.listenerAddress";
private static final Log logger = LogFactory.getLog(RemoteDebugPortProvider.class);
@Override
public int getPort() {
Assert.state(isRemoteDebugRunning(), "Remote debug is not running");
return getRemoteDebugPort();
}
public static boolean isRemoteDebugRunning() {
return getRemoteDebugPort() != -1;
}
@UsesUnsafeJava
@SuppressWarnings("restriction")
private static int getRemoteDebugPort() {
String property = sun.misc.VMSupport.getAgentProperties().getProperty(
JDWP_ADDRESS_PROPERTY);
try {
if (property != null && property.contains(":")) {
return Integer.valueOf(property.split(":")[1]);
}
}
catch (Exception ex) {
logger.trace("Unable to get JDWP port from property value '" + property + "'");
}
return -1;
}
}

@ -0,0 +1,101 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SocketChannel;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.util.Assert;
/**
* Socket based {@link TargetServerConnection}.
*
* @author Phillip Webb
* @since 1.3.0
*/
public class SocketTargetServerConnection implements TargetServerConnection {
private static final Log logger = LogFactory
.getLog(SocketTargetServerConnection.class);
private final PortProvider portProvider;
/**
* Create a new {@link SocketTargetServerConnection}.
* @param portProvider the port provider
*/
public SocketTargetServerConnection(PortProvider portProvider) {
Assert.notNull(portProvider, "PortProvider must not be null");
this.portProvider = portProvider;
}
@Override
public ByteChannel open(int socketTimeout) throws IOException {
SocketAddress address = new InetSocketAddress(this.portProvider.getPort());
logger.trace("Opening tunnel connection to target server on " + address);
SocketChannel channel = SocketChannel.open(address);
channel.socket().setSoTimeout(socketTimeout);
return new TimeoutAwareChannel(channel);
}
/**
* Wrapper to expose the {@link SocketChannel} in such a way that
* {@code SocketTimeoutExceptions} are still thrown from read methods.
*/
private static class TimeoutAwareChannel implements ByteChannel {
private final SocketChannel socketChannel;
private final ReadableByteChannel readChannel;
public TimeoutAwareChannel(SocketChannel socketChannel) throws IOException {
this.socketChannel = socketChannel;
this.readChannel = Channels.newChannel(socketChannel.socket()
.getInputStream());
}
@Override
public int read(ByteBuffer dst) throws IOException {
return this.readChannel.read(dst);
}
@Override
public int write(ByteBuffer src) throws IOException {
return this.socketChannel.write(src);
}
@Override
public boolean isOpen() {
return this.socketChannel.isOpen();
}
@Override
public void close() throws IOException {
this.socketChannel.close();
}
}
}

@ -0,0 +1,41 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import org.springframework.util.Assert;
/**
* {@link PortProvider} for a static port that won't change.
*
* @author Phillip Webb
* @since 1.3.0
*/
public class StaticPortProvider implements PortProvider {
private final int port;
public StaticPortProvider(int port) {
Assert.isTrue(port > 0, "Port must be positive");
this.port = port;
}
@Override
public int getPort() {
return this.port;
}
}

@ -0,0 +1,38 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import java.io.IOException;
import java.nio.channels.ByteChannel;
/**
* Manages the connection to the ultimate tunnel target server.
*
* @author Phillip Webb
* @since 1.3.0
*/
public interface TargetServerConnection {
/**
* Open a connection to the target server with the specified timeout.
* @param timeout the read timeout
* @return a {@link ByteChannel} providing read/write access to the server
* @throws IOException
*/
ByteChannel open(int timeout) throws IOException;
}

@ -0,0 +1,21 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Server side TCP tunnel support.
*/
package org.springframework.boot.developertools.tunnel.server;

@ -0,0 +1,166 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.client;
import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import java.util.concurrent.Executor;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.boot.developertools.test.MockClientHttpRequestFactory;
import org.springframework.boot.developertools.tunnel.client.HttpTunnelConnection.TunnelChannel;
import org.springframework.http.HttpStatus;
import org.springframework.util.SocketUtils;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link HttpTunnelConnection}.
*
* @author Phillip Webb
* @author Rob Winch
*/
public class HttpTunnelConnectionTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
private int port = SocketUtils.findAvailableTcpPort();
private String url;
private ByteArrayOutputStream incommingData;
private WritableByteChannel incomingChannel;
@Mock
private Closeable closeable;
private MockClientHttpRequestFactory requestFactory = new MockClientHttpRequestFactory();
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
this.url = "http://localhost:" + this.port;
this.incommingData = new ByteArrayOutputStream();
this.incomingChannel = Channels.newChannel(this.incommingData);
}
@Test
public void urlMustNotBeNull() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("URL must not be empty");
new HttpTunnelConnection(null, this.requestFactory);
}
@Test
public void urlMustNotBeEmpty() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("URL must not be empty");
new HttpTunnelConnection("", this.requestFactory);
}
@Test
public void urlMustNotBeMalformed() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Malformed URL 'htttttp:///ttest'");
new HttpTunnelConnection("htttttp:///ttest", this.requestFactory);
}
@Test
public void requestFactoryMustNotBeNull() {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("RequestFactory must not be null");
new HttpTunnelConnection(this.url, null);
}
@Test
public void closeTunnelChangesIsOpen() throws Exception {
this.requestFactory.willRespondAfterDelay(1000, HttpStatus.GONE);
WritableByteChannel channel = openTunnel(false);
assertThat(channel.isOpen(), equalTo(true));
channel.close();
assertThat(channel.isOpen(), equalTo(false));
}
@Test
public void closeTunnelCallsCloseableOnce() throws Exception {
this.requestFactory.willRespondAfterDelay(1000, HttpStatus.GONE);
WritableByteChannel channel = openTunnel(false);
verify(this.closeable, never()).close();
channel.close();
channel.close();
verify(this.closeable, times(1)).close();
}
@Test
public void typicalTraffic() throws Exception {
this.requestFactory.willRespond("hi", "=2", "=3");
TunnelChannel channel = openTunnel(true);
write(channel, "hello");
write(channel, "1+1");
write(channel, "1+2");
assertThat(this.incommingData.toString(), equalTo("hi=2=3"));
}
@Test
public void trafficWithLongPollTimeouts() throws Exception {
for (int i = 0; i < 10; i++) {
this.requestFactory.willRespond(HttpStatus.NO_CONTENT);
}
this.requestFactory.willRespond("hi");
TunnelChannel channel = openTunnel(true);
write(channel, "hello");
assertThat(this.incommingData.toString(), equalTo("hi"));
assertThat(this.requestFactory.getExecutedRequests().size(), greaterThan(10));
}
private void write(TunnelChannel channel, String string) throws IOException {
channel.write(ByteBuffer.wrap(string.getBytes()));
}
private TunnelChannel openTunnel(boolean singleThreaded) throws Exception {
HttpTunnelConnection connection = new HttpTunnelConnection(this.url,
this.requestFactory,
(singleThreaded ? new CurrentThreadExecutor() : null));
return connection.open(this.incomingChannel, this.closeable);
}
private static class CurrentThreadExecutor implements Executor {
@Override
public void execute(Runnable command) {
command.run();
}
}
}

@ -0,0 +1,199 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.client;
import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.SocketChannel;
import java.nio.channels.WritableByteChannel;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.util.SocketUtils;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link TunnelClient}.
*
* @author Phillip Webb
*/
public class TunnelClientTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
private int listenPort = SocketUtils.findAvailableTcpPort();
private MockTunnelConnection tunnelConnection = new MockTunnelConnection();
@Test
public void listenPortMustBePositive() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("ListenPort must be positive");
new TunnelClient(0, this.tunnelConnection);
}
@Test
public void tunnelConnectionMustNotBeNull() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("TunnelConnection must not be null");
new TunnelClient(1, null);
}
@Test
public void typicalTraffic() throws Exception {
TunnelClient client = new TunnelClient(this.listenPort, this.tunnelConnection);
client.start();
SocketChannel channel = SocketChannel
.open(new InetSocketAddress(this.listenPort));
channel.write(ByteBuffer.wrap("hello".getBytes()));
ByteBuffer buffer = ByteBuffer.allocate(5);
channel.read(buffer);
channel.close();
this.tunnelConnection.verifyWritten("hello");
assertThat(new String(buffer.array()), equalTo("olleh"));
}
@Test
public void socketChannelClosedTriggersTunnelClose() throws Exception {
TunnelClient client = new TunnelClient(this.listenPort, this.tunnelConnection);
client.start();
SocketChannel channel = SocketChannel
.open(new InetSocketAddress(this.listenPort));
channel.close();
client.getServerThread().stopAcceptingConnections();
client.getServerThread().join(2000);
assertThat(this.tunnelConnection.getOpenedTimes(), equalTo(1));
assertThat(this.tunnelConnection.isOpen(), equalTo(false));
}
@Test
public void stopTriggersTunnelClose() throws Exception {
TunnelClient client = new TunnelClient(this.listenPort, this.tunnelConnection);
client.start();
SocketChannel channel = SocketChannel
.open(new InetSocketAddress(this.listenPort));
client.stop();
assertThat(this.tunnelConnection.getOpenedTimes(), equalTo(1));
assertThat(this.tunnelConnection.isOpen(), equalTo(false));
assertThat(channel.read(ByteBuffer.allocate(1)), equalTo(-1));
}
@Test
public void addListener() throws Exception {
TunnelClient client = new TunnelClient(this.listenPort, this.tunnelConnection);
TunnelClientListener listener = mock(TunnelClientListener.class);
client.addListener(listener);
client.start();
SocketChannel channel = SocketChannel
.open(new InetSocketAddress(this.listenPort));
channel.close();
client.getServerThread().stopAcceptingConnections();
client.getServerThread().join(2000);
verify(listener).onOpen(any(SocketChannel.class));
verify(listener).onClose(any(SocketChannel.class));
}
private static class MockTunnelConnection implements TunnelConnection {
private final ByteArrayOutputStream written = new ByteArrayOutputStream();
private boolean open;
private int openedTimes;
@Override
public WritableByteChannel open(WritableByteChannel incomingChannel,
Closeable closeable) throws Exception {
this.openedTimes++;
this.open = true;
return new TunnelChannel(incomingChannel, closeable);
}
public void verifyWritten(String expected) {
verifyWritten(expected.getBytes());
}
public void verifyWritten(byte[] expected) {
synchronized (this.written) {
assertThat(this.written.toByteArray(), equalTo(expected));
this.written.reset();
}
}
public boolean isOpen() {
return this.open;
}
public int getOpenedTimes() {
return this.openedTimes;
}
private class TunnelChannel implements WritableByteChannel {
private final WritableByteChannel incomingChannel;
private final Closeable closeable;
public TunnelChannel(WritableByteChannel incomingChannel, Closeable closeable) {
this.incomingChannel = incomingChannel;
this.closeable = closeable;
}
@Override
public boolean isOpen() {
return MockTunnelConnection.this.open;
}
@Override
public void close() throws IOException {
MockTunnelConnection.this.open = false;
this.closeable.close();
}
@Override
public int write(ByteBuffer src) throws IOException {
int remaining = src.remaining();
ByteArrayOutputStream stream = new ByteArrayOutputStream();
Channels.newChannel(stream).write(src);
byte[] bytes = stream.toByteArray();
synchronized (MockTunnelConnection.this.written) {
MockTunnelConnection.this.written.write(bytes);
}
byte[] reversed = new byte[bytes.length];
for (int i = 0; i < reversed.length; i++) {
reversed[i] = bytes[bytes.length - 1 - i];
}
this.incomingChannel.write(ByteBuffer.wrap(reversed));
return remaining;
}
}
}
}

@ -0,0 +1,85 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.payload;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertThat;
/**
* Tests for {@link HttpTunnelPayloadForwarder}.
*
* @author Phillip Webb
*/
public class HttpTunnelPayloadForwarderTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
@Test
public void targetChannelMustNoBeNull() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("TargetChannel must not be null");
new HttpTunnelPayloadForwarder(null);
}
@Test
public void forwardInSequence() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
WritableByteChannel channel = Channels.newChannel(out);
HttpTunnelPayloadForwarder forwarder = new HttpTunnelPayloadForwarder(channel);
forwarder.forward(payload(1, "he"));
forwarder.forward(payload(2, "ll"));
forwarder.forward(payload(3, "o"));
assertThat(out.toByteArray(), equalTo("hello".getBytes()));
}
@Test
public void forwardOutOfSequence() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
WritableByteChannel channel = Channels.newChannel(out);
HttpTunnelPayloadForwarder forwarder = new HttpTunnelPayloadForwarder(channel);
forwarder.forward(payload(3, "o"));
forwarder.forward(payload(2, "ll"));
forwarder.forward(payload(1, "he"));
assertThat(out.toByteArray(), equalTo("hello".getBytes()));
}
@Test
public void overflow() throws Exception {
WritableByteChannel channel = Channels.newChannel(new ByteArrayOutputStream());
HttpTunnelPayloadForwarder forwarder = new HttpTunnelPayloadForwarder(channel);
this.thrown.expect(IllegalStateException.class);
this.thrown.expectMessage("Too many messages queued");
for (int i = 2; i < 130; i++) {
forwarder.forward(payload(i, "data" + i));
}
}
private HttpTunnelPayload payload(long sequence, String data) {
return new HttpTunnelPayload(sequence, ByteBuffer.wrap(data.getBytes()));
}
}

@ -0,0 +1,151 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.payload;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.http.HttpInputMessage;
import org.springframework.http.HttpOutputMessage;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link HttpTunnelPayload}.
*
* @author Phillip Webb
*/
public class HttpTunnelPayloadTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
@Test
public void sequenceMustBePositive() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Sequence must be positive");
new HttpTunnelPayload(0, ByteBuffer.allocate(1));
}
@Test
public void dataMustNotBeNull() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Data must not be null");
new HttpTunnelPayload(1, null);
}
@Test
public void getSequence() throws Exception {
HttpTunnelPayload payload = new HttpTunnelPayload(1, ByteBuffer.allocate(1));
assertThat(payload.getSequence(), equalTo(1L));
}
@Test
public void getData() throws Exception {
ByteBuffer data = ByteBuffer.wrap("hello".getBytes());
HttpTunnelPayload payload = new HttpTunnelPayload(1, data);
assertThat(getData(payload), equalTo(data.array()));
}
@Test
public void assignTo() throws Exception {
ByteBuffer data = ByteBuffer.wrap("hello".getBytes());
HttpTunnelPayload payload = new HttpTunnelPayload(2, data);
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
HttpOutputMessage response = new ServletServerHttpResponse(servletResponse);
payload.assignTo(response);
assertThat(servletResponse.getHeader("x-seq"), equalTo("2"));
assertThat(servletResponse.getContentAsString(), equalTo("hello"));
}
@Test
public void getNoData() throws Exception {
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
HttpInputMessage request = new ServletServerHttpRequest(servletRequest);
HttpTunnelPayload payload = HttpTunnelPayload.get(request);
assertThat(payload, nullValue());
}
@Test
public void getWithMissingHeader() throws Exception {
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
servletRequest.setContent("hello".getBytes());
HttpInputMessage request = new ServletServerHttpRequest(servletRequest);
this.thrown.expect(IllegalStateException.class);
this.thrown.expectMessage("Missing sequence header");
HttpTunnelPayload.get(request);
}
@Test
public void getWithData() throws Exception {
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
servletRequest.setContent("hello".getBytes());
servletRequest.addHeader("x-seq", 123);
HttpInputMessage request = new ServletServerHttpRequest(servletRequest);
HttpTunnelPayload payload = HttpTunnelPayload.get(request);
assertThat(payload.getSequence(), equalTo(123L));
assertThat(getData(payload), equalTo("hello".getBytes()));
}
@Test
public void getPayloadData() throws Exception {
ReadableByteChannel channel = Channels.newChannel(new ByteArrayInputStream(
"hello".getBytes()));
ByteBuffer payloadData = HttpTunnelPayload.getPayloadData(channel);
ByteArrayOutputStream out = new ByteArrayOutputStream();
WritableByteChannel writeChannel = Channels.newChannel(out);
while (payloadData.hasRemaining()) {
writeChannel.write(payloadData);
}
assertThat(out.toByteArray(), equalTo("hello".getBytes()));
}
@Test
public void getPayloadDataWithTimeout() throws Exception {
ReadableByteChannel channel = mock(ReadableByteChannel.class);
given(channel.read(any(ByteBuffer.class)))
.willThrow(new SocketTimeoutException());
ByteBuffer payload = HttpTunnelPayload.getPayloadData(channel);
assertThat(payload, nullValue());
}
private byte[] getData(HttpTunnelPayload payload) throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
WritableByteChannel channel = Channels.newChannel(out);
payload.writeTo(channel);
return out.toByteArray();
}
}

@ -0,0 +1,55 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link HttpTunnelServerHandler}.
*
* @author Phillip Webb
*/
public class HttpTunnelServerHandlerTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
@Test
public void serverMustNotBeNull() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Server must not be null");
new HttpTunnelServerHandler(null);
}
@Test
public void handleDelegatesToServer() throws Exception {
HttpTunnelServer server = mock(HttpTunnelServer.class);
HttpTunnelServerHandler handler = new HttpTunnelServerHandler(server);
ServerHttpRequest request = mock(ServerHttpRequest.class);
ServerHttpResponse response = mock(ServerHttpResponse.class);
handler.handle(request, response);
verify(server).handle(request, response);
}
}

@ -0,0 +1,480 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.Channels;
import java.util.concurrent.BlockingDeque;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicBoolean;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.springframework.boot.developertools.tunnel.payload.HttpTunnelPayload;
import org.springframework.boot.developertools.tunnel.server.HttpTunnelServer.HttpConnection;
import org.springframework.http.HttpStatus;
import org.springframework.http.server.ServerHttpAsyncRequestControl;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Matchers.anyInt;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link HttpTunnelServer}.
*
* @author Phillip Webb
*/
public class HttpTunnelServerTests {
private static final int DEFAULT_LONG_POLL_TIMEOUT = 10000;
private static final byte[] NO_DATA = {};
private static final String SEQ_HEADER = "x-seq";
@Rule
public ExpectedException thrown = ExpectedException.none();
private HttpTunnelServer server;
@Mock
private TargetServerConnection serverConnection;
private MockHttpServletRequest servletRequest;
private MockHttpServletResponse servletResponse;
private ServerHttpRequest request;
private ServerHttpResponse response;
private MockServerChannel serverChannel;
@Before
public void setup() throws Exception {
MockitoAnnotations.initMocks(this);
this.server = new HttpTunnelServer(this.serverConnection);
given(this.serverConnection.open(anyInt())).willAnswer(new Answer<ByteChannel>() {
@Override
public ByteChannel answer(InvocationOnMock invocation) throws Throwable {
MockServerChannel channel = HttpTunnelServerTests.this.serverChannel;
channel.setTimeout((Integer) invocation.getArguments()[0]);
return channel;
}
});
this.servletRequest = new MockHttpServletRequest();
this.servletRequest.setAsyncSupported(true);
this.servletResponse = new MockHttpServletResponse();
this.request = new ServletServerHttpRequest(this.servletRequest);
this.response = new ServletServerHttpResponse(this.servletResponse);
this.serverChannel = new MockServerChannel();
}
@Test
public void serverConnectionIsRequired() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("ServerConnection must not be null");
new HttpTunnelServer(null);
}
@Test
public void serverConnectedOnFirstRequest() throws Exception {
verify(this.serverConnection, never()).open(anyInt());
this.server.handle(this.request, this.response);
verify(this.serverConnection, times(1)).open(DEFAULT_LONG_POLL_TIMEOUT);
}
@Test
public void longPollTimeout() throws Exception {
this.server.setLongPollTimeout(800);
this.server.handle(this.request, this.response);
verify(this.serverConnection, times(1)).open(800);
}
@Test
public void longPollTimeoutMustBePositiveValue() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("LongPollTimeout must be a positive value");
this.server.setLongPollTimeout(0);
}
@Test
public void initialRequestIsSentToServer() throws Exception {
this.servletRequest.addHeader(SEQ_HEADER, "1");
this.servletRequest.setContent("hello".getBytes());
this.server.handle(this.request, this.response);
this.serverChannel.disconnect();
this.server.getServerThread().join();
this.serverChannel.verifyReceived("hello");
}
@Test
public void intialRequestIsUsedForFirstServerResponse() throws Exception {
this.servletRequest.addHeader(SEQ_HEADER, "1");
this.servletRequest.setContent("hello".getBytes());
this.server.handle(this.request, this.response);
System.out.println("sending");
this.serverChannel.send("hello");
this.serverChannel.disconnect();
this.server.getServerThread().join();
assertThat(this.servletResponse.getContentAsString(), equalTo("hello"));
this.serverChannel.verifyReceived("hello");
}
@Test
public void initialRequestHasNoPayload() throws Exception {
this.server.handle(this.request, this.response);
this.serverChannel.disconnect();
this.server.getServerThread().join();
this.serverChannel.verifyReceived(NO_DATA);
}
@Test
public void typicalReqestResponseTraffic() throws Exception {
MockHttpConnection h1 = new MockHttpConnection();
this.server.handle(h1);
MockHttpConnection h2 = new MockHttpConnection("hello server", 1);
this.server.handle(h2);
this.serverChannel.verifyReceived("hello server");
this.serverChannel.send("hello client");
h1.verifyReceived("hello client", 1);
MockHttpConnection h3 = new MockHttpConnection("1+1", 2);
this.server.handle(h3);
this.serverChannel.send("=2");
h2.verifyReceived("=2", 2);
MockHttpConnection h4 = new MockHttpConnection("1+2", 3);
this.server.handle(h4);
this.serverChannel.send("=3");
h3.verifyReceived("=3", 3);
this.serverChannel.disconnect();
this.server.getServerThread().join();
}
@Test
public void clientIsAwareOfServerClose() throws Exception {
MockHttpConnection h1 = new MockHttpConnection("1", 1);
this.server.handle(h1);
this.serverChannel.disconnect();
this.server.getServerThread().join();
assertThat(h1.getServletResponse().getStatus(), equalTo(410));
}
@Test
public void clientCanCloseServer() throws Exception {
MockHttpConnection h1 = new MockHttpConnection();
this.server.handle(h1);
MockHttpConnection h2 = new MockHttpConnection("DISCONNECT", 1);
h2.getServletRequest().addHeader("Content-Type", "application/x-disconnect");
this.server.handle(h2);
this.server.getServerThread().join();
assertThat(h1.getServletResponse().getStatus(), equalTo(410));
assertThat(this.serverChannel.isOpen(), equalTo(false));
}
@Test
public void neverMoreThanTwoHttpConnections() throws Exception {
MockHttpConnection h1 = new MockHttpConnection();
this.server.handle(h1);
MockHttpConnection h2 = new MockHttpConnection("1", 2);
this.server.handle(h2);
MockHttpConnection h3 = new MockHttpConnection("2", 3);
this.server.handle(h3);
h1.waitForResponse();
assertThat(h1.getServletResponse().getStatus(), equalTo(429));
this.serverChannel.disconnect();
this.server.getServerThread().join();
}
@Test
public void requestRecievedOutOfOrder() throws Exception {
MockHttpConnection h1 = new MockHttpConnection();
MockHttpConnection h2 = new MockHttpConnection("1+2", 1);
MockHttpConnection h3 = new MockHttpConnection("+3", 2);
this.server.handle(h1);
this.server.handle(h3);
this.server.handle(h2);
this.serverChannel.verifyReceived("1+2+3");
this.serverChannel.disconnect();
this.server.getServerThread().join();
}
@Test
public void httpConnectionsAreClosedAfterLongPollTimeout() throws Exception {
this.server.setDisconnectTimeout(1000);
this.server.setLongPollTimeout(100);
MockHttpConnection h1 = new MockHttpConnection();
this.server.handle(h1);
MockHttpConnection h2 = new MockHttpConnection();
this.server.handle(h2);
Thread.sleep(400);
this.serverChannel.disconnect();
this.server.getServerThread().join();
assertThat(h1.getServletResponse().getStatus(), equalTo(204));
assertThat(h2.getServletResponse().getStatus(), equalTo(204));
}
@Test
public void disconnectTimeout() throws Exception {
this.server.setDisconnectTimeout(100);
this.server.setLongPollTimeout(100);
MockHttpConnection h1 = new MockHttpConnection();
this.server.handle(h1);
this.serverChannel.send("hello");
this.server.getServerThread().join();
assertThat(this.serverChannel.isOpen(), equalTo(false));
}
@Test
public void disconnectTimeoutMustBePositive() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("DisconnectTimeout must be a positive value");
this.server.setDisconnectTimeout(0);
}
@Test
public void httpConnectionRespondWithPayload() throws Exception {
HttpConnection connection = new HttpConnection(this.request, this.response);
connection.waitForResponse();
connection.respond(new HttpTunnelPayload(1, ByteBuffer.wrap("hello".getBytes())));
assertThat(this.servletResponse.getStatus(), equalTo(200));
assertThat(this.servletResponse.getContentAsString(), equalTo("hello"));
assertThat(this.servletResponse.getHeader(SEQ_HEADER), equalTo("1"));
}
@Test
public void httpConnectionRespondWithStatus() throws Exception {
HttpConnection connection = new HttpConnection(this.request, this.response);
connection.waitForResponse();
connection.respond(HttpStatus.I_AM_A_TEAPOT);
assertThat(this.servletResponse.getStatus(), equalTo(418));
assertThat(this.servletResponse.getContentLength(), equalTo(0));
}
@Test
public void httpConnectionAsync() throws Exception {
ServerHttpAsyncRequestControl async = mock(ServerHttpAsyncRequestControl.class);
ServerHttpRequest request = mock(ServerHttpRequest.class);
given(request.getAsyncRequestControl(this.response)).willReturn(async);
HttpConnection connection = new HttpConnection(request, this.response);
connection.waitForResponse();
verify(async).start();
connection.respond(HttpStatus.NO_CONTENT);
verify(async).complete();
}
@Test
public void httpConnectionNonAsync() throws Exception {
testHttpConnectionNonAsync(0);
testHttpConnectionNonAsync(100);
}
private void testHttpConnectionNonAsync(long sleepBeforeResponse) throws IOException,
InterruptedException {
ServerHttpRequest request = mock(ServerHttpRequest.class);
given(request.getAsyncRequestControl(this.response)).willThrow(
new IllegalArgumentException());
final HttpConnection connection = new HttpConnection(request, this.response);
final AtomicBoolean responded = new AtomicBoolean();
Thread connectionThread = new Thread() {
@Override
public void run() {
connection.waitForResponse();
responded.set(true);
}
};
connectionThread.start();
assertThat(responded.get(), equalTo(false));
Thread.sleep(sleepBeforeResponse);
connection.respond(HttpStatus.NO_CONTENT);
connectionThread.join();
assertThat(responded.get(), equalTo(true));
}
@Test
public void httpConnectionRunning() throws Exception {
HttpConnection connection = new HttpConnection(this.request, this.response);
assertThat(connection.isOlderThan(100), equalTo(false));
Thread.sleep(200);
assertThat(connection.isOlderThan(100), equalTo(true));
}
/**
* Mock {@link ByteChannel} used to simulate the server connection.
*/
private static class MockServerChannel implements ByteChannel {
private static final ByteBuffer DISCONNECT = ByteBuffer.wrap(NO_DATA);
private int timeout;
private BlockingDeque<ByteBuffer> outgoing = new LinkedBlockingDeque<ByteBuffer>();
private ByteArrayOutputStream written = new ByteArrayOutputStream();
private AtomicBoolean open = new AtomicBoolean(true);
public void setTimeout(int timeout) {
this.timeout = timeout;
}
public void send(String content) {
send(content.getBytes());
}
public void send(byte[] bytes) {
this.outgoing.addLast(ByteBuffer.wrap(bytes));
}
public void disconnect() {
this.outgoing.addLast(DISCONNECT);
}
public void verifyReceived(String expected) {
verifyReceived(expected.getBytes());
}
public void verifyReceived(byte[] expected) {
synchronized (this.written) {
assertThat(this.written.toByteArray(), equalTo(expected));
this.written.reset();
}
}
@Override
public int read(ByteBuffer dst) throws IOException {
try {
ByteBuffer bytes = this.outgoing.pollFirst(this.timeout,
TimeUnit.MILLISECONDS);
if (bytes == null) {
throw new SocketTimeoutException();
}
if (bytes == DISCONNECT) {
this.open.set(false);
return -1;
}
int initialRemaining = dst.remaining();
bytes.limit(Math.min(bytes.limit(), initialRemaining));
dst.put(bytes);
bytes.limit(bytes.capacity());
return initialRemaining - dst.remaining();
}
catch (InterruptedException ex) {
throw new IllegalStateException(ex);
}
}
@Override
public int write(ByteBuffer src) throws IOException {
int remaining = src.remaining();
synchronized (this.written) {
Channels.newChannel(this.written).write(src);
}
return remaining;
}
@Override
public boolean isOpen() {
return this.open.get();
}
@Override
public void close() throws IOException {
this.open.set(false);
}
}
/**
* Mock {@link HttpConnection}.
*/
private static class MockHttpConnection extends HttpConnection {
public MockHttpConnection() {
super(new ServletServerHttpRequest(new MockHttpServletRequest()),
new ServletServerHttpResponse(new MockHttpServletResponse()));
}
public MockHttpConnection(String content, int seq) {
this();
MockHttpServletRequest request = getServletRequest();
request.setContent(content.getBytes());
request.addHeader(SEQ_HEADER, String.valueOf(seq));
}
@Override
protected ServerHttpAsyncRequestControl startAsync() {
getServletRequest().setAsyncSupported(true);
return super.startAsync();
}
@Override
protected void complete() {
super.complete();
getServletResponse().setCommitted(true);
}
public MockHttpServletRequest getServletRequest() {
return (MockHttpServletRequest) ((ServletServerHttpRequest) getRequest())
.getServletRequest();
}
public MockHttpServletResponse getServletResponse() {
return (MockHttpServletResponse) ((ServletServerHttpResponse) getResponse())
.getServletResponse();
}
public void verifyReceived(String expectedContent, int expectedSeq)
throws Exception {
waitForServletResponse();
MockHttpServletResponse resp = getServletResponse();
assertThat(resp.getContentAsString(), equalTo(expectedContent));
assertThat(resp.getHeader(SEQ_HEADER), equalTo(String.valueOf(expectedSeq)));
}
public void waitForServletResponse() throws InterruptedException {
while (!getServletResponse().isCommitted()) {
Thread.sleep(10);
}
}
}
}

@ -0,0 +1,178 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import org.junit.Before;
import org.junit.Test;
import org.springframework.util.SocketUtils;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.lessThan;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
/**
* Tests for {@link SocketTargetServerConnection}.
*
* @author Phillip Webb
*/
public class SocketTargetServerConnectionTests {
private static final int DEFAULT_TIMEOUT = 1000;
private int port;
private MockServer server;
private SocketTargetServerConnection connection;
@Before
public void setup() throws IOException {
this.port = SocketUtils.findAvailableTcpPort();
this.server = new MockServer(this.port);
StaticPortProvider portProvider = new StaticPortProvider(this.port);
this.connection = new SocketTargetServerConnection(portProvider);
}
@Test
public void readData() throws Exception {
this.server.willSend("hello".getBytes());
this.server.start();
ByteChannel channel = this.connection.open(DEFAULT_TIMEOUT);
ByteBuffer buffer = ByteBuffer.allocate(5);
channel.read(buffer);
assertThat(buffer.array(), equalTo("hello".getBytes()));
}
@Test
public void writeData() throws Exception {
this.server.expect("hello".getBytes());
this.server.start();
ByteChannel channel = this.connection.open(DEFAULT_TIMEOUT);
ByteBuffer buffer = ByteBuffer.wrap("hello".getBytes());
channel.write(buffer);
this.server.closeAndVerify();
}
@Test
public void timeout() throws Exception {
this.server.delay(1000);
this.server.start();
ByteChannel channel = this.connection.open(10);
long startTime = System.currentTimeMillis();
try {
channel.read(ByteBuffer.allocate(5));
fail("No socket timeout thrown");
}
catch (SocketTimeoutException ex) {
// Expected
long runTime = System.currentTimeMillis() - startTime;
assertThat(runTime, greaterThanOrEqualTo(10L));
assertThat(runTime, lessThan(10000L));
}
}
private static class MockServer {
private ServerSocketChannel serverSocket;
private byte[] send;
private byte[] expect;
private int delay;
private ByteBuffer actualRead;
private ServerThread thread;
public MockServer(int port) throws IOException {
this.serverSocket = ServerSocketChannel.open();
this.serverSocket.bind(new InetSocketAddress(port));
}
public void delay(int delay) {
this.delay = delay;
}
public void willSend(byte[] send) {
this.send = send;
}
public void expect(byte[] expect) {
this.expect = expect;
}
public void start() {
this.thread = new ServerThread();
this.thread.start();
}
public void closeAndVerify() throws InterruptedException {
close();
assertThat(this.actualRead.array(), equalTo(this.expect));
}
public void close() throws InterruptedException {
while (this.thread.isAlive()) {
Thread.sleep(10);
}
}
private class ServerThread extends Thread {
@Override
public void run() {
try {
SocketChannel channel = MockServer.this.serverSocket.accept();
Thread.sleep(MockServer.this.delay);
if (MockServer.this.send != null) {
ByteBuffer buffer = ByteBuffer.wrap(MockServer.this.send);
while (buffer.hasRemaining()) {
channel.write(buffer);
}
}
if (MockServer.this.expect != null) {
ByteBuffer buffer = ByteBuffer
.allocate(MockServer.this.expect.length);
while (buffer.hasRemaining()) {
channel.read(buffer);
}
MockServer.this.actualRead = buffer;
}
channel.close();
}
catch (Exception ex) {
ex.printStackTrace();
fail();
}
}
}
}
}

@ -0,0 +1,49 @@
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertThat;
/**
* Tests for {@link StaticPortProvider}.
*
* @author Phillip Webb
*/
public class StaticPortProviderTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
@Test
public void portMustBePostive() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Port must be positive");
new StaticPortProvider(0);
}
@Test
public void getPort() throws Exception {
StaticPortProvider provider = new StaticPortProvider(123);
assertThat(provider.getPort(), equalTo(123));
}
}
Loading…
Cancel
Save