Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions docs/changelog/135873.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
pr: 135873
summary: Convert `BytesTransportResponse` when proxying response from/to local node
area: "Network"
type: bug
issues: []
Original file line number Diff line number Diff line change
Expand Up @@ -607,7 +607,13 @@ static void registerNodeSearchAction(
}
}
);
TransportActionProxy.registerProxyAction(transportService, NODE_SEARCH_ACTION_NAME, true, NodeQueryResponse::new);
TransportActionProxy.registerProxyAction(
transportService,
NODE_SEARCH_ACTION_NAME,
true,
NodeQueryResponse::new,
namedWriteableRegistry
);
}

private static void releaseLocalContext(
Expand Down Expand Up @@ -845,7 +851,10 @@ void onShardDone() {
out.close();
}
}
ActionListener.respondAndRelease(channelListener, new BytesTransportResponse(out.moveToBytesReference()));
ActionListener.respondAndRelease(
channelListener,
new BytesTransportResponse(out.moveToBytesReference(), out.getTransportVersion())
);
}

private void maybeFreeContext(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.client.internal.OriginSettingClient;
import org.elasticsearch.client.internal.node.NodeClient;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand Down Expand Up @@ -384,7 +385,11 @@ public void writeTo(StreamOutput out) throws IOException {
}
}

public static void registerRequestHandler(TransportService transportService, SearchService searchService) {
public static void registerRequestHandler(
TransportService transportService,
SearchService searchService,
NamedWriteableRegistry namedWriteableRegistry
) {
final TransportRequestHandler<ScrollFreeContextRequest> freeContextHandler = (request, channel, task) -> {
logger.trace("releasing search context [{}]", request.id());
boolean freed = searchService.freeReaderContext(request.id());
Expand All @@ -401,7 +406,8 @@ public static void registerRequestHandler(TransportService transportService, Sea
transportService,
FREE_CONTEXT_SCROLL_ACTION_NAME,
false,
SearchFreeContextResponse::readFrom
SearchFreeContextResponse::readFrom,
namedWriteableRegistry
);

// TODO: remove this handler once the lowest compatible version stops using it
Expand All @@ -411,7 +417,13 @@ public static void registerRequestHandler(TransportService transportService, Sea
OriginalIndices.readOriginalIndices(in);
return res;
}, freeContextHandler);
TransportActionProxy.registerProxyAction(transportService, FREE_CONTEXT_ACTION_NAME, false, SearchFreeContextResponse::readFrom);
TransportActionProxy.registerProxyAction(
transportService,
FREE_CONTEXT_ACTION_NAME,
false,
SearchFreeContextResponse::readFrom,
namedWriteableRegistry
);

transportService.registerRequestHandler(
CLEAR_SCROLL_CONTEXTS_ACTION_NAME,
Expand All @@ -426,7 +438,8 @@ public static void registerRequestHandler(TransportService transportService, Sea
transportService,
CLEAR_SCROLL_CONTEXTS_ACTION_NAME,
false,
(in) -> ActionResponse.Empty.INSTANCE
(in) -> ActionResponse.Empty.INSTANCE,
namedWriteableRegistry
);

transportService.registerRequestHandler(
Expand All @@ -435,7 +448,7 @@ public static void registerRequestHandler(TransportService transportService, Sea
ShardSearchRequest::new,
(request, channel, task) -> searchService.executeDfsPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel))
);
TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, true, DfsSearchResult::new);
TransportActionProxy.registerProxyAction(transportService, DFS_ACTION_NAME, true, DfsSearchResult::new, namedWriteableRegistry);

transportService.registerRequestHandler(
QUERY_ACTION_NAME,
Expand All @@ -451,7 +464,8 @@ public static void registerRequestHandler(TransportService transportService, Sea
transportService,
QUERY_ACTION_NAME,
true,
(request) -> ((ShardSearchRequest) request).numberOfShards() == 1 ? QueryFetchSearchResult::new : QuerySearchResult::new
(request) -> ((ShardSearchRequest) request).numberOfShards() == 1 ? QueryFetchSearchResult::new : QuerySearchResult::new,
namedWriteableRegistry
);

transportService.registerRequestHandler(
Expand All @@ -465,7 +479,13 @@ public static void registerRequestHandler(TransportService transportService, Sea
channel.getVersion()
)
);
TransportActionProxy.registerProxyAction(transportService, QUERY_ID_ACTION_NAME, true, QuerySearchResult::new);
TransportActionProxy.registerProxyAction(
transportService,
QUERY_ID_ACTION_NAME,
true,
QuerySearchResult::new,
namedWriteableRegistry
);

transportService.registerRequestHandler(
QUERY_SCROLL_ACTION_NAME,
Expand All @@ -478,7 +498,13 @@ public static void registerRequestHandler(TransportService transportService, Sea
channel.getVersion()
)
);
TransportActionProxy.registerProxyAction(transportService, QUERY_SCROLL_ACTION_NAME, true, ScrollQuerySearchResult::new);
TransportActionProxy.registerProxyAction(
transportService,
QUERY_SCROLL_ACTION_NAME,
true,
ScrollQuerySearchResult::new,
namedWriteableRegistry
);

transportService.registerRequestHandler(
QUERY_FETCH_SCROLL_ACTION_NAME,
Expand All @@ -490,7 +516,13 @@ public static void registerRequestHandler(TransportService transportService, Sea
new ChannelActionListener<>(channel)
)
);
TransportActionProxy.registerProxyAction(transportService, QUERY_FETCH_SCROLL_ACTION_NAME, true, ScrollQueryFetchSearchResult::new);
TransportActionProxy.registerProxyAction(
transportService,
QUERY_FETCH_SCROLL_ACTION_NAME,
true,
ScrollQueryFetchSearchResult::new,
namedWriteableRegistry
);

final TransportRequestHandler<RankFeatureShardRequest> rankShardFeatureRequest = (request, channel, task) -> searchService
.executeRankFeaturePhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel));
Expand All @@ -500,7 +532,13 @@ public static void registerRequestHandler(TransportService transportService, Sea
RankFeatureShardRequest::new,
rankShardFeatureRequest
);
TransportActionProxy.registerProxyAction(transportService, RANK_FEATURE_SHARD_ACTION_NAME, true, RankFeatureResult::new);
TransportActionProxy.registerProxyAction(
transportService,
RANK_FEATURE_SHARD_ACTION_NAME,
true,
RankFeatureResult::new,
namedWriteableRegistry
);

final TransportRequestHandler<ShardFetchRequest> shardFetchRequestHandler = (request, channel, task) -> searchService
.executeFetchPhase(request, (SearchShardTask) task, new ChannelActionListener<>(channel));
Expand All @@ -510,7 +548,13 @@ public static void registerRequestHandler(TransportService transportService, Sea
ShardFetchRequest::new,
shardFetchRequestHandler
);
TransportActionProxy.registerProxyAction(transportService, FETCH_ID_SCROLL_ACTION_NAME, true, FetchSearchResult::new);
TransportActionProxy.registerProxyAction(
transportService,
FETCH_ID_SCROLL_ACTION_NAME,
true,
FetchSearchResult::new,
namedWriteableRegistry
);

transportService.registerRequestHandler(
FETCH_ID_ACTION_NAME,
Expand All @@ -520,15 +564,27 @@ public static void registerRequestHandler(TransportService transportService, Sea
ShardFetchSearchRequest::new,
shardFetchRequestHandler
);
TransportActionProxy.registerProxyAction(transportService, FETCH_ID_ACTION_NAME, true, FetchSearchResult::new);
TransportActionProxy.registerProxyAction(
transportService,
FETCH_ID_ACTION_NAME,
true,
FetchSearchResult::new,
namedWriteableRegistry
);

transportService.registerRequestHandler(
QUERY_CAN_MATCH_NODE_NAME,
transportService.getThreadPool().executor(ThreadPool.Names.SEARCH_COORDINATION),
CanMatchNodeRequest::new,
(request, channel, task) -> searchService.canMatch(request, new ChannelActionListener<>(channel))
);
TransportActionProxy.registerProxyAction(transportService, QUERY_CAN_MATCH_NODE_NAME, true, CanMatchNodeResponse::new);
TransportActionProxy.registerProxyAction(
transportService,
QUERY_CAN_MATCH_NODE_NAME,
true,
CanMatchNodeResponse::new,
namedWriteableRegistry
);
}

private static Executor buildFreeContextExecutor(TransportService transportService) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,13 @@ public TransportOpenPointInTimeAction(
ShardOpenReaderRequest::new,
new ShardOpenReaderRequestHandler()
);
TransportActionProxy.registerProxyAction(transportService, OPEN_SHARD_READER_CONTEXT_NAME, false, ShardOpenReaderResponse::new);
TransportActionProxy.registerProxyAction(
transportService,
OPEN_SHARD_READER_CONTEXT_NAME,
false,
ShardOpenReaderResponse::new,
namedWriteableRegistry
);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ public TransportSearchAction(
this.searchPhaseController = searchPhaseController;
this.searchTransportService = searchTransportService;
this.remoteClusterService = searchTransportService.getRemoteClusterService();
SearchTransportService.registerRequestHandler(transportService, searchService);
SearchTransportService.registerRequestHandler(transportService, searchService, namedWriteableRegistry);
SearchQueryThenFetchAsyncAction.registerNodeSearchAction(
searchTransportService,
searchService,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,20 +9,41 @@

package org.elasticsearch.transport;

import org.elasticsearch.TransportVersion;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;

import java.io.IOException;
import java.util.Objects;

/**
* A specialized, bytes only response, that can potentially be optimized on the network layer.
*/
public class BytesTransportResponse extends TransportResponse implements BytesTransportMessage {

private final ReleasableBytesReference bytes;
private final TransportVersion version;

public BytesTransportResponse(ReleasableBytesReference bytes) {
public BytesTransportResponse(ReleasableBytesReference bytes, TransportVersion version) {
this.bytes = bytes;
this.version = Objects.requireNonNull(version);
}

/**
* Does the binary response need conversion before being sent to the provided target version?
*/
public boolean mustConvertResponseForVersion(TransportVersion targetVersion) {
return version.equals(targetVersion) == false;
}

/**
* Returns a {@link StreamInput} configured to read the underlying bytes that this response holds.
*/
public StreamInput streamInput() throws IOException {
StreamInput streamInput = bytes.streamInput();
streamInput.setTransportVersion(version);
return streamInput;
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
package org.elasticsearch.transport;

import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.common.io.stream.NamedWriteableAwareStreamInput;
import org.elasticsearch.common.io.stream.NamedWriteableRegistry;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
import org.elasticsearch.common.io.stream.Writeable;
Expand All @@ -18,6 +20,7 @@
import org.elasticsearch.tasks.TaskId;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.Map;
import java.util.concurrent.Executor;
import java.util.function.Function;
Expand All @@ -36,15 +39,18 @@ private static class ProxyRequestHandler<T extends ProxyRequest<TransportRequest
private final TransportService service;
private final String action;
private final Function<TransportRequest, Writeable.Reader<? extends TransportResponse>> responseFunction;
private final NamedWriteableRegistry namedWriteableRegistry;

ProxyRequestHandler(
TransportService service,
String action,
Function<TransportRequest, Writeable.Reader<? extends TransportResponse>> responseFunction
Function<TransportRequest, Writeable.Reader<? extends TransportResponse>> responseFunction,
NamedWriteableRegistry namedWriteableRegistry
) {
this.service = service;
this.action = action;
this.responseFunction = responseFunction;
this.namedWriteableRegistry = namedWriteableRegistry;
}

@Override
Expand All @@ -62,7 +68,28 @@ public Executor executor() {

@Override
public void handleResponse(TransportResponse response) {
channel.sendResponse(response);
// This is a short term solution to ensure data node responses for batched search go back to the coordinating
// node in the expected format when a proxy data node proxies the request to itself. The response would otherwise
// be sent directly via DirectResponseChannel, skipping the read and write step that this handler normally performs.
if (response instanceof BytesTransportResponse btr && btr.mustConvertResponseForVersion(channel.getVersion())) {
try (
NamedWriteableAwareStreamInput in = new NamedWriteableAwareStreamInput(
btr.streamInput(),
namedWriteableRegistry
)
) {
TransportResponse convertedResponse = responseFunction.apply(wrappedRequest).read(in);
try {
channel.sendResponse(convertedResponse);
} finally {
convertedResponse.decRef();
}
} catch (IOException e) {
throw new UncheckedIOException(e);
}
} else {
channel.sendResponse(response);
}
}

@Override
Expand All @@ -73,7 +100,7 @@ public void handleException(TransportException exp) {
@Override
public TransportResponse read(StreamInput in) throws IOException {
if (in.getTransportVersion().equals(channel.getVersion()) && in.supportReadAllToReleasableBytesReference()) {
return new BytesTransportResponse(in.readAllToReleasableBytesReference());
return new BytesTransportResponse(in.readAllToReleasableBytesReference(), in.getTransportVersion());
} else {
return responseFunction.apply(wrappedRequest).read(in);
}
Expand Down Expand Up @@ -144,7 +171,9 @@ public static void registerProxyActionWithDynamicResponseType(
TransportService service,
String action,
boolean cancellable,
Function<TransportRequest, Writeable.Reader<? extends TransportResponse>> responseFunction
Function<TransportRequest, Writeable.Reader<? extends TransportResponse>> responseFunction,
NamedWriteableRegistry namedWriteableRegistry

) {
RequestHandlerRegistry<? extends TransportRequest> requestHandler = service.getRequestHandler(action);
service.registerRequestHandler(
Expand All @@ -155,7 +184,7 @@ public static void registerProxyActionWithDynamicResponseType(
in -> cancellable
? new CancellableProxyRequest<>(in, requestHandler::newRequest)
: new ProxyRequest<>(in, requestHandler::newRequest),
new ProxyRequestHandler<>(service, action, responseFunction)
new ProxyRequestHandler<>(service, action, responseFunction, namedWriteableRegistry)
);
}

Expand All @@ -167,9 +196,10 @@ public static void registerProxyAction(
TransportService service,
String action,
boolean cancellable,
Writeable.Reader<? extends TransportResponse> reader
Writeable.Reader<? extends TransportResponse> reader,
NamedWriteableRegistry namedWriteableRegistry
) {
registerProxyActionWithDynamicResponseType(service, action, cancellable, request -> reader);
registerProxyActionWithDynamicResponseType(service, action, cancellable, request -> reader, namedWriteableRegistry);
}

private static final String PROXY_ACTION_PREFIX = "internal:transport/proxy/";
Expand Down
Loading