Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 35 additions & 4 deletions haystack/core/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from haystack.core.component.component import _hook_component_init
from haystack.core.errors import DeserializationError, SerializationError
from haystack.utils.auth import Secret, deserialize_secrets_inplace
from haystack.utils.device import ComponentDevice
from haystack.utils.type_serialization import thread_safe_import

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -182,8 +183,8 @@ def default_to_dict(obj: Any, **init_parameters: Any) -> dict[str, Any]:
instance of `obj` with `from_dict`. Omitting them might cause deserialisation
errors or unexpected behaviours later, when calling `from_dict`.

Secret instances in `init_parameters` are automatically serialized by calling
their `to_dict()` method.
Secret and ComponentDevice instances in `init_parameters` are automatically
serialized by calling their `to_dict()` method.

An example usage:

Expand Down Expand Up @@ -213,17 +214,38 @@ def to_dict(self):
:returns:
A dictionary representation of the instance.
"""
# Automatically serialize Secret instances
# Automatically serialize Secret and ComponentDevice instances
serialized_params = {}
for key, value in init_parameters.items():
if isinstance(value, Secret):
if isinstance(value, (Secret, ComponentDevice)):
serialized_params[key] = value.to_dict()
else:
serialized_params[key] = value

return {"type": generate_qualified_class_name(type(obj)), "init_parameters": serialized_params}


def _is_serialized_component_device(value: Any) -> bool:
"""
Check if a value is a serialized ComponentDevice dictionary.

A dictionary is considered a serialized ComponentDevice if:
- It has "type": "single" and a "device" key with a string value, or
- It has "type": "multiple" and a "device_map" key with a dict value

This matches the structure produced by ComponentDevice.to_dict().
"""
if not isinstance(value, dict):
return False

type_value = value.get("type")
if type_value == "single":
return "device" in value and isinstance(value["device"], str)
elif type_value == "multiple":
return "device_map" in value and isinstance(value["device_map"], dict)
return False


def default_from_dict(cls: type[T], data: dict[str, Any]) -> T:
"""
Utility function to deserialize a dictionary to an object.
Expand All @@ -240,6 +262,10 @@ def default_from_dict(cls: type[T], data: dict[str, Any]) -> T:
deserialized. A dictionary is considered a serialized Secret if it has a "type" key
with value "env_var".

Serialized ComponentDevice dictionaries in `init_parameters` are automatically detected
and deserialized. A dictionary is considered a serialized ComponentDevice if it has a
"type" key with value "single" or "multiple".

:param cls:
The class to be used for deserialization.
:param data:
Expand All @@ -265,6 +291,11 @@ def default_from_dict(cls: type[T], data: dict[str, Any]) -> T:
if secret_keys:
deserialize_secrets_inplace(init_params, keys=secret_keys)

# Automatically detect and deserialize ComponentDevice instances
for key, value in init_params.items():
if _is_serialized_component_device(value):
init_params[key] = ComponentDevice.from_dict(value)

return cls(**init_params)


Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
---
features:
- |
component_from_dict() and component_to_dict() now work with custom components out of the box also if the component has a ComponentDevice as an attribute.
Users no longer need to explicitly define to_dict() and from_dict() methods in their custom components to call ComponentDevice.from_dict() or device.to_dict().
component_from_dict() and component_to_dict() now handle this automatically.
176 changes: 175 additions & 1 deletion test/core/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@
import_class_by_name,
)
from haystack.testing import factory
from haystack.utils import Secret
from haystack.utils import ComponentDevice, Secret
from haystack.utils.device import Device, DeviceMap


def test_default_component_to_dict():
Expand Down Expand Up @@ -253,3 +254,176 @@ def test_component_to_dict_and_from_dict_roundtrip_with_secret():
assert deserialized_comp.regular_param == "test"
assert deserialized_comp.api_key._env_vars == env_secret1._env_vars
assert deserialized_comp.token._env_vars == env_secret2._env_vars


@component
class CustomComponentWithDevice:
def __init__(
self,
device: ComponentDevice | None = None,
other_device: ComponentDevice | None = None,
name: str | None = None,
):
self.device = device
self.other_device = other_device
self.name = name

@component.output_types(value=str)
def run(self, value: str):
return {"value": value}


def test_component_to_dict_with_component_device():
"""Test that ComponentDevice instances are automatically serialized in component_to_dict."""
# Test with single device (CPU)
device = ComponentDevice.from_single(Device.cpu())
comp = CustomComponentWithDevice(device=device)
res = component_to_dict(comp, "test_component")
assert res["init_parameters"]["device"] == {"type": "single", "device": "cpu"}

# Test with single device (GPU with id)
device = ComponentDevice.from_single(Device.gpu(1))
comp = CustomComponentWithDevice(device=device)
res = component_to_dict(comp, "test_component")
assert res["init_parameters"]["device"] == {"type": "single", "device": "cuda:1"}

# Test with None
comp = CustomComponentWithDevice(device=None)
res = component_to_dict(comp, "test_component")
assert res["init_parameters"]["device"] is None

# Test with multiple devices (device map)
device_map = DeviceMap({"layer1": Device.gpu(0), "layer2": Device.gpu(1)})
device = ComponentDevice.from_multiple(device_map)
comp = CustomComponentWithDevice(device=device)
res = component_to_dict(comp, "test_component")
assert res["init_parameters"]["device"] == {
"type": "multiple",
"device_map": {"layer1": "cuda:0", "layer2": "cuda:1"},
}

# Test with multiple ComponentDevice params
device1 = ComponentDevice.from_single(Device.cpu())
device2 = ComponentDevice.from_single(Device.gpu(0))
comp = CustomComponentWithDevice(device=device1, other_device=device2, name="test")
res = component_to_dict(comp, "test_component")
assert res["init_parameters"]["device"] == {"type": "single", "device": "cpu"}
assert res["init_parameters"]["other_device"] == {"type": "single", "device": "cuda:0"}
assert res["init_parameters"]["name"] == "test"


def test_component_from_dict_with_component_device():
"""Test that serialized ComponentDevice dictionaries are automatically deserialized in component_from_dict."""
# Test with single device (CPU)
data = {
"type": generate_qualified_class_name(CustomComponentWithDevice),
"init_parameters": {"device": {"type": "single", "device": "cpu"}, "name": "test"},
}
comp = component_from_dict(CustomComponentWithDevice, data, "test_component")
assert isinstance(comp, CustomComponentWithDevice)
assert isinstance(comp.device, ComponentDevice)
assert comp.device.to_torch_str() == "cpu"
assert comp.name == "test"

# Test with single device (GPU with id)
data = {
"type": generate_qualified_class_name(CustomComponentWithDevice),
"init_parameters": {"device": {"type": "single", "device": "cuda:1"}, "name": "test"},
}
comp = component_from_dict(CustomComponentWithDevice, data, "test_component")
assert isinstance(comp.device, ComponentDevice)
assert comp.device.to_torch_str() == "cuda:1"

# Test with None
data = {
"type": generate_qualified_class_name(CustomComponentWithDevice),
"init_parameters": {"device": None, "name": "test"},
}
comp = component_from_dict(CustomComponentWithDevice, data, "test_component")
assert comp.device is None
assert comp.name == "test"

# Test with multiple devices (device map)
data = {
"type": generate_qualified_class_name(CustomComponentWithDevice),
"init_parameters": {"device": {"type": "multiple", "device_map": {"layer1": "cuda:0", "layer2": "cuda:1"}}},
}
comp = component_from_dict(CustomComponentWithDevice, data, "test_component")
assert isinstance(comp.device, ComponentDevice)
assert comp.device.has_multiple_devices

# Test with regular dict (not a ComponentDevice - different structure)
data = {
"type": generate_qualified_class_name(CustomComponentWithDevice),
"init_parameters": {"device": {"some": "dict"}, "name": "test"},
}
comp = component_from_dict(CustomComponentWithDevice, data, "test_component")
assert comp.device == {"some": "dict"}
assert comp.name == "test"

# Test with multiple ComponentDevice params
data = {
"type": generate_qualified_class_name(CustomComponentWithDevice),
"init_parameters": {
"device": {"type": "single", "device": "cpu"},
"other_device": {"type": "single", "device": "cuda:0"},
"name": "test",
},
}
comp = component_from_dict(CustomComponentWithDevice, data, "test_component")
assert isinstance(comp.device, ComponentDevice)
assert isinstance(comp.other_device, ComponentDevice)
assert comp.device.to_torch_str() == "cpu"
assert comp.other_device.to_torch_str() == "cuda:0"
assert comp.name == "test"


def test_component_to_dict_and_from_dict_roundtrip_with_component_device():
"""Test that serialization and deserialization work together for ComponentDevice."""
# Test roundtrip with single device
original_device = ComponentDevice.from_single(Device.cpu())
comp = CustomComponentWithDevice(device=original_device)

serialized = component_to_dict(comp, "test_component")
assert serialized["init_parameters"]["device"]["type"] == "single"

deserialized_comp = component_from_dict(CustomComponentWithDevice, serialized, "test_component")
assert isinstance(deserialized_comp.device, ComponentDevice)
assert deserialized_comp.device.to_torch_str() == original_device.to_torch_str()

# Test roundtrip with GPU device
original_device = ComponentDevice.from_single(Device.gpu(2))
comp = CustomComponentWithDevice(device=original_device)

serialized = component_to_dict(comp, "test_component")
deserialized_comp = component_from_dict(CustomComponentWithDevice, serialized, "test_component")
assert deserialized_comp.device.to_torch_str() == "cuda:2"

# Test roundtrip with device map
device_map = DeviceMap({"layer1": Device.gpu(0), "layer2": Device.cpu()})
original_device = ComponentDevice.from_multiple(device_map)
comp = CustomComponentWithDevice(device=original_device)

serialized = component_to_dict(comp, "test_component")
assert serialized["init_parameters"]["device"]["type"] == "multiple"

deserialized_comp = component_from_dict(CustomComponentWithDevice, serialized, "test_component")
assert isinstance(deserialized_comp.device, ComponentDevice)
assert deserialized_comp.device.has_multiple_devices

# Test roundtrip with multiple ComponentDevice params
device1 = ComponentDevice.from_single(Device.cpu())
device2 = ComponentDevice.from_single(Device.gpu(0))
comp = CustomComponentWithDevice(device=device1, other_device=device2, name="test")

serialized = component_to_dict(comp, "test_component")
assert serialized["init_parameters"]["device"]["type"] == "single"
assert serialized["init_parameters"]["other_device"]["type"] == "single"
assert serialized["init_parameters"]["name"] == "test"

deserialized_comp = component_from_dict(CustomComponentWithDevice, serialized, "test_component")
assert isinstance(deserialized_comp.device, ComponentDevice)
assert isinstance(deserialized_comp.other_device, ComponentDevice)
assert deserialized_comp.device.to_torch_str() == "cpu"
assert deserialized_comp.other_device.to_torch_str() == "cuda:0"
assert deserialized_comp.name == "test"
Loading