Skip to content

Commit c7acd26

Browse files
Copybara import of the project:
-- de00614 by hsuan-lun-chiang <hsuan-lun.chiang@cienet.com>: Migrate Gpt3 to NNX. COPYBARA_INTEGRATE_REVIEW=#2062 from CIeNET-International:feat/Migrate-Gpt3-to-NNX de00614 PiperOrigin-RevId: 839048626
1 parent 177a75f commit c7acd26

File tree

2 files changed

+204
-163
lines changed

2 files changed

+204
-163
lines changed

src/MaxText/layers/decoders.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -412,7 +412,7 @@ def get_decoder_layers(self):
412412
case DecoderBlockType.GEMMA3:
413413
return [gemma3.Gemma3DecoderLayerToLinen]
414414
case DecoderBlockType.GPT3:
415-
return [gpt3.Gpt3DecoderLayer]
415+
return [gpt3.Gpt3DecoderLayerToLinen]
416416
case DecoderBlockType.GPT_OSS:
417417
return [gpt_oss.GptOssScannableBlockToLinen] if self.config.scan_layers else [gpt_oss.GptOssDecoderLayerToLinen]
418418
case DecoderBlockType.QWEN3:
@@ -598,7 +598,7 @@ def _apply_embedding(
598598
name="position_embedder",
599599
config=cfg,
600600
mesh=self.mesh,
601-
)(decoder_positions, model_mode=model_mode)
601+
)(decoder_positions.astype("int32"), model_mode=model_mode)
602602
return y
603603

604604
@nn.compact

0 commit comments

Comments
 (0)