|
1 | 1 | import os |
2 | 2 | from dataclasses import dataclass |
3 | 3 | from pathlib import Path |
4 | | -from typing import Any, Dict, Iterable, List |
| 4 | +from typing import Any, Iterable, List |
5 | 5 |
|
6 | 6 | import click |
7 | 7 | import structlog |
|
10 | 10 | from IPython.core.magic_arguments import argument, magic_arguments |
11 | 11 | from IPython.utils.process import arg_split |
12 | 12 | from rich import print as rprint |
13 | | -from rich.progress import Progress, TaskID |
| 13 | +from tqdm.auto import tqdm |
14 | 14 | from traitlets import Float, Unicode |
15 | 15 | from traitlets.config import Configurable |
16 | 16 |
|
@@ -142,29 +142,27 @@ def process_file_update_stream(path: str, stream: DatasetOperationStream): |
142 | 142 | error_message = None |
143 | 143 | complete_message = None |
144 | 144 |
|
145 | | - with Progress() as progress: |
146 | | - tasks_by_file_path: Dict[str, TaskID] = {} |
| 145 | + progress_bars = {} |
147 | 146 |
|
| 147 | + try: |
148 | 148 | for msg in stream: |
149 | 149 | if isinstance(msg, StreamErrorMessage): |
150 | 150 | error_message = msg.content.detail |
151 | 151 | break |
152 | 152 | elif isinstance(msg, FileProgressUpdateMessage): |
153 | 153 | got_file_update_msg = True |
154 | 154 |
|
155 | | - if msg.content.file_name not in tasks_by_file_path: |
156 | | - tasks_by_file_path[msg.content.file_name] = progress.add_task( |
157 | | - msg.content.file_name, total=100.0 |
158 | | - ) |
| 155 | + if msg.content.file_name not in progress_bars: |
| 156 | + progress_bars[msg.content.file_name] = tqdm(total=100.0, desc=msg.content.file_name) |
159 | 157 |
|
160 | | - progress.update( |
161 | | - tasks_by_file_path[msg.content.file_name], |
162 | | - completed=msg.content.percent_complete * 100.0, |
163 | | - ) |
| 158 | + progress_bars[msg.content.file_name].update(msg.content.percent_complete * 100.0) |
164 | 159 | elif isinstance(msg, FileProgressStartMessage): |
165 | | - progress.console.print(msg.content.message) |
| 160 | + print(msg.content.message) |
166 | 161 | elif isinstance(msg, FileProgressEndMessage) and got_file_update_msg: |
167 | 162 | complete_message = msg.content.message |
| 163 | + finally: |
| 164 | + for bar in progress_bars.values(): |
| 165 | + bar.close() |
168 | 166 |
|
169 | 167 | if error_message: |
170 | 168 | rprint(f"[red]{error_message}[/red]") |
|
0 commit comments