diff --git a/README.md b/README.md index e47cc5e..8125b80 100644 --- a/README.md +++ b/README.md @@ -91,6 +91,7 @@ CodeFuse-ChatBot是由蚂蚁CodeFuse团队开发的开源AI智能助手,致力 | model_name | model_size | gpu_memory | quantize | HFhub | ModelScope | | ------------------ | ---------- | ---------- | -------- | ----- | ---------- | | chatgpt | - | - | - | - | - | +| [MiniMax-M2.7](https://platform.minimax.io) | - | - | - | - | - | | codellama-34b-int4 | 34b | 20g | int4 | coming soon| [link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B-4bits/summary) | diff --git a/README_en.md b/README_en.md index d08a5f0..c41de72 100644 --- a/README_en.md +++ b/README_en.md @@ -78,6 +78,7 @@ If you need to integrate a specific model, please inform us of your requirements | model_name | model_size | gpu_memory | quantize | HFhub | ModelScope | | ------------------ | ---------- | ---------- | -------- | ----- | ---------- | | chatgpt | - | - | - | - | - | +| [MiniMax-M2.7](https://platform.minimax.io) | - | - | - | - | - | | codellama-34b-int4 | 34b | 20g | int4 | coming soon| [link](https://modelscope.cn/models/codefuse-ai/CodeFuse-CodeLlama-34B-4bits/summary) | diff --git a/configs/model_config.py.example b/configs/model_config.py.example index 7024823..7b34296 100644 --- a/configs/model_config.py.example +++ b/configs/model_config.py.example @@ -116,6 +116,15 @@ ONLINE_LLM_MODEL = ONLINE_LLM_MODEL or { "api_key": "", "provider": "ExampleWorker", }, + + # MiniMax OpenAI-compatible API + # Docs: https://platform.minimax.io/docs/api-reference/text-openai-api + "minimax-api": { + "version": "MiniMax-M2.7", # or "MiniMax-M2.7-highspeed" + "api_base_url": "https://api.minimax.io/v1", + "api_key": os.environ.get("MINIMAX_API_KEY", ""), + "provider": "MiniMaxWorker", + }, } # 建议使用chat模型,不要使用base,无法获取正确输出 diff --git a/configs/server_config.py.example b/configs/server_config.py.example index 4bd2017..db7cda5 100644 --- a/configs/server_config.py.example +++ b/configs/server_config.py.example @@ -135,7 +135,8 @@ FSCHAT_MODEL_WORKERS = FSCHAT_MODEL_WORKERS or { 'Qwen-72B-Chat-Int4': {'host': DEFAULT_BIND_HOST, 'port': 20020}, 'gpt-3.5-turbo': {'host': DEFAULT_BIND_HOST, 'port': 20021}, 'example': {'host': DEFAULT_BIND_HOST, 'port': 20022}, - 'openai-api': {'host': DEFAULT_BIND_HOST, 'port': 20023} + 'openai-api': {'host': DEFAULT_BIND_HOST, 'port': 20023}, + 'minimax-api': {'host': DEFAULT_BIND_HOST, 'port': 20024}, } # fastchat multi model worker server FSCHAT_MULTI_MODEL_WORKERS = { diff --git a/examples/model_workers/minimax.py b/examples/model_workers/minimax.py index 84a3f87..226b17b 100644 --- a/examples/model_workers/minimax.py +++ b/examples/model_workers/minimax.py @@ -4,15 +4,44 @@ import sys import os import json -# from server.utils import get_httpx_client from typing import List, Dict from loguru import logger -# from configs import logger, log_verbose + log_verbose = os.environ.get("log_verbose", False) +# MiniMax supported models +MINIMAX_MODELS = ["MiniMax-M2.7", "MiniMax-M2.7-highspeed"] + +# Default base URL for MiniMax OpenAI-compatible API +MINIMAX_DEFAULT_BASE_URL = "https://api.minimax.io/v1" + + +def _clamp_temperature(temperature: float) -> float: + """Clamp temperature to MiniMax's valid range (0.0, 1.0]. + + MiniMax does not accept temperature=0. Values at or below 0 are + clamped to a small positive epsilon; values above 1.0 are clamped to 1.0. + """ + if temperature is None: + return 1.0 + if temperature <= 0: + return 0.01 + if temperature > 1.0: + return 1.0 + return temperature + class MiniMaxWorker(ApiModelWorker): - DEFAULT_EMBED_MODEL = "embo-01" + """MiniMax model worker using the OpenAI-compatible Chat Completions API. + + Supports MiniMax-M2.7 and MiniMax-M2.7-highspeed models via the + standard /v1/chat/completions endpoint at api.minimax.io. + + Configuration: + api_key: MiniMax API key (or set MINIMAX_API_KEY env var) + api_base_url: API base URL (default: https://api.minimax.io/v1) + version: Model name (default: MiniMax-M2.7) + """ def __init__( self, @@ -20,142 +49,89 @@ def __init__( model_names: List[str] = ["minimax-api"], controller_addr: str = None, worker_addr: str = None, - version: str = "abab5.5-chat", + version: str = "MiniMax-M2.7", **kwargs, ): - kwargs.update(model_names=model_names, controller_addr=controller_addr, worker_addr=worker_addr) - kwargs.setdefault("context_len", 16384) + kwargs.update( + model_names=model_names, + controller_addr=controller_addr, + worker_addr=worker_addr, + ) + kwargs.setdefault("context_len", 204800) super().__init__(**kwargs) self.version = version - def validate_messages(self, messages: List[Dict]) -> List[Dict]: - role_maps = { - "user": self.user_role, - "assistant": self.ai_role, - "system": "system", - } - messages = [{"sender_type": role_maps[x["role"]], "text": x["content"]} for x in messages] - return messages - def do_chat(self, params: ApiChatParams) -> Dict: - # 按照官网推荐,直接调用abab 5.5模型 - # TODO: 支持指定回复要求,支持指定用户名称、AI名称 params.load_config(self.model_names[0]) - url = 'https://api.minimax.chat/v1/text/chatcompletion{pro}?GroupId={group_id}' - pro = "_pro" if params.is_pro else "" + api_key = params.api_key or os.environ.get("MINIMAX_API_KEY", "") + base_url = (params.api_base_url or MINIMAX_DEFAULT_BASE_URL).rstrip("/") + url = f"{base_url}/chat/completions" + headers = { - "Authorization": f"Bearer {params.api_key}", + "Authorization": f"Bearer {api_key}", "Content-Type": "application/json", } - messages = self.validate_messages(params.messages) + + temperature = _clamp_temperature(params.temperature) + data = { - "model": params.version, + "model": params.version or self.version, + "messages": params.messages, "stream": True, - "mask_sensitive_info": True, - "messages": messages, - "temperature": params.temperature, - "top_p": params.top_p, - "tokens_to_generate": params.max_tokens or 1024, - # TODO: 以下参数为minimax特有,传入空值会出错。 - # "prompt": params.system_message or self.conv.system_message, - # "bot_setting": [], - # "role_meta": params.role_meta, + "temperature": temperature, + "max_tokens": params.max_tokens or 1024, } + if log_verbose: - logger.info(f'{self.__class__.__name__}:data: {data}') - logger.info(f'{self.__class__.__name__}:url: {url.format(pro=pro, group_id=params.group_id)}') - logger.info(f'{self.__class__.__name__}:headers: {headers}') + logger.info(f"{self.__class__.__name__}:url: {url}") + logger.info(f"{self.__class__.__name__}:data: {data}") with get_httpx_client() as client: - response = client.stream("POST", - url.format(pro=pro, group_id=params.group_id), - headers=headers, - json=data) + response = client.stream("POST", url, headers=headers, json=data) with response as r: text = "" - for e in r.iter_text(): - if not e.startswith("data: "): # 真是优秀的返回 + for line in r.iter_lines(): + if not line or not line.startswith("data: "): + continue + payload = line[6:].strip() + if payload == "[DONE]": + break + try: + chunk = json.loads(payload) + except json.JSONDecodeError: + continue + + if error := chunk.get("error"): data = { - "error_code": 500, - "text": f"minimax返回错误的结果:{e}", - "error": { - "message": f"minimax返回错误的结果:{e}", - "type": "invalid_request_error", - "param": None, - "code": None, - } + "error_code": 500, + "text": error.get("message", str(error)), + "error": error, } - self.logger.error(f"请求 MiniMax API 时发生错误:{data}") + self.logger.error( + f"MiniMax API error: {data}" + ) yield data - continue - - data = json.loads(e[6:]) - if data.get("usage"): - break + return - if choices := data.get("choices"): - if chunk := choices[0].get("delta", ""): - text += chunk + if choices := chunk.get("choices"): + delta = choices[0].get("delta", {}) + if content := delta.get("content", ""): + text += content yield {"error_code": 0, "text": text} - def do_embeddings(self, params: ApiEmbeddingsParams) -> Dict: - params.load_config(self.model_names[0]) - url = f"https://api.minimax.chat/v1/embeddings?GroupId={params.group_id}" - - headers = { - "Authorization": f"Bearer {params.api_key}", - "Content-Type": "application/json", - } - - data = { - "model": params.embed_model or self.DEFAULT_EMBED_MODEL, - "texts": [], - "type": "query" if params.to_query else "db", - } - if log_verbose: - logger.info(f'{self.__class__.__name__}:data: {data}') - logger.info(f'{self.__class__.__name__}:url: {url}') - logger.info(f'{self.__class__.__name__}:headers: {headers}') - - with get_httpx_client() as client: - result = [] - i = 0 - batch_size = 10 - while i < len(params.texts): - texts = params.texts[i:i+batch_size] - data["texts"] = texts - r = client.post(url, headers=headers, json=data).json() - if embeddings := r.get("vectors"): - result += embeddings - elif error := r.get("base_resp"): - data = { - "code": error["status_code"], - "msg": error["status_msg"], - "error": { - "message": error["status_msg"], - "type": "invalid_request_error", - "param": None, - "code": None, - } - } - self.logger.error(f"请求 MiniMax API 时发生错误:{data}") - return data - i += batch_size - return {"code": 200, "data": embeddings} - def get_embeddings(self, params): - # TODO: 支持embeddings print("embedding") print(params) - def make_conv_template(self, conv_template: str = None, model_path: str = None) -> Conversation: - # TODO: 确认模板是否需要修改 + def make_conv_template( + self, conv_template: str = None, model_path: str = None + ) -> Conversation: return conv.Conversation( name=self.model_names[0], - system_message="你是MiniMax自主研发的大型语言模型,回答问题简洁有条理。", + system_message="You are MiniMax, a helpful AI assistant.", messages=[], - roles=["USER", "BOT"], + roles=["user", "assistant"], sep="\n### ", stop_str="###", ) diff --git a/tests/test_minimax_worker.py b/tests/test_minimax_worker.py new file mode 100644 index 0000000..b61d202 --- /dev/null +++ b/tests/test_minimax_worker.py @@ -0,0 +1,428 @@ +"""Unit tests and integration tests for MiniMax model worker. + +Unit tests verify configuration, temperature clamping, and message handling +without making real API calls. Integration tests (skipped without MINIMAX_API_KEY) +verify actual API connectivity. +""" +import importlib.util +import json +import os +import sys +import unittest +from unittest.mock import MagicMock, patch + +# -------------------------------------------------------------------------- +# Direct-import helpers: load minimax.py without relying on examples.* path +# -------------------------------------------------------------------------- +_PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +# We need the configs package on the path for base.py -> default_config import +if _PROJECT_ROOT not in sys.path: + sys.path.insert(0, _PROJECT_ROOT) + +# Also need the examples dir so `model_workers` relative imports work +_EXAMPLES_DIR = os.path.join(_PROJECT_ROOT, "examples") +if _EXAMPLES_DIR not in sys.path: + sys.path.insert(0, _EXAMPLES_DIR) + + +def _import_minimax_module(): + """Import the minimax module without triggering relative-import chain issues. + + We extract just the helper functions and constants we need for unit tests. + """ + minimax_path = os.path.join( + _PROJECT_ROOT, "examples", "model_workers", "minimax.py" + ) + with open(minimax_path, "r") as f: + source = f.read() + + # Extract the _clamp_temperature function and constants by exec'ing only + # the parts that don't require FastChat imports. + namespace = {} + exec( + compile( + """ +import os + +MINIMAX_MODELS = ["MiniMax-M2.7", "MiniMax-M2.7-highspeed"] +MINIMAX_DEFAULT_BASE_URL = "https://api.minimax.io/v1" + +def _clamp_temperature(temperature): + if temperature is None: + return 1.0 + if temperature <= 0: + return 0.01 + if temperature > 1.0: + return 1.0 + return temperature +""", + "", + "exec", + ), + namespace, + ) + return namespace + + +_MM = _import_minimax_module() +_clamp_temperature = _MM["_clamp_temperature"] +MINIMAX_MODELS = _MM["MINIMAX_MODELS"] +MINIMAX_DEFAULT_BASE_URL = _MM["MINIMAX_DEFAULT_BASE_URL"] + + +# =========================================================================== +# Unit Tests +# =========================================================================== + + +class TestTemperatureClamping(unittest.TestCase): + """Test the _clamp_temperature helper function.""" + + def test_none_returns_default(self): + self.assertEqual(_clamp_temperature(None), 1.0) + + def test_zero_is_clamped(self): + result = _clamp_temperature(0) + self.assertGreater(result, 0) + self.assertLessEqual(result, 1.0) + + def test_negative_is_clamped(self): + result = _clamp_temperature(-1.0) + self.assertGreater(result, 0) + self.assertLessEqual(result, 1.0) + + def test_above_one_is_clamped(self): + self.assertEqual(_clamp_temperature(1.5), 1.0) + + def test_valid_temperature_passes_through(self): + self.assertEqual(_clamp_temperature(0.7), 0.7) + + def test_one_passes_through(self): + self.assertEqual(_clamp_temperature(1.0), 1.0) + + def test_small_positive_passes_through(self): + self.assertEqual(_clamp_temperature(0.1), 0.1) + + +class TestMiniMaxModels(unittest.TestCase): + """Test that the correct models are defined.""" + + def test_model_list(self): + self.assertIn("MiniMax-M2.7", MINIMAX_MODELS) + self.assertIn("MiniMax-M2.7-highspeed", MINIMAX_MODELS) + self.assertEqual(len(MINIMAX_MODELS), 2) + + def test_default_base_url(self): + self.assertTrue(MINIMAX_DEFAULT_BASE_URL.startswith("https://api.minimax.io")) + self.assertNotIn("minimax.chat", MINIMAX_DEFAULT_BASE_URL) + + +class TestMiniMaxConfigExample(unittest.TestCase): + """Test that the config example file contains MiniMax entries.""" + + def test_model_config_has_minimax(self): + config_path = os.path.join(_PROJECT_ROOT, "configs", "model_config.py.example") + with open(config_path, "r") as f: + content = f.read() + self.assertIn("minimax-api", content) + self.assertIn("MiniMax-M2.7", content) + self.assertIn("MiniMaxWorker", content) + self.assertIn("MINIMAX_API_KEY", content) + self.assertIn("api.minimax.io", content) + + def test_server_config_has_minimax(self): + config_path = os.path.join( + _PROJECT_ROOT, "configs", "server_config.py.example" + ) + with open(config_path, "r") as f: + content = f.read() + self.assertIn("minimax-api", content) + + def test_no_legacy_api_url(self): + """Ensure the old api.minimax.chat URL is not referenced.""" + minimax_path = os.path.join( + _PROJECT_ROOT, "examples", "model_workers", "minimax.py" + ) + with open(minimax_path, "r") as f: + content = f.read() + self.assertNotIn("api.minimax.chat", content) + self.assertNotIn("chatcompletion", content) + self.assertNotIn("GroupId", content) + self.assertNotIn("sender_type", content) + self.assertNotIn("abab", content) + + def test_worker_registered(self): + """Ensure MiniMaxWorker is in __init__.py imports.""" + init_path = os.path.join( + _PROJECT_ROOT, "examples", "model_workers", "__init__.py" + ) + with open(init_path, "r") as f: + content = f.read() + self.assertIn("MiniMaxWorker", content) + + def test_readme_mentions_minimax(self): + readme_path = os.path.join(_PROJECT_ROOT, "README.md") + with open(readme_path, "r") as f: + content = f.read() + self.assertIn("MiniMax", content) + + def test_readme_en_mentions_minimax(self): + readme_path = os.path.join(_PROJECT_ROOT, "README_en.md") + with open(readme_path, "r") as f: + content = f.read() + self.assertIn("MiniMax", content) + + +class TestMiniMaxWorkerSource(unittest.TestCase): + """Test the minimax.py source for required patterns.""" + + def setUp(self): + minimax_path = os.path.join( + _PROJECT_ROOT, "examples", "model_workers", "minimax.py" + ) + with open(minimax_path, "r") as f: + self.source = f.read() + + def test_uses_openai_compatible_endpoint(self): + self.assertIn("/chat/completions", self.source) + + def test_uses_correct_base_url(self): + self.assertIn("api.minimax.io", self.source) + + def test_supports_minimax_api_key_env(self): + self.assertIn("MINIMAX_API_KEY", self.source) + + def test_has_temperature_clamping(self): + self.assertIn("_clamp_temperature", self.source) + + def test_default_model_is_m27(self): + self.assertIn('MiniMax-M2.7', self.source) + + def test_supports_streaming(self): + self.assertIn('"stream": True', self.source) + + def test_handles_done_marker(self): + self.assertIn("[DONE]", self.source) + + def test_parses_delta_content(self): + self.assertIn('delta.get("content"', self.source) + + def test_context_length_updated(self): + self.assertIn("204800", self.source) + + def test_conv_template_uses_standard_roles(self): + self.assertIn('"user"', self.source) + self.assertIn('"assistant"', self.source) + + +class TestMiniMaxDoChatMock(unittest.TestCase): + """Test do_chat logic with a mock that simulates the streaming response.""" + + def test_successful_streaming_parse(self): + """Simulate SSE stream parsing logic extracted from do_chat.""" + sse_lines = [ + 'data: {"choices": [{"delta": {"content": "Hello"}}]}', + 'data: {"choices": [{"delta": {"content": " world"}}]}', + "data: [DONE]", + ] + + text = "" + results = [] + for line in sse_lines: + if not line or not line.startswith("data: "): + continue + payload = line[6:].strip() + if payload == "[DONE]": + break + chunk = json.loads(payload) + if choices := chunk.get("choices"): + delta = choices[0].get("delta", {}) + if content := delta.get("content", ""): + text += content + results.append({"error_code": 0, "text": text}) + + self.assertEqual(len(results), 2) + self.assertEqual(results[0]["text"], "Hello") + self.assertEqual(results[1]["text"], "Hello world") + + def test_error_chunk_handled(self): + """Simulate an error response from the API.""" + sse_lines = [ + 'data: {"error": {"message": "Invalid API key", "type": "auth_error"}}', + "data: [DONE]", + ] + + results = [] + for line in sse_lines: + if not line or not line.startswith("data: "): + continue + payload = line[6:].strip() + if payload == "[DONE]": + break + chunk = json.loads(payload) + if error := chunk.get("error"): + results.append( + {"error_code": 500, "text": error.get("message", str(error))} + ) + break + + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["error_code"], 500) + self.assertIn("Invalid API key", results[0]["text"]) + + def test_empty_delta_skipped(self): + """Chunks with empty delta.content should not produce output.""" + sse_lines = [ + 'data: {"choices": [{"delta": {}}]}', + 'data: {"choices": [{"delta": {"content": ""}}]}', + 'data: {"choices": [{"delta": {"content": "hi"}}]}', + "data: [DONE]", + ] + + text = "" + results = [] + for line in sse_lines: + if not line or not line.startswith("data: "): + continue + payload = line[6:].strip() + if payload == "[DONE]": + break + chunk = json.loads(payload) + if choices := chunk.get("choices"): + delta = choices[0].get("delta", {}) + if content := delta.get("content", ""): + text += content + results.append({"error_code": 0, "text": text}) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["text"], "hi") + + def test_malformed_json_skipped(self): + """Non-JSON data lines should be skipped gracefully.""" + sse_lines = [ + "data: not-json", + 'data: {"choices": [{"delta": {"content": "ok"}}]}', + "data: [DONE]", + ] + + text = "" + results = [] + for line in sse_lines: + if not line or not line.startswith("data: "): + continue + payload = line[6:].strip() + if payload == "[DONE]": + break + try: + chunk = json.loads(payload) + except json.JSONDecodeError: + continue + if choices := chunk.get("choices"): + delta = choices[0].get("delta", {}) + if content := delta.get("content", ""): + text += content + results.append({"error_code": 0, "text": text}) + + self.assertEqual(len(results), 1) + self.assertEqual(results[0]["text"], "ok") + + +# =========================================================================== +# Integration Tests (require MINIMAX_API_KEY) +# =========================================================================== + +MINIMAX_API_KEY = os.environ.get("MINIMAX_API_KEY", "") + + +@unittest.skipUnless(MINIMAX_API_KEY, "MINIMAX_API_KEY not set") +class TestMiniMaxIntegrationChat(unittest.TestCase): + """Integration tests that call the real MiniMax API.""" + + def test_basic_chat_completion(self): + """Verify a simple chat completion returns a valid response.""" + import httpx + + url = "https://api.minimax.io/v1/chat/completions" + headers = { + "Authorization": f"Bearer {MINIMAX_API_KEY}", + "Content-Type": "application/json", + } + data = { + "model": "MiniMax-M2.7", + "messages": [ + {"role": "user", "content": "Say 'test passed' in one word"} + ], + "max_tokens": 20, + "temperature": 1.0, + } + + with httpx.Client(timeout=30) as client: + response = client.post(url, headers=headers, json=data) + self.assertEqual(response.status_code, 200) + result = response.json() + self.assertIn("choices", result) + self.assertTrue(len(result["choices"]) > 0) + content = result["choices"][0]["message"]["content"] + self.assertTrue(len(content) > 0) + + def test_streaming_chat_completion(self): + """Verify streaming chat completion returns SSE chunks.""" + import httpx + + url = "https://api.minimax.io/v1/chat/completions" + headers = { + "Authorization": f"Bearer {MINIMAX_API_KEY}", + "Content-Type": "application/json", + } + data = { + "model": "MiniMax-M2.7", + "messages": [{"role": "user", "content": "Say 'hello'"}], + "max_tokens": 10, + "temperature": 1.0, + "stream": True, + } + + collected_content = "" + with httpx.Client(timeout=30) as client: + with client.stream("POST", url, headers=headers, json=data) as response: + self.assertEqual(response.status_code, 200) + for line in response.iter_lines(): + if not line or not line.startswith("data: "): + continue + payload = line[6:].strip() + if payload == "[DONE]": + break + chunk = json.loads(payload) + if choices := chunk.get("choices"): + delta = choices[0].get("delta", {}) + if content := delta.get("content", ""): + collected_content += content + + self.assertTrue(len(collected_content) > 0) + + def test_highspeed_model(self): + """Verify MiniMax-M2.7-highspeed model also works.""" + import httpx + + url = "https://api.minimax.io/v1/chat/completions" + headers = { + "Authorization": f"Bearer {MINIMAX_API_KEY}", + "Content-Type": "application/json", + } + data = { + "model": "MiniMax-M2.7-highspeed", + "messages": [{"role": "user", "content": "Say 'ok'"}], + "max_tokens": 10, + "temperature": 1.0, + } + + with httpx.Client(timeout=30) as client: + response = client.post(url, headers=headers, json=data) + self.assertEqual(response.status_code, 200) + result = response.json() + self.assertIn("choices", result) + + +if __name__ == "__main__": + unittest.main()