Skip to content

Commit 0b9084f

Browse files
committed
fixed merge conflicts; temporarily added granite 3.3to micro models using 3.2 as config is same
Signed-off-by: Joshua Rosenkranz <jmrosenk@us.ibm.com>
2 parents 3aa311d + 7d214a0 commit 0b9084f

File tree

3 files changed

+8
-20
lines changed

3 files changed

+8
-20
lines changed

tests/models/test_decoders.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -34,18 +34,20 @@
3434
except ImportError:
3535
GPTQ_ENABLED = False
3636

37-
ORIGINAL_HF_HOME = os.environ.get("HF_HOME", None)
3837
MICRO_MODELS_HOME = os.environ.get("FMS_TEST_SHAPES_MICRO_MODELS_HOME", "/mnt/home/models/tiny-models")
3938

4039
# Add models to test here
4140
LLAMA_3p1_8B_INSTRUCT = "meta-llama/Llama-3.1-8B-Instruct"
4241
GRANITE_3p2_8B_INSTRUCT = "ibm-granite/granite-3.2-8b-instruct"
42+
GRANITE_3p3_8B_INSTRUCT = "ibm-granite/granite-3.3-8b-instruct"
4343
GRANITE_20B_CODE_INSTRUCT_8K = "ibm-granite/granite-20b-code-instruct-8k"
4444
LLAMA_3p1_70B_INSTRUCT = "meta-llama/Llama-3.1-70B-Instruct"
4545

4646
micro_model_mapping = {
4747
LLAMA_3p1_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "llama-3.1-8b-layers-3-step-24000"),
4848
GRANITE_3p2_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "granite-3.2-8b-layers-3-step-100000"),
49+
# FIXME: Because this uses the same config as 3.2, re-using here, but should update
50+
GRANITE_3p3_8B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "granite-3.2-8b-layers-3-step-100000"),
4951
LLAMA_3p1_70B_INSTRUCT: os.path.join(MICRO_MODELS_HOME, "llama-3.1-70b-layers-3-step-24000")
5052
}
5153

@@ -67,6 +69,7 @@
6769
[
6870
LLAMA_3p1_8B_INSTRUCT,
6971
GRANITE_3p2_8B_INSTRUCT,
72+
GRANITE_3p3_8B_INSTRUCT,
7073
GRANITE_20B_CODE_INSTRUCT_8K,
7174
LLAMA_3p1_70B_INSTRUCT,
7275
],
@@ -149,6 +152,10 @@
149152
2.3919514417648315,
150153
0.0005767398688476533,
151154
),
155+
(GRANITE_3p3_8B_INSTRUCT, False): (
156+
2.4444521379470827,
157+
0.0004970188625156878,
158+
),
152159
(GRANITE_20B_CODE_INSTRUCT_8K, False): (
153160
2.640706129074097,
154161
0.00034344267623964697,
@@ -171,10 +178,6 @@ def reset_compiler():
171178
torch.compiler.reset()
172179
torch._dynamo.reset()
173180
os.environ.pop("COMPILATION_MODE", None)
174-
if ORIGINAL_HF_HOME is None:
175-
os.environ.pop("HF_HOME", None)
176-
else:
177-
os.environ["HF_HOME"] = ORIGINAL_HF_HOME
178181

179182

180183
# TODO: Currently, gptq does not have the same level of support as non-gptq models for get_model. This method provides the extra requirements for gptq for get_model,
@@ -316,9 +319,6 @@ def test_common_shapes(model_path, batch_size, seq_length, max_new_tokens):
316319
torch.manual_seed(42)
317320
os.environ["COMPILATION_MODE"] = "offline_decoder"
318321

319-
if "HF_HOME" not in os.environ:
320-
os.environ["HF_HOME"] = "/tmp/models/hf_cache"
321-
322322
dprint(
323323
f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}, max_new_tokens={max_new_tokens}, micro_model={USE_MICRO_MODELS}"
324324
)

tests/models/test_encoders.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
import os
1111
import numpy as np
1212

13-
ORIGINAL_HF_HOME = os.environ.get("HF_HOME", None)
14-
1513
# Add models to test here
1614
ROBERTA_SQUAD_V2 = "deepset/roberta-base-squad2"
1715

@@ -81,17 +79,10 @@ def reset_compiler():
8179
torch.compiler.reset()
8280
torch._dynamo.reset()
8381
os.environ.pop('COMPILATION_MODE', None)
84-
if ORIGINAL_HF_HOME is None:
85-
os.environ.pop('HF_HOME', None)
86-
else:
87-
os.environ['HF_HOME'] = ORIGINAL_HF_HOME
8882

8983
@pytest.mark.parametrize("model_path,batch_size,seq_length", common_shapes)
9084
def test_common_shapes(model_path, batch_size, seq_length):
9185
os.environ["COMPILATION_MODE"] = "offline"
92-
93-
if "HF_HOME" not in os.environ:
94-
os.environ["HF_HOME"] = "/tmp/models/hf_cache"
9586

9687
dprint(f"testing model={model_path}, batch_size={batch_size}, seq_length={seq_length}")
9788

tests/models/test_model_expectations.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,6 @@
1313

1414
os.environ["COMPILATION_MODE"] = "offline"
1515

16-
if "HF_HOME" not in os.environ:
17-
os.environ["HF_HOME"] = "/tmp/models/hf_cache"
18-
1916
model_dir = os.environ.get("FMS_TESTING_MODEL_DIR", "/tmp/models")
2017
LLAMA_3p1_8B_INSTRUCT = "meta-llama/Llama-3.1-8B-Instruct"
2118
GRANITE_3p2_8B_INSTRUCT = "ibm-granite/granite-3.2-8b-instruct"

0 commit comments

Comments
 (0)