Skip to content

Commit cb245f5

Browse files
committed
Map promoted oneHotEncoding
1 parent 6cc2ddb commit cb245f5

File tree

2 files changed

+32
-1
lines changed

2 files changed

+32
-1
lines changed

graphdatascience/tests/integration/test_util_ops.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,13 @@ def test_util_nodeProperty(gds: GraphDataScience, G: Graph) -> None:
103103
assert result == 1337
104104

105105

106-
def test_ml_oneHotEncoding(gds: GraphDataScience) -> None:
106+
@pytest.mark.filterwarnings("ignore: .*gds.alpha.ml.oneHotEncoding.*")
107+
def test_alpha_ml_oneHotEncoding(gds: GraphDataScience) -> None:
107108
result = gds.alpha.ml.oneHotEncoding(["Chinese", "Indian", "Italian"], ["Italian"])
108109
assert result == [0, 0, 1]
110+
111+
112+
@pytest.mark.compatible_with(min_inclusive=ServerVersion(2, 24, 0))
113+
def test_util_oneHotEncoding(gds: GraphDataScience) -> None:
114+
result = gds.util.oneHotEncoding(["Chinese", "Indian", "Italian"], ["Italian"])
115+
assert result == [0, 0, 1]

graphdatascience/utils/util_proc_runner.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import Any
22

3+
from graphdatascience.call_parameters import CallParameters
4+
35
from ..error.illegal_attr_checker import IllegalAttrChecker
46
from ..error.uncallable_namespace import UncallableNamespace
57
from ..utils.util_node_property_func_runner import NodePropertyFuncRunner
@@ -43,3 +45,25 @@ def asNodes(self, node_ids: list[int]) -> list[Any]:
4345
@property
4446
def nodeProperty(self) -> NodePropertyFuncRunner:
4547
return NodePropertyFuncRunner(self._query_runner, self._namespace + ".nodeProperty", self._server_version)
48+
49+
def oneHotEncoding(self, available_values: list[Any], selected_values: list[Any]) -> list[int]:
50+
"""
51+
One hot encode a list of values.
52+
53+
Args:
54+
available_values: The available values to encode.
55+
selected_values: The values to encode.
56+
57+
Returns:
58+
The one hot encoded values.
59+
"""
60+
namespace = self._namespace + ".oneHotEncoding"
61+
62+
params = CallParameters(
63+
available_values=available_values,
64+
selected_values=selected_values,
65+
)
66+
query = f"RETURN {namespace}($available_values, $selected_values) AS encoded"
67+
result = self._query_runner.run_cypher(query=query, params=params)
68+
69+
return result.iat[0, 0] # type: ignore

0 commit comments

Comments
 (0)