@@ -386,76 +386,42 @@ public struct ModelConfig: Codable, Sendable {
386386
387387// MARK: - Language Model Components
388388
389- // Base protocol for RMSNorm variants
390- private protocol Gemma3nRMSNormProtocol : UnaryLayer {
391- func callAsFunction( _ x: MLXArray ) -> MLXArray
392- }
393-
394- // RMSNorm with scale parameter
395- private class Gemma3nRMSNormWithScale : Module , Gemma3nRMSNormProtocol {
389+ private class Gemma3nRMSNorm : Module {
396390 let eps : Float
397- let scaleShift : Float
398- let weight : MLXArray
391+ let scaleShift : Float ?
392+ let weight : MLXArray ?
399393
400- init ( dim: Int , eps: Float = 1e-6 , scaleShift: Float = 0.0 ) {
394+ init ( dim: Int , eps: Float = 1e-6 , scaleShift: Float ? = nil ) {
401395 self . eps = eps
402396 self . scaleShift = scaleShift
403- self . weight = MLXArray . ones ( [ dim] )
397+ self . weight = scaleShift != nil ? MLXArray . ones ( [ dim] ) : nil
404398 super. init ( )
405399 }
406400
407401 func callAsFunction( _ x: MLXArray ) -> MLXArray {
408402 let output = norm ( x. asType ( . float32) )
409- return ( output * ( weight + scaleShift) ) . asType ( x. dtype)
410- }
411403
412- private func norm( _ x: MLXArray ) -> MLXArray {
413- return x * rsqrt( x. square ( ) . mean ( axis: - 1 , keepDims: true ) + eps)
414- }
415- }
416-
417- // RMSNorm without scale parameter (no weight to load from checkpoint)
418- private class Gemma3nRMSNormNoScale : Module , Gemma3nRMSNormProtocol {
419- let eps : Float
420-
421- init ( dim: Int , eps: Float = 1e-6 ) {
422- self . eps = eps
423- super. init ( )
424- }
425-
426- func callAsFunction( _ x: MLXArray ) -> MLXArray {
427- let output = norm ( x. asType ( . float32) )
428- return output. asType ( x. dtype)
404+ if let weight, let scaleShift {
405+ return ( output * ( weight + scaleShift) ) . asType ( x. dtype)
406+ } else {
407+ return output. asType ( x. dtype)
408+ }
429409 }
430410
431411 private func norm( _ x: MLXArray ) -> MLXArray {
432412 return x * rsqrt( x. square ( ) . mean ( axis: - 1 , keepDims: true ) + eps)
433413 }
434414}
435415
436- // Factory function to create the appropriate RMSNorm variant
437- private func createGemma3nRMSNorm(
438- dim: Int ,
439- eps: Float = 1e-6 ,
440- scaleShift: Float = 0.0 ,
441- withScale: Bool = true
442- ) -> any Gemma3nRMSNormProtocol {
443- if withScale {
444- return Gemma3nRMSNormWithScale ( dim: dim, eps: eps, scaleShift: scaleShift)
445- } else {
446- return Gemma3nRMSNormNoScale ( dim: dim, eps: eps)
447- }
448- }
449-
450416private class Gemma3nLaurelBlock : Module {
451417 @ModuleInfo ( key: " linear_left " ) var linearLeft : Linear
452418 @ModuleInfo ( key: " linear_right " ) var linearRight : Linear
453- @ModuleInfo ( key: " post_laurel_norm " ) var postLaurelNorm : Gemma3nRMSNormWithScale
419+ @ModuleInfo ( key: " post_laurel_norm " ) var postLaurelNorm : Gemma3nRMSNorm
454420
455421 init ( config: TextConfig ) {
456422 self . _linearLeft. wrappedValue = Linear ( config. hiddenSize, config. laurelRank, bias: false )
457423 self . _linearRight. wrappedValue = Linear ( config. laurelRank, config. hiddenSize, bias: false )
458- self . _postLaurelNorm. wrappedValue = Gemma3nRMSNormWithScale (
424+ self . _postLaurelNorm. wrappedValue = Gemma3nRMSNorm (
459425 dim: config. hiddenSize,
460426 eps: config. rmsNormEps,
461427 scaleShift: 0.0
@@ -570,9 +536,9 @@ private class Gemma3nAttention: Module {
570536 @ModuleInfo ( key: " k_proj " ) var kProj : Linear
571537 @ModuleInfo ( key: " v_proj " ) var vProj : Linear
572538 @ModuleInfo ( key: " o_proj " ) var oProj : Linear
573- @ModuleInfo ( key: " q_norm " ) var qNorm : Gemma3nRMSNormWithScale
574- @ModuleInfo ( key: " k_norm " ) var kNorm : Gemma3nRMSNormWithScale
575- @ModuleInfo ( key: " v_norm " ) var vNorm : Gemma3nRMSNormNoScale
539+ @ModuleInfo ( key: " q_norm " ) var qNorm : Gemma3nRMSNorm
540+ @ModuleInfo ( key: " k_norm " ) var kNorm : Gemma3nRMSNorm
541+ @ModuleInfo ( key: " v_norm " ) var vNorm : Gemma3nRMSNorm
576542
577543 init ( config: TextConfig , layerIdx: Int ) {
578544 self . isSliding =
@@ -594,11 +560,11 @@ private class Gemma3nAttention: Module {
594560 self . _vProj. wrappedValue = Linear ( dim, numKVHeads * headDim, bias: false )
595561 self . _oProj. wrappedValue = Linear ( numHeads * headDim, dim, bias: false )
596562
597- self . _qNorm. wrappedValue = Gemma3nRMSNormWithScale (
563+ self . _qNorm. wrappedValue = Gemma3nRMSNorm (
598564 dim: config. headDim, eps: config. rmsNormEps)
599- self . _kNorm. wrappedValue = Gemma3nRMSNormWithScale (
565+ self . _kNorm. wrappedValue = Gemma3nRMSNorm (
600566 dim: config. headDim, eps: config. rmsNormEps)
601- self . _vNorm. wrappedValue = Gemma3nRMSNormNoScale (
567+ self . _vNorm. wrappedValue = Gemma3nRMSNorm (
602568 dim: config. headDim,
603569 eps: config. rmsNormEps
604570 )
@@ -749,7 +715,7 @@ private class Gemma3nAltUp: Module {
749715 @ModuleInfo ( key: " correction_coefs " ) var correctionCoefs : Linear
750716 @ModuleInfo ( key: " prediction_coefs " ) var predictionCoefs : Linear
751717 @ModuleInfo ( key: " modality_router " ) var modalityRouter : Linear
752- @ModuleInfo ( key: " router_norm " ) var routerNorm : Gemma3nRMSNormWithScale
718+ @ModuleInfo ( key: " router_norm " ) var routerNorm : Gemma3nRMSNorm
753719 private let _routerInputScale : MLXArray
754720
755721 let config : TextConfig
@@ -773,7 +739,7 @@ private class Gemma3nAltUp: Module {
773739 config. altupNumInputs,
774740 bias: false
775741 )
776- self . _routerNorm. wrappedValue = Gemma3nRMSNormWithScale (
742+ self . _routerNorm. wrappedValue = Gemma3nRMSNorm (
777743 dim: config. hiddenSize,
778744 eps: config. rmsNormEps,
779745 scaleShift: 0.0
@@ -784,8 +750,13 @@ private class Gemma3nAltUp: Module {
784750 }
785751
786752 func computeRouterModalities( _ x: MLXArray ) -> MLXArray {
787- let routerInputs =
788- routerNorm ( x) * _routerInputScale. asType ( routerNorm. weight. dtype)
753+ guard let routerNormWeight = routerNorm. weight else {
754+ // This should never happen, since `routerNorm` is assigned `Gemma3nRMSNorm` with `scaleShift`, so `weight` should not be nil
755+ fatalError ( " routerNorm.weight is nil " )
756+ }
757+
758+ let routerInputs = routerNorm ( x) * _routerInputScale. asType ( routerNormWeight. dtype)
759+
789760 let routed = modalityRouter ( routerInputs) . asType ( . float32)
790761 return tanh ( routed)
791762 }
@@ -875,17 +846,15 @@ private class Gemma3nDecoderLayer: Module {
875846
876847 @ModuleInfo ( key: " self_attn " ) var selfAttn : Gemma3nAttention
877848 @ModuleInfo var mlp : MLP
878- @ModuleInfo ( key: " input_layernorm " ) var inputLayernorm : Gemma3nRMSNormWithScale
879- @ModuleInfo ( key: " post_attention_layernorm " ) var postAttentionLayernorm : Gemma3nRMSNormWithScale
880- @ModuleInfo ( key: " pre_feedforward_layernorm " ) var preFeedforwardLayernorm :
881- Gemma3nRMSNormWithScale
882- @ModuleInfo ( key: " post_feedforward_layernorm " ) var postFeedforwardLayernorm :
883- Gemma3nRMSNormWithScale
849+ @ModuleInfo ( key: " input_layernorm " ) var inputLayernorm : Gemma3nRMSNorm
850+ @ModuleInfo ( key: " post_attention_layernorm " ) var postAttentionLayernorm : Gemma3nRMSNorm
851+ @ModuleInfo ( key: " pre_feedforward_layernorm " ) var preFeedforwardLayernorm : Gemma3nRMSNorm
852+ @ModuleInfo ( key: " post_feedforward_layernorm " ) var postFeedforwardLayernorm : Gemma3nRMSNorm
884853 @ModuleInfo var altup : Gemma3nAltUp
885854 @ModuleInfo var laurel : Gemma3nLaurelBlock
886855 @ModuleInfo ( key: " per_layer_input_gate " ) var perLayerInputGate : Linear
887856 @ModuleInfo ( key: " per_layer_projection " ) var perLayerProjection : Linear
888- @ModuleInfo ( key: " post_per_layer_input_norm " ) var postPerLayerInputNorm : Gemma3nRMSNormWithScale
857+ @ModuleInfo ( key: " post_per_layer_input_norm " ) var postPerLayerInputNorm : Gemma3nRMSNorm
889858
890859 init ( config: TextConfig , layerIdx: Int ) {
891860 self . config = config
@@ -901,23 +870,23 @@ private class Gemma3nDecoderLayer: Module {
901870 == " sliding_attention "
902871
903872 self . _mlp. wrappedValue = MLP ( config: config, layerIdx: layerIdx)
904- self . _inputLayernorm. wrappedValue = Gemma3nRMSNormWithScale (
873+ self . _inputLayernorm. wrappedValue = Gemma3nRMSNorm (
905874 dim: hiddenSize,
906875 eps: config. rmsNormEps,
907876 scaleShift: 0.0
908877 )
909878
910- self . _postAttentionLayernorm. wrappedValue = Gemma3nRMSNormWithScale (
879+ self . _postAttentionLayernorm. wrappedValue = Gemma3nRMSNorm (
911880 dim: hiddenSize,
912881 eps: config. rmsNormEps,
913882 scaleShift: 0.0
914883 )
915- self . _preFeedforwardLayernorm. wrappedValue = Gemma3nRMSNormWithScale (
884+ self . _preFeedforwardLayernorm. wrappedValue = Gemma3nRMSNorm (
916885 dim: hiddenSize,
917886 eps: config. rmsNormEps,
918887 scaleShift: 0.0
919888 )
920- self . _postFeedforwardLayernorm. wrappedValue = Gemma3nRMSNormWithScale (
889+ self . _postFeedforwardLayernorm. wrappedValue = Gemma3nRMSNorm (
921890 dim: hiddenSize,
922891 eps: config. rmsNormEps,
923892 scaleShift: 0.0
@@ -936,7 +905,7 @@ private class Gemma3nDecoderLayer: Module {
936905 hiddenSize,
937906 bias: false
938907 )
939- self . _postPerLayerInputNorm. wrappedValue = Gemma3nRMSNormWithScale (
908+ self . _postPerLayerInputNorm. wrappedValue = Gemma3nRMSNorm (
940909 dim: hiddenSize,
941910 eps: config. rmsNormEps,
942911 scaleShift: 0.0
@@ -1049,13 +1018,12 @@ private class Gemma3Model: Module {
10491018 @ModuleInfo ( key: " layers " ) var layers : [ Gemma3nDecoderLayer ]
10501019 @ModuleInfo ( key: " embed_tokens_per_layer " ) var embedTokensPerLayer : Embedding
10511020 @ModuleInfo ( key: " per_layer_model_projection " ) var perLayerModelProjection : Linear
1052- @ModuleInfo ( key: " per_layer_projection_norm " ) var perLayerProjectionNorm :
1053- Gemma3nRMSNormWithScale
1021+ @ModuleInfo ( key: " per_layer_projection_norm " ) var perLayerProjectionNorm : Gemma3nRMSNorm
10541022
10551023 @ModuleInfo ( key: " altup_projections " ) var altupProjections : [ Linear ]
10561024 @ModuleInfo ( key: " altup_unembed_projections " ) var altupUnembedProjections : [ Linear ]
10571025
1058- @ModuleInfo var norm : Gemma3nRMSNormWithScale
1026+ @ModuleInfo var norm : Gemma3nRMSNorm
10591027 @ModuleInfo ( key: " rope_embedding " ) var ropeEmbedding : Gemma3nRotaryEmbedding
10601028 @ModuleInfo ( key: " rope_embedding_local " ) var ropeEmbeddingLocal : Gemma3nRotaryEmbedding
10611029
@@ -1090,7 +1058,7 @@ private class Gemma3Model: Module {
10901058 bias: false
10911059 )
10921060
1093- self . _perLayerProjectionNorm. wrappedValue = Gemma3nRMSNormWithScale (
1061+ self . _perLayerProjectionNorm. wrappedValue = Gemma3nRMSNorm (
10941062 dim: config. hiddenSizePerLayerInput,
10951063 eps: config. rmsNormEps,
10961064 scaleShift: 0.0
@@ -1103,7 +1071,7 @@ private class Gemma3Model: Module {
11031071 Linear ( config. hiddenSize, config. hiddenSize, bias: false )
11041072 }
11051073
1106- self . _norm. wrappedValue = Gemma3nRMSNormWithScale (
1074+ self . _norm. wrappedValue = Gemma3nRMSNorm (
11071075 dim: config. hiddenSize,
11081076 eps: config. rmsNormEps,
11091077 scaleShift: 0.0
@@ -1375,11 +1343,11 @@ private class Gemma3nMultimodalEmbedder: Module, UnaryLayer {
13751343 let textHiddenSize : Int
13761344
13771345 @ModuleInfo var embedding : Embedding
1378- @ModuleInfo ( key: " hard_embedding_norm " ) var hardEmbeddingNorm : Gemma3nRMSNormWithScale
1379- @ModuleInfo ( key: " soft_embedding_norm " ) var softEmbeddingNorm : Gemma3nRMSNormWithScale
1346+ @ModuleInfo ( key: " hard_embedding_norm " ) var hardEmbeddingNorm : Gemma3nRMSNorm
1347+ @ModuleInfo ( key: " soft_embedding_norm " ) var softEmbeddingNorm : Gemma3nRMSNorm
13801348 @ModuleInfo ( key: " embedding_projection " ) var embeddingProjection : Linear
13811349 @ModuleInfo ( key: " embedding_post_projection_norm " ) var embeddingPostProjectionNorm :
1382- Gemma3nRMSNormNoScale
1350+ Gemma3nRMSNorm
13831351
13841352 init ( multimodalConfig: any MultimodalConfig , textConfig: TextConfig ) {
13851353 self . multimodalHiddenSize = multimodalConfig. hiddenSize
@@ -1392,11 +1360,11 @@ private class Gemma3nMultimodalEmbedder: Module, UnaryLayer {
13921360 embeddingCount: vocabSize,
13931361 dimensions: multimodalHiddenSize
13941362 )
1395- self . _hardEmbeddingNorm. wrappedValue = Gemma3nRMSNormWithScale (
1363+ self . _hardEmbeddingNorm. wrappedValue = Gemma3nRMSNorm (
13961364 dim: multimodalHiddenSize,
13971365 eps: eps
13981366 )
1399- self . _softEmbeddingNorm. wrappedValue = Gemma3nRMSNormWithScale (
1367+ self . _softEmbeddingNorm. wrappedValue = Gemma3nRMSNorm (
14001368 dim: multimodalHiddenSize,
14011369 eps: eps
14021370 )
@@ -1405,7 +1373,7 @@ private class Gemma3nMultimodalEmbedder: Module, UnaryLayer {
14051373 textHiddenSize,
14061374 bias: false
14071375 )
1408- self . _embeddingPostProjectionNorm. wrappedValue = Gemma3nRMSNormNoScale (
1376+ self . _embeddingPostProjectionNorm. wrappedValue = Gemma3nRMSNorm (
14091377 dim: textHiddenSize,
14101378 eps: eps
14111379 )
@@ -2538,21 +2506,21 @@ private class Gemma3nAudioConformerAttention: Module {
25382506 let postInFeatures : Int
25392507 private let _gradientClipping : MLXArray
25402508
2541- @ModuleInfo ( key: " pre_attn_norm " ) var preAttnNorm : Gemma3nRMSNormWithScale
2509+ @ModuleInfo ( key: " pre_attn_norm " ) var preAttnNorm : Gemma3nRMSNorm
25422510 @ModuleInfo var attn : Gemma3nAudioAttention
25432511 @ModuleInfo var post : Linear
2544- @ModuleInfo ( key: " post_norm " ) var postNorm : Gemma3nRMSNormWithScale
2512+ @ModuleInfo ( key: " post_norm " ) var postNorm : Gemma3nRMSNorm
25452513
25462514 init ( config: AudioConfig ) {
25472515 self . config = config
25482516 let headDim = config. hiddenSize / config. confNumAttentionHeads
25492517 self . postInFeatures = config. hiddenSize
25502518 self . _gradientClipping = MLXArray ( config. gradientClipping)
25512519
2552- self . _preAttnNorm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
2520+ self . _preAttnNorm. wrappedValue = Gemma3nRMSNorm ( dim: config. hiddenSize)
25532521 self . _attn. wrappedValue = Gemma3nAudioAttention ( config: config)
25542522 self . _post. wrappedValue = Linear ( postInFeatures, config. hiddenSize, bias: false )
2555- self . _postNorm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
2523+ self . _postNorm. wrappedValue = Gemma3nRMSNorm ( dim: config. hiddenSize)
25562524
25572525 super. init ( )
25582526 }
@@ -2581,20 +2549,20 @@ private class Gemma3nAudioConformerFeedForward: Module {
25812549 private let _gradientClipping : MLXArray
25822550 private let _postLayerScale : MLXArray
25832551
2584- @ModuleInfo ( key: " pre_layer_norm " ) var preLayerNorm : Gemma3nRMSNormWithScale
2552+ @ModuleInfo ( key: " pre_layer_norm " ) var preLayerNorm : Gemma3nRMSNorm
25852553 @ModuleInfo ( key: " ffw_layer_1 " ) var ffwLayer1 : Linear
25862554 @ModuleInfo ( key: " ffw_layer_2 " ) var ffwLayer2 : Linear
2587- @ModuleInfo ( key: " post_layer_norm " ) var postLayerNorm : Gemma3nRMSNormWithScale
2555+ @ModuleInfo ( key: " post_layer_norm " ) var postLayerNorm : Gemma3nRMSNorm
25882556
25892557 init ( config: AudioConfig ) {
25902558 self . config = config
25912559 self . _gradientClipping = MLXArray ( config. gradientClipping)
25922560 self . _postLayerScale = MLXArray ( config. confResidualWeight)
25932561
2594- self . _preLayerNorm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
2562+ self . _preLayerNorm. wrappedValue = Gemma3nRMSNorm ( dim: config. hiddenSize)
25952563 self . _ffwLayer1. wrappedValue = Linear ( config. hiddenSize, config. hiddenSize * 4 , bias: false )
25962564 self . _ffwLayer2. wrappedValue = Linear ( config. hiddenSize * 4 , config. hiddenSize, bias: false )
2597- self . _postLayerNorm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
2565+ self . _postLayerNorm. wrappedValue = Gemma3nRMSNorm ( dim: config. hiddenSize)
25982566
25992567 super. init ( )
26002568 }
@@ -2618,18 +2586,18 @@ private class Gemma3nAudioConformerLightConv1d: Module {
26182586 private let _gradientClipping : MLXArray
26192587 let causalPadding : Int
26202588
2621- @ModuleInfo ( key: " pre_layer_norm " ) var preLayerNorm : Gemma3nRMSNormWithScale
2589+ @ModuleInfo ( key: " pre_layer_norm " ) var preLayerNorm : Gemma3nRMSNorm
26222590 @ModuleInfo ( key: " linear_start " ) var linearStart : Linear
26232591 @ModuleInfo ( key: " depthwise_conv1d " ) var depthwiseConv1d : Conv1d
2624- @ModuleInfo ( key: " conv_norm " ) var convNorm : Gemma3nRMSNormWithScale
2592+ @ModuleInfo ( key: " conv_norm " ) var convNorm : Gemma3nRMSNorm
26252593 @ModuleInfo ( key: " linear_end " ) var linearEnd : Linear
26262594
26272595 init ( config: AudioConfig ) {
26282596 self . config = config
26292597 self . _gradientClipping = MLXArray ( config. gradientClipping)
26302598 self . causalPadding = config. confConvKernelSize - 1
26312599
2632- self . _preLayerNorm. wrappedValue = Gemma3nRMSNormWithScale (
2600+ self . _preLayerNorm. wrappedValue = Gemma3nRMSNorm (
26332601 dim: config. hiddenSize,
26342602 eps: config. rmsNormEps
26352603 )
@@ -2647,7 +2615,7 @@ private class Gemma3nAudioConformerLightConv1d: Module {
26472615 groups: config. hiddenSize,
26482616 bias: false
26492617 )
2650- self . _convNorm. wrappedValue = Gemma3nRMSNormWithScale (
2618+ self . _convNorm. wrappedValue = Gemma3nRMSNorm (
26512619 dim: config. hiddenSize,
26522620 eps: config. rmsNormEps
26532621 )
@@ -2690,7 +2658,7 @@ private class Gemma3nAudioConformerBlock: Module {
26902658 @ModuleInfo var attention : Gemma3nAudioConformerAttention
26912659 @ModuleInfo var lconv1d : Gemma3nAudioConformerLightConv1d
26922660 @ModuleInfo ( key: " ffw_layer_end " ) var ffwLayerEnd : Gemma3nAudioConformerFeedForward
2693- @ModuleInfo var norm : Gemma3nRMSNormWithScale
2661+ @ModuleInfo var norm : Gemma3nRMSNorm
26942662
26952663 init ( config: AudioConfig ) {
26962664 self . config = config
@@ -2700,7 +2668,7 @@ private class Gemma3nAudioConformerBlock: Module {
27002668 self . _attention. wrappedValue = Gemma3nAudioConformerAttention ( config: config)
27012669 self . _lconv1d. wrappedValue = Gemma3nAudioConformerLightConv1d ( config: config)
27022670 self . _ffwLayerEnd. wrappedValue = Gemma3nAudioConformerFeedForward ( config: config)
2703- self . _norm. wrappedValue = Gemma3nRMSNormWithScale ( dim: config. hiddenSize)
2671+ self . _norm. wrappedValue = Gemma3nRMSNorm ( dim: config. hiddenSize)
27042672
27052673 super. init ( )
27062674 }
0 commit comments