Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
176 changes: 111 additions & 65 deletions backend/python/qwen-asr/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
gRPC server of LocalAI for Qwen3-ASR (transformers backend, non-vLLM).
"""
from concurrent import futures
import threading
import time
import argparse
import signal
Expand Down Expand Up @@ -108,22 +109,28 @@ def LoadModel(self, request, context):
)
if attn_implementation:
load_kwargs["attn_implementation"] = attn_implementation

# Save for lazy-loading the forced-aligner variant later.
self.model_path = model_path
self._load_kwargs = dict(load_kwargs)
self._ts_model = None
self._ts_lock = threading.Lock()
self._forced_aligner_name = forced_aligner
self._forced_aligner_kwargs = {}
if forced_aligner:
load_kwargs["forced_aligner"] = forced_aligner
forced_aligner_kwargs = dict(
self._forced_aligner_kwargs = dict(
dtype=load_dtype,
device_map=device_map,
)
if attn_implementation:
forced_aligner_kwargs["attn_implementation"] = attn_implementation
load_kwargs["forced_aligner_kwargs"] = forced_aligner_kwargs
self._forced_aligner_kwargs["attn_implementation"] = attn_implementation

try:
print(f"Loading Qwen3-ASR from {model_path}", file=sys.stderr)
if attn_implementation:
print(f"Using attn_implementation: {attn_implementation}", file=sys.stderr)
if forced_aligner:
print(f"Loading with forced_aligner: {forced_aligner}", file=sys.stderr)
# Load the base model WITHOUT forced_aligner — keeps VRAM lean
# when timestamps are not needed.
self.model = Qwen3ASRModel.from_pretrained(model_path, **load_kwargs)
print("Qwen3-ASR model loaded successfully", file=sys.stderr)
except Exception as err:
Expand All @@ -134,6 +141,36 @@ def LoadModel(self, request, context):

return backend_pb2.Result(message="Model loaded successfully", success=True)

def _get_ts_model(self):
"""Return a model instance with forced_aligner loaded (lazy, cached).

The first call loads a second model copy with the forced_aligner
attached; subsequent calls return the cached instance. Thread-safe.
"""
if self._ts_model is not None:
return self._ts_model
if not self._forced_aligner_name:
print("WARNING: timestamps requested but no forced_aligner configured; "
"returning plain text without timestamps", file=sys.stderr)
return None # no aligner configured — signal caller to fall back
with self._ts_lock:
if self._ts_model is not None:
return self._ts_model
load_kwargs = dict(self._load_kwargs)
load_kwargs["forced_aligner"] = self._forced_aligner_name
if self._forced_aligner_kwargs:
load_kwargs["forced_aligner_kwargs"] = self._forced_aligner_kwargs
print(f"Lazy-loading forced_aligner: {self._forced_aligner_name}", file=sys.stderr)
self._ts_model = Qwen3ASRModel.from_pretrained(
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The VRAM table in the PR description lists the timestamps-requested path as ~4.7 GB both before and after, but this loads a second full Qwen3ASRModel (base ASR weights + aligner) while self.model (the base copy) is never freed. So once timestamps are requested you hold both — roughly self.model (~3.5 GB) + self._ts_model (~4.7 GB) ≈ 8.2 GB, which is actually worse than the old eager behavior for any workload that does request timestamps. The win is real only for pure-transcription workloads.

That may be an acceptable trade-off, but two questions:

  • Does from_pretrained(forced_aligner=...) actually reload the full ASR backbone, or can the aligner be attached to the existing self.model in place? If it can attach, the duplication disappears.
  • If duplication is unavoidable, should self.model be dropped once _ts_model is loaded (under the lock) to avoid double-holding VRAM?

Either way, the description's table should be corrected.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch on the VRAM table — it was misleading. Fixed:

from_pretrained(forced_aligner=...) **does** reload the full backbone, so we now del self.modelimmediately after loading_ts_model` to avoid holding both copies:

self._ts_model = Qwen3ASRModel.from_pretrained(self.model_path, **load_kwargs)
if self.model is not None:
    del self.model
    self.model = None

This keeps peak VRAM at ~4.7 GB (single copy) instead of ~8.2 GB (double). Updated VRAM table:

Scenario Before (old) After (this PR)
No timestamps ~4.7 GB (aligner always loaded) ~3.5 GB (aligner not loaded)
Timestamps requested ~4.7 GB ~4.7 GB (lazy-loaded, base freed)

self.model_path, **load_kwargs
)
# Drop the base-only copy to avoid holding both in VRAM.
if self.model is not None:
del self.model
self.model = None
print("Forced-aligner model loaded", file=sys.stderr)
Comment on lines +163 to +171
return self._ts_model

@staticmethod
def _is_cjk(ch):
"""Check if a character is CJK (Chinese/Japanese/Korean)."""
Expand Down Expand Up @@ -205,82 +242,76 @@ def _extract_word_info(ts):
return (0.0, 0.0, "")

@staticmethod
def _compute_gap_threshold(time_stamps):
def _compute_gap_threshold_from_extracted(extracted):
"""Compute a gap threshold for sentence boundary detection.

Uses the median inter-item gap multiplied by a factor, with a
minimum floor of 0.3s. Returns 0 if there are too few items.
Accepts pre-extracted (start, end, text) tuples. Uses the median
inter-item gap multiplied by a factor, with a minimum floor of 0.3s.
Returns 0 if there are too few items.
"""
if len(time_stamps) < 2:
if len(extracted) < 2:
return 0.0
gaps = []
for i in range(1, len(time_stamps)):
prev_s, prev_e, _ = BackendServicer._extract_word_info(time_stamps[i - 1])
curr_s, _, _ = BackendServicer._extract_word_info(time_stamps[i])
gaps.append(curr_s - prev_e)
for i in range(1, len(extracted)):
gaps.append(extracted[i][0] - extracted[i - 1][1])
if not gaps:
return 0.0
gaps.sort()
median = gaps[len(gaps) // 2]
# threshold = max(median * 4, 0.3s)
return max(median * 4, 0.3)

def _build_segments(self, time_stamps, granularity):
"""Build TranscriptSegment list from forced-aligner output.

granularity:
- "word": one segment per aligned item (character / word)
- "segment" (default): merge consecutive items, splitting at
time gaps that exceed a dynamic threshold (sentence boundaries).
For "word" granularity, each word is placed in the ``words`` field
of the enclosing sentence-level segment (populated via gap-based
merging). This mirrors the OpenAI ``verbose_json`` format where
``segments[].words`` contains the word-level alignment.

For "segment" granularity (default), only sentence-level segments
are returned with no ``words`` children.
"""
if granularity == "word":
result = []
for idx, ts in enumerate(time_stamps):
s, e, t = self._extract_word_info(ts)
result.append(backend_pb2.TranscriptSegment(
id=idx,
start=int(s * 1_000_000_000),
end=int(e * 1_000_000_000),
text=t,
))
return result

# segment mode — merge at time-gap boundaries
threshold = self._compute_gap_threshold(time_stamps)
result = []
buf_text = []
buf_start = None
buf_end = 0.0
# Always compute sentence-level segments via gap merging.
# Extract word info once and reuse throughout.
extracted = [self._extract_word_info(ts) for ts in time_stamps]
threshold = self._compute_gap_threshold_from_extracted(extracted)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Now that _build_segments uses _compute_gap_threshold_from_extracted, the original _compute_gap_threshold(time_stamps) has no remaining callers — it's dead code. Either delete it, or collapse to a single implementation that extracts internally:

@staticmethod
def _compute_gap_threshold(time_stamps):
    return BackendServicer._compute_gap_threshold_from_extracted(
        [BackendServicer._extract_word_info(ts) for ts in time_stamps]
    )

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done — removed _compute_gap_threshold(time_stamps) entirely. Updated the docstring of the remaining _compute_gap_threshold_from_extracted to be self-contained.

sentence_groups = [] # list of list of (s, e, t)
buf = []
prev_end = None

for ts in time_stamps:
s, e, t = self._extract_word_info(ts)

# Detect sentence boundary via time gap
if prev_end is not None and (s - prev_end) >= threshold and buf_text:
result.append(backend_pb2.TranscriptSegment(
id=len(result),
start=int(buf_start * 1_000_000_000),
end=int(buf_end * 1_000_000_000),
text=self._smart_join(buf_text),
))
buf_text = []
buf_start = None

if buf_start is None:
buf_start = s
buf_text.append(t)
buf_end = e
for info in extracted:
s, e, t = info
if prev_end is not None and (s - prev_end) >= threshold and buf:
sentence_groups.append(buf)
buf = []
buf.append(info)
prev_end = e
if buf:
sentence_groups.append(buf)

# flush remaining
if buf_text and buf_start is not None:
result.append(backend_pb2.TranscriptSegment(
result = []
for group in sentence_groups:
seg_start = group[0][0]
seg_end = group[-1][1]
seg_text = self._smart_join([w[2] for w in group if w[2]])

seg = backend_pb2.TranscriptSegment(
id=len(result),
start=int(buf_start * 1_000_000_000),
end=int(buf_end * 1_000_000_000),
text=self._smart_join(buf_text),
))
start=int(seg_start * 1_000_000_000),
end=int(seg_end * 1_000_000_000),
text=seg_text,
)

if granularity == "word":
for ws, we, wt in group:
if wt:
seg.words.append(backend_pb2.TranscriptWord(
start=int(ws * 1_000_000_000),
end=int(we * 1_000_000_000),
text=wt,
))

result.append(seg)

return result

Expand All @@ -303,16 +334,31 @@ def AudioTranscription(self, request, context):

# Determine requested granularity (default: segment)
granularities = list(request.timestamp_granularities) if request.timestamp_granularities else []
want_timestamps = len(granularities) > 0
granularity = "word" if "word" in granularities else "segment"

has_aligner = getattr(self.model, 'forced_aligner', None) is not None
# Select model: with or without forced aligner
if want_timestamps:
ts_model = self._get_ts_model()
if ts_model is None:
# No aligner configured — fall back to plain transcription
model = self.model
has_aligner = False
want_timestamps = False
else:
model = ts_model
has_aligner = True
else:
model = self.model
has_aligner = False
Comment on lines +340 to +353

try:
results = self.model.transcribe(
results = model.transcribe(
audio=audio_path, language=language, context=ctx,
return_time_stamps=has_aligner,
)
except TypeError:
results = self.model.transcribe(audio=audio_path, language=language, context=ctx)
results = model.transcribe(audio=audio_path, language=language, context=ctx)

if not results:
return backend_pb2.TranscriptResult(segments=[], text="")
Expand Down