diff --git a/flink-filesystems/flink-s3-fs-native/src/main/java/org/apache/flink/fs/s3native/NativeS3InputStream.java b/flink-filesystems/flink-s3-fs-native/src/main/java/org/apache/flink/fs/s3native/NativeS3InputStream.java index f6dca047f648d..f1c7084a6cc8d 100644 --- a/flink-filesystems/flink-s3-fs-native/src/main/java/org/apache/flink/fs/s3native/NativeS3InputStream.java +++ b/flink-filesystems/flink-s3-fs-native/src/main/java/org/apache/flink/fs/s3native/NativeS3InputStream.java @@ -27,7 +27,10 @@ import software.amazon.awssdk.services.s3.model.GetObjectRequest; import software.amazon.awssdk.services.s3.model.GetObjectResponse; +import javax.annotation.concurrent.GuardedBy; + import java.io.BufferedInputStream; +import java.io.EOFException; import java.io.IOException; import java.util.concurrent.locks.ReentrantLock; @@ -56,9 +59,16 @@ class NativeS3InputStream extends FSDataInputStream { private final long contentLength; private final int readBufferSize; + @GuardedBy("lock") private ResponseInputStream currentStream; + + @GuardedBy("lock") private BufferedInputStream bufferedStream; + + @GuardedBy("lock") private long position; + + @GuardedBy("lock") private volatile boolean closed; public NativeS3InputStream( @@ -108,24 +118,7 @@ private void lazyInitialize() throws IOException { private void openStreamAtCurrentPosition() throws IOException { lock.lock(); try { - if (bufferedStream != null) { - try { - bufferedStream.close(); - } catch (IOException e) { - LOG.warn("Error closing buffered stream for {}/{}", bucketName, key, e); - } finally { - bufferedStream = null; - } - } - if (currentStream != null) { - try { - currentStream.close(); - } catch (IOException e) { - LOG.warn("Error closing S3 response stream for {}/{}", bucketName, key, e); - } finally { - currentStream = null; - } - } + releaseStreams(); try { GetObjectRequest.Builder requestBuilder = @@ -143,20 +136,7 @@ private void openStreamAtCurrentPosition() throws IOException { currentStream = s3Client.getObject(requestBuilder.build()); bufferedStream = new BufferedInputStream(currentStream, readBufferSize); } catch (Exception e) { - if (bufferedStream != null) { - try { - bufferedStream.close(); - } catch (IOException ignored) { - } - bufferedStream = null; - } - if (currentStream != null) { - try { - currentStream.close(); - } catch (IOException ignored) { - } - currentStream = null; - } + releaseStreams(); throw new IOException("Failed to open S3 stream for " + bucketName + "/" + key, e); } } finally { @@ -164,6 +144,64 @@ private void openStreamAtCurrentPosition() throws IOException { } } + /** + * Aborts the in-flight HTTP connection so that subsequent {@code close()} calls on the stream + * do not drain remaining bytes over the network. + * + * @see ResponseInputStream#abort() + */ + @GuardedBy("lock") + private void abortCurrentStream() { + assert lock.isHeldByCurrentThread() : "abortCurrentStream() requires lock to be held"; + if (currentStream != null) { + try { + currentStream.abort(); + } catch (RuntimeException e) { + LOG.warn("Error aborting S3 response stream for {}/{}", bucketName, key, e); + } + } + } + + /** + * Aborts and closes both streams, nulling the references. The abort is called first to prevent + * {@link ResponseInputStream#close()} from draining remaining bytes over the network. + * + * @return the first {@link IOException} encountered (with subsequent ones added as suppressed), + * or {@code null} if cleanup succeeded without errors + */ + @GuardedBy("lock") + private IOException releaseStreams() { + assert lock.isHeldByCurrentThread() : "releaseStreams() requires lock to be held"; + abortCurrentStream(); + IOException exception = null; + + if (bufferedStream != null) { + try { + bufferedStream.close(); + } catch (IOException e) { + exception = e; + LOG.warn("Error closing buffered stream for {}/{}", bucketName, key, e); + } finally { + bufferedStream = null; + } + } + if (currentStream != null) { + try { + currentStream.close(); + } catch (IOException e) { + if (exception == null) { + exception = e; + } else { + exception.addSuppressed(e); + } + LOG.warn("Error closing S3 response stream for {}/{}", bucketName, key, e); + } finally { + currentStream = null; + } + } + return exception; + } + @Override public void seek(long desired) throws IOException { lock(); @@ -172,7 +210,14 @@ public void seek(long desired) throws IOException { throw new IOException("Stream is closed"); } if (desired < 0) { - throw new IOException("Cannot seek to negative position: " + desired); + throw new EOFException("Cannot seek to negative position: " + desired); + } + if (desired > contentLength) { + throw new EOFException( + "Cannot seek past end of stream: position=" + + desired + + ", length=" + + contentLength); } if (desired != position) { @@ -270,33 +315,8 @@ public void close() throws IOException { } closed = true; - IOException exception = null; - - if (bufferedStream != null) { - try { - bufferedStream.close(); - } catch (IOException e) { - exception = e; - LOG.warn("Error closing buffered stream for {}/{}", bucketName, key, e); - } finally { - bufferedStream = null; - } - } - if (currentStream != null) { - try { - currentStream.close(); - } catch (IOException e) { - if (exception == null) { - exception = e; - } else { - exception.addSuppressed(e); - } - LOG.warn("Error closing S3 response stream for {}/{}", bucketName, key, e); - } finally { - currentStream = null; - } - } + IOException exception = releaseStreams(); LOG.debug( "Closed S3 input stream - bucket: {}, key: {}, final position: {}/{}", diff --git a/flink-filesystems/flink-s3-fs-native/src/test/java/org/apache/flink/fs/s3native/NativeS3InputStreamTest.java b/flink-filesystems/flink-s3-fs-native/src/test/java/org/apache/flink/fs/s3native/NativeS3InputStreamTest.java new file mode 100644 index 0000000000000..7a628735bf4b8 --- /dev/null +++ b/flink-filesystems/flink-s3-fs-native/src/test/java/org/apache/flink/fs/s3native/NativeS3InputStreamTest.java @@ -0,0 +1,324 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.flink.fs.s3native; + +import org.junit.jupiter.api.Test; +import software.amazon.awssdk.core.ResponseInputStream; +import software.amazon.awssdk.http.Abortable; +import software.amazon.awssdk.http.AbortableInputStream; +import software.amazon.awssdk.services.s3.S3Client; +import software.amazon.awssdk.services.s3.model.GetObjectRequest; +import software.amazon.awssdk.services.s3.model.GetObjectResponse; + +import java.io.ByteArrayInputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.time.Duration; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatThrownBy; + +/** Tests for {@link NativeS3InputStream}. */ +class NativeS3InputStreamTest { + + private static final String BUCKET = "test-bucket"; + private static final String KEY = "test-key"; + + private static final byte[] DATA; + + static { + DATA = new byte[256]; + for (int i = 0; i < DATA.length; i++) { + DATA[i] = (byte) i; + } + } + + private static class TrackingInputStream extends InputStream implements Abortable { + private final ByteArrayInputStream delegate; + private final AtomicBoolean aborted = new AtomicBoolean(); + private final AtomicBoolean closed = new AtomicBoolean(); + private volatile boolean abortedBeforeClose; + + TrackingInputStream(byte[] data, int offset) { + this.delegate = new ByteArrayInputStream(data, offset, data.length - offset); + } + + TrackingInputStream(byte[] data) { + this(data, 0); + } + + @Override + public int read() { + return delegate.read(); + } + + @Override + public int read(byte[] b, int off, int len) { + return delegate.read(b, off, len); + } + + @Override + public void abort() { + aborted.set(true); + } + + @Override + public void close() throws IOException { + if (aborted.get()) { + abortedBeforeClose = true; + } + closed.set(true); + delegate.close(); + } + + boolean wasAborted() { + return aborted.get(); + } + + boolean wasClosed() { + return closed.get(); + } + + boolean wasAbortedBeforeClose() { + return abortedBeforeClose; + } + } + + /** {@link S3Client} stub. */ + private static final class StubS3Client implements S3Client { + private final byte[] data; + private final AtomicInteger getObjectCalls = new AtomicInteger(); + private volatile TrackingInputStream lastStream; + + StubS3Client(byte[] data) { + this.data = data; + } + + @Override + public ResponseInputStream getObject(GetObjectRequest request) { + getObjectCalls.incrementAndGet(); + int offset = 0; + String range = request.range(); + if (range != null && range.startsWith("bytes=")) { + offset = Integer.parseInt(range.substring(6, range.indexOf('-'))); + } + TrackingInputStream tracking = new TrackingInputStream(data, offset); + lastStream = tracking; + AbortableInputStream abortable = AbortableInputStream.create(tracking, tracking); + return new ResponseInputStream<>( + GetObjectResponse.builder().build(), abortable, Duration.ZERO); + } + + @Override + public String serviceName() { + return "s3"; + } + + @Override + public void close() {} + + TrackingInputStream lastStream() { + return lastStream; + } + + int getObjectCalls() { + return getObjectCalls.get(); + } + } + + @Test + void closeAbortsUnderlyingStream() throws Exception { + StubS3Client client = new StubS3Client(DATA); + try (NativeS3InputStream in = new NativeS3InputStream(client, BUCKET, KEY, DATA.length)) { + assertThat(in.read()).isEqualTo(0); + assertThat(in.getPos()).isEqualTo(1); + } + assertThat(client.lastStream().wasAborted()).isTrue(); + } + + @Test + void closeAbortsAndThenClosesUnderlyingStream() throws Exception { + StubS3Client client = new StubS3Client(DATA); + try (NativeS3InputStream in = new NativeS3InputStream(client, BUCKET, KEY, DATA.length)) { + in.read(); + } + TrackingInputStream stream = client.lastStream(); + // abort() must be called to kill the HTTP connection (prevents drain) + assertThat(stream.wasAborted()).isTrue(); + // close() must still be called for SDK resource cleanup (connection pool return, etc.) + assertThat(stream.wasClosed()).isTrue(); + // abort() must happen BEFORE close() — otherwise close() drains remaining bytes + assertThat(stream.wasAbortedBeforeClose()).isTrue(); + } + + @Test + void seekAbortsAndClosesOldStreamBeforeOpeningNew() throws Exception { + StubS3Client client = new StubS3Client(DATA); + try (NativeS3InputStream in = new NativeS3InputStream(client, BUCKET, KEY, DATA.length)) { + in.read(); + TrackingInputStream first = client.lastStream(); + + in.seek(100); + + // old stream must be aborted, closed, and in the correct order + assertThat(first.wasAborted()).isTrue(); + assertThat(first.wasClosed()).isTrue(); + assertThat(first.wasAbortedBeforeClose()).isTrue(); + assertThat(client.getObjectCalls()).isEqualTo(2); + assertThat(in.getPos()).isEqualTo(100); + in.seek(100); + assertThat(client.getObjectCalls()).isEqualTo(2); + } + } + + @Test + void skipAbortsOldStreamAndOpensNew() throws Exception { + StubS3Client client = new StubS3Client(DATA); + try (NativeS3InputStream in = new NativeS3InputStream(client, BUCKET, KEY, DATA.length)) { + in.read(); + TrackingInputStream first = client.lastStream(); + assertThat(in.skip(100)).isEqualTo(100); + assertThat(first.wasAborted()).isTrue(); + assertThat(first.wasClosed()).isTrue(); + assertThat(first.wasAbortedBeforeClose()).isTrue(); + assertThat(client.getObjectCalls()).isEqualTo(2); + // skip(0) and skip(negative) are no-ops + assertThat(in.skip(0)).isZero(); + assertThat(in.skip(-5)).isZero(); + assertThat(client.getObjectCalls()).isEqualTo(2); + } + } + + @Test + void closeWithoutReadNeverOpensStream() throws Exception { + StubS3Client client = new StubS3Client(DATA); + try (NativeS3InputStream in = new NativeS3InputStream(client, BUCKET, KEY, DATA.length)) { + // lazy init means no getObject call + } + assertThat(client.getObjectCalls()).isEqualTo(0); + } + + @Test + void doubleCloseIsIdempotent() throws Exception { + StubS3Client client = new StubS3Client(DATA); + NativeS3InputStream in = new NativeS3InputStream(client, BUCKET, KEY, DATA.length); + in.read(); + + // Verify state after first close + in.close(); + assertThat(client.getObjectCalls()).isEqualTo(1); + assertThat(client.lastStream().wasAborted()).isTrue(); + + // Second close should be a no-op + in.close(); + assertThat(client.getObjectCalls()).isEqualTo(1); + assertThat(client.lastStream().wasAborted()).isTrue(); + } + + @Test + void seekBeforeFirstReadUpdatesPositionOnly() throws Exception { + StubS3Client client = new StubS3Client(DATA); + try (NativeS3InputStream in = new NativeS3InputStream(client, BUCKET, KEY, DATA.length)) { + in.seek(50); + assertThat(in.getPos()).isEqualTo(50); + assertThat(client.getObjectCalls()).isEqualTo(0); + } + } + + @Test + void readAndSeekReturnCorrectData() throws Exception { + StubS3Client client = new StubS3Client(DATA); + try (NativeS3InputStream in = new NativeS3InputStream(client, BUCKET, KEY, DATA.length)) { + // single-byte read + assertThat(in.read()).isEqualTo(0); + assertThat(in.getPos()).isEqualTo(1); + // bulk read returns correct bytes and advances position + byte[] buf = new byte[10]; + assertThat(in.read(buf, 0, 10)).isEqualTo(10); + assertThat(in.getPos()).isEqualTo(11); + for (int i = 0; i < 10; i++) { + assertThat(buf[i]).isEqualTo(DATA[i + 1]); + } + // available() reflects remaining bytes + assertThat(in.available()).isEqualTo(DATA.length - 11); + // seek then read returns data at the seeked position + in.seek(200); + assertThat(in.read()).isEqualTo(200); + assertThat(in.getPos()).isEqualTo(201); + // partial read at EOF returns only remaining bytes + in.seek(250); + byte[] tail = new byte[20]; + assertThat(in.read(tail, 0, 20)).isEqualTo(6); + assertThat(in.getPos()).isEqualTo(256); + // read past EOF + assertThat(in.read()).isEqualTo(-1); + assertThat(in.read(new byte[1], 0, 1)).isEqualTo(-1); + } + } + + @Test + void seekPastEofThrowsEofException() throws Exception { + StubS3Client client = new StubS3Client(DATA); + try (NativeS3InputStream in = new NativeS3InputStream(client, BUCKET, KEY, DATA.length)) { + assertThatThrownBy(() -> in.seek(DATA.length + 1)) + .isInstanceOf(EOFException.class) + .hasMessageContaining("past end of stream"); + in.seek(DATA.length); + assertThat(in.getPos()).isEqualTo(DATA.length); + assertThat(client.getObjectCalls()).isEqualTo(0); + } + } + + @Test + void readAtEofReturnsMinusOne() throws Exception { + StubS3Client client = new StubS3Client(DATA); + try (NativeS3InputStream in = new NativeS3InputStream(client, BUCKET, KEY, DATA.length)) { + in.seek(DATA.length); + assertThat(in.read()).isEqualTo(-1); + assertThat(in.read(new byte[8], 0, 8)).isEqualTo(-1); + assertThat(in.getPos()).isEqualTo(DATA.length); + assertThat(in.available()).isZero(); + } + } + + @Test + void rejectsInvalidArguments() throws Exception { + StubS3Client client = new StubS3Client(DATA); + try (NativeS3InputStream in = new NativeS3InputStream(client, BUCKET, KEY, DATA.length)) { + assertThatThrownBy(() -> in.read(null, 0, 1)).isInstanceOf(NullPointerException.class); + assertThatThrownBy(() -> in.read(new byte[5], -1, 1)) + .isInstanceOf(IndexOutOfBoundsException.class); + assertThatThrownBy(() -> in.read(new byte[5], 0, 6)) + .isInstanceOf(IndexOutOfBoundsException.class); + assertThatThrownBy(() -> in.seek(-1)) + .isInstanceOf(EOFException.class) + .hasMessageContaining("negative"); + assertThat(in.read(new byte[5], 0, 0)).isZero(); + } + NativeS3InputStream closed = new NativeS3InputStream(client, BUCKET, KEY, DATA.length); + closed.close(); + assertThatThrownBy(closed::read).isInstanceOf(IOException.class); + assertThatThrownBy(() -> closed.read(new byte[1], 0, 1)).isInstanceOf(IOException.class); + assertThatThrownBy(() -> closed.seek(0)).isInstanceOf(IOException.class); + assertThatThrownBy(closed::available).isInstanceOf(IOException.class); + } +}