Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions src/mctpd.c
Original file line number Diff line number Diff line change
Expand Up @@ -3163,10 +3163,8 @@ static int query_peer_properties(struct peer *peer)
}

for (unsigned int i = 0; i < peer->num_message_types; i++) {
if (peer->message_types[i] ==
MCTP_GET_VDM_SUPPORT_IANA_FORMAT_ID ||
peer->message_types[i] ==
MCTP_GET_VDM_SUPPORT_PCIE_FORMAT_ID) {
if (peer->message_types[i] == MCTP_TYPE_VENDOR_IANA ||
peer->message_types[i] == MCTP_TYPE_VENDOR_PCIE) {
supports_vdm = true;
break;
}
Expand Down
78 changes: 67 additions & 11 deletions tests/mctpenv/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -347,6 +347,62 @@ def to_buf(self):
return bytes([flags, self.cmd]) + self.data


class VDMType:
TYPE_PCI = 0x7E
TYPE_IANA = 0x7F
FORMAT_PCI = 0
FORMAT_IANA = 1
# type to (name, format, type_size)
type_map = {
TYPE_PCI: ("PCI", FORMAT_PCI, 2),
TYPE_IANA: ("IANA", FORMAT_IANA, 4),
}
# format to (type, dbus type)
fmt_map = {
FORMAT_PCI: (TYPE_PCI, 'q'),
FORMAT_IANA: (TYPE_IANA, 'u'),
}

def __init__(self, msgtype, vdmtype, subtype=0):
self.msgtype = msgtype
self.vdmtype = vdmtype
self.subtype = subtype

def __repr__(self):
(name, _, _) = self.type_map.get(self.msgtype)
return f"<VDMType {name}: {self.vdmtype:x} {self.subtype:x}>"

def _key(self):
return (self.msgtype, self.vdmtype, self.subtype)

def __eq__(self, value):
return self._key() == value._key()

def __hash__(self):
return hash(self._key())

def format(self):
"""Convert to the Get Vendor Defined Message Support response format"""
(_, vid_fmt, vid_size) = self.type_map.get(self.msgtype)

return (
vid_fmt.to_bytes(1)
+ self.vdmtype.to_bytes(vid_size, 'big')
+ self.subtype.to_bytes(2, 'big')
)

@classmethod
def parse_dbus(cls, dbus_res):
"""Convert from a dbus GetVendorDefinedMessageTypes reply"""
types = []
for t in dbus_res:
fmt, var, subtype = t
(msgtype, dbus_type) = cls.fmt_map.get(fmt)
assert var.signature == dbus_type
types.append(cls(msgtype, var.value, subtype))
return types


class Endpoint:
def __init__(
self, iface, lladdr, ep_uuid=None, eid=0, types=None, vdm_msg_types=None
Expand All @@ -355,8 +411,12 @@ def __init__(
self.lladdr = lladdr
self.uuid = ep_uuid or uuid.uuid1()
self.eid = eid
self.types = types or [0]
self.vdm_msg_types = vdm_msg_types or []
vdm_set = set([t.msgtype for t in self.vdm_msg_types])
if types is None:
self.types = [0] + list(vdm_set)
else:
self.types = types
self.bridged_eps = []
self.allocated_pool = None # or (start, size)

Expand Down Expand Up @@ -433,21 +493,17 @@ async def handle_mctp_control(self, sock, addr, data):

elif opcode == 6:
# Get Vendor Defined Message Support
vdm_support = self.vdm_msg_types
vdm_types = self.vdm_msg_types
n_vdm_types = len(vdm_types)
selector = data[2]
if selector >= len(vdm_support):
if selector >= n_vdm_types:
await sock.send(raddr, bytes(hdr + [0x02]))
return
vdm_format, vendor_id, cmd_set = vdm_support[selector]
vdm_data = vdm_types[selector].format()
next_selector = (
0xFF if selector == (len(vdm_support) - 1) else selector + 1
0xFF if selector == (n_vdm_types - 1) else selector + 1
)
resp = bytes(hdr + [0x00, next_selector, vdm_format])
if vdm_format == 0:
resp += vendor_id.to_bytes(2, 'big')
elif vdm_format == 1:
resp += vendor_id.to_bytes(4, 'big')
resp += cmd_set.to_bytes(2, 'big')
resp = bytes(hdr + [0x00, next_selector]) + vdm_data
await sock.send(raddr, resp)

elif opcode == 8:
Expand Down
62 changes: 43 additions & 19 deletions tests/test_mctpd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,13 @@
mctpd_mctp_endpoint_control_obj,
mctpd_mctp_base_iface_obj,
)
from mctpenv import Endpoint, MCTPSockAddr, MCTPControlCommand, MctpdWrapper
from mctpenv import (
Endpoint,
MCTPSockAddr,
MCTPControlCommand,
MctpdWrapper,
VDMType,
)

# DBus constant symbol suffixes:
#
Expand Down Expand Up @@ -640,8 +646,11 @@ async def test_query_message_types(dbus, mctpd):
async def test_query_vdm_types(dbus, mctpd):
"""Test that VendorDefinedMessageTypes is queried and populated."""
iface = mctpd.system.interfaces[0]
vdm_support = [[0, 0x1234, 0x5678], [1, 0xABCDEF12, 0x3456]]
ep = Endpoint(iface, bytes([0x1E]), eid=15, vdm_msg_types=vdm_support)
vdm_types = [
VDMType(VDMType.TYPE_PCI, 0x1234, 0x5678),
VDMType(VDMType.TYPE_IANA, 0xABCDEF12, 0x3456),
]
ep = Endpoint(iface, bytes([0x1E]), eid=15, vdm_msg_types=vdm_types)
mctpd.network.add_endpoint(ep)

mctp = await mctpd_mctp_iface_obj(dbus, iface)
Expand All @@ -651,25 +660,40 @@ async def test_query_vdm_types(dbus, mctpd):

ep_obj = await mctpd_mctp_endpoint_common_obj(dbus, path)

# Query VendorDefinedMessageTypes property
vdm_types = list(await ep_obj.get_vendor_defined_message_types())
r = await ep_obj.get_vendor_defined_message_types()
ret_vdm_types = VDMType.parse_dbus(r)
assert set(vdm_types) == set(ret_vdm_types)

# Verify we got 2 VDM types
assert len(vdm_types) == 2

# Verify first VDM type: PCIe format (0), VID 0x1234, cmd_set 0x5678
assert vdm_types[0][0] == 0 # format: PCIe
assert (
vdm_types[0][1].value == 0x1234
) # vendor_id (variant containing uint16)
assert vdm_types[0][2] == 0x5678 # cmd_set
async def test_query_vdm_types_no_control(dbus, mctpd):
"""Test that we query VDM types if *only* a VDM type is reported in the
non-vendor Message Type Support response
"""
iface = mctpd.system.interfaces[0]
vdm_types = [
VDMType(VDMType.TYPE_PCI, 0x1234, 0x5678),
]
# only include the VDM type
ep = Endpoint(
iface,
bytes([0x1E]),
eid=15,
types=[VDMType.TYPE_PCI],
vdm_msg_types=vdm_types,
)
mctpd.network.add_endpoint(ep)

# Verify second VDM type: IANA format (1), VID 0xabcdef12, cmd_set 0x3456
assert vdm_types[1][0] == 1 # format: IANA
assert (
vdm_types[1][1].value == 0xABCDEF12
) # vendor_id (variant containing uint32)
assert vdm_types[1][2] == 0x3456 # cmd_set
mctp = await mctpd_mctp_iface_obj(dbus, iface)
(eid, net, path, new) = await mctp.call_learn_endpoint(ep.lladdr)

assert eid == ep.eid

ep_obj = await mctpd_mctp_endpoint_common_obj(dbus, path)

# Query VendorDefinedMessageTypes property
vdm_types = list(await ep_obj.get_vendor_defined_message_types())

assert len(vdm_types) == 1


class InvalidVDMEndpointBase(Endpoint):
Expand Down