Skip to content

Commit 8dedd00

Browse files
fix(dspy): solving for null response bug in Cortex API
1 parent b0b18a7 commit 8dedd00

File tree

1 file changed

+19
-6
lines changed

1 file changed

+19
-6
lines changed

dsp/modules/snowflake.py

Lines changed: 19 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,8 +52,18 @@ def __init__(self, model: str = "mixtral-8x7b", credentials=None, **kwargs):
5252
super().__init__(model)
5353

5454
self.model = model
55-
cortex_models = ["llama3-8b","llama3-70b","reka-core","snowflake-arctic","mistral-large", "reka-flash", "mixtral-8x7b",
56-
"llama2-70b-chat", "mistral-7b", "gemma-7b"]
55+
cortex_models = [
56+
"llama3-8b",
57+
"llama3-70b",
58+
"reka-core",
59+
"snowflake-arctic",
60+
"mistral-large",
61+
"reka-flash",
62+
"mixtral-8x7b",
63+
"llama2-70b-chat",
64+
"mistral-7b",
65+
"gemma-7b",
66+
]
5767

5868
if model in cortex_models:
5969
self.available_args = {
@@ -81,9 +91,8 @@ def __init__(self, model: str = "mixtral-8x7b", credentials=None, **kwargs):
8191

8292
@classmethod
8393
def _init_cortex(cls, credentials: dict) -> None:
84-
8594
session = Session.builder.configs(credentials).create()
86-
session.query_tag = {"origin":"sf_sit", "name":"dspy", "version":{"major":1, "minor":0}}
95+
session.query_tag = {"origin": "sf_sit", "name": "dspy", "version": {"major": 1, "minor": 0}}
8796

8897
return session
8998

@@ -103,9 +112,13 @@ def _cortex_complete_request(self, prompt: str, **kwargs) -> dict:
103112
snow_func.lit([{"role": "user", "content": prompt}]),
104113
snow_func.lit(kwargs),
105114
)
106-
raw_response = self.client.range(1).withColumn("complete_cal", cortex_complete_args).collect()[0].COMPLETE_CAL
115+
raw_response = self.client.range(1).withColumn("complete_cal", cortex_complete_args).collect()
107116

108-
return json.loads(raw_response)
117+
if len(raw_response) > 0:
118+
return json.loads(raw_response[0].COMPLETE_CAL)
119+
120+
else:
121+
return json.loads('{"choices": [{"messages": "None"}]}')
109122

110123
def basic_request(self, prompt: str, **kwargs) -> list:
111124
raw_kwargs = kwargs

0 commit comments

Comments
 (0)