|
14 | 14 | import requests |
15 | 15 | import yaml |
16 | 16 |
|
| 17 | +from tensorrt_llm._utils import get_free_port |
17 | 18 | from tensorrt_llm.executor.result import GenerationResultBase |
18 | 19 | from tensorrt_llm.llmapi import CompletionOutput, RequestOutput, SamplingParams |
19 | 20 | from tensorrt_llm.llmapi.llm_args import LlmArgs |
@@ -67,37 +68,19 @@ def __exit__(self, exc_type, exc_val, exc_tb): |
67 | 68 | return False |
68 | 69 |
|
69 | 70 |
|
70 | | -def check_port_available(port: int) -> int: |
71 | | - import socket |
72 | | - try: |
73 | | - with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: |
74 | | - s.bind(('localhost', port)) |
75 | | - return port |
76 | | - except socket.error: |
77 | | - # find a free port |
78 | | - sock = socket.socket() |
79 | | - sock.bind(('', 0)) |
80 | | - return sock.getsockname()[1] |
81 | | - |
82 | | - |
83 | 71 | def revise_disaggregated_server_config_urls_with_free_ports( |
84 | 72 | disaggregated_server_config: Dict[str, Any]) -> Dict[str, Any]: |
85 | | - disaggregated_server_config['port'] = check_port_available( |
86 | | - disaggregated_server_config['port']) |
87 | | - ctx_urls = disaggregated_server_config["context_servers"]["urls"] |
88 | | - gen_urls = disaggregated_server_config["generation_servers"]["urls"] |
| 73 | + num_ctx_ports = len(disaggregated_server_config["context_servers"]["urls"]) |
| 74 | + num_gen_ports = len( |
| 75 | + disaggregated_server_config["generation_servers"]["urls"]) |
89 | 76 |
|
90 | | - new_ctx_urls = [] |
91 | | - new_gen_urls = [] |
92 | | - for url in ctx_urls: |
93 | | - port = check_port_available(int(url.split(":")[1])) |
94 | | - new_ctx_urls.append(f"localhost:{port}") |
95 | | - for url in gen_urls: |
96 | | - port = check_port_available(int(url.split(":")[1])) |
97 | | - new_gen_urls.append(f"localhost:{port}") |
98 | | - |
99 | | - disaggregated_server_config["context_servers"]["urls"] = new_ctx_urls |
100 | | - disaggregated_server_config["generation_servers"]["urls"] = new_gen_urls |
| 77 | + disaggregated_server_config['port'] = get_free_port() |
| 78 | + disaggregated_server_config["context_servers"]["urls"] = [ |
| 79 | + f"localhost:{get_free_port()}" for _ in range(num_ctx_ports) |
| 80 | + ] |
| 81 | + disaggregated_server_config["generation_servers"]["urls"] = [ |
| 82 | + f"localhost:{get_free_port()}" for _ in range(num_gen_ports) |
| 83 | + ] |
101 | 84 |
|
102 | 85 | return disaggregated_server_config |
103 | 86 |
|
|
0 commit comments