Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

balance dataset bugfix #28

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 13 additions & 4 deletions core/kernels/data/balance_dataset_ops.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
#include "core/utility/semaphore.h"

#include <brpc/server.h>
#include <butil/rand_util.h>

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
Expand Down Expand Up @@ -86,7 +85,7 @@ void BalanceInputDataInfo::ProcessBrpcDatasetPullReq(const DatasetPullRequest* r
CHECK(iter != op_elements_.end()) << "balance_handle " << balance_handle << " not registered.";
auto* elements = iter->second;
std::vector<Tensor> tensors;
if (elements->get(&tensors)) {
if (elements->get_wait(&tensors)) {
VariantTensorData variant_tensor;
{
for (auto& element : tensors) {
Expand Down Expand Up @@ -225,6 +224,13 @@ class BalanceDatasetOp::Dataset : public DatasetBase {
BufferQueueWithLock* q = data_info->op_elements_[dataset()->balance_handle_];
if (q->empty() && *end_of_sequence) {
GetDataFromBrpcInternal(end_of_sequence, out_tensors);

if (*end_of_sequence) {
LOG(INFO) << "Shard [" << PsCluster::Instance()->Rank()
<< "] consumed all data, total spend "
<< data_info->TimerElapsedInSecond() << " seconds.";
}

return Status::OK();
}

Expand Down Expand Up @@ -280,14 +286,17 @@ class BalanceDatasetOp::Dataset : public DatasetBase {

auto* data_info = BalanceInputDataInfo::Instance();
BufferQueueWithLock* q = data_info->op_elements_[dataset()->balance_handle_];
while (!q->buffer_full() && !*end_of_sequence) {
std::vector<std::vector<Tensor> > inputs;
size_t fill_count = q->fill_count();
for (size_t i = 0; i < fill_count && !*end_of_sequence; ++i) {
std::vector<Tensor> input_vec;
TF_RETURN_IF_ERROR(
input_impl_->GetNext(ctx, &input_vec, end_of_sequence));
if (!*end_of_sequence) {
q->put(std::move(input_vec));
inputs.emplace_back(std::move(input_vec));
}
}
q->put(std::move(inputs));

return Status::OK();
}
Expand Down
42 changes: 38 additions & 4 deletions core/kernels/data/balance_dataset_ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,10 @@
#include <vector>
#include <mutex>
#include <queue>
#include <condition_variable>
#include <chrono>

#include <butil/time.h>

#include "core/ps/ps_server_interface.h"
#include "core/ps/ps_cluster.h"
Expand Down Expand Up @@ -49,21 +53,30 @@ class BufferQueueWithLock {
void put(std::vector<Tensor>&& element) {
const std::lock_guard<std::mutex> lock(mu_);
elements_.emplace(element);
cv_.notify_one();
}

void put(std::vector<std::vector<Tensor> >&& elements) {
const std::lock_guard<std::mutex> lock(mu_);
for (auto&& v : elements) {
elements_.emplace(v);
}
cv_.notify_all();

}

bool empty() {
const std::lock_guard<std::mutex> lock(mu_);
return elements_.empty();
}

bool buffer_full() {
size_t fill_count() {
const std::lock_guard<std::mutex> lock(mu_);
return elements_.size() > buffer_size_;
return buffer_size_ - elements_.size();
}

bool get(std::vector<Tensor>* tensors) {
const std::lock_guard<std::mutex> lock(mu_);

if (empty_unlock()) {
return false;
}
Expand All @@ -73,6 +86,17 @@ class BufferQueueWithLock {
return true;
}

bool get_wait(std::vector<Tensor>* tensors) {
{
std::unique_lock<std::mutex> lock(mu_);
if (empty_unlock()) {
cv_.wait_for(lock, std::chrono::seconds(timeout_s_));
}
}

return get(tensors);
}

void pop() {
const std::lock_guard<std::mutex> lock(mu_);
elements_.pop();
Expand All @@ -90,6 +114,8 @@ class BufferQueueWithLock {

private:
size_t buffer_size_ = 100;
int64_t timeout_s_ = 10;
std::condition_variable cv_;
std::mutex mu_;
std::queue<std::vector<Tensor> > elements_;
};
Expand All @@ -104,7 +130,6 @@ class BalanceInputDataInfo {
uint32_t Register(BufferQueueWithLock* elements) {
const std::lock_guard<std::mutex> lock(mu_);
uint32_t handle = op_elements_.size();
// LOG(INFO) << "Register:" << handle << " pid:" << std::this_thread::get_id();
op_elements_[handle] = elements;
return handle;
}
Expand All @@ -118,6 +143,8 @@ class BalanceInputDataInfo {

finished_ = false;

timer_.start();

return 0;
}

Expand All @@ -144,13 +171,20 @@ class BalanceInputDataInfo {

void CopyDataToBuffer(const tensornet::DatasetPullResponse* resp, uint32_t balance_handle);

double TimerElapsedInSecond() {
timer_.stop();
return timer_.s_elapsed();
}

public:
std::mutex remaining_mu_;
std::set<uint32_t> remaining_shards_;

std::mutex mu_;
bool finished_ = false;
std::map<uint32_t, BufferQueueWithLock*> op_elements_;

butil::Timer timer_;
};

} // namespace tensorflow
Expand Down