Skip to content

Commit becdc2c

Browse files
authored
Support imported fleets in dstack fleet get (#3773)
Support the `<project>/<fleet>` syntax in `dstack fleet get`. Example: ``` $ dstack fleet get my-project/my-fleet --json ```
1 parent f3686d6 commit becdc2c

3 files changed

Lines changed: 56 additions & 10 deletions

File tree

src/dstack/_internal/cli/commands/fleet.py

Lines changed: 23 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
)
1515
from dstack._internal.cli.utils.fleet import get_fleets_table, print_fleets_table
1616
from dstack._internal.core.errors import CLIError, ResourceNotExistsError
17+
from dstack._internal.core.models.common import EntityReference
1718
from dstack._internal.utils.json_utils import pydantic_orjson_dumps_with_indent
1819

1920

@@ -49,6 +50,7 @@ def _register(self):
4950
)
5051
delete_parser.add_argument(
5152
"name",
53+
type=EntityReference.parse,
5254
help="The name of the fleet",
5355
).completer = FleetNameCompleter() # type: ignore[attr-defined]
5456
delete_parser.add_argument(
@@ -73,6 +75,7 @@ def _register(self):
7375
"name",
7476
nargs="?",
7577
metavar="NAME",
78+
type=EntityReference.parse,
7679
help="The name of the fleet",
7780
).completer = FleetNameCompleter() # type: ignore[attr-defined]
7881
name_group.add_argument(
@@ -112,35 +115,43 @@ def _list(self, args: argparse.Namespace):
112115
pass
113116

114117
def _delete(self, args: argparse.Namespace):
118+
if args.name.project is not None:
119+
console.print(
120+
"The [code]<project>/<fleet>[/] format is not supported for fleet names."
121+
" Can only delete fleets or instances owned by the current project"
122+
)
123+
exit(1)
124+
name = args.name.name
125+
115126
try:
116-
self.api.client.fleets.get(project_name=self.api.project, name=args.name)
127+
self.api.client.fleets.get(project_name=self.api.project, name=name)
117128
except ResourceNotExistsError:
118-
console.print(f"Fleet [code]{args.name}[/] does not exist")
129+
console.print(f"Fleet [code]{name}[/] does not exist")
119130
exit(1)
120131

121132
if not args.instances:
122-
if not args.yes and not confirm_ask(f"Delete the fleet [code]{args.name}[/]?"):
133+
if not args.yes and not confirm_ask(f"Delete the fleet [code]{name}[/]?"):
123134
console.print("\nExiting...")
124135
return
125136

126137
with console.status("Deleting fleet..."):
127-
self.api.client.fleets.delete(project_name=self.api.project, names=[args.name])
138+
self.api.client.fleets.delete(project_name=self.api.project, names=[name])
128139

129-
console.print(f"Fleet [code]{args.name}[/] deleted")
140+
console.print(f"Fleet [code]{name}[/] deleted")
130141
return
131142

132143
if not args.yes and not confirm_ask(
133-
f"Delete the fleet [code]{args.name}[/] instances [code]{args.instances}[/]?"
144+
f"Delete the fleet [code]{name}[/] instances [code]{args.instances}[/]?"
134145
):
135146
console.print("\nExiting...")
136147
return
137148

138149
with console.status("Deleting fleet instances..."):
139150
self.api.client.fleets.delete_instances(
140-
project_name=self.api.project, name=args.name, instance_nums=args.instances
151+
project_name=self.api.project, name=name, instance_nums=args.instances
141152
)
142153

143-
console.print(f"Fleet [code]{args.name}[/] instances deleted")
154+
console.print(f"Fleet [code]{name}[/] instances deleted")
144155

145156
def _get(self, args: argparse.Namespace):
146157
# TODO: Implement non-json output format
@@ -157,7 +168,10 @@ def _get(self, args: argparse.Namespace):
157168
project_name=self.api.project, fleet_id=fleet_id
158169
)
159170
else:
160-
fleet = self.api.client.fleets.get(project_name=self.api.project, name=args.name)
171+
fleet = self.api.client.fleets.get(
172+
project_name=args.name.project or self.api.project,
173+
name=args.name.name,
174+
)
161175
except ResourceNotExistsError:
162176
console.print(f"Fleet [code]{args.name or args.id}[/] not found")
163177
exit(1)

src/dstack/_internal/core/models/common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -160,12 +160,17 @@ class EntityReference(CoreModel):
160160
def parse(cls, v: Union[str, "EntityReference"]) -> "EntityReference":
161161
if isinstance(v, EntityReference):
162162
return v
163+
invalid_ref_error = ValueError(
164+
"Invalid entity reference. Only `<name>` or `<project>/<name>` formats are allowed"
165+
)
163166
parts = v.split("/")
167+
if any(len(part) == 0 for part in parts):
168+
raise invalid_ref_error
164169
if len(parts) == 1:
165170
return cls(project=None, name=parts[0])
166171
if len(parts) == 2:
167172
return cls(project=parts[0], name=parts[1])
168-
raise ValueError("Invalid entity reference. Only `<project>/<name>` format is allowed")
173+
raise invalid_ref_error
169174

170175
def format(self) -> str:
171176
if self.project is None:
Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import pytest
2+
3+
from dstack._internal.core.models.common import EntityReference
4+
5+
6+
class TestEntityReferenceParse:
7+
@pytest.mark.parametrize(
8+
"value, expected",
9+
[
10+
("fleet", EntityReference(project=None, name="fleet")),
11+
("project/fleet", EntityReference(project="project", name="fleet")),
12+
(
13+
EntityReference(project="proj", name="fleet"),
14+
EntityReference(project="proj", name="fleet"),
15+
),
16+
],
17+
)
18+
def test_valid(self, value, expected):
19+
assert EntityReference.parse(value) == expected
20+
21+
@pytest.mark.parametrize(
22+
"value",
23+
["", "/name", "name/", "/", "a/b/c"],
24+
)
25+
def test_invalid(self, value: str):
26+
with pytest.raises(ValueError, match="Invalid entity reference"):
27+
EntityReference.parse(value)

0 commit comments

Comments
 (0)