@@ -8831,6 +8831,23 @@ pi_result _pi_buffer::free() {
88318831
88328832// / command-buffer Extension
88338833
8834+ // / Helper function to take a list of pi_ext_sync_points and fill the provided
8835+ // / vector with the associated ZeEvents
8836+ static pi_result getEventsFromSyncPoints (
8837+ const std::unordered_map<pi_ext_sync_point, pi_event> &SyncPoints,
8838+ size_t NumSyncPointsInWaitList, const pi_ext_sync_point *SyncPointWaitList,
8839+ std::vector<ze_event_handle_t > &ZeEventList) {
8840+ for (size_t i = 0 ; i < NumSyncPointsInWaitList; i++) {
8841+ if (auto EventHandle = SyncPoints.find (SyncPointWaitList[i]);
8842+ EventHandle != SyncPoints.end ()) {
8843+ ZeEventList.push_back (EventHandle->second ->ZeEvent );
8844+ } else {
8845+ return PI_ERROR_INVALID_VALUE;
8846+ }
8847+ }
8848+ return PI_SUCCESS;
8849+ }
8850+
88348851pi_result piextCommandBufferCreate (pi_context Context, pi_device Device,
88358852 const pi_ext_command_buffer_desc *Desc,
88368853 pi_ext_command_buffer *RetCommandBuffer) {
@@ -8935,19 +8952,16 @@ pi_result piextCommandBufferNDRangeKernel(
89358952
89368953 ZE_CALL (zeKernelSetGroupSize, (Kernel->ZeKernel , WG[0 ], WG[1 ], WG[2 ]));
89378954
8938- std::vector<ze_event_handle_t > ZeEventList (NumSyncPointsInWaitList);
8939- for (size_t i = 0 ; i < NumSyncPointsInWaitList; i++) {
8940- if (auto EventHandle = CommandBuffer->SyncPoints .find (SyncPointWaitList[i]);
8941- EventHandle != CommandBuffer->SyncPoints .end ()) {
8942- ZeEventList[i] = CommandBuffer->SyncPoints [SyncPointWaitList[i]]->ZeEvent ;
8943- } else {
8944- return PI_ERROR_INVALID_VALUE;
8945- }
8955+ std::vector<ze_event_handle_t > ZeEventList;
8956+ pi_result Res = getEventsFromSyncPoints (CommandBuffer->SyncPoints ,
8957+ NumSyncPointsInWaitList,
8958+ SyncPointWaitList, ZeEventList);
8959+ if (Res) {
8960+ return Res;
89468961 }
8947-
89488962 pi_event LaunchEvent;
8949- auto res = EventCreate (CommandBuffer->Context , nullptr , true , &LaunchEvent);
8950- if (res )
8963+ Res = EventCreate (CommandBuffer->Context , nullptr , true , &LaunchEvent);
8964+ if (Res )
89518965 return PI_ERROR_OUT_OF_HOST_MEMORY;
89528966
89538967 LaunchEvent->CommandData = (void *)Kernel;
@@ -8972,6 +8986,41 @@ pi_result piextCommandBufferNDRangeKernel(
89728986 return PI_SUCCESS;
89738987}
89748988
8989+ pi_result piextCommandBufferMemcpyUSM (
8990+ pi_ext_command_buffer CommandBuffer, void *DstPtr, const void *SrcPtr,
8991+ size_t Size, pi_uint32 NumSyncPointsInWaitList,
8992+ const pi_ext_sync_point *SyncPointWaitList, pi_ext_sync_point *SyncPoint) {
8993+ if (!DstPtr) {
8994+ return PI_ERROR_INVALID_VALUE;
8995+ }
8996+
8997+ std::vector<ze_event_handle_t > ZeEventList;
8998+ pi_result Res = getEventsFromSyncPoints (CommandBuffer->SyncPoints ,
8999+ NumSyncPointsInWaitList,
9000+ SyncPointWaitList, ZeEventList);
9001+ if (Res) {
9002+ return Res;
9003+ }
9004+
9005+ pi_event LaunchEvent;
9006+ Res = EventCreate (CommandBuffer->Context , nullptr , true , &LaunchEvent);
9007+ if (Res)
9008+ return PI_ERROR_OUT_OF_HOST_MEMORY;
9009+
9010+ ZE_CALL (zeCommandListAppendMemoryCopy,
9011+ (CommandBuffer->ZeCommandList , DstPtr, SrcPtr, Size,
9012+ LaunchEvent->ZeEvent , ZeEventList.size (), ZeEventList.data ()));
9013+
9014+ urPrint (" calling zeCommandListAppendMemoryCopy() with"
9015+ " ZeEvent %#lx\n " ,
9016+ ur_cast<std::uintptr_t >(LaunchEvent->ZeEvent ));
9017+
9018+ // Get sync point and register the event with it.
9019+ *SyncPoint = CommandBuffer->GetNextSyncPoint ();
9020+ CommandBuffer->RegisterSyncPoint (*SyncPoint, LaunchEvent);
9021+ return PI_SUCCESS;
9022+ }
9023+
89759024pi_result piextEnqueueCommandBuffer (pi_ext_command_buffer CommandBuffer,
89769025 pi_queue Queue,
89779026 pi_uint32 NumEventsInWaitList,
0 commit comments