Skip to content

Commit 5573aac

Browse files
committed
Add session reset
1 parent 1676a9e commit 5573aac

File tree

1 file changed

+57
-41
lines changed

1 file changed

+57
-41
lines changed

src/superannotate/lib/infrastructure/stream_data_handler.py

Lines changed: 57 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,10 @@
22
import json
33
import logging
44
import os
5+
import threading
6+
import time
57
import typing
8+
from functools import lru_cache
69
from typing import Callable
710

811
import aiohttp
@@ -24,6 +27,7 @@
2427
class StreamedAnnotations:
2528
DELIMITER = "\\n;)\\n"
2629
DELIMITER_LEN = len(DELIMITER)
30+
VERIFY_SSL = False
2731

2832
def __init__(
2933
self,
@@ -50,15 +54,16 @@ def get_json(self, data: bytes):
5054
async def fetch(
5155
self,
5256
method: str,
53-
session: AIOHttpSession,
5457
url: str,
5558
data: dict = None,
5659
params: dict = None,
5760
):
5861
kwargs = {"params": params, "json": data}
5962
if data:
6063
kwargs["json"].update(data)
61-
response = await session.request(method, url, **kwargs, timeout=TIMEOUT) # noqa
64+
response = await self.get_session().request(
65+
method, url, **kwargs, timeout=TIMEOUT
66+
) # noqa
6267
if not response.ok:
6368
logger.error(response.text)
6469
buffer = ""
@@ -103,33 +108,47 @@ async def fetch(
103108
)
104109
break
105110

111+
@lru_cache(maxsize=32)
112+
def _get_session(self, thread_id, ttl=None): # noqa
113+
del ttl
114+
del thread_id
115+
return AIOHttpSession(
116+
headers=self._headers,
117+
timeout=TIMEOUT,
118+
connector=aiohttp.TCPConnector(
119+
ssl=self.VERIFY_SSL, keepalive_timeout=2**32
120+
),
121+
raise_for_status=True,
122+
)
123+
124+
def get_session(self):
125+
return self._get_session(
126+
thread_id=threading.get_ident(), ttl=round(time.time() / 360)
127+
)
128+
129+
def rest_session(self):
130+
self._get_session.cache_clear()
131+
106132
async def list_annotations(
107133
self,
108134
method: str,
109135
url: str,
110136
data: typing.List[int] = None,
111137
params: dict = None,
112-
verify_ssl=False,
113138
):
114139
params = copy.copy(params)
115140
params["limit"] = len(data)
116141
annotations = []
117-
async with AIOHttpSession(
118-
headers=self._headers,
119-
timeout=TIMEOUT,
120-
connector=aiohttp.TCPConnector(ssl=verify_ssl, keepalive_timeout=2**32),
121-
raise_for_status=True,
122-
) as session:
123-
async for annotation in self.fetch(
124-
method,
125-
session,
126-
url,
127-
self._process_data(data),
128-
params=copy.copy(params),
129-
):
130-
annotations.append(
131-
self._callback(annotation) if self._callback else annotation
132-
)
142+
143+
async for annotation in self.fetch(
144+
method,
145+
url,
146+
self._process_data(data),
147+
params=copy.copy(params),
148+
):
149+
annotations.append(
150+
self._callback(annotation) if self._callback else annotation
151+
)
133152

134153
return annotations
135154

@@ -143,28 +162,22 @@ async def download_annotations(
143162
):
144163
params = copy.copy(params)
145164
params["limit"] = len(data)
146-
async with AIOHttpSession(
147-
headers=self._headers,
148-
timeout=TIMEOUT,
149-
connector=aiohttp.TCPConnector(ssl=False, keepalive_timeout=2**32),
150-
raise_for_status=True,
151-
) as session:
152-
async for annotation in self.fetch(
153-
method,
154-
session,
155-
url,
156-
self._process_data(data),
157-
params=params,
158-
):
159-
self._annotations.append(
160-
self._callback(annotation) if self._callback else annotation
161-
)
162-
self._store_annotation(
163-
download_path,
164-
annotation,
165-
self._callback,
166-
)
167-
self._items_downloaded += 1
165+
166+
async for annotation in self.fetch(
167+
method,
168+
url,
169+
self._process_data(data),
170+
params=params,
171+
):
172+
self._annotations.append(
173+
self._callback(annotation) if self._callback else annotation
174+
)
175+
self._store_annotation(
176+
download_path,
177+
annotation,
178+
self._callback,
179+
)
180+
self._items_downloaded += 1
168181

169182
@staticmethod
170183
def _store_annotation(path, annotation: dict, callback: Callable = None):
@@ -177,3 +190,6 @@ def _process_data(self, data):
177190
if data and self._map_function:
178191
return self._map_function(data)
179192
return data
193+
194+
def __del__(self):
195+
self._get_session.cache_clear()

0 commit comments

Comments
 (0)