11from typing import Callable , Dict , List , Optional
22
3+ from tenacity import ( # for exponential backoff
4+ retry ,
5+ stop_after_attempt ,
6+ wait_random_exponential ,
7+ )
8+
39from redisvl .vectorize .base import BaseVectorizer
410
511
612class OpenAITextVectorizer (BaseVectorizer ):
13+ # TODO - add docstring
714 def __init__ (self , model : str , api_config : Optional [Dict ] = None ):
815 dims = 1536
916 super ().__init__ (model , dims , api_config )
@@ -18,42 +25,112 @@ def __init__(self, model: str, api_config: Optional[Dict] = None):
1825 openai .api_key = api_config .get ("api_key" , None )
1926 self ._model_client = openai .Embedding
2027
28+ @retry (wait = wait_random_exponential (min = 1 , max = 60 ), stop = stop_after_attempt (6 ))
2129 def embed_many (
2230 self ,
23- inputs : List [str ],
31+ texts : List [str ],
2432 preprocess : Optional [Callable ] = None ,
25- chunk_size : int = 1000 ,
33+ batch_size : Optional [int ] = 10 ,
34+ as_buffer : Optional [float ] = False ,
2635 ) -> List [List [float ]]:
27- results = []
28- for batch in self .batchify (inputs , chunk_size , preprocess ):
36+ """Embed many chunks of texts using the OpenAI API.
37+
38+ Args:
39+ texts (List[str]): List of text chunks to embed.
40+ preprocess (Optional[Callable], optional): Optional preprocessing callable to
41+ perform before vectorization. Defaults to None.
42+ batch_size (int, optional): Batch size of texts to use when creating
43+ embeddings. Defaults to 10.
44+ as_buffer (Optional[float], optional): Whether to convert the raw embedding
45+ to a byte string. Defaults to False.
46+
47+ Returns:
48+ List[List[float]]: List of embeddings.
49+ """
50+ embeddings : List = []
51+ for batch in self .batchify (texts , batch_size , preprocess ):
2952 response = self ._model_client .create (input = batch , engine = self ._model )
30- results += [r ["embedding" ] for r in response ["data" ]]
31- return results
53+ embeddings += [
54+ self ._process_embedding (r ["embedding" ], as_buffer )
55+ for r in response ["data" ]
56+ ]
57+ return embeddings
3258
59+ @retry (wait = wait_random_exponential (min = 1 , max = 60 ), stop = stop_after_attempt (6 ))
3360 def embed (
34- self , emb_input : str , preprocess : Optional [Callable ] = None
61+ self ,
62+ text : str ,
63+ preprocess : Optional [Callable ] = None ,
64+ as_buffer : Optional [float ] = False ,
3565 ) -> List [float ]:
66+ """Embed a chunk of text using the OpenAI API.
67+
68+ Args:
69+ text (str): Chunk of text to embed.
70+ preprocess (Optional[Callable], optional): Optional preprocessing callable to
71+ perform before vectorization. Defaults to None.
72+ as_buffer (Optional[float], optional): Whether to convert the raw embedding
73+ to a byte string. Defaults to False.
74+
75+ Returns:
76+ List[float]: Embedding.
77+ """
3678 if preprocess :
37- emb_input = preprocess (emb_input )
38- result = self ._model_client .create (input = [emb_input ], engine = self ._model )
39- return result ["data" ][0 ]["embedding" ]
79+ text = preprocess (text )
80+ result = self ._model_client .create (input = [text ], engine = self ._model )
81+ return self . _process_embedding ( result ["data" ][0 ]["embedding" ], as_buffer )
4082
83+ @retry (wait = wait_random_exponential (min = 1 , max = 60 ), stop = stop_after_attempt (6 ))
4184 async def aembed_many (
4285 self ,
43- inputs : List [str ],
86+ texts : List [str ],
4487 preprocess : Optional [Callable ] = None ,
45- chunk_size : int = 1000 ,
88+ batch_size : int = 1000 ,
89+ as_buffer : Optional [bool ] = False ,
4690 ) -> List [List [float ]]:
47- results = []
48- for batch in self .batchify (inputs , chunk_size , preprocess ):
91+ """Asynchronously embed many chunks of texts using the OpenAI API.
92+
93+ Args:
94+ texts (List[str]): List of text chunks to embed.
95+ preprocess (Optional[Callable], optional): Optional preprocessing callable to
96+ perform before vectorization. Defaults to None.
97+ batch_size (int, optional): Batch size of texts to use when creating
98+ embeddings. Defaults to 10.
99+ as_buffer (Optional[float], optional): Whether to convert the raw embedding
100+ to a byte string. Defaults to False.
101+
102+ Returns:
103+ List[List[float]]: List of embeddings.
104+ """
105+ embeddings : List = []
106+ for batch in self .batchify (texts , batch_size , preprocess ):
49107 response = await self ._model_client .acreate (input = batch , engine = self ._model )
50- results += [r ["embedding" ] for r in response ["data" ]]
51- return results
108+ embeddings += [
109+ self ._process_embedding (r ["embedding" ], as_buffer )
110+ for r in response ["data" ]
111+ ]
112+ return embeddings
52113
114+ @retry (wait = wait_random_exponential (min = 1 , max = 60 ), stop = stop_after_attempt (6 ))
53115 async def aembed (
54- self , emb_input : str , preprocess : Optional [Callable ] = None
116+ self ,
117+ text : str ,
118+ preprocess : Optional [Callable ] = None ,
119+ as_buffer : Optional [bool ] = False ,
55120 ) -> List [float ]:
121+ """Asynchronously embed a chunk of text using the OpenAI API.
122+
123+ Args:
124+ text (str): Chunk of text to embed.
125+ preprocess (Optional[Callable], optional): Optional preprocessing callable to
126+ perform before vectorization. Defaults to None.
127+ as_buffer (Optional[float], optional): Whether to convert the raw embedding
128+ to a byte string. Defaults to False.
129+
130+ Returns:
131+ List[float]: Embedding.
132+ """
56133 if preprocess :
57- emb_input = preprocess (emb_input )
58- result = await self ._model_client .acreate (input = [emb_input ], engine = self ._model )
59- return result ["data" ][0 ]["embedding" ]
134+ text = preprocess (text )
135+ result = await self ._model_client .acreate (input = [text ], engine = self ._model )
136+ return self . _process_embedding ( result ["data" ][0 ]["embedding" ], as_buffer )
0 commit comments