Merge pull request #16535 from ayudovin

* pr/16535:
  Polish "Fix connection timeout configuration for Netty"
  Fix connection timeout configuration for Netty
  Chain predicates in PropertyMapper when methods

Closes gh-16535
pull/17391/head
Phillip Webb 5 years ago
commit 5e3438f095

@ -24,7 +24,6 @@ import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.cloud.CloudPlatform;
import org.springframework.boot.context.properties.PropertyMapper;
import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory;
import org.springframework.boot.web.embedded.netty.NettyServerCustomizer;
import org.springframework.boot.web.server.WebServerFactoryCustomizer;
import org.springframework.core.Ordered;
import org.springframework.core.env.Environment;
@ -58,11 +57,11 @@ public class NettyWebServerFactoryCustomizer
@Override
public void customize(NettyReactiveWebServerFactory factory) {
factory.setUseForwardHeaders(getOrDeduceUseForwardHeaders(this.serverProperties, this.environment));
PropertyMapper propertyMapper = PropertyMapper.get();
propertyMapper.from(this.serverProperties::getMaxHttpHeaderSize).whenNonNull().asInt(DataSize::toBytes)
PropertyMapper propertyMapper = PropertyMapper.get().alwaysApplyingWhenNonNull();
propertyMapper.from(this.serverProperties::getMaxHttpHeaderSize)
.to((maxHttpRequestHeaderSize) -> customizeMaxHttpHeaderSize(factory, maxHttpRequestHeaderSize));
propertyMapper.from(this.serverProperties::getConnectionTimeout).whenNonNull().asInt(Duration::toMillis)
.to((duration) -> factory.addServerCustomizers(getConnectionTimeOutCustomizer(duration)));
propertyMapper.from(this.serverProperties::getConnectionTimeout)
.to((connectionTimeout) -> customizeConnectionTimeout(factory, connectionTimeout));
}
private boolean getOrDeduceUseForwardHeaders(ServerProperties serverProperties, Environment environment) {
@ -73,14 +72,17 @@ public class NettyWebServerFactoryCustomizer
return platform != null && platform.isUsingForwardHeaders();
}
private void customizeMaxHttpHeaderSize(NettyReactiveWebServerFactory factory, Integer maxHttpHeaderSize) {
factory.addServerCustomizers((NettyServerCustomizer) (httpServer) -> httpServer.httpRequestDecoder(
(httpRequestDecoderSpec) -> httpRequestDecoderSpec.maxHeaderSize(maxHttpHeaderSize)));
private void customizeMaxHttpHeaderSize(NettyReactiveWebServerFactory factory, DataSize maxHttpHeaderSize) {
factory.addServerCustomizers((httpServer) -> httpServer.httpRequestDecoder(
(httpRequestDecoderSpec) -> httpRequestDecoderSpec.maxHeaderSize((int) maxHttpHeaderSize.toBytes())));
}
private NettyServerCustomizer getConnectionTimeOutCustomizer(int duration) {
return (httpServer) -> httpServer.tcpConfiguration(
(tcpServer) -> tcpServer.selectorOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, duration));
private void customizeConnectionTimeout(NettyReactiveWebServerFactory factory, Duration connectionTimeout) {
if (!connectionTimeout.isZero()) {
long timeoutMillis = connectionTimeout.isNegative() ? 0 : connectionTimeout.toMillis();
factory.addServerCustomizers((httpServer) -> httpServer.tcpConfiguration((tcpServer) -> tcpServer
.selectorOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) timeoutMillis)));
}
}
}

@ -16,21 +16,38 @@
package org.springframework.boot.autoconfigure.web.embedded;
import java.time.Duration;
import java.util.Map;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelOption;
import org.junit.Before;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.MockitoAnnotations;
import reactor.netty.http.server.HttpServer;
import reactor.netty.tcp.TcpServer;
import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.context.properties.source.ConfigurationPropertySources;
import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory;
import org.springframework.boot.web.embedded.netty.NettyServerCustomizer;
import org.springframework.mock.env.MockEnvironment;
import org.springframework.test.util.ReflectionTestUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link NettyWebServerFactoryCustomizer}.
*
* @author Brian Clozel
* @author Artsiom Yudovin
*/
public class NettyWebServerFactoryCustomizerTests {
@ -40,8 +57,12 @@ public class NettyWebServerFactoryCustomizerTests {
private NettyWebServerFactoryCustomizer customizer;
@Captor
private ArgumentCaptor<NettyServerCustomizer> customizerCaptor;
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
this.environment = new MockEnvironment();
this.serverProperties = new ServerProperties();
ConfigurationPropertySources.attach(this.environment);
@ -71,4 +92,49 @@ public class NettyWebServerFactoryCustomizerTests {
verify(factory).setUseForwardHeaders(true);
}
@Test
public void setConnectionTimeoutAsZero() {
setupConnectionTimeout(Duration.ZERO);
NettyReactiveWebServerFactory factory = mock(NettyReactiveWebServerFactory.class);
this.customizer.customize(factory);
verifyConnectionTimeout(factory, null);
}
@Test
public void setConnectionTimeoutAsMinusOne() {
setupConnectionTimeout(Duration.ofNanos(-1));
NettyReactiveWebServerFactory factory = mock(NettyReactiveWebServerFactory.class);
this.customizer.customize(factory);
verifyConnectionTimeout(factory, 0);
}
@Test
public void setConnectionTimeout() {
setupConnectionTimeout(Duration.ofSeconds(1));
NettyReactiveWebServerFactory factory = mock(NettyReactiveWebServerFactory.class);
this.customizer.customize(factory);
verifyConnectionTimeout(factory, 1000);
}
@SuppressWarnings("unchecked")
private void verifyConnectionTimeout(NettyReactiveWebServerFactory factory, Integer expected) {
if (expected == null) {
verify(factory, never()).addServerCustomizers(any(NettyServerCustomizer.class));
return;
}
verify(factory, times(1)).addServerCustomizers(this.customizerCaptor.capture());
NettyServerCustomizer serverCustomizer = this.customizerCaptor.getValue();
HttpServer httpServer = serverCustomizer.apply(HttpServer.create());
TcpServer tcpConfiguration = ReflectionTestUtils.invokeMethod(httpServer, "tcpConfiguration");
ServerBootstrap bootstrap = tcpConfiguration.configure();
Map<Object, Object> options = (Map<Object, Object>) ReflectionTestUtils.getField(bootstrap, "options");
assertThat(options).containsEntry(ChannelOption.CONNECT_TIMEOUT_MILLIS, expected);
}
private void setupConnectionTimeout(Duration connectionTimeout) {
this.serverProperties.setUseForwardHeaders(null);
this.serverProperties.setMaxHttpHeaderSize(null);
this.serverProperties.setConnectionTimeout(connectionTimeout);
}
}

@ -50,6 +50,7 @@ import org.springframework.util.StringUtils;
* {@link Source#toInstance(Function) new instance}.
*
* @author Phillip Webb
* @author Artsiom Yudovin
* @since 2.0.0
*/
public final class PropertyMapper {
@ -288,7 +289,7 @@ public final class PropertyMapper {
*/
public Source<T> whenNot(Predicate<T> predicate) {
Assert.notNull(predicate, "Predicate must not be null");
return new Source<>(this.supplier, predicate.negate());
return when(predicate.negate());
}
/**
@ -299,7 +300,7 @@ public final class PropertyMapper {
*/
public Source<T> when(Predicate<T> predicate) {
Assert.notNull(predicate, "Predicate must not be null");
return new Source<>(this.supplier, predicate);
return new Source<>(this.supplier, (this.predicate != null) ? this.predicate.and(predicate) : predicate);
}
/**

@ -28,6 +28,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
* Tests for {@link PropertyMapper}.
*
* @author Phillip Webb
* @author Artsiom Yudovin
*/
public class PropertyMapperTests {
@ -190,6 +191,17 @@ public class PropertyMapperTests {
assertThat(source.getCount()).isOne();
}
@Test
public void whenWhenValueNotMatchesShouldSupportChainedCalls() {
this.map.from("123").when("456"::equals).when("123"::equals).toCall(Assert::fail);
}
@Test
public void whenWhenValueMatchesShouldSupportChainedCalls() {
String result = this.map.from("123").when((s) -> s.contains("2")).when("123"::equals).toInstance(String::new);
assertThat(result).isEqualTo("123");
}
@Test
public void alwaysApplyingWhenNonNullShouldAlwaysApplyNonNullToSource() {
this.map.alwaysApplyingWhenNonNull().from(() -> null).toCall(Assert::fail);

Loading…
Cancel
Save