@@ -139,8 +139,7 @@ void* allocate_device_mem(const size_t num_bytes, L0Device& device) {
139139 return mem;
140140}
141141
142- L0DataFetcher::L0DataFetcher (const L0Driver& driver, ze_device_handle_t device)
143- : device_(device), driver_(driver) {
142+ L0Device::L0DataFetcher::L0DataFetcher (L0Device& device) : my_device_(device) {
144143 ze_command_queue_desc_t command_queue_fetch_desc = {
145144 ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC,
146145 nullptr ,
@@ -149,77 +148,87 @@ L0DataFetcher::L0DataFetcher(const L0Driver& driver, ze_device_handle_t device)
149148 0 ,
150149 ZE_COMMAND_QUEUE_MODE_ASYNCHRONOUS,
151150 ZE_COMMAND_QUEUE_PRIORITY_NORMAL};
152- L0_SAFE_CALL (zeCommandQueueCreate (
153- driver.ctx (), device_, &command_queue_fetch_desc, &queue_handle_));
154- current_cl_bytes = {{}, 0 };
155- L0_SAFE_CALL (
156- zeCommandListCreate (driver.ctx (), device_, &cl_desc, ¤t_cl_bytes.first ));
157- }
158-
159- L0DataFetcher::~L0DataFetcher () {
151+ L0_SAFE_CALL (zeCommandQueueCreate (my_device_.driver_ .ctx (),
152+ my_device_.device_ ,
153+ &command_queue_fetch_desc,
154+ &queue_handle_));
155+ cur_cl_bytes_ = {{}, 0 };
156+ L0_SAFE_CALL (zeCommandListCreate (my_device_.driver_ .ctx (),
157+ my_device_.device_ ,
158+ &cl_desc_,
159+ &cur_cl_bytes_.cl_handle_ ));
160+ }
161+
162+ L0Device::L0DataFetcher::~L0DataFetcher () {
160163 zeCommandQueueDestroy (queue_handle_);
161- zeCommandListDestroy (current_cl_bytes. first );
162- for (auto & dead_handle : graveyard ) {
164+ zeCommandListDestroy (cur_cl_bytes_. cl_handle_ );
165+ for (auto & dead_handle : graveyard_ ) {
163166 zeCommandListDestroy (dead_handle);
164167 }
165- for (auto & cl_handle : recycled ) {
168+ for (auto & cl_handle : recycled_ ) {
166169 zeCommandListDestroy (cl_handle);
167170 }
168171}
169172
170- void L0DataFetcher::recycleGraveyard () {
171- while (recycled .size () < GRAVEYARD_LIMIT && graveyard .size ()) {
172- recycled .push_back (graveyard .front ());
173- graveyard .pop_front ();
174- L0_SAFE_CALL (zeCommandListReset (recycled .back ()));
173+ void L0Device:: L0DataFetcher::recycleGraveyard () {
174+ while (recycled_ .size () < GRAVEYARD_LIMIT && graveyard_ .size ()) {
175+ recycled_ .push_back (graveyard_ .front ());
176+ graveyard_ .pop_front ();
177+ L0_SAFE_CALL (zeCommandListReset (recycled_ .back ()));
175178 }
176- for (auto & dead_handle : graveyard) {
177- L0_SAFE_CALL (zeCommandListDestroy (recycled.back ()));
179+ for (auto & dead_handle : graveyard_) {
180+ L0_SAFE_CALL (zeCommandListDestroy (dead_handle));
181+ }
182+ graveyard_.clear ();
183+ }
184+
185+ void L0Device::L0DataFetcher::setCLRecycledOrNew () {
186+ cur_cl_bytes_ = {{}, 0 };
187+ if (recycled_.size ()) {
188+ cur_cl_bytes_.cl_handle_ = recycled_.front ();
189+ recycled_.pop_front ();
190+ } else {
191+ L0_SAFE_CALL (zeCommandListCreate (my_device_.driver_ .ctx (),
192+ my_device_.device_ ,
193+ &cl_desc_,
194+ &cur_cl_bytes_.cl_handle_ ));
178195 }
179- graveyard.clear ();
180196}
181197
182- void L0DataFetcher::appendCopyCommand (void * dst,
183- const void * src,
184- const size_t num_bytes) {
185- std::unique_lock<std::mutex> cl_lock (current_cl_lock );
198+ void L0Device:: L0DataFetcher::appendCopyCommand (void * dst,
199+ const void * src,
200+ const size_t num_bytes) {
201+ std::unique_lock<std::mutex> cl_lock (cur_cl_lock_ );
186202 L0_SAFE_CALL (zeCommandListAppendMemoryCopy (
187- current_cl_bytes.first , dst, src, num_bytes, nullptr , 0 , nullptr ));
188- current_cl_bytes.second += num_bytes;
189- if (current_cl_bytes.second >= 128 * 1024 * 1024 ) {
190- ze_command_list_handle_t cl_h_copy = current_cl_bytes.first ;
191- graveyard.push_back (current_cl_bytes.first );
192- current_cl_bytes = {{}, 0 };
193- if (recycled.size ()) {
194- current_cl_bytes.first = recycled.front ();
195- recycled.pop_front ();
196- } else {
197- L0_SAFE_CALL (
198- zeCommandListCreate (driver_.ctx (), device_, &cl_desc, ¤t_cl_bytes.first ));
199- }
203+ cur_cl_bytes_.cl_handle_ , dst, src, num_bytes, nullptr , 0 , nullptr ));
204+ cur_cl_bytes_.bytes_ += num_bytes;
205+ if (cur_cl_bytes_.bytes_ >= CL_BYTES_LIMIT) {
206+ ze_command_list_handle_t cl_h_copy = cur_cl_bytes_.cl_handle_ ;
207+ graveyard_.push_back (cur_cl_bytes_.cl_handle_ );
208+ setCLRecycledOrNew ();
200209 cl_lock.unlock ();
201210 L0_SAFE_CALL (zeCommandListClose (cl_h_copy));
202211 L0_SAFE_CALL (
203212 zeCommandQueueExecuteCommandLists (queue_handle_, 1 , &cl_h_copy, nullptr ));
204213 }
205214}
206215
207- void L0DataFetcher::sync () {
208- if (current_cl_bytes. second ) {
209- L0_SAFE_CALL (zeCommandListClose (current_cl_bytes. first ));
216+ void L0Device:: L0DataFetcher::sync () {
217+ if (cur_cl_bytes_. bytes_ ) {
218+ L0_SAFE_CALL (zeCommandListClose (cur_cl_bytes_. cl_handle_ ));
210219 L0_SAFE_CALL (zeCommandQueueExecuteCommandLists (
211- queue_handle_, 1 , ¤t_cl_bytes. first , nullptr ));
220+ queue_handle_, 1 , &cur_cl_bytes_. cl_handle_ , nullptr ));
212221 }
213222 L0_SAFE_CALL (
214223 zeCommandQueueSynchronize (queue_handle_, std::numeric_limits<uint32_t >::max ()));
215- L0_SAFE_CALL (zeCommandListReset (current_cl_bytes. first ));
216- if (graveyard .size () > GRAVEYARD_LIMIT) {
224+ L0_SAFE_CALL (zeCommandListReset (cur_cl_bytes_. cl_handle_ ));
225+ if (graveyard_ .size () > GRAVEYARD_LIMIT) {
217226 recycleGraveyard ();
218227 }
219228}
220229
221230L0Device::L0Device (const L0Driver& driver, ze_device_handle_t device)
222- : device_(device), driver_(driver), data_fetcher(driver, device ) {
231+ : device_(device), driver_(driver), data_fetcher_(* this ) {
223232 ze_command_queue_handle_t queue_handle;
224233 ze_command_queue_desc_t command_queue_desc = {ZE_STRUCTURE_TYPE_COMMAND_QUEUE_DESC,
225234 nullptr ,
@@ -271,6 +280,14 @@ unsigned L0Device::maxSharedLocalMemory() const {
271280 return compute_props_.maxSharedLocalMemory ;
272281}
273282
283+ void L0Device::transferToDevice (void * dst, const void * src, const size_t num_bytes) {
284+ data_fetcher_.appendCopyCommand (dst, src, num_bytes);
285+ }
286+
287+ void L0Device::syncDataTransfers () {
288+ data_fetcher_.sync ();
289+ }
290+
274291L0CommandQueue::L0CommandQueue (ze_command_queue_handle_t handle) : handle_(handle) {}
275292
276293ze_command_queue_handle_t L0CommandQueue::handle () const {
@@ -420,7 +437,7 @@ void L0Manager::copyHostToDeviceAsync(int8_t* device_ptr,
420437 CHECK_LT (device_num, drivers_[0 ]->devices ().size ());
421438
422439 auto & device = drivers ()[0 ]->devices ()[device_num];
423- device->data_fetcher . appendCopyCommand (device_ptr, host_ptr, num_bytes);
440+ device->transferToDevice (device_ptr, host_ptr, num_bytes);
424441}
425442
426443void L0Manager::copyHostToDeviceAsyncIfPossible (int8_t * device_ptr,
@@ -438,7 +455,7 @@ void L0Manager::synchronizeDeviceDataStream(const int device_num) {
438455 CHECK_GE (device_num, 0 );
439456 CHECK_LT (device_num, drivers_[0 ]->devices ().size ());
440457 auto & device = drivers ()[0 ]->devices ()[device_num];
441- device->data_fetcher . sync ();
458+ device->syncDataTransfers ();
442459}
443460
444461void L0Manager::copyDeviceToHost (int8_t * host_ptr,
0 commit comments