diff --git a/backend/python/qwen-asr/backend.py b/backend/python/qwen-asr/backend.py index 196f8f439fb4..106284154918 100644 --- a/backend/python/qwen-asr/backend.py +++ b/backend/python/qwen-asr/backend.py @@ -3,6 +3,7 @@ gRPC server of LocalAI for Qwen3-ASR (transformers backend, non-vLLM). """ from concurrent import futures +import threading import time import argparse import signal @@ -108,22 +109,28 @@ def LoadModel(self, request, context): ) if attn_implementation: load_kwargs["attn_implementation"] = attn_implementation + + # Save for lazy-loading the forced-aligner variant later. + self.model_path = model_path + self._load_kwargs = dict(load_kwargs) + self._ts_model = None + self._ts_lock = threading.Lock() + self._forced_aligner_name = forced_aligner + self._forced_aligner_kwargs = {} if forced_aligner: - load_kwargs["forced_aligner"] = forced_aligner - forced_aligner_kwargs = dict( + self._forced_aligner_kwargs = dict( dtype=load_dtype, device_map=device_map, ) if attn_implementation: - forced_aligner_kwargs["attn_implementation"] = attn_implementation - load_kwargs["forced_aligner_kwargs"] = forced_aligner_kwargs + self._forced_aligner_kwargs["attn_implementation"] = attn_implementation try: print(f"Loading Qwen3-ASR from {model_path}", file=sys.stderr) if attn_implementation: print(f"Using attn_implementation: {attn_implementation}", file=sys.stderr) - if forced_aligner: - print(f"Loading with forced_aligner: {forced_aligner}", file=sys.stderr) + # Load the base model WITHOUT forced_aligner — keeps VRAM lean + # when timestamps are not needed. self.model = Qwen3ASRModel.from_pretrained(model_path, **load_kwargs) print("Qwen3-ASR model loaded successfully", file=sys.stderr) except Exception as err: @@ -134,6 +141,36 @@ def LoadModel(self, request, context): return backend_pb2.Result(message="Model loaded successfully", success=True) + def _get_ts_model(self): + """Return a model instance with forced_aligner loaded (lazy, cached). + + The first call loads a second model copy with the forced_aligner + attached; subsequent calls return the cached instance. Thread-safe. + """ + if self._ts_model is not None: + return self._ts_model + if not self._forced_aligner_name: + print("WARNING: timestamps requested but no forced_aligner configured; " + "returning plain text without timestamps", file=sys.stderr) + return None # no aligner configured — signal caller to fall back + with self._ts_lock: + if self._ts_model is not None: + return self._ts_model + load_kwargs = dict(self._load_kwargs) + load_kwargs["forced_aligner"] = self._forced_aligner_name + if self._forced_aligner_kwargs: + load_kwargs["forced_aligner_kwargs"] = self._forced_aligner_kwargs + print(f"Lazy-loading forced_aligner: {self._forced_aligner_name}", file=sys.stderr) + self._ts_model = Qwen3ASRModel.from_pretrained( + self.model_path, **load_kwargs + ) + # Drop the base-only copy to avoid holding both in VRAM. + if self.model is not None: + del self.model + self.model = None + print("Forced-aligner model loaded", file=sys.stderr) + return self._ts_model + @staticmethod def _is_cjk(ch): """Check if a character is CJK (Chinese/Japanese/Korean).""" @@ -205,82 +242,76 @@ def _extract_word_info(ts): return (0.0, 0.0, "") @staticmethod - def _compute_gap_threshold(time_stamps): + def _compute_gap_threshold_from_extracted(extracted): """Compute a gap threshold for sentence boundary detection. - Uses the median inter-item gap multiplied by a factor, with a - minimum floor of 0.3s. Returns 0 if there are too few items. + Accepts pre-extracted (start, end, text) tuples. Uses the median + inter-item gap multiplied by a factor, with a minimum floor of 0.3s. + Returns 0 if there are too few items. """ - if len(time_stamps) < 2: + if len(extracted) < 2: return 0.0 gaps = [] - for i in range(1, len(time_stamps)): - prev_s, prev_e, _ = BackendServicer._extract_word_info(time_stamps[i - 1]) - curr_s, _, _ = BackendServicer._extract_word_info(time_stamps[i]) - gaps.append(curr_s - prev_e) + for i in range(1, len(extracted)): + gaps.append(extracted[i][0] - extracted[i - 1][1]) if not gaps: return 0.0 gaps.sort() median = gaps[len(gaps) // 2] - # threshold = max(median * 4, 0.3s) return max(median * 4, 0.3) def _build_segments(self, time_stamps, granularity): """Build TranscriptSegment list from forced-aligner output. - granularity: - - "word": one segment per aligned item (character / word) - - "segment" (default): merge consecutive items, splitting at - time gaps that exceed a dynamic threshold (sentence boundaries). + For "word" granularity, each word is placed in the ``words`` field + of the enclosing sentence-level segment (populated via gap-based + merging). This mirrors the OpenAI ``verbose_json`` format where + ``segments[].words`` contains the word-level alignment. + + For "segment" granularity (default), only sentence-level segments + are returned with no ``words`` children. """ - if granularity == "word": - result = [] - for idx, ts in enumerate(time_stamps): - s, e, t = self._extract_word_info(ts) - result.append(backend_pb2.TranscriptSegment( - id=idx, - start=int(s * 1_000_000_000), - end=int(e * 1_000_000_000), - text=t, - )) - return result - - # segment mode — merge at time-gap boundaries - threshold = self._compute_gap_threshold(time_stamps) - result = [] - buf_text = [] - buf_start = None - buf_end = 0.0 + # Always compute sentence-level segments via gap merging. + # Extract word info once and reuse throughout. + extracted = [self._extract_word_info(ts) for ts in time_stamps] + threshold = self._compute_gap_threshold_from_extracted(extracted) + sentence_groups = [] # list of list of (s, e, t) + buf = [] prev_end = None - for ts in time_stamps: - s, e, t = self._extract_word_info(ts) - - # Detect sentence boundary via time gap - if prev_end is not None and (s - prev_end) >= threshold and buf_text: - result.append(backend_pb2.TranscriptSegment( - id=len(result), - start=int(buf_start * 1_000_000_000), - end=int(buf_end * 1_000_000_000), - text=self._smart_join(buf_text), - )) - buf_text = [] - buf_start = None - - if buf_start is None: - buf_start = s - buf_text.append(t) - buf_end = e + for info in extracted: + s, e, t = info + if prev_end is not None and (s - prev_end) >= threshold and buf: + sentence_groups.append(buf) + buf = [] + buf.append(info) prev_end = e + if buf: + sentence_groups.append(buf) - # flush remaining - if buf_text and buf_start is not None: - result.append(backend_pb2.TranscriptSegment( + result = [] + for group in sentence_groups: + seg_start = group[0][0] + seg_end = group[-1][1] + seg_text = self._smart_join([w[2] for w in group if w[2]]) + + seg = backend_pb2.TranscriptSegment( id=len(result), - start=int(buf_start * 1_000_000_000), - end=int(buf_end * 1_000_000_000), - text=self._smart_join(buf_text), - )) + start=int(seg_start * 1_000_000_000), + end=int(seg_end * 1_000_000_000), + text=seg_text, + ) + + if granularity == "word": + for ws, we, wt in group: + if wt: + seg.words.append(backend_pb2.TranscriptWord( + start=int(ws * 1_000_000_000), + end=int(we * 1_000_000_000), + text=wt, + )) + + result.append(seg) return result @@ -303,16 +334,31 @@ def AudioTranscription(self, request, context): # Determine requested granularity (default: segment) granularities = list(request.timestamp_granularities) if request.timestamp_granularities else [] + want_timestamps = len(granularities) > 0 granularity = "word" if "word" in granularities else "segment" - has_aligner = getattr(self.model, 'forced_aligner', None) is not None + # Select model: with or without forced aligner + if want_timestamps: + ts_model = self._get_ts_model() + if ts_model is None: + # No aligner configured — fall back to plain transcription + model = self.model + has_aligner = False + want_timestamps = False + else: + model = ts_model + has_aligner = True + else: + model = self.model + has_aligner = False + try: - results = self.model.transcribe( + results = model.transcribe( audio=audio_path, language=language, context=ctx, return_time_stamps=has_aligner, ) except TypeError: - results = self.model.transcribe(audio=audio_path, language=language, context=ctx) + results = model.transcribe(audio=audio_path, language=language, context=ctx) if not results: return backend_pb2.TranscriptResult(segments=[], text="")