@@ -395,12 +395,12 @@ private protocol Gemma3nRMSNormProtocol: UnaryLayer {
395395private class Gemma3nRMSNormWithScale : Module , Gemma3nRMSNormProtocol {
396396 let eps : Float
397397 let scaleShift : Float
398- @ ModuleInfo var weight : MLXArray
398+ let weight : MLXArray
399399
400400 init ( dim: Int , eps: Float = 1e-6 , scaleShift: Float = 0.0 ) {
401401 self . eps = eps
402402 self . scaleShift = scaleShift
403- self . _weight . wrappedValue = MLXArray . ones ( [ dim] )
403+ self . weight = MLXArray . ones ( [ dim] )
404404 super. init ( )
405405 }
406406
@@ -495,8 +495,8 @@ private class Gemma3nRotaryEmbedding: Module {
495495 let originalMaxSeqLen : Int
496496 let config : TextConfig
497497 let attentionScaling : Float
498- let _invFreq : MLXArray
499- let _originalInvFreq : MLXArray
498+ private let _invFreq : MLXArray
499+ private let _originalInvFreq : MLXArray
500500
501501 init ( config: TextConfig ) {
502502 if let ropeScaling = config. ropeScaling {
@@ -750,7 +750,7 @@ private class Gemma3nAltUp: Module {
750750 @ModuleInfo ( key: " prediction_coefs " ) var predictionCoefs : Linear
751751 @ModuleInfo ( key: " modality_router " ) var modalityRouter : Linear
752752 @ModuleInfo ( key: " router_norm " ) var routerNorm : Gemma3nRMSNormWithScale
753- let _routerInputScale : MLXArray
753+ private let _routerInputScale : MLXArray
754754
755755 let config : TextConfig
756756
@@ -2340,7 +2340,7 @@ private class Gemma3nAudioAttention: Module {
23402340
23412341 @ModuleInfo ( key: " relative_position_embedding " ) var relativePositionEmbedding :
23422342 Gemma3nAudioRelativePositionEmbedding
2343- @ ModuleInfo ( key : " per_dim_scale " ) var perDimScale : MLXArray
2343+ private let _perDimScale : MLXArray
23442344 @ModuleInfo ( key: " q_proj " ) var qProj : Linear
23452345 @ModuleInfo ( key: " k_proj " ) var kProj : Linear
23462346 @ModuleInfo ( key: " v_proj " ) var vProj : Linear
@@ -2359,7 +2359,7 @@ private class Gemma3nAudioAttention: Module {
23592359
23602360 self . _relativePositionEmbedding. wrappedValue = Gemma3nAudioRelativePositionEmbedding (
23612361 config: config)
2362- self . _perDimScale. wrappedValue = MLXArray . zeros ( [ headDim] )
2362+ self . _perDimScale = MLXArray . zeros ( [ headDim] )
23632363
23642364 self . _qProj. wrappedValue = Linear ( hiddenSize, numHeads * headDim, bias: false )
23652365 self . _kProj. wrappedValue = Linear ( hiddenSize, numHeads * headDim, bias: false )
@@ -2460,7 +2460,7 @@ private class Gemma3nAudioAttention: Module {
24602460 Array ( x. shape. dropLast ( ) ) + [ numHeads, headDim]
24612461 )
24622462
2463- let perDimScaleSp = logAddExp ( perDimScale , MLXArray ( 0.0 ) )
2463+ let perDimScaleSp = logAddExp ( _perDimScale , MLXArray ( 0.0 ) )
24642464 let broadcastShape = [ 1 , 1 , 1 , headDim]
24652465 let perDimScaleSpBroadcast = perDimScaleSp. reshaped ( broadcastShape)
24662466 let scaledQueryStates = queryStates * qScale * perDimScaleSpBroadcast
@@ -2728,19 +2728,19 @@ private class Gemma3nAudioConformerBlock: Module {
27282728// MARK: - Layer Scale 2D
27292729private class LayerScale2d : Module , UnaryLayer {
27302730 let inplace : Bool
2731- @ ModuleInfo var gamma : MLXArray
2731+ private let _gamma : MLXArray
27322732
27332733 init ( dim: Int , initValues: Float = 1e-5 , inplace: Bool = false ) {
27342734 self . inplace = inplace
2735- self . _gamma. wrappedValue = MLXArray ( initValues) * MLXArray. ones ( [ dim] )
2735+ self . _gamma = MLXArray ( initValues) * MLXArray. ones ( [ dim] )
27362736 super. init ( )
27372737 }
27382738
27392739 func callAsFunction( _ x: MLXArray ) -> MLXArray {
27402740 if inplace {
2741- return x * gamma
2741+ return x * _gamma
27422742 } else {
2743- return x * gamma
2743+ return x * _gamma
27442744 }
27452745 }
27462746}
0 commit comments