Skip to content

Commit 49197ca

Browse files
authored
Merge pull request #122 from jayzhenghan/master
add resumable download file
2 parents abcfe6d + 76bc9aa commit 49197ca

File tree

4 files changed

+311
-1
lines changed

4 files changed

+311
-1
lines changed

qcloud_cos/cos_client.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .cos_exception import CosServiceError
2626
from .version import __version__
2727
from .select_event_stream import EventStream
28+
from .resumable_downloader import ResumableDownLoader
2829
logger = logging.getLogger(__name__)
2930

3031

@@ -185,7 +186,7 @@ def __init__(self, conf, retry=1, session=None):
185186
else:
186187
self._session = session
187188

188-
def get_conf():
189+
def get_conf(self):
189190
"""获取配置"""
190191
return self._conf
191192

@@ -2944,6 +2945,30 @@ def _check_all_upload_parts(self, bucket, key, uploadid, local_path, parts_num,
29442945
already_exist_parts[part_num] = part['ETag']
29452946
return True
29462947

2948+
def download_file(self, Bucket, Key, DestFilePath, PartSize=20, MAZThread=5, EnableCRC=False, **Kwargs):
2949+
"""小于等于20MB的文件简单下载,大于20MB的文件使用续传下载
2950+
2951+
:param Bucket(string): 存储桶名称.
2952+
:param key(string): COS文件的路径名.
2953+
:param DestFilePath(string): 下载文件的目的路径.
2954+
:param PartSize(int): 分块下载的大小设置,单位为MB.
2955+
:param MAXThread(int): 并发下载的最大线程数.
2956+
:param EnableCRC(bool): 校验下载文件与源文件是否一致
2957+
:param kwargs(dict): 设置请求headers.
2958+
"""
2959+
logger.debug("Start to download file, bucket: {0}, key: {1}, dest_filename: {2}, part_size: {3}MB, "
2960+
"max_thread: {4}".format(Bucket, Key, DestFilePath, PartSize, MAZThread))
2961+
2962+
object_info = self.head_object(Bucket, Key)
2963+
file_size = object_info['Content-Length']
2964+
if file_size <= 1024*1024*20:
2965+
response = self.get_object(Bucket, Key, **Kwargs)
2966+
response['Body'].get_stream_to_file(DestFilePath)
2967+
return
2968+
2969+
downloader = ResumableDownLoader(self, Bucket, Key, DestFilePath, object_info, PartSize, MAZThread, EnableCRC, **Kwargs)
2970+
downloader.start()
2971+
29472972
def upload_file(self, Bucket, Key, LocalFilePath, PartSize=1, MAXThread=5, EnableMD5=False, **kwargs):
29482973
"""小于等于20MB的文件简单上传,大于20MB的文件使用分块上传
29492974

qcloud_cos/resumable_downloader.py

Lines changed: 206 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,206 @@
1+
# -*- coding: utf-8 -*-
2+
3+
import json
4+
import os
5+
import sys
6+
import threading
7+
import logging
8+
import uuid
9+
import hashlib
10+
import crcmod
11+
from .cos_comm import *
12+
from .streambody import StreamBody
13+
from .cos_threadpool import SimpleThreadPool
14+
logger = logging.getLogger(__name__)
15+
16+
class ResumableDownLoader(object):
17+
def __init__(self, cos_client, bucket, key, dest_filename, object_info, part_size=20, max_thread=5, enable_crc=False, **kwargs):
18+
self.__cos_client = cos_client
19+
self.__bucket = bucket
20+
self.__key = key
21+
self.__dest_file_path = os.path.abspath(dest_filename)
22+
self.__object_info = object_info
23+
self.__max_thread = max_thread
24+
self.__enable_crc = enable_crc
25+
self.__headers = kwargs
26+
27+
self.__max_part_count = 100 # 取决于服务端是否对并发有限制
28+
self.__min_part_size = 1024 * 1024 # 1M
29+
self.__part_size = self.__determine_part_size_internal(int(object_info['Content-Length']), part_size)
30+
self.__finished_parts = []
31+
self.__lock = threading.Lock()
32+
self.__record = None #记录当前的上下文
33+
self.__dump_record_dir = os.path.join(os.path.expanduser('~'), '.cos_download_tmp_file')
34+
35+
record_filename = self.__get_record_filename(bucket, key, self.__dest_file_path)
36+
self.__record_filepath = os.path.join(self.__dump_record_dir, record_filename)
37+
self.__tmp_file = None
38+
39+
if not os.path.exists(self.__dump_record_dir):
40+
os.makedirs(self.__dump_record_dir)
41+
42+
logger.debug('resumale downloader init finish, bucket: {0}, key: {1}'.format(bucket, key))
43+
44+
def start(self):
45+
logger.debug('start resumable downloade, bucket: {0}, key: {1}'.format(self.__bucket, self.__key))
46+
self.__load_record() # 从record文件中恢复读取上下文
47+
48+
assert self.__tmp_file
49+
open(self.__tmp_file, 'a').close()
50+
51+
parts_need_to_download = self.__get_parts_need_to_download()
52+
logger.debug('parts_need_to_download: {0}'.format(parts_need_to_download))
53+
pool = SimpleThreadPool(self.__max_thread)
54+
for part in parts_need_to_download:
55+
part_range = "bytes=" + str(part.start) + "-" + str(part.start + part.length - 1)
56+
headers = dict.copy(self.__headers)
57+
headers["Range"] = part_range
58+
pool.add_task(self.__download_part, part, headers)
59+
60+
pool.wait_completion()
61+
result = pool.get_result()
62+
if not result['success_all']:
63+
raise CosClientError('some download_part fail after max_retry, please downloade_file again')
64+
65+
if os.path.exists(self.__dest_file_path):
66+
os.remove(self.__dest_file_path)
67+
os.rename(self.__tmp_file, self.__dest_file_path)
68+
69+
if self.__enable_crc:
70+
self.__check_crc()
71+
72+
self.__del_record()
73+
logger.debug('download success, bucket: {0}, key: {1}'.format(self.__bucket, self.__key))
74+
75+
def __get_record_filename(self, bucket, key, dest_file_path):
76+
dest_file_path_md5 = hashlib.md5(dest_file_path).hexdigest()
77+
key_md5 = hashlib.md5(key).hexdigest()
78+
return '{0}_{1}.{2}'.format(bucket, key_md5, dest_file_path_md5)
79+
80+
def __determine_part_size_internal(self, file_size, part_size):
81+
real_part_size = part_size * 1024 * 1024 # MB
82+
if real_part_size < self.__min_part_size:
83+
real_part_size = self.__min_part_size
84+
85+
while real_part_size * self.__max_part_count < file_size:
86+
real_part_size = real_part_size * 2
87+
logger.debug('finish to determine part size, file_size: {0}, part_size: {1}'.format(file_size, real_part_size))
88+
return real_part_size
89+
90+
def __splite_to_parts(self):
91+
parts = []
92+
file_size = int(self.__object_info['Content-Length'])
93+
num_parts = (file_size + self.__part_size - 1) / self.__part_size
94+
for i in range(num_parts):
95+
start = i * self.__part_size
96+
if i == num_parts - 1:
97+
length = file_size - start
98+
else:
99+
length = self.__part_size
100+
101+
parts.append(PartInfo(i + 1, start, length))
102+
return parts
103+
104+
def __get_parts_need_to_download(self):
105+
all_set = set(self.__splite_to_parts())
106+
logger.debug('all_set: {0}'.format(len(all_set)))
107+
finished_set = set(self.__finished_parts)
108+
logger.debug('finished_set: {0}'.format(len(finished_set)))
109+
return list(all_set - finished_set)
110+
111+
def __download_part(self, part, headers):
112+
with open(self.__tmp_file, 'rb+') as f:
113+
f.seek(part.start, 0)
114+
range = None
115+
traffic_limit = None
116+
if 'Range' in headers:
117+
range = headers['Range']
118+
119+
if 'TrafficLimit' in headers:
120+
traffic_limit = headers['TrafficLimit']
121+
logger.debug("part_id: {0}, part_range: {1}, traffic_limit:{2}".format(part.part_id, range, traffic_limit))
122+
result = self.__cos_client.get_object(Bucket=self.__bucket, Key=self.__key, **headers)
123+
result["Body"].pget_stream_to_file(f, part.start, part.length)
124+
125+
self.__finish_part(part)
126+
127+
def __finish_part(self, part):
128+
logger.debug('download part finished,bucket: {0}, key: {1}, part_id: {2}'.
129+
format(self.__bucket, self.__key, part.part_id))
130+
with self.__lock:
131+
self.__finished_parts.append(part)
132+
self.__record['parts'].append({'part_id': part.part_id,
133+
'start': part.start,
134+
'length': part.length})
135+
self.__dump_record(self.__record)
136+
137+
def __dump_record(self, record):
138+
with open(self.__record_filepath, 'w') as f:
139+
json.dump(record, f)
140+
logger.debug('dump record to {0}, bucket: {1}, key: {2}'.
141+
format(self.__record_filepath, self.__bucket, self.__key))
142+
143+
def __load_record(self):
144+
record = None
145+
146+
if os.path.exists(self.__record_filepath):
147+
with open(self.__record_filepath, 'r') as f:
148+
record = json.load(f)
149+
150+
ret = self.__check_record(record)
151+
# record记录是否跟head object的一致,不一致则删除
152+
if ret == False:
153+
self.__del_record()
154+
record = None
155+
else:
156+
self.__part_size = record['part_size']
157+
self.__tmp_file = record['tmp_filename']
158+
if not os.path.exists(self.__tmp_file):
159+
record = None
160+
self.__tmp_file = None
161+
self.__del_record()
162+
else:
163+
self.__finished_parts = list(PartInfo(p['part_id'], p['start'], p['length']) for p in record['parts'])
164+
logger.debug('load record: finished parts nums: {0}'.format(len(self.__finished_parts)))
165+
self.__record = record
166+
167+
if not record:
168+
self.__tmp_file = "{file_name}_{uuid}".format(file_name=self.__dest_file_path, uuid=uuid.uuid4().hex)
169+
record = {'bucket': self.__bucket, 'key': self.__key, 'tmp_filename':self.__tmp_file,
170+
'mtime':self.__object_info['Last-Modified'], 'etag':self.__object_info['ETag'],
171+
'file_size':self.__object_info['Content-Length'], 'part_size': self.__part_size, 'parts':[]}
172+
self.__record = record
173+
self.__dump_record(record)
174+
175+
def __check_record(self, record):
176+
return record['etag'] == self.__object_info['ETag'] and\
177+
record['mtime'] == self.__object_info['Last-Modified'] and\
178+
record['file_size'] == self.__object_info['Content-Length']
179+
180+
def __del_record(self):
181+
os.remove(self.__record_filepath)
182+
logger.debug('ResumableDownLoader delete record_file, path: {0}'.format(self.__record_filepath))
183+
184+
def __check_crc(self):
185+
logger.debug('start to check crc')
186+
c64 = crcmod.mkCrcFun(0x142F0E1EBA9EA3693L, initCrc=0L, xorOut=0xffffffffffffffffL, rev=True)
187+
with open(self.__dest_file_path,'rb') as f:
188+
local_crc64 = str(c64(f.read()))
189+
object_crc64 = self.__object_info['x-cos-hash-crc64ecma']
190+
if local_crc64 is not None and object_crc64 is not None and local_crc64 != object_crc64:
191+
raise CosClientError('crc of client: {0} is mismatch with cos: {1}'.format(local_crc64, object_crc64))
192+
193+
class PartInfo(object):
194+
def __init__(self, part_id, start, length):
195+
self.part_id = part_id
196+
self.start = start
197+
self.length = length
198+
199+
def __eq__(self, other):
200+
return self.__key() == other.__key()
201+
202+
def __hash__(self):
203+
return hash(self.__key())
204+
205+
def __key(self):
206+
return self.part_id, self.start, self.length

qcloud_cos/streambody.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,3 +53,34 @@ def get_stream_to_file(self, file_name, auto_decompress=False):
5353
if os.path.exists(file_name):
5454
os.remove(file_name)
5555
os.rename(tmp_file_name, file_name)
56+
57+
def pget_stream_to_file(self, fdst, offset, expected_len, auto_decompress=False):
58+
"""保存流到本地文件的offset偏移"""
59+
use_chunked = False
60+
use_encoding = False
61+
if 'Transfer-Encoding' in self._rt.headers and self._rt.headers['Transfer-Encoding'] == "chunked":
62+
use_chunked = True
63+
elif 'Content-Length' not in self._rt.headers:
64+
raise IOError("download failed without Content-Length header or Transfer-Encoding header")
65+
66+
if 'Content-Encoding' in self._rt.headers:
67+
use_encoding = True
68+
read_len = 0
69+
fdst.seek(offset, 0)
70+
71+
if use_encoding and not auto_decompress:
72+
chunk = self._rt.raw.read(1024)
73+
while chunk:
74+
read_len += len(chunk)
75+
fdst.write(chunk)
76+
chunk = self._rt.raw.read(1024)
77+
else:
78+
for chunk in self._rt.iter_content(chunk_size=1024):
79+
if chunk:
80+
read_len += len(chunk)
81+
fdst.write(chunk)
82+
83+
84+
if not use_chunked and not (use_encoding and auto_decompress) and read_len != expected_len:
85+
raise IOError("download failed with incomplete file")
86+

ut/test.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1164,6 +1164,53 @@ def _test_get_object_sensitive_content_recognition():
11641164
print(response)
11651165
assert response
11661166

1167+
def test_download_file():
1168+
"""测试断点续传下载接口"""
1169+
#测试普通下载
1170+
client.download_file(copy_test_bucket, test_object, 'test_download_file.local')
1171+
if os.path.exists('test_download_file.local'):
1172+
os.remove('test_download_file.local')
1173+
1174+
# 测试限速下载
1175+
client.download_file(copy_test_bucket, test_object, 'test_download_traffic_limit.local', TrafficLimit='819200')
1176+
if os.path.exists('test_download_traffic_limit.local'):
1177+
os.remove('test_download_traffic_limit.local')
1178+
1179+
# 测试crc64校验开关
1180+
client.download_file(copy_test_bucket, test_object, 'test_download_crc.local', EnableCRC=True)
1181+
if os.path.exists('test_download_crc.local'):
1182+
os.remove('test_download_crc.local')
1183+
1184+
# 测试源文件的md5与下载下来后的文件md5
1185+
file_size = 25 # MB
1186+
file_id = str(random.randint(0, 1000)) + str(random.randint(0, 1000))
1187+
file_name = "tmp" + file_id + "_" + str(file_size) + "MB"
1188+
gen_file(file_name, file_size)
1189+
1190+
source_file_md5 = None
1191+
dest_file_md5 = None
1192+
with open(file_name, 'rb') as f:
1193+
source_file_md5 = get_raw_md5(f.read())
1194+
1195+
client.put_object_from_local_file(
1196+
Bucket=copy_test_bucket,
1197+
LocalFilePath=file_name,
1198+
Key=file_name
1199+
)
1200+
1201+
client.download_file(copy_test_bucket, file_name, 'test_download_md5.local')
1202+
if os.path.exists('test_download_md5.local'):
1203+
with open('test_download_md5.local', 'rb') as f:
1204+
dest_file_md5 = get_raw_md5(f.read())
1205+
assert source_file_md5 and dest_file_md5 and source_file_md5 == dest_file_md5
1206+
1207+
# 释放资源
1208+
client.delete_object(
1209+
Bucket=copy_test_bucket,
1210+
Key=file_name
1211+
)
1212+
if os.path.exists(file_name):
1213+
os.remove(file_name)
11671214

11681215
if __name__ == "__main__":
11691216
setUp()
@@ -1190,6 +1237,7 @@ def _test_get_object_sensitive_content_recognition():
11901237
test_put_get_delete_bucket_domain()
11911238
test_select_object()
11921239
_test_get_object_sensitive_content_recognition()
1240+
test_download_file()
11931241
"""
11941242

11951243
tearDown()

0 commit comments

Comments
 (0)