Skip to content

Commit

Permalink
Improve HTTP validation for inlining and specialized types
Browse files Browse the repository at this point in the history
  • Loading branch information
franz1981 authored and vietj committed Jan 15, 2025
1 parent bff88a0 commit 5554385
Show file tree
Hide file tree
Showing 3 changed files with 278 additions and 43 deletions.
209 changes: 169 additions & 40 deletions src/main/java/io/vertx/core/http/impl/HttpUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -863,67 +863,181 @@ public static void validateHeader(CharSequence name, Iterable<? extends CharSequ
});
}

public static void validateHeaderValue(CharSequence seq) {
public static void validateHeaderValue(CharSequence value) {
if (value instanceof AsciiString) {
validateAsciiHeaderValue((AsciiString) value);
} else if (value instanceof String) {
validateStringHeaderValue((String) value);
} else {
validateSequenceHeaderValue(value);
}
}

private static void validateAsciiHeaderValue(AsciiString value) {
final int length = value.length();
if (length == 0) {
return;
}
byte[] asciiChars = value.array();
int off = value.arrayOffset();
if (off == 0 && length == asciiChars.length) {
for (int index = 0; index < asciiChars.length; index++) {
int latinChar = asciiChars[index] & 0xFF;
if (latinChar == 0x7F) {
throw new IllegalArgumentException("a header value contains a prohibited character '127': " + value);
}
// non-printable chars are rare so let's make it a fall-back method, whilst still accepting HTAB
if (latinChar < 32 && latinChar != 0x09) {
validateSequenceHeaderValue(value, index - off);
break;
}
}
} else {
validateAsciiRangeHeaderValue(value, off, length, asciiChars);
}
}

/**
* This method is the slow-path generic version of {@link #validateAsciiHeaderValue(AsciiString)} which
* is optimized for {@link AsciiString} instances which are backed by a 0-offset full-blown byte array.
*/
private static void validateAsciiRangeHeaderValue(AsciiString value, int off, int length, byte[] asciiChars) {
int end = off + length;
for (int index = off; index < end; index++) {
int latinChar = asciiChars[index] & 0xFF;
if (latinChar == 0x7F) {
throw new IllegalArgumentException("a header value contains a prohibited character '127': " + value);
}
// non-printable chars are rare so let's make it a fall-back method, whilst still accepting HTAB
if (latinChar < 32 && latinChar != 0x09) {
validateSequenceHeaderValue(value, index - off);
break;
}
}
}

private static void validateStringHeaderValue(String value) {
final int length = value.length();
if (length == 0) {
return;
}

for (int index = 0; index < length; index++) {
char latinChar = value.charAt(index);
if (latinChar == 0x7F) {
throw new IllegalArgumentException("a header value contains a prohibited character '127': " + value);
}
// non-printable chars are rare so let's make it a fall-back method, whilst still accepting HTAB
if (latinChar < 32 && latinChar != 0x09) {
validateSequenceHeaderValue(value, index);
break;
}
}
}

int state = 0;
// Start looping through each of the character
for (int index = 0; index < seq.length(); index++) {
state = validateValueChar(seq, state, seq.charAt(index));
private static void validateSequenceHeaderValue(CharSequence value) {
final int length = value.length();
if (length == 0) {
return;
}

if (state != 0) {
throw new IllegalArgumentException("a header value must not end with '\\r' or '\\n':" + seq);
for (int index = 0; index < length; index++) {
char latinChar = value.charAt(index);
if (latinChar == 0x7F) {
throw new IllegalArgumentException("a header value contains a prohibited character '127': " + value);
}
// non-printable chars are rare so let's make it a fall-back method, whilst still accepting HTAB
if (latinChar < 32 && latinChar != 0x09) {
validateSequenceHeaderValue(value, index);
break;
}
}
}

private static final int HIGHEST_INVALID_VALUE_CHAR_MASK = ~0x1F;
private static final int NO_CR_LF_STATE = 0;
private static final int CR_STATE = 1;
private static final int LF_STATE = 2;

/**
* This method is taken as we need to validate the header value for the non-printable characters.
*/
private static void validateSequenceHeaderValue(CharSequence seq, int index) {
// we already expect the very-first character to be non-printable
int state = validateValueChar(seq, NO_CR_LF_STATE, seq.charAt(index));
for (int i = index + 1; i < seq.length(); i++) {
state = validateValueChar(seq, state, seq.charAt(i));
}
if (state != NO_CR_LF_STATE) {
throw new IllegalArgumentException("a header value must not end with '\\r' or '\\n':" + seq);
}
}

private static int validateValueChar(CharSequence seq, int state, char character) {
private static int validateValueChar(CharSequence seq, int state, char ch) {
/*
* State:
* 0: Previous character was neither CR nor LF
* 1: The previous character was CR
* 2: The previous character was LF
*/
if ((character & HIGHEST_INVALID_VALUE_CHAR_MASK) == 0 || character == 0x7F) { // 0x7F is "DEL".
// The only characters allowed in the range 0x00-0x1F are : HTAB, LF and CR
switch (character) {
case 0x09: // Horizontal tab - HTAB
case 0x0a: // Line feed - LF
case 0x0d: // Carriage return - CR
break;
default:
throw new IllegalArgumentException("a header value contains a prohibited character '" + (int) character + "': " + seq);
if (ch == 0x7F) {
throw new IllegalArgumentException("a header value contains a prohibited character '127': " + seq);
}
if ((ch & HIGHEST_INVALID_VALUE_CHAR_MASK) == 0) {
// this is a rare scenario
validateNonPrintableCtrlChar(seq, ch);
// this can include LF and CR as they are non-printable characters
if (state == NO_CR_LF_STATE) {
// Check the CRLF (HT | SP) pattern
switch (ch) {
case '\r':
return CR_STATE;
case '\n':
return LF_STATE;
}
return NO_CR_LF_STATE;
}
}
if (state != NO_CR_LF_STATE) {
// this is a rare scenario
return validateCrLfChar(seq, state, ch);
} else {
return NO_CR_LF_STATE;
}
}

// Check the CRLF (HT | SP) pattern
private static int validateCrLfChar(CharSequence seq, int state, char ch) {
switch (state) {
case 0:
switch (character) {
case '\r':
return 1;
case '\n':
return 2;
}
break;
case 1:
switch (character) {
case '\n':
return 2;
default:
throw new IllegalArgumentException("only '\\n' is allowed after '\\r': " + seq);
case CR_STATE:
if (ch == '\n') {
return LF_STATE;
}
case 2:
switch (character) {
throw new IllegalArgumentException("only '\\n' is allowed after '\\r': " + seq);
case LF_STATE:
switch (ch) {
case '\t':
case ' ':
return 0;
// return to the normal state
return NO_CR_LF_STATE;
default:
throw new IllegalArgumentException("only ' ' and '\\t' are allowed after '\\n': " + seq);
}
default:
// this should never happen
throw new AssertionError();
}
}

private static void validateNonPrintableCtrlChar(CharSequence seq, int ch) {
// The only characters allowed in the range 0x00-0x1F are : HTAB, LF and CR
switch (ch) {
case 0x09: // Horizontal tab - HTAB
case 0x0a: // Line feed - LF
case 0x0d: // Carriage return - CR
break;
default:
throw new IllegalArgumentException("a header value contains a prohibited character '" + (int) ch + "': " + seq);
}
return state;
}

private static final boolean[] VALID_H_NAME_ASCII_CHARS;
Expand Down Expand Up @@ -959,13 +1073,15 @@ private static int validateValueChar(CharSequence seq, int state, char character
public static void validateHeaderName(CharSequence value) {
if (value instanceof AsciiString) {
// no need to check for ASCII-ness anymore
validateHeaderName((AsciiString) value);
validateAsciiHeaderName((AsciiString) value);
} else if(value instanceof String) {
validateStringHeaderName((String) value);
} else {
validateHeaderName0(value);
validateSequenceHeaderName(value);
}
}

private static void validateHeaderName(AsciiString value) {
private static void validateAsciiHeaderName(AsciiString value) {
final int len = value.length();
final int off = value.arrayOffset();
final byte[] asciiChars = value.array();
Expand All @@ -981,7 +1097,20 @@ private static void validateHeaderName(AsciiString value) {
}
}

private static void validateHeaderName0(CharSequence value) {
private static void validateStringHeaderName(String value) {
for (int i = 0; i < value.length(); i++) {
final char c = value.charAt(i);
// Check to see if the character is not an ASCII character, or invalid
if (c > 0x7f) {
throw new IllegalArgumentException("a header name cannot contain non-ASCII character: " + value);
}
if (!VALID_H_NAME_ASCII_CHARS[c & 0x7F]) {
throw new IllegalArgumentException("a header name cannot contain some prohibited characters, such as : " + value);
}
}
}

private static void validateSequenceHeaderName(CharSequence value) {
for (int i = 0; i < value.length(); i++) {
final char c = value.charAt(i);
// Check to see if the character is not an ASCII character, or invalid
Expand Down
3 changes: 0 additions & 3 deletions src/test/benchmarks/io/vertx/benchmarks/BenchmarkBase.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@
@Threads(1)
@BenchmarkMode(Mode.Throughput)
@Fork(value = 1, jvmArgs = {
"-XX:+UseBiasedLocking",
"-XX:BiasedLockingStartupDelay=0",
"-XX:+AggressiveOpts",
"-Djmh.executor=CUSTOM",
"-Djmh.executor.class=io.vertx.core.impl.VertxExecutorService"
})
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package io.vertx.benchmarks;

import java.util.concurrent.TimeUnit;
import java.util.function.Consumer;

import org.openjdk.jmh.annotations.Benchmark;
import org.openjdk.jmh.annotations.CompilerControl;
import org.openjdk.jmh.annotations.Measurement;
import org.openjdk.jmh.annotations.Param;
import org.openjdk.jmh.annotations.Scope;
import org.openjdk.jmh.annotations.Setup;
import org.openjdk.jmh.annotations.State;
import org.openjdk.jmh.annotations.Warmup;

import io.netty.handler.codec.DefaultHeaders;
import io.netty.handler.codec.http.DefaultHttpHeadersFactory;
import io.vertx.core.http.impl.HttpUtils;

@State(Scope.Thread)
@Warmup(iterations = 10, time = 1, timeUnit = TimeUnit.SECONDS)
@Measurement(iterations = 5, time = 400, timeUnit = TimeUnit.MILLISECONDS)
public class HeadersValidationBenchmark extends BenchmarkBase {

@Param({ "true", "false" })
public boolean ascii;

private CharSequence[] headerNames;
private CharSequence[] headerValues;
private static final DefaultHeaders.NameValidator<CharSequence> nettyNameValidator = DefaultHttpHeadersFactory.headersFactory().getNameValidator();
private static final DefaultHeaders.ValueValidator<CharSequence> nettyValueValidator = DefaultHttpHeadersFactory.headersFactory().getValueValidator();
private static final Consumer<CharSequence> vertxNameValidator = HttpUtils::validateHeaderName;
private static final Consumer<CharSequence> vertxValueValidator = HttpUtils::validateHeaderValue;

@Setup
public void setup() {
headerNames = new CharSequence[4];
headerValues = new CharSequence[4];
headerNames[0] = io.vertx.core.http.HttpHeaders.CONTENT_TYPE;
headerValues[0] = io.vertx.core.http.HttpHeaders.createOptimized("text/plain");
headerNames[1] = io.vertx.core.http.HttpHeaders.CONTENT_LENGTH;
headerValues[1] = io.vertx.core.http.HttpHeaders.createOptimized("20");
headerNames[2] = io.vertx.core.http.HttpHeaders.SERVER;
headerValues[2] = io.vertx.core.http.HttpHeaders.createOptimized("vert.x");
headerNames[3] = io.vertx.core.http.HttpHeaders.DATE;
headerValues[3] = io.vertx.core.http.HttpHeaders.createOptimized(HeadersUtils.DATE_FORMAT.format(new java.util.Date(0)));
if (!ascii) {
for (int i = 0; i < headerNames.length; i++) {
headerNames[i] = headerNames[i].toString();
}
}
if (!ascii) {
for (int i = 0; i < headerValues.length; i++) {
headerValues[i] = headerValues[i].toString();
}
}
}

@Benchmark
public void validateNameNetty() {
for (CharSequence headerName : headerNames) {
nettyNameValidation(headerName);
}
}

@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private void nettyNameValidation(CharSequence headerName) {
nettyNameValidator.validateName(headerName);
}

@Benchmark
public void validateNameVertx() {
for (CharSequence headerName : headerNames) {
vertxNameValidation(headerName);
}
}

@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private void vertxNameValidation(CharSequence headerName) {
vertxNameValidator.accept(headerName);
}


@Benchmark
public void validateValueNetty() {
for (CharSequence headerValue : headerValues) {
nettyValueValidation(headerValue);
}
}

@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private void nettyValueValidation(CharSequence headerValue) {
nettyValueValidator.validate(headerValue);
}

@Benchmark
public void validateValueVertx() {
for (CharSequence headerValue : headerValues) {
vertxValueValidation(headerValue);
}
}

@CompilerControl(CompilerControl.Mode.DONT_INLINE)
private void vertxValueValidation(CharSequence headerValue) {
vertxValueValidator.accept(headerValue);
}



}

0 comments on commit 5554385

Please sign in to comment.