-
-
Notifications
You must be signed in to change notification settings - Fork 4.1k
fix(qwen-asr): lazy-load forced_aligner and populate word-level timestamps #10054
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
d37e18d
29ed84e
39933e8
c404785
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -3,6 +3,7 @@ | |
| gRPC server of LocalAI for Qwen3-ASR (transformers backend, non-vLLM). | ||
| """ | ||
| from concurrent import futures | ||
| import threading | ||
| import time | ||
| import argparse | ||
| import signal | ||
|
|
@@ -108,22 +109,28 @@ def LoadModel(self, request, context): | |
| ) | ||
| if attn_implementation: | ||
| load_kwargs["attn_implementation"] = attn_implementation | ||
|
|
||
| # Save for lazy-loading the forced-aligner variant later. | ||
| self.model_path = model_path | ||
| self._load_kwargs = dict(load_kwargs) | ||
| self._ts_model = None | ||
| self._ts_lock = threading.Lock() | ||
| self._forced_aligner_name = forced_aligner | ||
| self._forced_aligner_kwargs = {} | ||
| if forced_aligner: | ||
| load_kwargs["forced_aligner"] = forced_aligner | ||
| forced_aligner_kwargs = dict( | ||
| self._forced_aligner_kwargs = dict( | ||
| dtype=load_dtype, | ||
| device_map=device_map, | ||
| ) | ||
| if attn_implementation: | ||
| forced_aligner_kwargs["attn_implementation"] = attn_implementation | ||
| load_kwargs["forced_aligner_kwargs"] = forced_aligner_kwargs | ||
| self._forced_aligner_kwargs["attn_implementation"] = attn_implementation | ||
|
|
||
| try: | ||
| print(f"Loading Qwen3-ASR from {model_path}", file=sys.stderr) | ||
| if attn_implementation: | ||
| print(f"Using attn_implementation: {attn_implementation}", file=sys.stderr) | ||
| if forced_aligner: | ||
| print(f"Loading with forced_aligner: {forced_aligner}", file=sys.stderr) | ||
| # Load the base model WITHOUT forced_aligner — keeps VRAM lean | ||
| # when timestamps are not needed. | ||
| self.model = Qwen3ASRModel.from_pretrained(model_path, **load_kwargs) | ||
| print("Qwen3-ASR model loaded successfully", file=sys.stderr) | ||
| except Exception as err: | ||
|
|
@@ -134,6 +141,36 @@ def LoadModel(self, request, context): | |
|
|
||
| return backend_pb2.Result(message="Model loaded successfully", success=True) | ||
|
|
||
| def _get_ts_model(self): | ||
| """Return a model instance with forced_aligner loaded (lazy, cached). | ||
|
|
||
| The first call loads a second model copy with the forced_aligner | ||
| attached; subsequent calls return the cached instance. Thread-safe. | ||
| """ | ||
| if self._ts_model is not None: | ||
| return self._ts_model | ||
| if not self._forced_aligner_name: | ||
| print("WARNING: timestamps requested but no forced_aligner configured; " | ||
| "returning plain text without timestamps", file=sys.stderr) | ||
| return None # no aligner configured — signal caller to fall back | ||
| with self._ts_lock: | ||
| if self._ts_model is not None: | ||
| return self._ts_model | ||
| load_kwargs = dict(self._load_kwargs) | ||
| load_kwargs["forced_aligner"] = self._forced_aligner_name | ||
| if self._forced_aligner_kwargs: | ||
| load_kwargs["forced_aligner_kwargs"] = self._forced_aligner_kwargs | ||
| print(f"Lazy-loading forced_aligner: {self._forced_aligner_name}", file=sys.stderr) | ||
| self._ts_model = Qwen3ASRModel.from_pretrained( | ||
| self.model_path, **load_kwargs | ||
| ) | ||
| # Drop the base-only copy to avoid holding both in VRAM. | ||
| if self.model is not None: | ||
| del self.model | ||
| self.model = None | ||
| print("Forced-aligner model loaded", file=sys.stderr) | ||
|
Comment on lines
+163
to
+171
|
||
| return self._ts_model | ||
|
|
||
| @staticmethod | ||
| def _is_cjk(ch): | ||
| """Check if a character is CJK (Chinese/Japanese/Korean).""" | ||
|
|
@@ -205,82 +242,76 @@ def _extract_word_info(ts): | |
| return (0.0, 0.0, "") | ||
|
|
||
| @staticmethod | ||
| def _compute_gap_threshold(time_stamps): | ||
| def _compute_gap_threshold_from_extracted(extracted): | ||
| """Compute a gap threshold for sentence boundary detection. | ||
|
|
||
| Uses the median inter-item gap multiplied by a factor, with a | ||
| minimum floor of 0.3s. Returns 0 if there are too few items. | ||
| Accepts pre-extracted (start, end, text) tuples. Uses the median | ||
| inter-item gap multiplied by a factor, with a minimum floor of 0.3s. | ||
| Returns 0 if there are too few items. | ||
| """ | ||
| if len(time_stamps) < 2: | ||
| if len(extracted) < 2: | ||
| return 0.0 | ||
| gaps = [] | ||
| for i in range(1, len(time_stamps)): | ||
| prev_s, prev_e, _ = BackendServicer._extract_word_info(time_stamps[i - 1]) | ||
| curr_s, _, _ = BackendServicer._extract_word_info(time_stamps[i]) | ||
| gaps.append(curr_s - prev_e) | ||
| for i in range(1, len(extracted)): | ||
| gaps.append(extracted[i][0] - extracted[i - 1][1]) | ||
| if not gaps: | ||
| return 0.0 | ||
| gaps.sort() | ||
| median = gaps[len(gaps) // 2] | ||
| # threshold = max(median * 4, 0.3s) | ||
| return max(median * 4, 0.3) | ||
|
|
||
| def _build_segments(self, time_stamps, granularity): | ||
| """Build TranscriptSegment list from forced-aligner output. | ||
|
|
||
| granularity: | ||
| - "word": one segment per aligned item (character / word) | ||
| - "segment" (default): merge consecutive items, splitting at | ||
| time gaps that exceed a dynamic threshold (sentence boundaries). | ||
| For "word" granularity, each word is placed in the ``words`` field | ||
| of the enclosing sentence-level segment (populated via gap-based | ||
| merging). This mirrors the OpenAI ``verbose_json`` format where | ||
| ``segments[].words`` contains the word-level alignment. | ||
|
|
||
| For "segment" granularity (default), only sentence-level segments | ||
| are returned with no ``words`` children. | ||
| """ | ||
| if granularity == "word": | ||
| result = [] | ||
| for idx, ts in enumerate(time_stamps): | ||
| s, e, t = self._extract_word_info(ts) | ||
| result.append(backend_pb2.TranscriptSegment( | ||
| id=idx, | ||
| start=int(s * 1_000_000_000), | ||
| end=int(e * 1_000_000_000), | ||
| text=t, | ||
| )) | ||
| return result | ||
|
|
||
| # segment mode — merge at time-gap boundaries | ||
| threshold = self._compute_gap_threshold(time_stamps) | ||
| result = [] | ||
| buf_text = [] | ||
| buf_start = None | ||
| buf_end = 0.0 | ||
| # Always compute sentence-level segments via gap merging. | ||
| # Extract word info once and reuse throughout. | ||
| extracted = [self._extract_word_info(ts) for ts in time_stamps] | ||
| threshold = self._compute_gap_threshold_from_extracted(extracted) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Now that @staticmethod
def _compute_gap_threshold(time_stamps):
return BackendServicer._compute_gap_threshold_from_extracted(
[BackendServicer._extract_word_info(ts) for ts in time_stamps]
)
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done — removed |
||
| sentence_groups = [] # list of list of (s, e, t) | ||
| buf = [] | ||
| prev_end = None | ||
|
|
||
| for ts in time_stamps: | ||
| s, e, t = self._extract_word_info(ts) | ||
|
|
||
| # Detect sentence boundary via time gap | ||
| if prev_end is not None and (s - prev_end) >= threshold and buf_text: | ||
| result.append(backend_pb2.TranscriptSegment( | ||
| id=len(result), | ||
| start=int(buf_start * 1_000_000_000), | ||
| end=int(buf_end * 1_000_000_000), | ||
| text=self._smart_join(buf_text), | ||
| )) | ||
| buf_text = [] | ||
| buf_start = None | ||
|
|
||
| if buf_start is None: | ||
| buf_start = s | ||
| buf_text.append(t) | ||
| buf_end = e | ||
| for info in extracted: | ||
| s, e, t = info | ||
| if prev_end is not None and (s - prev_end) >= threshold and buf: | ||
| sentence_groups.append(buf) | ||
| buf = [] | ||
| buf.append(info) | ||
| prev_end = e | ||
| if buf: | ||
| sentence_groups.append(buf) | ||
|
|
||
| # flush remaining | ||
| if buf_text and buf_start is not None: | ||
| result.append(backend_pb2.TranscriptSegment( | ||
| result = [] | ||
| for group in sentence_groups: | ||
| seg_start = group[0][0] | ||
| seg_end = group[-1][1] | ||
| seg_text = self._smart_join([w[2] for w in group if w[2]]) | ||
|
|
||
| seg = backend_pb2.TranscriptSegment( | ||
| id=len(result), | ||
| start=int(buf_start * 1_000_000_000), | ||
| end=int(buf_end * 1_000_000_000), | ||
| text=self._smart_join(buf_text), | ||
| )) | ||
| start=int(seg_start * 1_000_000_000), | ||
| end=int(seg_end * 1_000_000_000), | ||
| text=seg_text, | ||
| ) | ||
|
|
||
| if granularity == "word": | ||
| for ws, we, wt in group: | ||
| if wt: | ||
| seg.words.append(backend_pb2.TranscriptWord( | ||
| start=int(ws * 1_000_000_000), | ||
| end=int(we * 1_000_000_000), | ||
| text=wt, | ||
| )) | ||
|
|
||
| result.append(seg) | ||
|
|
||
| return result | ||
|
|
||
|
|
@@ -303,16 +334,31 @@ def AudioTranscription(self, request, context): | |
|
|
||
| # Determine requested granularity (default: segment) | ||
| granularities = list(request.timestamp_granularities) if request.timestamp_granularities else [] | ||
| want_timestamps = len(granularities) > 0 | ||
| granularity = "word" if "word" in granularities else "segment" | ||
|
|
||
| has_aligner = getattr(self.model, 'forced_aligner', None) is not None | ||
| # Select model: with or without forced aligner | ||
| if want_timestamps: | ||
| ts_model = self._get_ts_model() | ||
| if ts_model is None: | ||
| # No aligner configured — fall back to plain transcription | ||
| model = self.model | ||
| has_aligner = False | ||
| want_timestamps = False | ||
| else: | ||
| model = ts_model | ||
| has_aligner = True | ||
| else: | ||
| model = self.model | ||
| has_aligner = False | ||
|
Comment on lines
+340
to
+353
|
||
|
|
||
| try: | ||
| results = self.model.transcribe( | ||
| results = model.transcribe( | ||
| audio=audio_path, language=language, context=ctx, | ||
| return_time_stamps=has_aligner, | ||
| ) | ||
| except TypeError: | ||
| results = self.model.transcribe(audio=audio_path, language=language, context=ctx) | ||
| results = model.transcribe(audio=audio_path, language=language, context=ctx) | ||
|
|
||
| if not results: | ||
| return backend_pb2.TranscriptResult(segments=[], text="") | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The VRAM table in the PR description lists the timestamps-requested path as ~4.7 GB both before and after, but this loads a second full
Qwen3ASRModel(base ASR weights + aligner) whileself.model(the base copy) is never freed. So once timestamps are requested you hold both — roughlyself.model(~3.5 GB) +self._ts_model(~4.7 GB) ≈ 8.2 GB, which is actually worse than the old eager behavior for any workload that does request timestamps. The win is real only for pure-transcription workloads.That may be an acceptable trade-off, but two questions:
from_pretrained(forced_aligner=...)actually reload the full ASR backbone, or can the aligner be attached to the existingself.modelin place? If it can attach, the duplication disappears.self.modelbe dropped once_ts_modelis loaded (under the lock) to avoid double-holding VRAM?Either way, the description's table should be corrected.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch on the VRAM table — it was misleading. Fixed:
from_pretrained(forced_aligner=...) **does** reload the full backbone, so we nowdel self.modelimmediately after loading_ts_model` to avoid holding both copies:This keeps peak VRAM at ~4.7 GB (single copy) instead of ~8.2 GB (double). Updated VRAM table: