@@ -1896,53 +1896,12 @@ public class Gemma3n: Module, VLMModel, KVCacheDimensionProvider {
18961896 sanitizedWeights [ k] = v
18971897 }
18981898 }
1899+ sanitizedWeights = visionTower. sanitize ( weights: sanitizedWeights)
1900+ // TODO: The audio and language sanitization is not done in the Python implementation. Is this needed?
1901+ sanitizedWeights = audioTower. sanitize ( weights: sanitizedWeights)
1902+ sanitizedWeights = languageModel. sanitize ( weights: sanitizedWeights)
18991903 return sanitizedWeights
19001904 }
1901-
1902- public static func fromPretrained( pathOrHfRepo: String ) throws -> Gemma3n {
1903- let path = URL ( fileURLWithPath: pathOrHfRepo)
1904-
1905- let configPath = path. appendingPathComponent ( " config.json " )
1906- let configData = try Data ( contentsOf: configPath)
1907-
1908- let decoder = JSONDecoder ( )
1909- decoder. keyDecodingStrategy = . convertFromSnakeCase
1910- let modelConfig = try decoder. decode ( ModelConfig . self, from: configData)
1911-
1912- let model = Gemma3n ( modelConfig)
1913-
1914- // Load all weight files into a single dictionary
1915- let weightFiles = try FileManager . default. contentsOfDirectory ( atPath: path. path)
1916- . filter { $0. hasSuffix ( " .safetensors " ) }
1917- guard !weightFiles. isEmpty else {
1918- throw NSError (
1919- domain: " ModelLoading " , code: 1 ,
1920- userInfo: [ NSLocalizedDescriptionKey: " No safetensors found in \( path. path) " ] )
1921- }
1922-
1923- var weights = [ String: MLXArray] ( )
1924- for weightFile in weightFiles {
1925- let fileWeights = try loadArrays ( url: path. appendingPathComponent ( weightFile) )
1926- weights. merge ( fileWeights) { _, new in new }
1927- }
1928-
1929- var sanitizedWeights = model. sanitize ( weights: weights)
1930- sanitizedWeights = model. visionTower. sanitize ( weights: sanitizedWeights)
1931- // The audio and language sanitization is not done in the Python implementation
1932- // sanitizedWeights = model.audioTower.sanitize(weights: sanitizedWeights)
1933- // sanitizedWeights = model.languageModel.sanitize(weights: sanitizedWeights)
1934-
1935- // Handle tied lm_head weights
1936- if sanitizedWeights [ " language_model.lm_head.weight " ] == nil {
1937- if let embedWeight = sanitizedWeights [ " language_model.model.embed_tokens.weight " ] {
1938- sanitizedWeights [ " language_model.lm_head.weight " ] = embedWeight
1939- }
1940- }
1941-
1942- // Load the weights
1943- try model. update ( parameters: ModuleParameters . unflattened ( sanitizedWeights) , verify: [ . all] )
1944- return model
1945- }
19461905}
19471906
19481907// MARK: - Audio Model Components
0 commit comments