Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 15 additions & 2 deletions backends/exllamav3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -996,6 +996,7 @@ async def generate_gen(
max_rq_tokens=self.max_rq_tokens,
filters=grammar_handler.filters,
)
self.active_job_ids[request_id] = job

generated_tokens = 0
full_response = ""
Expand All @@ -1013,8 +1014,21 @@ async def generate_gen(
if chunk:
chunk_tokens = result.get("token_ids", self.tokenizer.encode(chunk))
full_response += chunk

# Extract token IDs as a plain list for downstream consumers
if isinstance(chunk_tokens, torch.Tensor):
token_id_list = chunk_tokens.flatten().tolist()
generated_tokens += chunk_tokens.size(dim=0)
elif isinstance(chunk_tokens, tuple):
first = chunk_tokens[0]
if isinstance(first, torch.Tensor):
token_id_list = first.flatten().tolist()
else:
token_id_list = list(first)
generated_tokens += len(token_id_list)
else:
token_id_list = list(chunk_tokens)
generated_tokens += len(token_id_list)

# Increase penalty range to generated token amount
# TODO:
Expand All @@ -1024,6 +1038,7 @@ async def generate_gen(
generation = {
"request_id": request_id,
"text": chunk,
"token_ids": token_id_list,
"prompt_tokens": context_len,
"generated_tokens": generated_tokens,
"offset": len(full_response),
Expand All @@ -1044,8 +1059,6 @@ async def generate_gen(

yield finish_chunk
break
# Assign the active job to the request ID
self.active_job_ids[request_id] = job

except asyncio.CancelledError:
await job.cancel()
Expand Down
39 changes: 39 additions & 0 deletions common/templating.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from jinja2.ext import loopcontrols
from jinja2.sandbox import ImmutableSandboxedEnvironment
from loguru import logger
from markupsafe import Markup
from packaging import version


Expand All @@ -24,12 +25,17 @@ class TemplateLoadError(Exception):
pass


VALID_TOOL_CALL_FORMATS = {"json", "xml", "auto"}


@dataclass
class TemplateMetadata:
"""Represents the parsed metadata from a template."""

stop_strings: List[str] = field(default_factory=list)
tool_start: Optional[str] = None
tool_end: Optional[str] = None
tool_call_format: str = "json"


class PromptTemplate:
Expand All @@ -46,6 +52,22 @@ class PromptTemplate:
)
metadata: Optional[TemplateMetadata] = None

@staticmethod
def _tojson_compat(value, indent=None, ensure_ascii=True):
"""Compatibility JSON filter for chat templates.

Some model templates call ``tojson(ensure_ascii=False)`` while the
bundled Jinja filter may not accept that keyword in sandboxed mode.
"""
return Markup(
json.dumps(
value,
indent=indent,
ensure_ascii=ensure_ascii,
separators=(",", ": "),
)
)

async def extract_metadata(self, template_vars: dict):
"""
Returns deserialized template metadata from a chat template.
Expand Down Expand Up @@ -76,6 +98,22 @@ async def extract_metadata(self, template_vars: dict):
if isinstance(template_module.tool_start, str):
template_metadata.tool_start = template_module.tool_start

if hasattr(template_module, "tool_end"):
if isinstance(template_module.tool_end, str):
template_metadata.tool_end = template_module.tool_end

if hasattr(template_module, "tool_call_format"):
fmt = template_module.tool_call_format
if isinstance(fmt, str) and fmt in VALID_TOOL_CALL_FORMATS:
template_metadata.tool_call_format = fmt
logger.debug(f"Template tool_call_format: {fmt}")
else:
logger.warning(
f"Invalid tool_call_format '{fmt}' in template, "
f"defaulting to 'json'. "
f"Valid values: {VALID_TOOL_CALL_FORMATS}"
)

self.metadata = template_metadata
return template_metadata

Expand Down Expand Up @@ -107,6 +145,7 @@ def raise_exception(message):

self.environment.globals["strftime_now"] = strftime_now
self.environment.globals["raise_exception"] = raise_exception
self.environment.filters["tojson"] = self._tojson_compat

return self.environment.from_string(template_str)

Expand Down
6 changes: 5 additions & 1 deletion endpoints/OAI/types/chat_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from uuid import uuid4

from endpoints.OAI.types.common import UsageStats, CommonCompletionRequest
from endpoints.OAI.types.tools import ToolSpec, ToolCall
from endpoints.OAI.types.tools import NamedToolChoice, ToolSpec, ToolCall


class ChatCompletionLogprob(BaseModel):
Expand Down Expand Up @@ -71,6 +71,10 @@ class ChatCompletionRequest(CommonCompletionRequest):

tools: Optional[List[ToolSpec]] = None
functions: Optional[List[Dict]] = None
tool_choice: Optional[
Union[Literal["none", "auto", "required"], NamedToolChoice]
] = None
parallel_tool_calls: Optional[bool] = True

# Chat completions requests do not have a BOS token preference. Backend
# respects the tokenization config for the individual model.
Expand Down
26 changes: 23 additions & 3 deletions endpoints/OAI/types/tools.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pydantic import BaseModel, Field
from typing import Dict, Literal
from typing import Dict, Literal, Optional
from uuid import uuid4


Expand Down Expand Up @@ -28,8 +28,28 @@ class Tool(BaseModel):


class ToolCall(BaseModel):
"""Represents an OAI tool description."""
"""Represents an OAI tool call.

The ``index`` field is optional so it can be omitted in non-streaming
responses (where OpenAI does not include it) via ``exclude_none=True``,
while being set explicitly for streaming deltas where it is required
by strict validators like the Vercel AI SDK.
"""

id: str = Field(default_factory=lambda: str(uuid4()).replace("-", "")[:9])
id: str = Field(default_factory=lambda: f"call_{uuid4().hex[:24]}")
function: Tool
type: Literal["function"] = "function"
index: Optional[int] = None


class NamedToolFunction(BaseModel):
"""Represents a named function reference for tool_choice."""

name: str


class NamedToolChoice(BaseModel):
"""Represents a named tool choice (forces a specific function call)."""

function: NamedToolFunction
type: Literal["function"] = "function"
Loading