diff --git a/lib/openai/internal/util.rb b/lib/openai/internal/util.rb index f463559f..d028d8a4 100644 --- a/lib/openai/internal/util.rb +++ b/lib/openai/internal/util.rb @@ -428,11 +428,11 @@ def close @stream.to_a.join in Integer @buf << @stream.next while @buf.length < max_len - @buf.slice!(..max_len) + @buf.slice!(0, max_len) end rescue StopIteration @stream = nil - @buf.slice!(0..) + @buf.empty? ? nil : @buf.slice!(0..) end # @api private @@ -442,21 +442,23 @@ def close # # @return [String, nil] def read(max_len = nil, out_string = nil) - case @stream - in nil - nil - in IO | StringIO - @stream.read(max_len, out_string) - in Enumerator - read = read_enum(max_len) - case out_string - in String - out_string.replace(read) + read = + case @stream in nil - read + nil + in IO | StringIO + return @stream.read(max_len, out_string).tap(&@blk) + in Enumerator + read_enum(max_len) end - end - .tap(&@blk) + + case out_string + in String + out_string.replace(read || "") + read.nil? ? nil : out_string + in nil + read + end.tap(&@blk) end # @api private diff --git a/test/openai/internal/util_test.rb b/test/openai/internal/util_test.rb index 8566246b..9f52ee4d 100644 --- a/test/openai/internal/util_test.rb +++ b/test/openai/internal/util_test.rb @@ -329,6 +329,58 @@ def test_copy_read end end + def test_read_respects_max_len_for_enumerator + input = + Enumerator.new do |y| + y << "ab" + y << "cd" + y << "ef" + end + + # rubocop:disable Lint/EmptyBlock + adapter = OpenAI::Internal::Util::ReadIOAdapter.new(input) {} + # rubocop:enable Lint/EmptyBlock + + assert_equal("abc", adapter.read(3)) + assert_equal("def", adapter.read(3)) + assert_nil(adapter.read(3)) + end + + def test_read_clears_out_string_at_eof_for_enumerator + input = + Enumerator.new do |y| + y << "ab" + y << "cd" + y << "ef" + end + + # rubocop:disable Lint/EmptyBlock + adapter = OpenAI::Internal::Util::ReadIOAdapter.new(input) {} + # rubocop:enable Lint/EmptyBlock + out = +"seed" + + assert_same(out, adapter.read(3, out)) + assert_equal("abc", out) + assert_same(out, adapter.read(3, out)) + assert_equal("def", out) + assert_nil(adapter.read(3, out)) + assert_equal("", out) + end + + def test_copy_read_for_enumerator_exactly_on_copy_stream_boundary + body = "a" * 16_384 + input = Enumerator.new { _1 << body } + io = StringIO.new + + # rubocop:disable Lint/EmptyBlock + adapter = OpenAI::Internal::Util::ReadIOAdapter.new(input) {} + # rubocop:enable Lint/EmptyBlock + + IO.copy_stream(adapter, io) + + assert_equal(body, io.string) + end + def test_copy_write cases = { StringIO.new => "",