@@ -1146,18 +1146,18 @@ private class Gemma3Model: Module {
11461146 perLayerInputs: MLXArray ? = nil
11471147 ) -> MLXArray {
11481148 var h : MLXArray
1149- if let inputsEmbeds = inputsEmbeds {
1149+ if let inputsEmbeds {
11501150 h = inputsEmbeds
1151- } else if let inputs = inputs {
1151+ } else if let inputs {
11521152 h = embedTokens ( inputs)
11531153 } else {
11541154 fatalError ( " Either inputs or inputsEmbeds must be provided " )
11551155 }
11561156
11571157 let perLayerInputsProcessed : MLXArray
1158- if let perLayerInputs = perLayerInputs {
1158+ if let perLayerInputs {
11591159 perLayerInputsProcessed = perLayerInputs
1160- } else if let inputs = inputs {
1160+ } else if let inputs {
11611161 perLayerInputsProcessed = getPerLayerInputs ( inputs)
11621162 } else {
11631163 fatalError ( " Cannot generate per layer inputs without input ids " )
@@ -1213,7 +1213,7 @@ private class Gemma3Model: Module {
12131213 == " global_attention "
12141214
12151215 let localMask : MLXFast . ScaledDotProductAttentionMaskMode
1216- if let mask = mask {
1216+ if let mask {
12171217 localMask = mask
12181218 } else if isGlobal {
12191219 localMask = fullMask
@@ -1437,9 +1437,9 @@ private class Gemma3nMultimodalEmbedder: Module, UnaryLayer {
14371437 }
14381438
14391439 let embNorm : MLXArray
1440- if let inputsEmbeds = inputsEmbeds {
1440+ if let inputsEmbeds {
14411441 embNorm = softEmbeddingNorm ( inputsEmbeds)
1442- } else if let inputIds = inputIds {
1442+ } else if let inputIds {
14431443 let hardEmb = embedding ( inputIds - vocabOffset)
14441444 embNorm = hardEmbeddingNorm ( hardEmb)
14451445 } else {
@@ -1490,7 +1490,7 @@ private func gemma3nAttentionWithCacheUpdate(
14901490 // Update cache and get cached keys/values (matches Python's cache.update_and_fetch)
14911491 let ( cachedKeys, cachedValues) : ( MLXArray , MLXArray )
14921492
1493- if let cache = cache {
1493+ if let cache {
14941494 ( cachedKeys, cachedValues) = cache. update ( keys: keys, values: values)
14951495 } else {
14961496 ( cachedKeys, cachedValues) = ( keys, values)
@@ -1667,7 +1667,6 @@ private func maskedScatter(
16671667private func checkArrayShape( _ arr: MLXArray ) -> Bool {
16681668 let shape = arr. shape
16691669 guard shape. count == 4 else {
1670- print ( " 🔍 checkArrayShape: Array has \( shape. count) dimensions, not 4 " )
16711670 return false
16721671 }
16731672
@@ -1792,7 +1791,7 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
17921791 }
17931792
17941793 // Process audio features
1795- if let inputFeatures = inputFeatures , let inputFeaturesMask = inputFeaturesMask {
1794+ if let inputFeatures, let inputFeaturesMask = inputFeaturesMask {
17961795 let ( audioFeatures, audioMask) = getAudioFeatures ( inputFeatures, .! inputFeaturesMask)
17971796 let audioPaddingIds = MLXArray ( [ config. vocabSize - 1 ] ) . expandedDimensions ( axis: 0 )
17981797 let audioPaddingEmbs = embedAudio. callAsFunction ( audioPaddingIds, inputsEmbeds: nil )
@@ -1862,7 +1861,7 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
18621861 ) -> MLXArray {
18631862 let specialModalityMask : MLXArray
18641863
1865- if let inputIds = inputIds {
1864+ if let inputIds {
18661865 specialModalityMask = expandedDimensions ( inputIds .== tokenId, axis: - 1 )
18671866 } else {
18681867 // When inputIds is nil, create mask by comparing embeddings
@@ -1924,10 +1923,9 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
19241923
19251924 // In class Gemma3n
19261925 public func sanitize( weights: [ String : MLXArray ] ) -> [ String : MLXArray ] {
1927- print ( " 🔍 Gemma3n.sanitize: Starting with \( weights. count) weights " )
19281926 var sanitizedWeights = [ String: MLXArray] ( )
19291927
1930- // This function's ONLY job is to remove the "model." prefix from keys.
1928+ // Remove the "model." prefix from keys.
19311929 for (k, v) in weights {
19321930 if k. hasPrefix ( " model. " ) {
19331931 let newKey = k. split ( separator: " . " ) . dropFirst ( ) . joined ( separator: " . " )
@@ -1937,13 +1935,11 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
19371935 }
19381936 }
19391937
1940- print ( " 🔍 Gemma3n.sanitize: After prefix removal, have \( sanitizedWeights. count) weights " )
19411938 return sanitizedWeights
19421939 }
19431940
19441941 public static func fromPretrained( pathOrHfRepo: String ) throws -> Gemma3n {
19451942 let path = URL ( fileURLWithPath: pathOrHfRepo)
1946- print ( " 🔍 Gemma3n.fromPretrained: Loading from \( pathOrHfRepo) " )
19471943
19481944 let configPath = path. appendingPathComponent ( " config.json " )
19491945 let configData = try Data ( contentsOf: configPath)
@@ -1968,30 +1964,25 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
19681964 let fileWeights = try loadArrays ( url: path. appendingPathComponent ( weightFile) )
19691965 weights. merge ( fileWeights) { _, new in new }
19701966 }
1971- print ( " 🔍 Gemma3n.fromPretrained: Total weights loaded: \( weights. count) " )
19721967
1973- // Step 1: Main sanitization (remove "model." prefix)
1968+ // Main sanitization (remove "model." prefix)
19741969 var sanitizedWeights = model. sanitize ( weights: weights)
19751970
1976- // Step 2: Vision model sanitization (transpose conv weights)
1971+ // Vision model sanitization (transpose conv weights)
19771972 sanitizedWeights = Gemma3nVisionModel . sanitizeWeights ( sanitizedWeights)
19781973
1979- // Step 3: Audio model sanitization (transpose conv weights) - THIS WAS MISSING
1974+ // Audio model sanitization (transpose conv weights)
19801975 sanitizedWeights = model. audioTower. sanitize ( weights: sanitizedWeights)
19811976
1982- // Step 4: Handle tied lm_head weights
1977+ // Handle tied lm_head weights
19831978 if sanitizedWeights [ " language_model.lm_head.weight " ] == nil {
19841979 if let embedWeight = sanitizedWeights [ " language_model.model.embed_tokens.weight " ] {
1985- print ( " 🔍 Tying lm_head weight. " )
19861980 sanitizedWeights [ " language_model.lm_head.weight " ] = embedWeight
19871981 }
19881982 }
19891983
1990- // Step 5: Load the weights
1991- print ( " 🔍 Attempting to load \( sanitizedWeights. count) final weights... " )
1984+ // Load the weights
19921985 try model. update ( parameters: ModuleParameters . unflattened ( sanitizedWeights) , verify: [ . all] )
1993- print ( " ✅ Model loaded successfully! " )
1994-
19951986 return model
19961987 }
19971988}
@@ -2211,7 +2202,7 @@ private class Gemma3nCumulativeGroupNorm: Module {
22112202 let expectedInputSuffix = featureDims + [ numChannels]
22122203 assert ( Array ( x. shape. suffix ( expectedInputSuffix. count) ) == expectedInputSuffix)
22132204
2214- if let mask = mask {
2205+ if let mask {
22152206 assert ( mask. shape == Array ( x. shape. prefix ( 2 ) ) )
22162207 assert ( mask. dtype == . bool)
22172208 }
@@ -2221,7 +2212,7 @@ private class Gemma3nCumulativeGroupNorm: Module {
22212212 let xCalc = x. asType ( calcDtype)
22222213
22232214 let maskCalc : MLXArray
2224- if let mask = mask {
2215+ if let mask {
22252216 let maskSuffixShape = Array ( repeating: 1 , count: expectedInputSuffix. count)
22262217 maskCalc = mask. reshaped ( Array ( mask. shape) + maskSuffixShape) . asType ( calcDtype)
22272218 } else {
@@ -2848,7 +2839,7 @@ private func rmsNorm2d(
28482839 let vMean = mean ( v, axis: 1 , keepDims: true )
28492840 var result = x * rsqrt( vMean + eps)
28502841
2851- if let weight = weight {
2842+ if let weight {
28522843 let weightReshaped = weight. reshaped ( [ 1 , - 1 , 1 , 1 ] )
28532844 result = result. asType ( dtype) * weightReshaped
28542845 }
@@ -3061,7 +3052,7 @@ private class UniversalInvertedResidual: Module, UnaryLayer {
30613052 )
30623053
30633054 // Layer Scale
3064- if let layerScaleInitValue = layerScaleInitValue {
3055+ if let layerScaleInitValue {
30653056 self . _layerScale. wrappedValue = LayerScale2d (
30663057 dim: outChannels, initValues: layerScaleInitValue)
30673058 } else {
@@ -3420,7 +3411,7 @@ private class MobileAttention: Module, UnaryLayer {
34203411 }
34213412
34223413 // Layer scaling
3423- if let layerScaleInitValue = layerScaleInitValue {
3414+ if let layerScaleInitValue {
34243415 self . _layerScale. wrappedValue = LayerScale2d (
34253416 dim: outChannels, initValues: layerScaleInitValue)
34263417 } else {
@@ -3843,7 +3834,6 @@ private class Gemma3nVisionModel: Module {
38433834 sanitizedWeights [ k] = v
38443835 }
38453836 } else {
3846- // THIS IS THE MISSING BLOCK
38473837 // Copy all other weights (biases, norm layers, etc.)
38483838 sanitizedWeights [ k] = v
38493839 }
@@ -3955,7 +3945,7 @@ private class Gemma3nAudioModel: Module {
39553945 for (k, v) in weights {
39563946 if k. contains ( " conv.weight " ) {
39573947 // The checkArrayShape function is not robust.
3958- // The Python reference doesn't use it. It's safer to just transpose.
3948+ // The Python implementation doesn't use it. It's safer to just transpose.
39593949 // Assuming NCHW -> NHWC for Conv2d
39603950 if v. ndim == 4 {
39613951 sanitizedWeights [ k] = v. transposed ( 0 , 2 , 3 , 1 )
@@ -3970,7 +3960,6 @@ private class Gemma3nAudioModel: Module {
39703960 sanitizedWeights [ k] = v
39713961 }
39723962 } else {
3973- // THIS IS THE MISSING BLOCK
39743963 sanitizedWeights [ k] = v
39753964 }
39763965 }
@@ -4175,8 +4164,8 @@ public struct Gemma3nProcessorConfiguration: Codable, Sendable {
41754164 public let doPanAndScan : Bool ?
41764165
41774166 // Token identifiers - use default values that match Python implementation
4178- public var imageTokenId : Int { 262145 } // From Python: image_token_id = 262145
4179- public var audioTokenId : Int { 262273 } // From Python: audio_token_id = 262273
4167+ public var imageTokenId : Int { 262145 }
4168+ public var audioTokenId : Int { 262273 }
41804169
41814170 public struct ImageSize : Codable , Sendable {
41824171 public let height : Int
0 commit comments