2626import warnings
2727import tempfile
2828import operator
29- import struct
3029import mmap
3130import json
32- import math
3331
3432from contextlib import closing , contextmanager
3533from enum import Enum
4745
4846import safetensors
4947import safetensors .numpy
50- from safetensors import deserialize
5148
5249from mindnlp .core import nn
5350from mindnlp .core .nn import Parameter
6562
6663MAGIC_NUMBER = 0x1950A86A20F9469CFC6C
6764PROTOCOL_VERSION = 1001
65+ MAX_HEADER_SIZE = 100 * 1000 * 1000
6866
6967
7068@contextmanager
@@ -1433,50 +1431,148 @@ def _save(obj, zip_file, pickle_module, pickle_protocol):
14331431}
14341432
14351433
1436- def legacy_safe_load_file (filename ):
1437- """
1438- This function safely loads a file containing state dictionary data and converts it into a dictionary of MindSpore Parameters.
1434+ _DTYPE_SIZE = {
1435+ "BOOL" : 1 ,
1436+ "U8" : 1 ,
1437+ "I8" : 1 ,
1438+ "F8_E5M2" : 1 ,
1439+ "F8_E4M3" : 1 ,
1440+ "I16" : 2 ,
1441+ "U16" : 2 ,
1442+ "I32" : 4 ,
1443+ "U32" : 4 ,
1444+ "I64" : 8 ,
1445+ "U64" : 8 ,
1446+ "F16" : 2 ,
1447+ "BF16" : 2 ,
1448+ "F32" : 4 ,
1449+ "F64" : 8 ,
1450+ }
14391451
1440- Args:
1441- filename (str): The path to the file containing the state dictionary data to be loaded.
1452+ class PySafeSlice :
1453+ def __init__ (self , info , bufferfile , base_ptr , buffermmap ):
1454+ self .info = info
1455+ self .bufferfile = bufferfile
1456+ self .buffermmap = buffermmap
1457+ self .base_ptr = base_ptr
1458+
1459+ self .start = [0 for _ in self .shape ]
1460+ self .stop = list (self .shape )
1461+ self .step = [1 for _ in self .shape ]
1462+
1463+ @property
1464+ def ndim (self ):
1465+ return len (self .shape )
1466+
1467+ def get (self , * args , ** kwargs ):
1468+ nbytes = int (np .prod (self .shape )) * np .dtype (self .dtype ).itemsize
1469+ offset = self .start_offset
1470+ tensor = np .frombuffer (self .buffermmap , dtype = self .dtype , offset = offset ,
1471+ count = nbytes // np .dtype (self .dtype ).itemsize )
1472+ tensor = tensor .reshape (self .shape )
1473+ if not SUPPORT_BF16 and self .info ["dtype" ] == 'BF16' :
1474+ tensor = tensor .astype (np .float16 )
1475+ tensor = Tensor .from_numpy (tensor )
1476+ return tensor
1477+
1478+ @property
1479+ def start_offset (self ):
1480+ return self .base_ptr + self .info ["data_offsets" ][0 ]
1481+
1482+ def get_shape (self ):
1483+ return self .shape
1484+
1485+ @property
1486+ def shape (self ):
1487+ return self .info ["shape" ]
1488+
1489+ @property
1490+ def dtype (self ):
1491+ return _NP_TYPES [self .info ["dtype" ]]
1492+
1493+ @property
1494+ def nelements (self ):
1495+ return np .prod (self .info ["shape" ])
1496+
1497+ @property
1498+ def bits (self ):
1499+ return _DTYPE_SIZE [self .info ["dtype" ]]
1500+
1501+ @property
1502+ def nbytes (self ):
1503+ return self .nelements * self .bits
1504+
1505+ def getSize (fileobject ):
1506+ fileobject .seek (0 , 2 ) # move the cursor to the end of the file
1507+ size = fileobject .tell ()
1508+ fileobject .seek (0 ) # move the cursor to the start of the file
1509+ return size
1510+
1511+
1512+ def metadata_validate (metadata ):
1513+ start = 0
1514+ for key , info in metadata .items ():
1515+ s , e = info ["data_offsets" ]
1516+ if s != start or e < s :
1517+ raise ValueError (f"SafeTensorError::InvalidOffset({ key } )" )
1518+ start = e
1519+ nelements = np .prod (info ["shape" ])
1520+ nbytes = nelements * _DTYPE_SIZE [info ["dtype" ]]
1521+ if (e - s ) != nbytes :
1522+ raise ValueError ("SafeTensorError::TensorInvalidInfo" )
1523+ return start
1524+
1525+ def read_metadata (buffer ):
1526+ buffer_len = getSize (buffer )
1527+ if buffer_len < 8 :
1528+ raise ValueError ("SafeTensorError::HeaderTooSmall" )
1529+
1530+ n = np .frombuffer (buffer .read (8 ), dtype = np .uint64 ).item ()
1531+
1532+ if n > MAX_HEADER_SIZE :
1533+ raise ValueError ("SafeTensorError::HeaderTooLarge" )
1534+
1535+ stop = n + 8
1536+ if stop > buffer_len :
1537+ raise ValueError ("SafeTensorError::InvalidHeaderLength" )
1538+
1539+ tensors = json .loads (buffer .read (n ), object_pairs_hook = OrderedDict )
1540+
1541+ metadata = tensors .pop ("__metadata__" , None )
1542+ buffer_end = metadata_validate (tensors )
1543+
1544+ if buffer_end + 8 + n != buffer_len :
1545+ raise ValueError ("SafeTensorError::MetadataIncompleteBuffer" )
1546+
1547+ return stop , tensors , metadata
1548+
1549+
1550+ class fast_safe_open :
1551+ def __init__ (self , filename , framework = None , device = "cpu" ):
1552+ self .filename = filename
1553+ self .framework = framework
1554+ self .file = open (self .filename , "rb" )
1555+ self .file_mmap = mmap .mmap (self .file .fileno (), 0 , access = mmap .ACCESS_COPY )
1556+ self .base , self .tensors_decs , self .__metadata__ = read_metadata (self .file )
1557+ self .tensors = OrderedDict ()
1558+ for key , info in self .tensors_decs .items ():
1559+ self .tensors [key ] = PySafeSlice (info , self .file , self .base , self .file_mmap )
1560+ self .tensors [key ].key = key
14421561
1443- Returns :
1444- dict: A dictionary where keys are parameter names and values are MindSpore Parameters.
1562+ def __enter__ ( self ) :
1563+ return self
14451564
1446- Raises:
1447- FileNotFoundError: If the specified file 'filename' does not exist.
1448- ValueError: If the data in the file is not in the correct format to create MindSpore Parameters.
1449- """
1450- with open (filename , "rb" ) as f :
1451- data = f .read ()
1565+ def __exit__ (self , * args ):
1566+ self .file .close ()
14521567
1453- safeview = deserialize (data )
1568+ def metadata (self ):
1569+ return self .__metadata__
14541570
1455- result = {}
1456- try :
1457- for k , v in safeview :
1458- dtype = _MS_TYPES [v ["dtype" ]]
1459- if (not SUPPORT_BF16 and dtype != mindspore .bfloat16 ) or SUPPORT_BF16 :
1460- arr = Tensor .convert_bytes_to_tensor (
1461- bytes (v ["data" ]), tuple (v ["shape" ]), dtype
1462- )
1463- result [k ] = Tensor (arr )
1464- else :
1465- raise TypeError (
1466- "Do not support bfloat16 on current device, use numpy as convert buffer to boost load."
1467- )
1468- return result
1571+ def keys (self ):
1572+ return list (self .tensors .keys ())
14691573
1470- except Exception as e :
1471- for k , v in safeview :
1472- dtype = _NP_TYPES [v ["dtype" ]]
1473- arr = np .frombuffer (v ["data" ], dtype = dtype ).reshape (v ["shape" ])
1474-
1475- if (not SUPPORT_BF16 and dtype != bfloat16 ) or SUPPORT_BF16 :
1476- result [k ] = Tensor .from_numpy (arr )
1477- else :
1478- result [k ] = Tensor .from_numpy (arr .astype (np .float16 ))
1479- return result
1574+ def get_tensor (self , name ):
1575+ return self .tensors [name ].get ()
14801576
14811577
14821578def safe_load_file (filename ):
@@ -1494,39 +1590,10 @@ def safe_load_file(filename):
14941590 ValueError: If the data in the file is not in the correct format to create MindSpore Parameters.
14951591 """
14961592
1497- def convert (info : dict [str , Any ]):
1498- numpy_dtype = _NP_TYPES [info ["dtype" ]]
1499- ms_dtype = _MS_TYPES [info ["dtype" ]]
1500- shape : list [int ] = info ["shape" ]
1501- begin , end = info ["data_offsets" ]
1502- assert 0 <= begin <= end <= len (byte_buf )
1503- assert end - begin == math .prod (shape ) * np .dtype (numpy_dtype ).itemsize
1504- buf = byte_buf [begin :end ]
1505-
1506- array = np .frombuffer (buf , dtype = numpy_dtype ).reshape (shape )
1507-
1508- if array .dtype == bfloat16 and not SUPPORT_BF16 :
1509- logger .warning_once (
1510- "MindSpore do not support bfloat16 dtype, we will automaticlly convert to float16"
1511- )
1512- array = array .astype (np .float16 )
1513- array = array .astype (array .dtype )
1514- out = Tensor .from_numpy (array )
1515- return out
1516-
1517- with open (filename , "rb" ) as fp :
1518- (header_size ,) = struct .unpack ("<Q" , fp .read (8 ))
1519- header : dict [str , dict [str , Any ]] = json .loads (fp .read (header_size ))
1520- # Use mmap for the actual data to avoid race conditions with the file offset.
1521- mapped = memoryview (mmap .mmap (fp .fileno (), 0 , access = mmap .ACCESS_READ ))
1522- byte_buf = mapped [8 + header_size :]
1523-
1524- result = {
1525- name : convert (info )
1526- for (name , info ) in header .items ()
1527- if name != "__metadata__"
1528- }
1529-
1593+ result = {}
1594+ with fast_safe_open (filename , framework = "np" ) as f :
1595+ for k in f .keys ():
1596+ result [k ] = f .get_tensor (k )
15301597 return result
15311598
15321599
0 commit comments