From bd9cc0f33a4b1e7269a3cf19e23de2cef2576358 Mon Sep 17 00:00:00 2001 From: lxy268263 Date: Thu, 14 Jul 2022 11:16:53 +0800 Subject: [PATCH] [Embedding] Enable construct saver graph when EV Ops are put on GPU. --- tensorflow/core/kernels/kv_variable_ops.cc | 205 ++++++++++++++++- tensorflow/core/kernels/kv_variable_ops.h | 215 ++++++++++++++++++ .../core/kernels/kv_variable_ops_gpu.cu.cc | 58 +++++ tensorflow/core/kernels/kv_variable_ops_gpu.h | 62 ++++- tensorflow/core/ops/kv_variable_ops.cc | 8 +- tensorflow/python/ops/kv_variable_ops.py | 10 +- tensorflow/python/training/saver.py | 14 +- .../training/saving/saveable_object_util.py | 18 +- 8 files changed, 570 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/kernels/kv_variable_ops.cc b/tensorflow/core/kernels/kv_variable_ops.cc index 12d162097cf..3b3c4e885e5 100644 --- a/tensorflow/core/kernels/kv_variable_ops.cc +++ b/tensorflow/core/kernels/kv_variable_ops.cc @@ -47,8 +47,6 @@ limitations under the License. namespace tensorflow { -using GPUDevice = Eigen::GpuDevice; - namespace { const int64 kEmbeddingVarUseDB = -214; const int64 kInitializableEmbeddingVarUseDB = -215; @@ -967,7 +965,6 @@ REGISTER_KERNELS_ALL_INDEX(float); #undef REGISTER_KERNELS_ALL_INDEX #undef REGISTER_KERNELS - #if GOOGLE_CUDA #if TF_ENABLE_GPU_EV #define REGISTER_KV_VAR_HANDLE(ktype, vtype) \ @@ -1335,7 +1332,11 @@ class KvResourceExportOpGPU : public OpKernel { REGISTER_KERNEL_BUILDER(Name("KvResourceExport") \ .Device(DEVICE_GPU) \ .TypeConstraint("Tkeys") \ - .TypeConstraint("Tvalues"), \ + .TypeConstraint("Tvalues") \ + .HostMemory("keys") \ + .HostMemory("values") \ + .HostMemory("versions") \ + .HostMemory("freqs"), \ KvResourceExportOpGPU); #define REGISTER_KERNELS_ALL_INDEX(type) \ REGISTER_KERNELS(int32, type) \ @@ -1346,6 +1347,202 @@ REGISTER_KERNELS_ALL_INDEX(float); #undef REGISTER_KERNELS_ALL_INDEX #undef REGISTER_KERNELS +template +class KvResourceImportV2OpGPU: public OpKernel { + public: + explicit KvResourceImportV2OpGPU(OpKernelConstruction* c) + : OpKernel(c) { + OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); + OP_REQUIRES_OK(c, c->GetAttr("counter_type", &counter_type_)); + OP_REQUIRES_OK(c, c->GetAttr("shape", &shape_)); + OP_REQUIRES(c, shape_.dims() == 1, + errors::InvalidArgument("KvVariable dimension must be 1")); + OP_REQUIRES_OK(c, c->GetAttr("steps_to_live", &steps_to_live_)); + OP_REQUIRES(c, steps_to_live_ >= 0, + errors::InvalidArgument( + "steps_to_live must >= 0, ", + std::to_string(steps_to_live_))); + OP_REQUIRES_OK(c, c->GetAttr("partition_id", &partition_id_)); + OP_REQUIRES(c, partition_id_ >= 0, + errors::InvalidArgument( + "partition_id must >= 0, ", + std::to_string(partition_id_))); + OP_REQUIRES_OK(c, c->GetAttr("partition_num", &partition_num_)); + OP_REQUIRES(c, partition_num_ >= 1, + errors::InvalidArgument( + "partition_num must >= 1, ", + std::to_string(partition_num_))); + //OP_REQUIRES_OK(c, c->GetAttr("restore_versions", &restore_versions_)); + OP_REQUIRES_OK(c, c->GetAttr("ht_type", &ht_type_)); + OP_REQUIRES_OK(c, c->GetAttr("ht_partition_num", &ht_partition_num_)); + OP_REQUIRES_OK(c, c->GetAttr("emb_index", &emb_index_)); + OP_REQUIRES_OK(c, c->GetAttr("slot_index", &slot_index_)); + OP_REQUIRES_OK(c, c->GetAttr("filter_freq", &filter_freq_)); + OP_REQUIRES_OK(c, c->GetAttr("block_num", &block_num_)); + OP_REQUIRES_OK(c, c->GetAttr("max_element_size", &max_element_size_)); + OP_REQUIRES_OK(c, c->GetAttr("false_positive_probability", + &false_positive_probability_)); + OP_REQUIRES_OK(c, c->GetAttr("l2_weight_threshold", + &l2_weight_threshold_)); + OP_REQUIRES_OK(c, c->GetAttr("layout", &layout_)); + OP_REQUIRES_OK(c, c->GetAttr("max_freq", &max_freq_)); + OP_REQUIRES_OK(c, c->GetAttr("default_value_dim", + &default_value_dim_)); + OP_REQUIRES_OK(c, c->GetAttr("slot_num", &slot_num_)); + int64 storage_type = 0; + OP_REQUIRES_OK(c, c->GetAttr("storage_type", &storage_type)); + storage_type_ = static_cast(storage_type); + + OP_REQUIRES_OK(c, c->GetAttr("storage_path", &storage_path_)); + OP_REQUIRES_OK(c, c->GetAttr("storage_size", &storage_size_)); + OP_REQUIRES_OK(c, c->GetAttr("record_freq", &record_freq_)); + OP_REQUIRES_OK(c, c->GetAttr("record_version", &record_version_)); + } + + void Compute(OpKernelContext* context) override { + const Tensor& file_name = context->input(0); + const std::string file_name_string = file_name.scalar()(); + const Tensor& name = context->input(4); + const std::string name_string = name.scalar()(); + const Tensor& default_values = context->input(3); + OP_REQUIRES(context, dtype_ == default_values.dtype(), + errors::InvalidArgument( + "Variable and ddd value dtypes don't match; respectively, ", + dtype_, " and ", default_values.dtype())); + + ResourceHandle handle_self = HandleFromInput(context, 1); + ResourceHandle handle_primary = HandleFromInput(context, 2); + std::string opname = handle_self.name(); + EmbeddingVarGPU* ev = nullptr; + if (handle_self.name() == handle_primary.name() && + handle_self.container() == handle_primary.container()) { + OP_REQUIRES_OK( + context, + LookupOrCreateResource>( + context, handle_self, &ev, + [this, default_values, opname, context, + handle_self](EmbeddingVarGPU** ptr) { + GPUHashTable* ht = + new GPUHashTable(-1, + context->get_allocator(AllocatorAttributes())); + *ptr = new EmbeddingVarGPU(handle_self.name(), + ht, context->get_allocator(AllocatorAttributes()), + EmbeddingConfig(emb_index_ + block_num_ * slot_index_, + emb_index_, + block_num_, slot_num_, + opname + "-primary", + steps_to_live_, filter_freq_, max_freq_, + l2_weight_threshold_, layout_, + max_element_size_, + false_positive_probability_, + counter_type_, default_value_dim_)); + return (*ptr)->Init(default_values, default_value_dim_); + })); + } else { + EmbeddingVarGPU* primary_variable = nullptr; + + OP_REQUIRES_OK( + context, + LookupOrCreateResource>( + context, handle_primary, &primary_variable, + [this, default_values, opname, context, + handle_primary](EmbeddingVarGPU** ptr) { + int64 primary_slot_index(0), primary_emb_index(0); + GPUHashTable* ht = + new GPUHashTable(-1, + context->get_allocator(AllocatorAttributes())); + *ptr = new EmbeddingVarGPU(handle_primary.name(), + ht, context->get_allocator(AllocatorAttributes()), + EmbeddingConfig( + primary_emb_index + block_num_ * primary_slot_index, + primary_emb_index, + block_num_, slot_num_, opname + "-primary", + steps_to_live_, filter_freq_, max_freq_, + l2_weight_threshold_, layout_, + max_element_size_, + false_positive_probability_, + counter_type_)); + return (*ptr)->Init(default_values, default_value_dim_); + })); + + + OP_REQUIRES_OK( + context, + LookupOrCreateResource>( + context, handle_self, &ev, + [this, default_values, opname, primary_variable, context, + handle_self](EmbeddingVarGPU** ptr) { + *ptr = new EmbeddingVarGPU(handle_self.name(), + primary_variable->kv(), + context->get_allocator(AllocatorAttributes()), + EmbeddingConfig(emb_index_ + block_num_ * slot_index_, + emb_index_, + block_num_, slot_num_, opname, + steps_to_live_, 0, + max_freq_, l2_weight_threshold_, + layout_, 0, -1.0, counter_type_, default_value_dim_)); + return (*ptr)->Init(default_values, default_value_dim_); + })); + core::ScopedUnref unref_me(primary_variable); + } + core::ScopedUnref unref_me(ev); + + BundleReader reader(Env::Default(), file_name_string); + auto s = reader.status(); + if (!s.ok()) { + LOG(FATAL) << "Restore EV failure, create BundleReader error:" + << s.ToString(); + } + + EVRestoreDynamicallyGPU( + ev, name_string, partition_id_, partition_num_, context, &reader, + "-partition_offset", "-keys", "-values", "-versions", "-freqs"); + ev->SetInitialized(); + } + + private: + int64 partition_id_; + int64 partition_num_; + DataType dtype_; + DataType counter_type_; + int64 max_element_size_; + float false_positive_probability_; + TensorShape shape_; + int64 steps_to_live_; + bool restore_versions_; + string ht_type_; + int64 ht_partition_num_; + int64 emb_index_; + int64 slot_index_; + int64 block_num_; + int64 slot_num_; + int64 filter_freq_; + float l2_weight_threshold_; + std::string layout_; + int64 max_freq_; + embedding::StorageType storage_type_; + std::string storage_path_; + std::vector storage_size_; + int64 default_value_dim_; + bool record_freq_; + bool record_version_; +}; + +#define REGISTER_KERNELS(ktype, vtype) \ + REGISTER_KERNEL_BUILDER(Name("KvResourceImportV2") \ + .Device(DEVICE_GPU) \ + .TypeConstraint("Tkeys") \ + .TypeConstraint("dtype"), \ + KvResourceImportV2OpGPU); +#define REGISTER_KERNELS_ALL_INDEX(type) \ + REGISTER_KERNELS(int32, type) \ + REGISTER_KERNELS(int64, type) +TF_CALL_float(REGISTER_KERNELS_ALL_INDEX); +TF_CALL_double(REGISTER_KERNELS_ALL_INDEX); +//TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS_ALL_INDEX); +#undef REGISTER_KERNELS_ALL_INDEX +#undef REGISTER_KERNELS + #endif // TF_ENABLE_GPU_EV #endif // GOOGLE_CUDA diff --git a/tensorflow/core/kernels/kv_variable_ops.h b/tensorflow/core/kernels/kv_variable_ops.h index 57291940816..79fa005e196 100644 --- a/tensorflow/core/kernels/kv_variable_ops.h +++ b/tensorflow/core/kernels/kv_variable_ops.h @@ -33,8 +33,16 @@ limitations under the License. #include "tensorflow/core/lib/random/random.h" #include "tensorflow/core/lib/random/random_distributions.h" #include "tensorflow/core/util/tensor_bundle/tensor_bundle.h" +#if GOOGLE_CUDA +#define EIGEN_USE_GPU +#if TF_ENABLE_GPU_EV +#include "tensorflow/core/kernels/kv_variable_ops_gpu.h" +#endif // TF_ENABLE_GPU_EV +#endif // GOOGLE_CUDA namespace tensorflow { +using GPUDevice = Eigen::GpuDevice; + namespace { const int kSavedPartitionNum = 1000; } @@ -1072,6 +1080,213 @@ Status EVRestoreDynamically(EmbeddingVar* ev, } return Status::OK(); } +#if GOOGLE_CUDA +#if TF_ENABLE_GPU_EV +template +Status EVRestoreNoPartitionGPU(EmbeddingVarGPU* ev, BundleReader* reader, + std::string tensor_key, std::string tensor_value, OpKernelContext* context, + std::string tensor_version, std::string tensor_freq) { + TensorShape key_shape; + TensorShape value_shape; + + Status st; + reader->LookupTensorShape(tensor_key, &key_shape); + reader->LookupTensorShape(tensor_value, &value_shape); + const cudaStream_t& stream = context->eigen_device().stream(); + bool filter_flag = true; + bool restore_filter_flag = true; + st = reader->LookupHeader(tensor_key, + sizeof(K) * key_shape.dim_size(0)); + if (!st.ok()) + return st; + st = reader->LookupHeader(tensor_value, + sizeof(V) * value_shape.dim_size(0) * value_shape.dim_size(1)); + if (!st.ok()) + return st; + + size_t buffer_size = 8 << 20; + RestoreBuffer restore_buff; + restore_buff.key_buffer = new char[buffer_size]; + restore_buff.value_buffer = new char[buffer_size]; + + size_t key_bytes_read = 0; + size_t value_bytes_read = 0; + + int64 tot_key_num = key_shape.dim_size(0); + size_t value_unit_bytes = sizeof(V) * value_shape.dim_size(1); + std::string key_str = "|"; + while(tot_key_num > 0) { + size_t read_key_num = std::min( + std::min(buffer_size / sizeof(K), + buffer_size / value_unit_bytes), buffer_size / sizeof(int64)); + read_key_num = std::min((int64)read_key_num, tot_key_num); + reader->LookupSegment(tensor_key, read_key_num * sizeof(K), + restore_buff.key_buffer, key_bytes_read); + reader->LookupSegment(tensor_value, read_key_num * value_unit_bytes, + restore_buff.value_buffer, value_bytes_read); + + if (key_bytes_read > 0) { + read_key_num = key_bytes_read / sizeof(K); + VLOG(2) << "restore, read_key_num:" << read_key_num; + + st = ev->Import(restore_buff, read_key_num, 1, 0, 1, false, stream); + if (!st.ok()) + return st; + tot_key_num -= read_key_num; + } + } + + return Status::OK(); +} + +template +Status EVRestoreDynamicallyGPU(EmbeddingVarGPU* ev, + const std::string& name_string, int partition_id, + int partition_num, OpKernelContext* context, + BundleReader* reader, const std::string& part_offset_tensor_suffix, + const std::string& key_suffix, const std::string& value_suffix, + const std::string& version_suffix, const std::string& freq_suffix) { + + // first check whether there is partition + if (name_string.find(part_str) == std::string::npos) { + Status s = EVRestoreNoPartitionGPU( + ev, reader, name_string + key_suffix, + name_string + value_suffix, context, name_string + version_suffix, + name_string + freq_suffix); + if (!s.ok()) { + LOG(FATAL) << "EV restoring fail:" << s.ToString(); + } + return s; + } + + const string& curr_partid_str = std::to_string(partition_id); + // first find out which sub parts we should load + std::vector loaded_parts; + for (int i = 0; i < kSavedPartitionNum; i++) { + if (i % partition_num == partition_id) { + loaded_parts.push_back(i); + } + } + + // then we use primary partition number to compose with + // sub partition number + VLOG(1) << "new form:" << name_string + << ", partition_id:" << partition_id + << ", partition_num:" << partition_num; + + int orig_partnum = 0; + size_t buffer_size = 8 << 20; + RestoreBuffer restore_buff; + restore_buff.key_buffer = new char[buffer_size]; + restore_buff.value_buffer = new char[buffer_size]; + const cudaStream_t& stream = context->eigen_device().stream(); + + + for (; ; orig_partnum++) { + string part_id = std::to_string(orig_partnum); + string pre_subname = name_string.substr(0, name_string.find(part_str)); + string post_subname = name_string.substr(name_string.find(part_str) + + part_str.size() + curr_partid_str.size()); + string tensor_name = pre_subname + part_str + part_id + post_subname; + + // first check whether is old ckpt form + string tensor_key = tensor_name + key_suffix; + string tensor_value = tensor_name + value_suffix; + TensorShape key_shape, value_shape, version_shape, freq_shape; + Status st = reader->LookupTensorShape(tensor_key, &key_shape); + if (!st.ok()) { + VLOG(1) << "ev part " << tensor_key + << " not exist, reach the end of restoring"; + break; + } + st = reader->LookupTensorShape(tensor_value, &value_shape); + if (!st.ok()) { + break; + } + + reader->LookupHeader(tensor_key, sizeof(K) * key_shape.dim_size(0)); + if (!st.ok()) { + break; + } + st = reader->LookupHeader(tensor_value, + sizeof(V) * value_shape.dim_size(0) * value_shape.dim_size(1)); + if (!st.ok()) { + break; + } + + TensorShape part_offset_shape; + DataType part_offset_type; + string offset_tensor_name = tensor_name + part_offset_tensor_suffix; + st = reader->LookupDtypeAndShape(offset_tensor_name, + &part_offset_type, &part_offset_shape); + if (!st.ok()) { + LOG(FATAL) << "EV restoring fail:" << st.ToString(); + } + Tensor part_offset_tensor; + st = context->allocate_temp(part_offset_type, + part_offset_shape, &part_offset_tensor); + if (!st.ok()) { + LOG(FATAL) << "EV restoring fail:" << st.ToString(); + } + st = reader->Lookup(offset_tensor_name, &part_offset_tensor); + if (!st.ok()) { + LOG(FATAL) << "EV restoring fail:" << st.ToString(); + } + auto part_offset_flat = part_offset_tensor.flat(); + + for (size_t i = 0; i < loaded_parts.size(); i++) { + int subpart_id = loaded_parts[i]; + int subpart_offset = part_offset_flat(subpart_id); + + size_t value_unit_bytes = sizeof(V) * value_shape.dim_size(1); + int64 tot_key_num = part_offset_flat(subpart_id + 1) - subpart_offset; + int64 key_part_offset = subpart_offset * sizeof(K); + int64 value_part_offset = subpart_offset * value_unit_bytes; + + VLOG(1) << "dynamically load ev : " << name_string + << ", subpartid:" << loaded_parts[i] + << ", subpart_offset:" << subpart_offset + << ", partition_id:" << partition_id + << ", partition_num:" << partition_num + << ", keynum:" << tot_key_num; + + int64 tot_key_bytes_read(0); + int64 tot_value_bytes_read(0); + size_t key_bytes_read = 0; + size_t value_bytes_read = 0; + while(tot_key_num > 0) { + size_t read_key_num = std::min(std::min(buffer_size / sizeof(K), + buffer_size / value_unit_bytes), buffer_size / sizeof(int64)); + read_key_num = std::min((int64)read_key_num, tot_key_num); + reader->LookupSegmentOffset(tensor_key, + key_part_offset + tot_key_bytes_read, read_key_num * sizeof(K), + restore_buff.key_buffer, key_bytes_read); + + reader->LookupSegmentOffset(tensor_value, + value_part_offset + tot_value_bytes_read, + read_key_num * value_unit_bytes, restore_buff.value_buffer, + value_bytes_read); + + if (key_bytes_read > 0) { + read_key_num = key_bytes_read / sizeof(K); + VLOG(2) << "restore, read_key_num:" << read_key_num; + st = ev->Import(restore_buff, read_key_num, kSavedPartitionNum, + partition_id, partition_num, false, stream); + if (!st.ok()) { + LOG(FATAL) << "EV restoring fail:" << st.ToString(); + } + } + tot_key_num -= read_key_num; + tot_key_bytes_read += key_bytes_read; + tot_value_bytes_read += value_bytes_read; + + } + } + } + return Status::OK(); +} +#endif // TF_ENABLE_GPU_EV +#endif // GOOGLE_CUDA } // namespace tensorflow diff --git a/tensorflow/core/kernels/kv_variable_ops_gpu.cu.cc b/tensorflow/core/kernels/kv_variable_ops_gpu.cu.cc index 96fed8cc41a..daf274ea703 100644 --- a/tensorflow/core/kernels/kv_variable_ops_gpu.cu.cc +++ b/tensorflow/core/kernels/kv_variable_ops_gpu.cu.cc @@ -292,6 +292,59 @@ struct KvLookupCreateEmb { } }; +template +__global__ void kv_update_emb_kernel(const Key* key_first, + Value* default_v, + int64 dim, + int32* item_idxs, + int32 slot_idx, + Value** d_banks, + bool** d_flags, + int32 slot_num, + int32 default_v_num, + int32 bank_size) { + auto item_idx = blockIdx.x; + auto item_pos = item_idxs[item_idx]; + auto bank_idx = item_pos / bank_size; + auto offset_in_bank = item_pos % bank_size; + auto slot_offset = bank_idx * slot_num + slot_idx; + bool stored = d_flags[slot_offset][offset_in_bank]; + __syncthreads(); + if (stored == false) { + d_flags[slot_offset][offset_in_bank] = true; + for (auto id = threadIdx.x; id < dim; id += blockDim.x) { + int32 default_v_idx; + default_v_idx = item_idx % default_v_num; + d_banks[slot_offset][offset_in_bank * dim + id] = default_v[default_v_idx * dim + id]; + } + } +} + +template +struct KvUpdateEmb { + void operator()(const Key* key_first, + Value* default_v, + int64 dim, + int32* item_idxs, + int32 num_items, + int32 slot_idx, + int32 default_v_num, + Value** d_banks, + bool** d_flags, + int32 slot_num, + int32 bank_size, + cudaStream_t stream) { + auto const block_size = 256; + auto const grid_size = num_items; + TF_CHECK_OK(GpuLaunchKernel(kv_update_emb_kernel, + grid_size, block_size, 0, stream, + key_first, default_v, dim, + item_idxs, slot_idx, + d_banks, d_flags, + slot_num, default_v_num, bank_size)); +} +}; + template , @@ -430,6 +483,11 @@ template struct functor::KvEmbGetSnapshot; template struct functor::KvEmbGetSnapshot; template struct functor::KvEmbGetSnapshot; +template struct functor::KvUpdateEmb; +template struct functor::KvUpdateEmb; +template struct functor::KvUpdateEmb; +template struct functor::KvUpdateEmb; + } // namespace tensorflow #endif // TF_ENABLE_GPU_EV diff --git a/tensorflow/core/kernels/kv_variable_ops_gpu.h b/tensorflow/core/kernels/kv_variable_ops_gpu.h index 149c56b95e4..ed05ffebbd7 100644 --- a/tensorflow/core/kernels/kv_variable_ops_gpu.h +++ b/tensorflow/core/kernels/kv_variable_ops_gpu.h @@ -84,6 +84,22 @@ struct KvLookupCreateEmb { cudaStream_t stream); }; +template +struct KvUpdateEmb { + void operator()(const Key* key_first, + Value* default_v, + int64 dim, + int32* item_idxs, + int32 num_items, + int32 slot_idx, + int32 default_v_num, + Value** d_banks, + bool** d_flags, + int32 slot_num, + int32 bank_size, + cudaStream_t stream); +}; + template struct KvKeyGetSnapshot { void operator()(Key* key_first, @@ -200,13 +216,21 @@ class EmbeddingVarGPU : public ResourceBase { void GetSnapshot(K* keys, V* values, cudaStream_t stream) { int32* item_idxs = TypedAllocator::Allocate(alloc_, Size(), AllocationAttributes()); - functor::KvKeyGetSnapshot()(keys, item_idxs, emb_config_.emb_index, emb_config_.primary_emb_index, + K* keys_gpu = TypedAllocator::Allocate(alloc_, Size(), AllocationAttributes()); + V* values_gpu = TypedAllocator::Allocate(alloc_, Size() * ValueLen(), AllocationAttributes()); + + functor::KvKeyGetSnapshot()(keys_gpu, item_idxs, emb_config_.emb_index, emb_config_.primary_emb_index, kv_->d_existence_flag_ptrs, kv_->mem_bank_num, (emb_config_.block_num * (1 + emb_config_.slot_num)), kv_->initial_bank_size, kv_, Size(), stream); - functor::KvEmbGetSnapshot()(keys, values, -1, value_len_, item_idxs, Size(), emb_config_.emb_index, + functor::KvEmbGetSnapshot()(keys_gpu, values_gpu, -1, value_len_, item_idxs, Size(), emb_config_.emb_index, kv_->d_bank_ptrs, kv_->mem_bank_num, (emb_config_.block_num * (1 + emb_config_.slot_num)), kv_->initial_bank_size, stream); + cudaMemcpy(keys, keys_gpu, Size() * sizeof(K), cudaMemcpyDeviceToHost); + cudaMemcpy(values, values_gpu, Size() * ValueLen()* sizeof(V), cudaMemcpyDeviceToHost); + TypedAllocator::Deallocate(alloc_, item_idxs, Size()); + TypedAllocator::Deallocate(alloc_, keys_gpu, Size()); + TypedAllocator::Deallocate(alloc_, values_gpu, Size() * ValueLen()); } int64 Size() const { @@ -261,6 +285,40 @@ class EmbeddingVarGPU : public ResourceBase { return emb_config_.default_value_dim; } + Status Import(RestoreBuffer& restore_buff, + int64 key_num, + int bucket_num, + int64 partition_id, + int64 partition_num, + bool is_filter, + cudaStream_t stream) { + K* key_buff = (K*)restore_buff.key_buffer; + V* value_buff = (V*)restore_buff.value_buffer; + std::vector key_import(key_num); + std::vector value_import(key_num * value_len_); + for (auto i = 0; i < key_num; ++ i) { + if (*(key_buff + i) % bucket_num % partition_num != partition_id) { + LOG(INFO) << "skip EV key:" << *(key_buff + i); + continue; + } + key_import.emplace_back(*(key_buff + i)); + for (int j = 0; j < value_len_; j++) { + value_import.emplace_back(*(value_buff + i * value_len_ + j)); + } + } + int n = key_import.size(); + int32* item_idxs = TypedAllocator::Allocate(alloc_, n, AllocationAttributes()); + LookupOrCreateKey(key_import.data(), item_idxs, n, stream); + V* value_gpu = TypedAllocator::Allocate(alloc_, value_import.size(), AllocationAttributes()); + cudaMemcpy(value_gpu, value_import.data(), value_import.size() * sizeof(V), cudaMemcpyHostToDevice); + functor::KvUpdateEmb()(key_import.data(), value_gpu, value_len_, item_idxs, n, emb_config_.emb_index, key_import.size(), + kv_->d_bank_ptrs, kv_->d_existence_flag_ptrs, + (emb_config_.block_num * (1 + emb_config_.slot_num)), kv_->initial_bank_size, stream); + TypedAllocator::Deallocate(alloc_, item_idxs, n); + TypedAllocator::Deallocate(alloc_, value_gpu, value_import.size()); + return Status::OK(); + } + private: std::string name_; GPUHashTable* kv_; diff --git a/tensorflow/core/ops/kv_variable_ops.cc b/tensorflow/core/ops/kv_variable_ops.cc index d87b55dc396..736c7fa7165 100644 --- a/tensorflow/core/ops/kv_variable_ops.cc +++ b/tensorflow/core/ops/kv_variable_ops.cc @@ -466,10 +466,10 @@ REGISTER_OP("KvResourceExport") .Attr("Tvalues: type") .SetShapeFn([](InferenceContext* c) { ShapeHandle values = c->UnknownShape(); - TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 1, &values)); - ShapeHandle keys = c->UnknownShapeOfRank(2); - ShapeHandle versions = c->UnknownShapeOfRank(3); - ShapeHandle freqs = c->UnknownShapeOfRank(4); + TF_RETURN_IF_ERROR(c->WithRankAtLeast(values, 2, &values)); + ShapeHandle keys = c->UnknownShapeOfRank(1); + ShapeHandle versions = c->UnknownShapeOfRank(1); + ShapeHandle freqs = c->UnknownShapeOfRank(1); c->set_output(0, keys); c->set_output(1, values); c->set_output(2, versions); diff --git a/tensorflow/python/ops/kv_variable_ops.py b/tensorflow/python/ops/kv_variable_ops.py index 049839f4bc1..3ff848d0f6a 100644 --- a/tensorflow/python/ops/kv_variable_ops.py +++ b/tensorflow/python/ops/kv_variable_ops.py @@ -325,7 +325,7 @@ def _init_from_args(self, list=attr_value_pb2.AttrValue.ListValue( s=[compat.as_bytes("loc:@%s" % handle_name)])) with ops.get_default_graph()._attr_scope({"_class": attr}): - with ops.name_scope("Initializer"), ops.device(None): + with ops.name_scope("Initializer"): initial_value = ops.convert_to_tensor( initial_value(), name="initial_value", dtype=dtype) rank = initial_value.get_shape().rank - 1 @@ -903,6 +903,14 @@ def blocknum(self): def _dense_var_to_tensor(var, dtype=None, name=None, as_ref=False): return var._dense_var_to_tensor(dtype=dtype, name=name, as_ref=as_ref) # pylint: disable=protected-access +def identity(var): + if "GPU" in var.device: + with ops.device(var.device): + keys, values, versions, freqs = gen_kv_variable_ops.kv_resource_export(var._handle, Tkeys=var._invalid_key_type, Tvalues=var.dtype) + return [keys, values, versions, freqs] + else: + return var.handle + # Register a conversion function which reads the value of the variable, # allowing instances of the class to be used as tensors. diff --git a/tensorflow/python/training/saver.py b/tensorflow/python/training/saver.py index 07e3e4959ef..63d59f3a239 100644 --- a/tensorflow/python/training/saver.py +++ b/tensorflow/python/training/saver.py @@ -227,10 +227,16 @@ def save_op(self, filename_tensor, saveables): for saveable in saveables: if isinstance(saveable, BaseSaverBuilder.EmbeddingVariableSaveable): - tensor_names.append(saveable.name) - tensors.append(saveable.handle_op) - tensor_slices.append("") - ev_key_types.append(saveable.key_type) + if "GPU" in saveable.var.device: + for spec in saveable.specs: + tensor_names.append(spec.name) + tensors.append(spec.tensor) + tensor_slices.append(spec.slice_spec) + else: + tensor_names.append(saveable.name) + tensors.append(saveable.handle_op) + tensor_slices.append("") + ev_key_types.append(saveable.key_type) continue for spec in saveable.specs: tensor_names.append(spec.name) diff --git a/tensorflow/python/training/saving/saveable_object_util.py b/tensorflow/python/training/saving/saveable_object_util.py index d128bca583b..33f49793211 100644 --- a/tensorflow/python/training/saving/saveable_object_util.py +++ b/tensorflow/python/training/saving/saveable_object_util.py @@ -169,13 +169,21 @@ def f(): return array_ops.identity(x) return f #unused_tensor = _read_variable_closure(var) - unused_tensor = var.handle + #unused_tensor = var.handle + unused_tensor = kv_variable_ops.identity(var) specs = [] - specs.append(saveable_object.SaveSpec(unused_tensor, "", name + "-keys", dtype=self.key_type, device=var.device)) - specs.append(saveable_object.SaveSpec(unused_tensor, "", name + "-values", dtype=dtypes.float32, device=var.device)) - specs.append(saveable_object.SaveSpec(unused_tensor, "", name + "-versions", dtype=dtypes.int64, device=var.device)) - specs.append(saveable_object.SaveSpec(unused_tensor, "", name + "-freqs", dtype=dtypes.int64, device=var.device)) + if isinstance(unused_tensor, list): + specs.append(saveable_object.SaveSpec(unused_tensor[0], "", name + "-keys", dtype=self.key_type, device=unused_tensor[0].device)) + specs.append(saveable_object.SaveSpec(unused_tensor[1], "", name + "-values", dtype=dtypes.float32, device=unused_tensor[1].device)) + specs.append(saveable_object.SaveSpec(unused_tensor[2], "", name + "-versions", dtype=dtypes.int64, device=unused_tensor[2].device)) + specs.append(saveable_object.SaveSpec(unused_tensor[3], "", name + "-freqs", dtype=dtypes.int64, device=unused_tensor[2].device)) + else: + specs.append(saveable_object.SaveSpec(unused_tensor, "", name + "-keys", dtype=self.key_type, device=var.device)) + specs.append(saveable_object.SaveSpec(unused_tensor, "", name + "-values", dtype=dtypes.float32, device=var.device)) + specs.append(saveable_object.SaveSpec(unused_tensor, "", name + "-versions", dtype=dtypes.int64, device=var.device)) + specs.append(saveable_object.SaveSpec(unused_tensor, "", name + "-freqs", dtype=dtypes.int64, device=var.device)) + # pylint: disable=protected-access super(EmbeddingVariableSaveable, self).__init__(var, specs, name) self.is_sparse = var._is_sparse