Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ Legend:
| IM2COL_3D | ❌ | ❌ | ✅ | ✅ | ❌ | ❌ | ❌ | ❌ | ❌ |
| L2_NORM | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| LEAKY_RELU | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ✅ | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | ❌ |
| LOG | ❌ | ✅ | ✅ | ✅ | ❌ | ❌ | ✅ | | ❌ |
| MEAN | ❌ | ✅ | ✅ | ✅ | ✅ | ❌ | ✅ | ❌ | ❌ |
| MUL | ❌ | ✅ | ✅ | ✅ | 🟡 | 🟡 | ✅ | ✅ | ❌ |
| MUL_MAT | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 | 🟡 |
Expand Down
4 changes: 2 additions & 2 deletions docs/ops/Vulkan.csv
Original file line number Diff line number Diff line change
Expand Up @@ -5431,13 +5431,13 @@
"Vulkan0","OUT_PROD","type_a=iq2_xxs,type_b=f16,m=256,n=16,k=16,bs=[3,3],nr=[2,2],trans_b=0","support","0","no","Vulkan"
"Vulkan0","SQR","type=f16,ne=[10,5,4,3]","support","0","no","Vulkan"
"Vulkan0","SQRT","type=f16,ne=[10,3,3,2]","support","0","no","Vulkan"
"Vulkan0","LOG","type=f16,ne=[10,5,4,3]","support","0","no","Vulkan"
"Vulkan0","LOG","type=f16,ne=[10,5,4,3]","support","1","yes","Vulkan"
"Vulkan0","SIN","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
"Vulkan0","COS","type=f16,ne=[10,2,2,2]","support","0","no","Vulkan"
"Vulkan0","CLAMP","type=f16,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","0","no","Vulkan"
"Vulkan0","SQR","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
"Vulkan0","SQRT","type=f32,ne=[10,3,3,2]","support","0","no","Vulkan"
"Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","0","no","Vulkan"
"Vulkan0","LOG","type=f32,ne=[10,5,4,3]","support","1","yes","Vulkan"
"Vulkan0","SIN","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
"Vulkan0","COS","type=f32,ne=[10,2,2,2]","support","1","yes","Vulkan"
"Vulkan0","CLAMP","type=f32,ne=[10,5,4,3],min=-0.500000,max=0.500000","support","1","yes","Vulkan"
Expand Down
25 changes: 25 additions & 0 deletions ggml/src/ggml-vulkan/ggml-vulkan.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -626,6 +626,7 @@ struct vk_device_struct {
vk_pipeline pipeline_sqrt_f32;
vk_pipeline pipeline_sin_f32;
vk_pipeline pipeline_cos_f32;
vk_pipeline pipeline_log[2];
vk_pipeline pipeline_clamp_f32;
vk_pipeline pipeline_pad_f32;
vk_pipeline pipeline_roll_f32;
Expand Down Expand Up @@ -3710,6 +3711,8 @@ static void ggml_vk_load_shaders(vk_device& device) {
ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_log[0], "log_f32", log_f32_len, log_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);
ggml_vk_create_pipeline(device, device->pipeline_log[1], "log_f16", log_f16_len, log_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);

ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1);

Expand Down Expand Up @@ -8233,6 +8236,12 @@ static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const
return ctx->device->pipeline_cos_f32;
}
return nullptr;
case GGML_OP_LOG:
if (src0->type == dst->type &&
(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16)) {
return ctx->device->pipeline_log[dst->type == GGML_TYPE_F16];
}
return nullptr;
case GGML_OP_CLAMP:
if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) {
return ctx->device->pipeline_clamp_f32;
Expand Down Expand Up @@ -8637,6 +8646,7 @@ static bool ggml_vk_op_supports_incontiguous(ggml_op op) {
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_REPEAT:
Expand Down Expand Up @@ -8929,6 +8939,7 @@ static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, co
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_ROLL:
Expand Down Expand Up @@ -9537,6 +9548,10 @@ static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst));
}

static void ggml_vk_log(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, nullptr, dst, GGML_OP_LOG, vk_op_unary_push_constants_init(src0, dst));
}

static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst) {
vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst);
p.param1 = ggml_get_op_params_f32(dst, 0);
Expand Down Expand Up @@ -11329,6 +11344,7 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_ROLL:
Expand Down Expand Up @@ -11553,6 +11569,10 @@ static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgr
case GGML_OP_COS:
ggml_vk_cos(ctx, compute_ctx, src0, node);

break;
case GGML_OP_LOG:
ggml_vk_log(ctx, compute_ctx, src0, node);

break;
case GGML_OP_CLAMP:
ggml_vk_clamp(ctx, compute_ctx, src0, node);
Expand Down Expand Up @@ -11821,6 +11841,7 @@ static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph *
case GGML_OP_SQRT:
case GGML_OP_SIN:
case GGML_OP_COS:
case GGML_OP_LOG:
case GGML_OP_CLAMP:
case GGML_OP_PAD:
case GGML_OP_ROLL:
Expand Down Expand Up @@ -13672,6 +13693,8 @@ static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggm
case GGML_OP_OPT_STEP_ADAMW:
case GGML_OP_OPT_STEP_SGD:
return op->src[0]->type == GGML_TYPE_F32;
case GGML_OP_LOG:
return op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16;
case GGML_OP_ARGSORT:
return op->ne[0] <= max_argsort_cols;
case GGML_OP_UPSCALE:
Expand Down Expand Up @@ -14167,6 +14190,8 @@ static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph *
tensor_clone = ggml_sin(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_COS) {
tensor_clone = ggml_cos(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_LOG) {
tensor_clone = ggml_log(ggml_ctx, src_clone[0]);
} else if (tensor->op == GGML_OP_CLAMP) {
const float * params = (const float *)tensor->op_params;
tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]);
Expand Down
17 changes: 17 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/log.comp
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
#version 450

#include "types.glsl"
#include "generic_unary_head.glsl"

layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;

void main() {
const uint idx = get_idx();

if (idx >= p.ne) {
return;
}

const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]);
data_d[get_doffset() + dst_idx(idx)] = D_TYPE(log(val));
}
3 changes: 3 additions & 0 deletions ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -802,6 +802,9 @@ void process_shaders() {

string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

string_to_spv("log_f32", "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});
string_to_spv("log_f16", "log.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}});

string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}});
Expand Down