Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#ifndef __PER_CHANNEL_QUANT_INT8_MOORE_API_H__
#define __PER_CHANNEL_QUANT_INT8_MOORE_API_H__
#include "../per_channel_quant_int8.h"

DESCRIPTOR(moore)

#endif // __PER_CHANNEL_QUANT_INT8_MOORE_API_H__
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
#include "../../../../devices/moore/moore_common.h"
#include "per_channel_quant_int8_moore.h"

#include "../../../../devices/moore/moore_kernel_common.h"
#include "../../../../reduce/cuda/reduce.cuh"
#include <cub/block/block_reduce.cuh>

#include "../cuda/kernel.cuh"

template <typename Tdata, unsigned int BLOCK_SIZE>
INFINIOP_MOORE_KERNEL blockPerChannelQuantI8(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, int M, int K) {
blockPerChannelQuantI8Kernel<Tdata, BLOCK_SIZE>(x_packed, x_scale, x_zero, x, M, K);
}
template <typename Tdata, unsigned int BLOCK_SIZE>
INFINIOP_MOORE_KERNEL blockPerChannelQuantI8Sym(
int8_t *x_packed, float *x_scale, const Tdata *x, int M, int K) {
blockPerChannelQuantI8SymKernel<Tdata, BLOCK_SIZE>(x_packed, x_scale, x, M, K);
}

template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
INFINIOP_MOORE_KERNEL warpPerChannelQuantI8(
int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, int M, int K) {
warpPerChannelQuantI8Kernel<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>(x_packed, x_scale, x_zero, x, M, K);
}
template <typename Tdata, unsigned int BLOCK_SIZE_x, unsigned int BLOCK_SIZE_y>
INFINIOP_MOORE_KERNEL warpPerChannelQuantI8Sym(
int8_t *x_packed, float *x_scale, const Tdata *x, int M, int K) {
warpPerChannelQuantI8SymKernel<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>(x_packed, x_scale, x, M, K);
}

namespace op::per_channel_quant_int8::moore {

struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};

Descriptor::~Descriptor() {
delete _opaque;
}

infiniStatus_t Descriptor::create(
infiniopHandle_t handle, Descriptor **desc_ptr,
infiniopTensorDescriptor_t x_packed_desc,
infiniopTensorDescriptor_t x_scale_desc,
infiniopTensorDescriptor_t x_zero_desc,
infiniopTensorDescriptor_t x_desc) {
auto info = PerChannelQuantI8Info::createPerChannelQuantI8Info(x_packed_desc, x_scale_desc, x_zero_desc, x_desc);
CHECK_RESULT(info);

*desc_ptr = new Descriptor(
new Opaque{reinterpret_cast<device::moore::Handle *>(handle)->internal()},
info.take(), 0, handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}

template <unsigned int BLOCK_SIZE, typename Tdata>
infiniStatus_t per_channel_quant_int8Kernel(const PerChannelQuantI8Info &info, int8_t *x_packed, float *x_scale, float *x_zero, const Tdata *x, musaStream_t stream) {
int M = (int)info.M;
int K = (int)info.K;

if (K >= 1024) {
if (x_zero == nullptr) {
blockPerChannelQuantI8Sym<Tdata, BLOCK_SIZE>
<<<M, BLOCK_SIZE, 0, stream>>>(x_packed, x_scale, x, M, K);
} else {
blockPerChannelQuantI8<Tdata, BLOCK_SIZE>
<<<M, BLOCK_SIZE, 0, stream>>>(x_packed, x_scale, x_zero, x, M, K);
}

} else {
constexpr unsigned int BLOCK_SIZE_x = 32;
constexpr unsigned int BLOCK_SIZE_y = 32;
int num_block_x = (M + BLOCK_SIZE_y - 1) / BLOCK_SIZE_y;
dim3 block_dim(BLOCK_SIZE_x, BLOCK_SIZE_y, 1);
dim3 grid_dim(num_block_x, 1, 1);
if (x_zero == nullptr) {
warpPerChannelQuantI8Sym<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>
<<<grid_dim, block_dim, 0, stream>>>(x_packed, x_scale, x, M, K);
} else {
warpPerChannelQuantI8<Tdata, BLOCK_SIZE_x, BLOCK_SIZE_y>
<<<grid_dim, block_dim, 0, stream>>>(x_packed, x_scale, x_zero, x, M, K);
}
}

return INFINI_STATUS_SUCCESS;
}

infiniStatus_t Descriptor::calculate(void *workspace, size_t workspace_size,
void *x_packed, void *x_scale, void *x_zero, const void *x,
void *stream_) const {
musaStream_t stream = (musaStream_t)stream_;
#define QUANT(BLOCK_SIZE, TDATA) \
per_channel_quant_int8Kernel<BLOCK_SIZE, TDATA>(_info, (int8_t *)x_packed, (float *)x_scale, (float *)x_zero, (const TDATA *)x, stream)
#define QUANT_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (_info.dtype == INFINI_DTYPE_F16) \
return QUANT(BLOCK_SIZE, half); \
else if (_info.dtype == INFINI_DTYPE_F32) \
return QUANT(BLOCK_SIZE, float); \
else if (_info.dtype == INFINI_DTYPE_BF16) \
return QUANT(BLOCK_SIZE, __mt_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_1024) {
QUANT_WITH_BLOCK_SIZE(MOORE_BLOCK_SIZE_1024)
} else if (_opaque->internal->maxThreadsPerBlock() == MOORE_BLOCK_SIZE_512) {
QUANT_WITH_BLOCK_SIZE(MOORE_BLOCK_SIZE_512)
} else {
return INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED;
}
return INFINI_STATUS_SUCCESS;
}

} // namespace op::per_channel_quant_int8::moore
15 changes: 15 additions & 0 deletions src/infiniop/ops/quant/per_channel_quant_int8/operator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include "nvidia/per_channel_quant_int8_nvidia.cuh"
#endif
#if defined(ENABLE_MOORE_API)
#include "moore/per_channel_quant_int8_moore.h"
#endif

__C infiniStatus_t infiniopCreatePerChannelQuantI8Descriptor(infiniopHandle_t handle,
infiniopPerChannelQuantI8Descriptor_t *desc_ptr,
Expand All @@ -27,6 +30,9 @@ __C infiniStatus_t infiniopCreatePerChannelQuantI8Descriptor(infiniopHandle_t ha
#endif
#ifdef ENABLE_QY_API
CREATE(INFINI_DEVICE_QY, nvidia)
#endif
#ifdef ENABLE_MOORE_API
CREATE(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand All @@ -45,6 +51,9 @@ __C infiniStatus_t infiniopGetPerChannelQuantI8WorkspaceSize(infiniopPerChannelQ
#endif
#ifdef ENABLE_QY_API
GET(INFINI_DEVICE_QY, nvidia)
#endif
#ifdef ENABLE_MOORE_API
GET(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand All @@ -71,6 +80,9 @@ __C infiniStatus_t infiniopPerChannelQuantI8(infiniopPerChannelQuantI8Descriptor
#endif
#ifdef ENABLE_QY_API
QUANT(INFINI_DEVICE_QY, nvidia)
#endif
#ifdef ENABLE_MOORE_API
QUANT(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand All @@ -90,6 +102,9 @@ __C infiniStatus_t infiniopDestroyPerChannelQuantI8Descriptor(infiniopPerChannel
#endif
#ifdef ENABLE_QY_API
DESTROY(INFINI_DEVICE_QY, nvidia)
#endif
#ifdef ENABLE_MOORE_API
DESTROY(INFINI_DEVICE_MOORE, moore)
#endif
default:
return INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED;
Expand Down
2 changes: 1 addition & 1 deletion src/infiniop/ops/scaled_mm/info.h
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
#ifndef __GEMM_INFO_H__
#ifndef __I8GEMM_INFO_H__
#define __I8GEMM_INFO_H__

#include "../../../utils.h"
Expand Down
4 changes: 2 additions & 2 deletions src/infiniop/ops/scaled_mm/int8_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@
size_t workspace_size, \
infiniDtype_t out_dtype, \
infiniDevice_t device_type, int device_id) \
: InfiniopDescriptor{device_type, device_id}, _out_dtype(out_dtype), \
_opaque(opaque), _info(info), _workspace_size(workspace_size) {} \
: InfiniopDescriptor{device_type, device_id}, _opaque(opaque), \
_workspace_size(workspace_size), _info(info), _out_dtype(out_dtype) {} \
\
public: \
~Descriptor(); \
Expand Down
7 changes: 7 additions & 0 deletions src/infiniop/ops/scaled_mm/moore/int8_gemm_moore.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
#ifndef __INT8_GEMM_MOORE_API_H__
#define __INT8_GEMM_MOORE_API_H__
#include "../int8_gemm.h"

DESCRIPTOR(moore)

#endif // __INT8_GEMM_MOORE_API_H__
Loading