Skip to content

Commit 5a0e066

Browse files
committed
change to jnp.bfloat16
1 parent 77fd012 commit 5a0e066

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

tpu_inference/runner/compilation_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def _precompile_pooling(self) -> None:
114114

115115
for num_tokens in self.runner.num_tokens_paddings:
116116
hidden_states = self._create_dummy_tensor(
117-
(num_tokens, hidden_size), t2j_dtype(dtype), sharding=hidden_sharding)
117+
(num_tokens, hidden_size), jnp.bfloat16, sharding=hidden_sharding)
118118

119119
for num_reqs in self.runner.num_reqs_paddings:
120120
if num_reqs == 0 or num_reqs > num_tokens:

0 commit comments

Comments
 (0)