diff --git a/mlx/io/safetensors.cpp b/mlx/io/safetensors.cpp index 0243a2eb40..a8d8627bac 100644 --- a/mlx/io/safetensors.cpp +++ b/mlx/io/safetensors.cpp @@ -28,6 +28,7 @@ using json = nlohmann::json; #define ST_U32 "U32" #define ST_U64 "U64" #define ST_F8_E4M3 "F8_E4M3" +#define ST_F8_E8M0 "F8_E8M0" // Note: Complex numbers aren't in the spec yet so this could change - // https://github.com/huggingface/safetensors/issues/389 @@ -97,6 +98,8 @@ Dtype dtype_from_safetensor_str(std::string_view str) { return complex64; } else if (str == ST_F8_E4M3) { return uint8; + } else if (str == ST_F8_E8M0) { + return uint8; } else { std::ostringstream msg; msg << "[safetensor] unsupported dtype" << str;