diff --git a/src/mctpd.c b/src/mctpd.c index 1f28346..ab61d99 100644 --- a/src/mctpd.c +++ b/src/mctpd.c @@ -3163,10 +3163,8 @@ static int query_peer_properties(struct peer *peer) } for (unsigned int i = 0; i < peer->num_message_types; i++) { - if (peer->message_types[i] == - MCTP_GET_VDM_SUPPORT_IANA_FORMAT_ID || - peer->message_types[i] == - MCTP_GET_VDM_SUPPORT_PCIE_FORMAT_ID) { + if (peer->message_types[i] == MCTP_TYPE_VENDOR_IANA || + peer->message_types[i] == MCTP_TYPE_VENDOR_PCIE) { supports_vdm = true; break; } diff --git a/tests/mctpenv/__init__.py b/tests/mctpenv/__init__.py index d1de8d0..1460e70 100644 --- a/tests/mctpenv/__init__.py +++ b/tests/mctpenv/__init__.py @@ -347,6 +347,62 @@ def to_buf(self): return bytes([flags, self.cmd]) + self.data +class VDMType: + TYPE_PCI = 0x7E + TYPE_IANA = 0x7F + FORMAT_PCI = 0 + FORMAT_IANA = 1 + # type to (name, format, type_size) + type_map = { + TYPE_PCI: ("PCI", FORMAT_PCI, 2), + TYPE_IANA: ("IANA", FORMAT_IANA, 4), + } + # format to (type, dbus type) + fmt_map = { + FORMAT_PCI: (TYPE_PCI, 'q'), + FORMAT_IANA: (TYPE_IANA, 'u'), + } + + def __init__(self, msgtype, vdmtype, subtype=0): + self.msgtype = msgtype + self.vdmtype = vdmtype + self.subtype = subtype + + def __repr__(self): + (name, _, _) = self.type_map.get(self.msgtype) + return f"" + + def _key(self): + return (self.msgtype, self.vdmtype, self.subtype) + + def __eq__(self, value): + return self._key() == value._key() + + def __hash__(self): + return hash(self._key()) + + def format(self): + """Convert to the Get Vendor Defined Message Support response format""" + (_, vid_fmt, vid_size) = self.type_map.get(self.msgtype) + + return ( + vid_fmt.to_bytes(1) + + self.vdmtype.to_bytes(vid_size, 'big') + + self.subtype.to_bytes(2, 'big') + ) + + @classmethod + def parse_dbus(cls, dbus_res): + """Convert from a dbus GetVendorDefinedMessageTypes reply""" + types = [] + for t in dbus_res: + fmt, var, subtype = t + (msgtype, dbus_type) = cls.fmt_map.get(fmt) + assert var.signature == dbus_type + types.append(cls(msgtype, var.value, subtype)) + return types + + class Endpoint: def __init__( self, iface, lladdr, ep_uuid=None, eid=0, types=None, vdm_msg_types=None @@ -355,8 +411,12 @@ def __init__( self.lladdr = lladdr self.uuid = ep_uuid or uuid.uuid1() self.eid = eid - self.types = types or [0] self.vdm_msg_types = vdm_msg_types or [] + vdm_set = set([t.msgtype for t in self.vdm_msg_types]) + if types is None: + self.types = [0] + list(vdm_set) + else: + self.types = types self.bridged_eps = [] self.allocated_pool = None # or (start, size) @@ -433,21 +493,17 @@ async def handle_mctp_control(self, sock, addr, data): elif opcode == 6: # Get Vendor Defined Message Support - vdm_support = self.vdm_msg_types + vdm_types = self.vdm_msg_types + n_vdm_types = len(vdm_types) selector = data[2] - if selector >= len(vdm_support): + if selector >= n_vdm_types: await sock.send(raddr, bytes(hdr + [0x02])) return - vdm_format, vendor_id, cmd_set = vdm_support[selector] + vdm_data = vdm_types[selector].format() next_selector = ( - 0xFF if selector == (len(vdm_support) - 1) else selector + 1 + 0xFF if selector == (n_vdm_types - 1) else selector + 1 ) - resp = bytes(hdr + [0x00, next_selector, vdm_format]) - if vdm_format == 0: - resp += vendor_id.to_bytes(2, 'big') - elif vdm_format == 1: - resp += vendor_id.to_bytes(4, 'big') - resp += cmd_set.to_bytes(2, 'big') + resp = bytes(hdr + [0x00, next_selector]) + vdm_data await sock.send(raddr, resp) elif opcode == 8: diff --git a/tests/test_mctpd.py b/tests/test_mctpd.py index 955ce38..4f3d484 100644 --- a/tests/test_mctpd.py +++ b/tests/test_mctpd.py @@ -9,7 +9,13 @@ mctpd_mctp_endpoint_control_obj, mctpd_mctp_base_iface_obj, ) -from mctpenv import Endpoint, MCTPSockAddr, MCTPControlCommand, MctpdWrapper +from mctpenv import ( + Endpoint, + MCTPSockAddr, + MCTPControlCommand, + MctpdWrapper, + VDMType, +) # DBus constant symbol suffixes: # @@ -640,8 +646,11 @@ async def test_query_message_types(dbus, mctpd): async def test_query_vdm_types(dbus, mctpd): """Test that VendorDefinedMessageTypes is queried and populated.""" iface = mctpd.system.interfaces[0] - vdm_support = [[0, 0x1234, 0x5678], [1, 0xABCDEF12, 0x3456]] - ep = Endpoint(iface, bytes([0x1E]), eid=15, vdm_msg_types=vdm_support) + vdm_types = [ + VDMType(VDMType.TYPE_PCI, 0x1234, 0x5678), + VDMType(VDMType.TYPE_IANA, 0xABCDEF12, 0x3456), + ] + ep = Endpoint(iface, bytes([0x1E]), eid=15, vdm_msg_types=vdm_types) mctpd.network.add_endpoint(ep) mctp = await mctpd_mctp_iface_obj(dbus, iface) @@ -651,25 +660,40 @@ async def test_query_vdm_types(dbus, mctpd): ep_obj = await mctpd_mctp_endpoint_common_obj(dbus, path) - # Query VendorDefinedMessageTypes property - vdm_types = list(await ep_obj.get_vendor_defined_message_types()) + r = await ep_obj.get_vendor_defined_message_types() + ret_vdm_types = VDMType.parse_dbus(r) + assert set(vdm_types) == set(ret_vdm_types) - # Verify we got 2 VDM types - assert len(vdm_types) == 2 - # Verify first VDM type: PCIe format (0), VID 0x1234, cmd_set 0x5678 - assert vdm_types[0][0] == 0 # format: PCIe - assert ( - vdm_types[0][1].value == 0x1234 - ) # vendor_id (variant containing uint16) - assert vdm_types[0][2] == 0x5678 # cmd_set +async def test_query_vdm_types_no_control(dbus, mctpd): + """Test that we query VDM types if *only* a VDM type is reported in the + non-vendor Message Type Support response + """ + iface = mctpd.system.interfaces[0] + vdm_types = [ + VDMType(VDMType.TYPE_PCI, 0x1234, 0x5678), + ] + # only include the VDM type + ep = Endpoint( + iface, + bytes([0x1E]), + eid=15, + types=[VDMType.TYPE_PCI], + vdm_msg_types=vdm_types, + ) + mctpd.network.add_endpoint(ep) - # Verify second VDM type: IANA format (1), VID 0xabcdef12, cmd_set 0x3456 - assert vdm_types[1][0] == 1 # format: IANA - assert ( - vdm_types[1][1].value == 0xABCDEF12 - ) # vendor_id (variant containing uint32) - assert vdm_types[1][2] == 0x3456 # cmd_set + mctp = await mctpd_mctp_iface_obj(dbus, iface) + (eid, net, path, new) = await mctp.call_learn_endpoint(ep.lladdr) + + assert eid == ep.eid + + ep_obj = await mctpd_mctp_endpoint_common_obj(dbus, path) + + # Query VendorDefinedMessageTypes property + vdm_types = list(await ep_obj.get_vendor_defined_message_types()) + + assert len(vdm_types) == 1 class InvalidVDMEndpointBase(Endpoint):