|
| 1 | +--- |
| 2 | +title: embeddings |
| 3 | +description: Calculate sentence embeddings on node strings using pytorch. |
| 4 | +--- |
| 5 | + |
| 6 | +# embeddings |
| 7 | + |
| 8 | +import { Cards } from 'nextra/components' |
| 9 | +import GitHub from '/components/icons/GitHub' |
| 10 | +import { Callout } from 'nextra/components' |
| 11 | + |
| 12 | +The embeddings module provides tools for calculating sentence embeddings on node strings using pytorch. |
| 13 | + |
| 14 | +<Cards> |
| 15 | + <Cards.Card |
| 16 | + icon={<GitHub />} |
| 17 | + title="Source code" |
| 18 | + href="https://github.com/memgraph/mage/blob/main/python/embeddings.py" |
| 19 | + /> |
| 20 | +</Cards> |
| 21 | + |
| 22 | +| Trait | Value | |
| 23 | +| ------------------- | ------------------- | |
| 24 | +| **Module type** | algorithm | |
| 25 | +| **Implementation** | Python | |
| 26 | +| **Parallelism** | parallel | |
| 27 | + |
| 28 | + |
| 29 | +## Procedures |
| 30 | + |
| 31 | +### `node_sentence()` |
| 32 | + |
| 33 | +The procedure computes the sentence embeddings on the string properties of nodes. Embeddings are |
| 34 | +created as a property of the nodes in the graph. |
| 35 | + |
| 36 | +{<h4 className="custom-header"> Input: </h4>} |
| 37 | + |
| 38 | +- `input_nodes: List[Vertex]` (**OPTIONAL**) ➡ The list of nodes to compute the embeddings for. If not provided, the embeddings are computed for all nodes in the graph. |
| 39 | +- `configuration`: (`mgp.Map`, **OPTIONAL**): User defined parameters from query module. Defaults to `{}`. |
| 40 | + |
| 41 | +**Configuration options:** |
| 42 | + |
| 43 | +| Name | Type | Default | Description | |
| 44 | +|----------------------------|--------------|-------------------|----------------------------------------------------------------------------------------------------------| |
| 45 | +| `embedding_property` | string | `"embedding"` | The name of the property to store the embeddings in. | |
| 46 | +| `excluded_properties` | List[string] | `[]` | The list of properties to exclude from the embeddings computation. | |
| 47 | +| `model_name` | string | `"all-MiniLM-L6-v2"` | The name of the model to use for the embeddings computation, provided by the `sentence-transformers` library. | |
| 48 | +| `return_embeddings` | bool | `False` | Whether to return the embeddings as an additional output or not. | |
| 49 | +| `batch_size` | int | `2000` | The batch size to use for the embeddings computation. | |
| 50 | +| `chunk_size` | int | `48` | The number of batches per "chunk". This is used when computing embeddings across multiple GPUs, as this has to be done by spawning multiple processes. Each spawned process computes the embeddings for a single chunk. | |
| 51 | +| `device` | NULL\|string\| int\|List[string\|int] | `NULL` | The device to use for the embeddings computation (see below). | |
| 52 | + |
| 53 | +<Callout type="info"> |
| 54 | +The `device` parameter can be one of the following: |
| 55 | + - `NULL` (default) - Use first GPU if available, otherwise use CPU. |
| 56 | + - `"cpu"` - Use CPU for computation. |
| 57 | + - `"cuda"` or `"all"` - Use all available CUDA devices for computation. |
| 58 | + - `"cuda:id"` - Use a specific CUDA device for computation. |
| 59 | + - `id` - Use a specific device for computation. |
| 60 | + - `[id1, id2, ...]` - Use a list of device ids for computation. |
| 61 | + - `["cuda:id1", "cuda:id2", ...]` - Use a list of CUDA devices for computation. |
| 62 | + |
| 63 | +**Note**: If you're running on a GPU device, make sure to start your container |
| 64 | +with the `--gpus=all` flag. |
| 65 | +For more details, see the [Install MAGE |
| 66 | +documentation](/advanced-algorithms/install-mage). |
| 67 | +</Callout> |
| 68 | + |
| 69 | + |
| 70 | +{<h4 className="custom-header"> Output: </h4>} |
| 71 | + |
| 72 | +- `success: bool` ➡ Whether the embeddings computation was successful. |
| 73 | +- `embeddings: List[List[float]]|NULL` ➡ The list of embeddings. Only returned if the |
| 74 | +`return_embeddings` parameter is set to `true` in the configuration, otherwise `NULL`. |
| 75 | +- `dimension: int` ➡ The dimension of the embeddings. |
| 76 | + |
| 77 | +{<h4 className="custom-header"> Usage: </h4>} |
| 78 | + |
| 79 | +To compute the embeddings across the entire graph with the default parameters, |
| 80 | +use the following query: |
| 81 | + |
| 82 | +```cypher |
| 83 | +CALL embeddings.node_sentence() |
| 84 | +YIELD success; |
| 85 | +``` |
| 86 | + |
| 87 | +To compute the embeddings for a specific list of nodes, use the following query: |
| 88 | + |
| 89 | + |
| 90 | +```cypher |
| 91 | +MATCH (n) |
| 92 | +WITH n ORDER BY id(n) |
| 93 | +LIMIT 5 |
| 94 | +WITH collect(n) AS subset |
| 95 | +CALL embeddings.node_sentence(subset) |
| 96 | +YIELD success; |
| 97 | +``` |
| 98 | + |
| 99 | +To run the computation on specific device(s), use the following query: |
| 100 | + |
| 101 | +```cypher |
| 102 | +WITH {device: "cuda:1"} AS configuration |
| 103 | +CALL embeddings.node_sentence(NULL, configuration) |
| 104 | +YIELD success; |
| 105 | +``` |
| 106 | + |
| 107 | +To return the embeddings as an additional output, use the following query: |
| 108 | + |
| 109 | +```cypher |
| 110 | +WITH {return_embeddings: True} AS configuration |
| 111 | +CALL embeddings.node_sentence(NULL, configuration) |
| 112 | +YIELD success, embeddings; |
| 113 | +``` |
| 114 | + |
| 115 | + |
| 116 | +### `text()` |
| 117 | + |
| 118 | +This procedure can be used to return a list of embeddings when given a list of strings. |
| 119 | + |
| 120 | +{<h4 className="custom-header"> Input: </h4>} |
| 121 | + |
| 122 | +- `strings: List[string]` ➡ The list of strings to compute the embeddings for. |
| 123 | +- `configuration: mgp.Map` (**OPTIONAL**) ➡ User defined parameters from query module. Defaults to `{}`. |
| 124 | + |
| 125 | +**Configuration options:** |
| 126 | + |
| 127 | +| Name | Type | Default | Description | |
| 128 | +|----------------------------|--------------|-------------------|----------------------------------------------------------------------------------------------------------| |
| 129 | +| `model_name` | string | `"all-MiniLM-L6-v2"` | The name of the model to use for the embeddings computation, provided by the `sentence-transformers` library. | |
| 130 | +| `batch_size` | int | `2000` | The batch size to use for the embeddings computation. | |
| 131 | +| `chunk_size` | int | `48` | The number of batches per "chunk". This is used when computing embeddings across multiple GPUs, as this has to be done by spawning multiple processes. Each spawned process computes the embeddings for a single chunk. | |
| 132 | +| `device` | NULL\|string\| int\|List[string\|int] | `NULL` | The device to use for the embeddings computation. | |
| 133 | + |
| 134 | + |
| 135 | +{<h4 className="custom-header"> Output: </h4>} |
| 136 | + |
| 137 | +- `success: bool` ➡ Whether the embeddings computation was successful. |
| 138 | +- `embeddings: List[List[float]]` ➡ The list of embeddings. |
| 139 | +- `dimension: int` ➡ The dimension of the embeddings. |
| 140 | + |
| 141 | +{<h4 className="custom-header"> Usage: </h4>} |
| 142 | + |
| 143 | +To compute the embeddings for a list of strings, use the following query: |
| 144 | + |
| 145 | +```cypher |
| 146 | +CALL embeddings.text(["Hello", "World"]) |
| 147 | +YIELD success, embeddings; |
| 148 | +``` |
| 149 | + |
| 150 | +### `model_info()` |
| 151 | + |
| 152 | +The procedure returns the information about the model used for the embeddings computation. |
| 153 | + |
| 154 | +{<h4 className="custom-header"> Input: </h4>} |
| 155 | + |
| 156 | +- `configuration: mgp.Map` (**OPTIONAL**) ➡ User defined parameters from query module. Defaults to `{}`. |
| 157 | +The key `model_name` is used to specify the name of the model to use for the embeddings computation. |
| 158 | + |
| 159 | +{<h4 className="custom-header"> Output: </h4>} |
| 160 | + |
| 161 | +- `model_info: mgp.Map` ➡ The information about the model used for the embeddings computation. |
| 162 | + |
| 163 | +| Name | Type | Default | Description | |
| 164 | +|----------------------------|--------------|-------------------|----------------------------------------------------------------------------------------------------------| |
| 165 | +| `model_name` | string | `"all-MiniLM-L6-v2"` | The name of the model to use for the embeddings computation, provided by the `sentence-transformers` library. | |
| 166 | +| `dimension` | int | `384` | The dimension of the embeddings. | |
| 167 | +| `max_seq_length` | int | `256` | The maximum sequence length. | |
| 168 | + |
| 169 | +## Example |
| 170 | + |
| 171 | +Create the following graph: |
| 172 | + |
| 173 | +```cypher |
| 174 | +CREATE (a:Node {id: 1, Title: "Stilton", Description: "A stinky cheese from the UK"}), |
| 175 | +(b:Node {id: 2, Title: "Roquefort", Description: "A blue cheese from France"}), |
| 176 | +(c:Node {id: 3, Title: "Cheddar", Description: "A yellow cheese from the UK"}), |
| 177 | +(d:Node {id: 4, Title: "Gouda", Description: "A Dutch cheese"}), |
| 178 | +(e:Node {id: 5, Title: "Parmesan", Description: "An Italian cheese"}), |
| 179 | +(f:Node {id: 6, Title: "Red Leicester", Description: "The best cheese in the world"}); |
| 180 | +``` |
| 181 | + |
| 182 | +Run the following query to compute the embeddings: |
| 183 | + |
| 184 | +```cypher |
| 185 | +CALL embeddings.node_sentence() |
| 186 | +YIELD success; |
| 187 | +
|
| 188 | +MATCH (n) |
| 189 | +WHERE n.embedding IS NOT NULL |
| 190 | +RETURN n.Title, n.embedding; |
| 191 | +``` |
| 192 | + |
| 193 | +Results: |
| 194 | + |
| 195 | +```plaintext |
| 196 | ++---------+ |
| 197 | +| success | |
| 198 | ++---------+ |
| 199 | +| true | |
| 200 | ++---------+ |
| 201 | ++----------------------------------------------------------------------+----------------------------------------------------------------------+ |
| 202 | +| n.Title | n.embedding | |
| 203 | ++----------------------------------------------------------------------+----------------------------------------------------------------------+ |
| 204 | +| "Stilton" | [-0.0485366, -0.021823, 0.0159757, 0.0376443, 0.00594089, -0.0044... | |
| 205 | +| "Roquefort" | [-0.0252884, 0.0250485, -0.0249728, 0.0571037, 0.0386177, 0.03863... | |
| 206 | +| "Cheddar" | [-0.0129724, -0.00756301, -0.00379329, 0.0037531, -0.0134941, 0.0... | |
| 207 | +| "Gouda" | [0.0128716, 0.025435, -0.0288951, 0.0177759, -0.0624398, 0.043577... | |
| 208 | +| "Parmesan" | [-0.0755439, 0.00906182, -0.010977, 0.0208911, -0.0527448, 0.0085... | |
| 209 | +| "Red Leicester" | [-0.0244318, -0.0280038, -0.0373183, 0.0284436, -0.0277753, 0.066... | |
| 210 | ++----------------------------------------------------------------------+----------------------------------------------------------------------+ |
| 211 | +``` |
| 212 | + |
| 213 | +To compute the embeddings for a list of strings, use the following query: |
| 214 | + |
| 215 | +```cypher |
| 216 | +CALL embeddings.text(["Hello", "World"]) |
| 217 | +YIELD success, embeddings; |
| 218 | +``` |
| 219 | + |
| 220 | +Results: |
| 221 | + |
| 222 | +```plaintext |
| 223 | ++----------------------------------------------------------+----------------------------------------------------------------------------------+ |
| 224 | +| success | embeddings | |
| 225 | ++----------------------------------------------------------+----------------------------------------------------------------------------------+ |
| 226 | +| true | [[-0.0627718, 0.0549588, 0.0521648, 0.08579, -0.0827489, -0.074573, 0.0685547... | |
| 227 | ++----------------------------------------------------------+----------------------------------------------------------------------------------+ |
| 228 | +``` |
| 229 | + |
| 230 | +To get the information about the model used for the embeddings computation, use the following query: |
| 231 | + |
| 232 | +```cypher |
| 233 | +CALL embeddings.model_info() |
| 234 | +YIELD info; |
| 235 | +``` |
| 236 | + |
| 237 | +Results: |
| 238 | + |
| 239 | +```plaintext |
| 240 | ++----------------------------------------------------------------------------+ |
| 241 | +| info | |
| 242 | ++----------------------------------------------------------------------------+ |
| 243 | +| {dimension: 384, max_sequence_length: 256, model_name: "all-MiniLM-L6-v2"} | |
| 244 | ++----------------------------------------------------------------------------+ |
| 245 | +``` |
0 commit comments