|
1 | 1 | import anyio |
2 | | -from typing import AsyncIterator, Optional |
| 2 | +from typing import AsyncIterator, Optional, Tuple |
3 | 3 | from fastapi import WebSocketDisconnect |
4 | 4 |
|
5 | 5 | from lib.store import Store |
@@ -114,6 +114,123 @@ async def _wait_for_reconnection(self, peer_type: str) -> bool: |
114 | 114 | self.warning(f"◆ {peer_type.capitalize()} did not reconnect in time") |
115 | 115 | return False |
116 | 116 |
|
| 117 | + async def _get_next_chunk(self, last_chunk_id: str, is_range_request: bool) -> Optional[Tuple[str, bytes]]: |
| 118 | + """Get next chunk from stream. Returns None if no more data available.""" |
| 119 | + if is_range_request: |
| 120 | + result = await self.store.get_chunk_by_range(last_chunk_id) |
| 121 | + if not result: |
| 122 | + if not await self._should_wait_for_sender(): |
| 123 | + return None |
| 124 | + return ('wait', None) |
| 125 | + return result |
| 126 | + else: |
| 127 | + return await self.store.get_next_chunk(self.STREAM_TIMEOUT, last_id=last_chunk_id) |
| 128 | + |
| 129 | + async def _should_wait_for_sender(self) -> bool: |
| 130 | + """Check if we should wait for sender to reconnect or give up.""" |
| 131 | + sender_state = await self.store.get_sender_state() |
| 132 | + if sender_state == ClientState.COMPLETE: |
| 133 | + return False |
| 134 | + elif sender_state == ClientState.DISCONNECTED: |
| 135 | + if not await self._wait_for_reconnection("sender"): |
| 136 | + await self.store.set_receiver_state(ClientState.ERROR) |
| 137 | + return False |
| 138 | + return True |
| 139 | + |
| 140 | + def _adjust_chunk_for_range(self, chunk_data: bytes, stream_position: int, |
| 141 | + start_byte: int, bytes_sent: int, bytes_to_send: int) -> Tuple[Optional[bytes], int]: |
| 142 | + """Adjust chunk data for byte range. Returns (data_to_send, new_stream_position).""" |
| 143 | + new_position = stream_position |
| 144 | + |
| 145 | + # Skip bytes before start_byte |
| 146 | + if stream_position < start_byte: |
| 147 | + skip = min(len(chunk_data), start_byte - stream_position) |
| 148 | + chunk_data = chunk_data[skip:] |
| 149 | + new_position += skip |
| 150 | + |
| 151 | + # Still haven't reached start? Skip entire chunk |
| 152 | + if new_position < start_byte: |
| 153 | + new_position += len(chunk_data) |
| 154 | + return None, new_position |
| 155 | + |
| 156 | + # Trim to remaining bytes needed |
| 157 | + if chunk_data and bytes_sent + len(chunk_data) > bytes_to_send: |
| 158 | + chunk_data = chunk_data[:bytes_to_send - bytes_sent] |
| 159 | + |
| 160 | + return chunk_data if chunk_data else None, new_position |
| 161 | + |
| 162 | + async def _save_progress_if_needed(self, stream_position: int, last_chunk_id: str, force: bool = False): |
| 163 | + """Save download progress periodically or when forced.""" |
| 164 | + if force or stream_position % (64 * 1024) == 0: |
| 165 | + await self.store.save_download_progress( |
| 166 | + bytes_downloaded=stream_position, |
| 167 | + last_read_id=last_chunk_id |
| 168 | + ) |
| 169 | + if force: |
| 170 | + self.debug(f"▼ Progress saved: {stream_position} bytes") |
| 171 | + |
| 172 | + async def _initialize_download_state(self, start_byte: int, is_range_request: bool) -> Tuple[int, str]: |
| 173 | + """Initialize download state and return (stream_position, last_chunk_id).""" |
| 174 | + stream_position = 0 |
| 175 | + last_chunk_id = '0' |
| 176 | + |
| 177 | + if start_byte > 0: |
| 178 | + self.info(f"▼ Starting download from byte {start_byte}") |
| 179 | + if not is_range_request: |
| 180 | + progress = await self.store.get_download_progress() |
| 181 | + if progress and progress.bytes_downloaded >= start_byte: |
| 182 | + last_chunk_id = progress.last_read_id |
| 183 | + stream_position = progress.bytes_downloaded |
| 184 | + |
| 185 | + return stream_position, last_chunk_id |
| 186 | + |
| 187 | + async def _finalize_download_status(self, bytes_sent: int, stream_position: int, |
| 188 | + start_byte: int, end_byte: Optional[int], |
| 189 | + last_chunk_id: str): |
| 190 | + """Update final download status based on what was transferred.""" |
| 191 | + if end_byte is not None: |
| 192 | + self.info(f"▼ Range download complete ({bytes_sent} bytes from {start_byte}-{end_byte})") |
| 193 | + return |
| 194 | + |
| 195 | + total_downloaded = start_byte + bytes_sent |
| 196 | + if total_downloaded >= self.file.size: |
| 197 | + self.info("▼ Full download complete") |
| 198 | + await self.store.set_receiver_state(ClientState.COMPLETE) |
| 199 | + else: |
| 200 | + self.info(f"▼ Download incomplete ({total_downloaded}/{self.file.size} bytes)") |
| 201 | + await self._save_progress_if_needed(stream_position, last_chunk_id, force=True) |
| 202 | + |
| 203 | + async def _handle_download_disconnect(self, error: Exception, stream_position: int, last_chunk_id: str): |
| 204 | + """Handle download disconnection errors.""" |
| 205 | + self.warning(f"▼ Download disconnected: {error}") |
| 206 | + await self.store.save_download_progress( |
| 207 | + bytes_downloaded=stream_position, |
| 208 | + last_read_id=last_chunk_id |
| 209 | + ) |
| 210 | + await self.store.set_receiver_state(ClientState.DISCONNECTED) |
| 211 | + |
| 212 | + if not await self._wait_for_reconnection("receiver"): |
| 213 | + await self.store.set_receiver_state(ClientState.ERROR) |
| 214 | + await self.set_interrupted() |
| 215 | + |
| 216 | + async def _handle_download_timeout(self, stream_position: int, last_chunk_id: str): |
| 217 | + """Handle download timeout by checking sender state.""" |
| 218 | + self.info("▼ Timeout waiting for data") |
| 219 | + sender_state = await self.store.get_sender_state() |
| 220 | + if sender_state == ClientState.DISCONNECTED: |
| 221 | + if not await self._wait_for_reconnection("sender"): |
| 222 | + await self.store.set_receiver_state(ClientState.ERROR) |
| 223 | + return False |
| 224 | + else: |
| 225 | + raise TimeoutError("Download timeout") |
| 226 | + return True |
| 227 | + |
| 228 | + async def _handle_download_fatal_error(self, error: Exception): |
| 229 | + """Handle unexpected download errors.""" |
| 230 | + self.error(f"▼ Unexpected download error: {error}", exc_info=True) |
| 231 | + await self.store.set_receiver_state(ClientState.ERROR) |
| 232 | + await self.set_interrupted() |
| 233 | + |
117 | 234 | async def collect_upload(self, stream: AsyncIterator[bytes], resume_from: int = 0) -> None: |
118 | 235 | """Collect file data from sender and store in Redis stream.""" |
119 | 236 | bytes_uploaded = resume_from |
@@ -195,135 +312,61 @@ async def collect_upload(self, stream: AsyncIterator[bytes], resume_from: int = |
195 | 312 |
|
196 | 313 | async def supply_download(self, start_byte: int = 0, end_byte: Optional[int] = None) -> AsyncIterator[bytes]: |
197 | 314 | """Stream file data to the receiver.""" |
198 | | - stream_position = 0 # Current position in the stream we've read to |
199 | | - bytes_sent = 0 # Bytes sent to client |
| 315 | + bytes_sent = 0 |
200 | 316 | bytes_to_send = (end_byte - start_byte + 1) if end_byte else (self.file.size - start_byte) |
201 | | - last_chunk_id = '0' |
202 | 317 | is_range_request = end_byte is not None |
203 | 318 |
|
| 319 | + stream_position, last_chunk_id = await self._initialize_download_state(start_byte, is_range_request) |
204 | 320 | await self.store.set_receiver_state(ClientState.ACTIVE) |
205 | 321 |
|
206 | | - if start_byte > 0: |
207 | | - self.info(f"▼ Starting download from byte {start_byte}") |
208 | | - if not is_range_request: |
209 | | - # For live streams starting mid-file, check if we have previous progress |
210 | | - progress = await self.store.get_download_progress() |
211 | | - if progress and progress.bytes_downloaded >= start_byte: |
212 | | - last_chunk_id = progress.last_read_id |
213 | | - stream_position = progress.bytes_downloaded |
214 | | - |
215 | 322 | self.debug(f"▼ Range request: {start_byte}-{end_byte or 'end'}, to_send: {bytes_to_send}") |
216 | 323 |
|
217 | 324 | try: |
218 | 325 | while bytes_sent < bytes_to_send: |
219 | | - try: |
220 | | - if is_range_request: |
221 | | - # For range requests, use non-blocking reads from existing stream data |
222 | | - result = await self.store.get_chunk_by_range(last_chunk_id) |
223 | | - if not result: |
224 | | - # Check if sender is still uploading |
225 | | - sender_state = await self.store.get_sender_state() |
226 | | - if sender_state == ClientState.COMPLETE: |
227 | | - # Upload is complete but no more chunks - we're done |
228 | | - break |
229 | | - elif sender_state == ClientState.DISCONNECTED: |
230 | | - if not await self._wait_for_reconnection("sender"): |
231 | | - await self.store.set_receiver_state(ClientState.ERROR) |
232 | | - return |
233 | | - await anyio.sleep(0.1) |
234 | | - continue |
235 | | - chunk_id, chunk_data = result |
236 | | - else: |
237 | | - # For live streams, use blocking reads |
238 | | - chunk_id, chunk_data = await self.store.get_next_chunk( |
239 | | - timeout=self.STREAM_TIMEOUT, |
240 | | - last_id=last_chunk_id |
241 | | - ) |
242 | | - |
243 | | - last_chunk_id = chunk_id |
244 | | - |
245 | | - if chunk_data == self.DONE_FLAG: |
246 | | - self.debug("▼ Done marker received") |
247 | | - await self.store.set_receiver_state(ClientState.COMPLETE) |
248 | | - break |
249 | | - elif chunk_data == self.DEAD_FLAG: |
250 | | - self.warning("▼ Dead marker received") |
251 | | - await self.store.set_receiver_state(ClientState.ERROR) |
252 | | - return |
| 326 | + # Get next chunk |
| 327 | + result = await self._get_next_chunk(last_chunk_id, is_range_request) |
| 328 | + if result is None: |
| 329 | + break |
| 330 | + if result[0] == 'wait': |
| 331 | + await anyio.sleep(0.1) |
| 332 | + continue |
253 | 333 |
|
254 | | - # Skip bytes until we reach start_byte |
255 | | - if stream_position < start_byte: |
256 | | - bytes_in_chunk = len(chunk_data) |
257 | | - skip = min(bytes_in_chunk, start_byte - stream_position) |
258 | | - chunk_data = chunk_data[skip:] |
259 | | - stream_position += skip |
260 | | - |
261 | | - # If we still haven't reached start_byte, move to next chunk |
262 | | - if stream_position < start_byte: |
263 | | - stream_position += len(chunk_data) |
264 | | - continue |
265 | | - |
266 | | - # Send only the bytes we need for this range |
267 | | - if len(chunk_data) > 0: |
268 | | - remaining = bytes_to_send - bytes_sent |
269 | | - if len(chunk_data) > remaining: |
270 | | - chunk_data = chunk_data[:remaining] |
271 | | - |
272 | | - yield chunk_data |
273 | | - bytes_sent += len(chunk_data) |
274 | | - stream_position += len(chunk_data) |
275 | | - |
276 | | - # Save progress periodically for resumption |
277 | | - if stream_position % (64 * 1024) == 0: |
278 | | - await self.store.save_download_progress( |
279 | | - bytes_downloaded=stream_position, |
280 | | - last_read_id=last_chunk_id |
281 | | - ) |
282 | | - |
283 | | - except TimeoutError: |
284 | | - self.info("▼ Timeout waiting for data") |
285 | | - sender_state = await self.store.get_sender_state() |
286 | | - if sender_state == ClientState.DISCONNECTED: |
287 | | - if not await self._wait_for_reconnection("sender"): |
288 | | - await self.store.set_receiver_state(ClientState.ERROR) |
289 | | - return |
290 | | - else: |
291 | | - raise |
| 334 | + chunk_id, chunk_data = result |
| 335 | + last_chunk_id = chunk_id |
292 | 336 |
|
293 | | - # Determine completion status |
294 | | - if is_range_request: |
295 | | - # For range requests, just log completion but don't mark transfer as complete |
296 | | - # Multiple ranges may be downloading different parts of the same file |
297 | | - self.info(f"▼ Range download complete ({bytes_sent} bytes from {start_byte}-{end_byte or 'end'})") |
298 | | - else: |
299 | | - # For full downloads, check if entire file was downloaded |
300 | | - total_downloaded = start_byte + bytes_sent |
301 | | - if total_downloaded >= self.file.size: |
302 | | - self.info("▼ Full download complete") |
| 337 | + # Check for control flags |
| 338 | + if chunk_data == self.DONE_FLAG: |
| 339 | + self.debug("▼ Done marker received") |
303 | 340 | await self.store.set_receiver_state(ClientState.COMPLETE) |
304 | | - else: |
305 | | - self.info(f"▼ Download incomplete ({total_downloaded}/{self.file.size} bytes)") |
306 | | - await self.store.save_download_progress( |
307 | | - bytes_downloaded=stream_position, |
308 | | - last_read_id=last_chunk_id |
309 | | - ) |
310 | | - |
311 | | - except (ConnectionError, WebSocketDisconnect) as e: |
312 | | - self.warning(f"▼ Download disconnected: {e}") |
313 | | - await self.store.save_download_progress( |
314 | | - bytes_downloaded=stream_position, |
315 | | - last_read_id=last_chunk_id |
| 341 | + break |
| 342 | + elif chunk_data == self.DEAD_FLAG: |
| 343 | + self.warning("▼ Dead marker received") |
| 344 | + await self.store.set_receiver_state(ClientState.ERROR) |
| 345 | + return |
| 346 | + |
| 347 | + # Process chunk for byte range |
| 348 | + chunk_to_send, stream_position = self._adjust_chunk_for_range( |
| 349 | + chunk_data, stream_position, start_byte, bytes_sent, bytes_to_send |
| 350 | + ) |
| 351 | + |
| 352 | + # Yield data if we have any |
| 353 | + if chunk_to_send: |
| 354 | + yield chunk_to_send |
| 355 | + bytes_sent += len(chunk_to_send) |
| 356 | + await self._save_progress_if_needed(stream_position, last_chunk_id) |
| 357 | + |
| 358 | + # Handle completion |
| 359 | + await self._finalize_download_status( |
| 360 | + bytes_sent, stream_position, start_byte, end_byte, last_chunk_id |
316 | 361 | ) |
317 | | - await self.store.set_receiver_state(ClientState.DISCONNECTED) |
318 | | - |
319 | | - if not await self._wait_for_reconnection("receiver"): |
320 | | - await self.store.set_receiver_state(ClientState.ERROR) |
321 | | - await self.set_interrupted() |
322 | 362 |
|
| 363 | + except TimeoutError: |
| 364 | + if not await self._handle_download_timeout(stream_position, last_chunk_id): |
| 365 | + return |
| 366 | + except (ConnectionError, WebSocketDisconnect) as e: |
| 367 | + await self._handle_download_disconnect(e, stream_position, last_chunk_id) |
323 | 368 | except Exception as e: |
324 | | - self.error(f"▼ Unexpected download error: {e}", exc_info=True) |
325 | | - await self.store.set_receiver_state(ClientState.ERROR) |
326 | | - await self.set_interrupted() |
| 369 | + await self._handle_download_fatal_error(e) |
327 | 370 |
|
328 | 371 | async def finalize_download(self): |
329 | 372 | """Finalize download and potentially clean up.""" |
|
0 commit comments