diff --git a/src/include/kompute/Sequence.hpp b/src/include/kompute/Sequence.hpp index de9b9f69..0787f74c 100644 --- a/src/include/kompute/Sequence.hpp +++ b/src/include/kompute/Sequence.hpp @@ -41,7 +41,7 @@ class Sequence : public std::enable_shared_from_this * function also requires the Sequence to be recording, otherwise it will * not be able to add the operation. * - * @param op Object derived from kp::BaseOp that will be recoreded by the + * @param op Object derived from kp::BaseOp that will be recorded by the * sequence which will be used when the operation is evaluated. * @return shared_ptr of the Sequence class itself */ @@ -53,37 +53,18 @@ class Sequence : public std::enable_shared_from_this * function also requires the Sequence to be recording, otherwise it will * not be able to add the operation. * - * @param tensors Vector of tensors to use for the operation + * @param param Template parameter that is used to initialise the operation. * @param TArgs Template parameters that are used to initialise operation * which allows for extensible configurations on initialisation. * @return shared_ptr of the Sequence class itself */ template std::shared_ptr record( - std::vector> tensors, + typename T::ConstructorParameterType param, TArgs&&... params) { - std::shared_ptr op{ new T(tensors, std::forward(params)...) }; - return this->record(op); - } - /** - * Record function for operation to be added to the GPU queue in batch. This - * template requires classes to be derived from the OpBase class. This - * function also requires the Sequence to be recording, otherwise it will - * not be able to add the operation. - * - * @param algorithm Algorithm to use for the record often used for OpAlgo - * operations - * @param TArgs Template parameters that are used to initialise operation - * which allows for extensible configurations on initialisation. - * @return shared_ptr of the Sequence class itself - */ - template - std::shared_ptr record(std::shared_ptr algorithm, - TArgs&&... params) - { - std::shared_ptr op{ new T(algorithm, - std::forward(params)...) }; + static_assert(std::is_base_of::value, "T must derive from OpBase"); + std::shared_ptr op{ new T(param, std::forward(params)...) }; return this->record(op); } @@ -108,34 +89,18 @@ class Sequence : public std::enable_shared_from_this * Eval sends all the recorded and stored operations in the vector of * operations into the gpu as a submit job with a barrier. * - * @param tensors Vector of tensors to use for the operation + * @param param Template parameter that is used to initialise the operation. * @param TArgs Template parameters that are used to initialise operation * which allows for extensible configurations on initialisation. * @return shared_ptr of the Sequence class itself */ template - std::shared_ptr eval(std::vector> tensors, - TArgs&&... params) - { - std::shared_ptr op{ new T(tensors, std::forward(params)...) }; - return this->eval(op); - } - /** - * Eval sends all the recorded and stored operations in the vector of - * operations into the gpu as a submit job with a barrier. - * - * @param algorithm Algorithm to use for the record often used for OpAlgo - * operations - * @param TArgs Template parameters that are used to initialise operation - * which allows for extensible configurations on initialisation. - * @return shared_ptr of the Sequence class itself - */ - template - std::shared_ptr eval(std::shared_ptr algorithm, - TArgs&&... params) + std::shared_ptr eval( + typename T::ConstructorParameterType param, + TArgs&&... params) { - std::shared_ptr op{ new T(algorithm, - std::forward(params)...) }; + static_assert(std::is_base_of::value, "T must derive from OpBase"); + std::shared_ptr op{ new T(param, std::forward(params)...) }; return this->eval(op); } @@ -148,6 +113,7 @@ class Sequence : public std::enable_shared_from_this * @return Boolean stating whether execution was successful. */ std::shared_ptr evalAsync(); + /** * Clears currnet operations to record provided one in the vector of * operations into the gpu as a submit job without a barrier. EvalAwait() @@ -157,39 +123,23 @@ class Sequence : public std::enable_shared_from_this * @return Boolean stating whether execution was successful. */ std::shared_ptr evalAsync(std::shared_ptr op); + /** * Eval sends all the recorded and stored operations in the vector of * operations into the gpu as a submit job with a barrier. * - * @param tensors Vector of tensors to use for the operation + * @param param Template parameter that is used to initialise the operation. * @param TArgs Template parameters that are used to initialise operation * which allows for extensible configurations on initialisation. * @return shared_ptr of the Sequence class itself */ template std::shared_ptr evalAsync( - std::vector> tensors, + typename T::ConstructorParameterType param, TArgs&&... params) { - std::shared_ptr op{ new T(tensors, std::forward(params)...) }; - return this->evalAsync(op); - } - /** - * Eval sends all the recorded and stored operations in the vector of - * operations into the gpu as a submit job with a barrier. - * - * @param algorithm Algorithm to use for the record often used for OpAlgo - * operations - * @param TArgs Template parameters that are used to initialise operation - * which allows for extensible configurations on initialisation. - * @return shared_ptr of the Sequence class itself - */ - template - std::shared_ptr evalAsync(std::shared_ptr algorithm, - TArgs&&... params) - { - std::shared_ptr op{ new T(algorithm, - std::forward(params)...) }; + static_assert(std::is_base_of::value, "T must derive from OpBase"); + std::shared_ptr op{ new T(param, std::forward(params)...) }; return this->evalAsync(op); } diff --git a/src/include/kompute/operations/OpAlgoDispatch.hpp b/src/include/kompute/operations/OpAlgoDispatch.hpp index e91598f0..bd58fb6d 100644 --- a/src/include/kompute/operations/OpAlgoDispatch.hpp +++ b/src/include/kompute/operations/OpAlgoDispatch.hpp @@ -17,6 +17,8 @@ namespace kp { class OpAlgoDispatch : public OpBase { public: + using ConstructorParameterType = std::shared_ptr; + /** * Constructor that stores the algorithm to use as well as the relevant * push constants to override when recording. diff --git a/src/include/kompute/operations/OpBase.hpp b/src/include/kompute/operations/OpBase.hpp index 73767084..23a217a8 100644 --- a/src/include/kompute/operations/OpBase.hpp +++ b/src/include/kompute/operations/OpBase.hpp @@ -18,6 +18,8 @@ namespace kp { class OpBase { public: + using ConstructorParameterType = void; + /** * Default destructor for OpBase class. This OpBase destructor class should * always be called to destroy and free owned resources unless it is diff --git a/src/include/kompute/operations/OpMemoryBarrier.hpp b/src/include/kompute/operations/OpMemoryBarrier.hpp index 4a232232..35a23113 100644 --- a/src/include/kompute/operations/OpMemoryBarrier.hpp +++ b/src/include/kompute/operations/OpMemoryBarrier.hpp @@ -18,6 +18,8 @@ namespace kp { class OpMemoryBarrier : public OpBase { public: + using ConstructorParameterType = std::vector>; + /** * Constructor that stores tensors as well as memory barrier parameters to * be used to create a pipeline barrier on the respective primary or staging diff --git a/src/include/kompute/operations/OpMult.hpp b/src/include/kompute/operations/OpMult.hpp index f75ccc4f..2d4f0eca 100644 --- a/src/include/kompute/operations/OpMult.hpp +++ b/src/include/kompute/operations/OpMult.hpp @@ -21,6 +21,8 @@ namespace kp { class OpMult : public OpAlgoDispatch { public: + using ConstructorParameterType = std::vector>; + /** * Default constructor with parameters that provides the bare minimum * requirements for the operations to be able to create and manage their diff --git a/src/include/kompute/operations/OpTensorCopy.hpp b/src/include/kompute/operations/OpTensorCopy.hpp index 968c1065..6438b108 100644 --- a/src/include/kompute/operations/OpTensorCopy.hpp +++ b/src/include/kompute/operations/OpTensorCopy.hpp @@ -18,6 +18,8 @@ namespace kp { class OpTensorCopy : public OpBase { public: + using ConstructorParameterType = std::vector>; + /** * Default constructor with parameters that provides the core vulkan * resources and the tensors that will be used in the operation. diff --git a/src/include/kompute/operations/OpTensorSyncDevice.hpp b/src/include/kompute/operations/OpTensorSyncDevice.hpp index 3a1792ac..7460a6ea 100644 --- a/src/include/kompute/operations/OpTensorSyncDevice.hpp +++ b/src/include/kompute/operations/OpTensorSyncDevice.hpp @@ -18,6 +18,8 @@ namespace kp { class OpTensorSyncDevice : public OpBase { public: + using ConstructorParameterType = std::vector>; + /** * Default constructor with parameters that provides the core vulkan * resources and the tensors that will be used in the operation. The tensos diff --git a/src/include/kompute/operations/OpTensorSyncLocal.hpp b/src/include/kompute/operations/OpTensorSyncLocal.hpp index 4216003e..95426775 100644 --- a/src/include/kompute/operations/OpTensorSyncLocal.hpp +++ b/src/include/kompute/operations/OpTensorSyncLocal.hpp @@ -20,6 +20,8 @@ namespace kp { class OpTensorSyncLocal : public OpBase { public: + using ConstructorParameterType = std::vector>; + /** * Default constructor with parameters that provides the core vulkan * resources and the tensors that will be used in the operation. The tensors