forked from pytorch/pytorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
net_async_task_future.cc
111 lines (93 loc) · 2.79 KB
/
net_async_task_future.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
#include "caffe2/core/net_async_task_future.h"
#include "c10/util/Logging.h"
#include "caffe2/core/common.h"
namespace caffe2 {
AsyncTaskFuture::AsyncTaskFuture() : completed_(false), failed_(false) {}
AsyncTaskFuture::AsyncTaskFuture(const std::vector<AsyncTaskFuture*>& futures)
: completed_(false), failed_(false) {
if (futures.size() > 1) {
parent_counter_ = std::make_unique<ParentCounter>(futures.size());
for (auto future : futures) {
future->SetCallback([this](const AsyncTaskFuture* f) {
if (f->IsFailed()) {
std::unique_lock<std::mutex> lock(parent_counter_->err_mutex);
if (parent_counter_->parent_failed) {
parent_counter_->err_msg += ", " + f->ErrorMessage();
} else {
parent_counter_->parent_failed = true;
parent_counter_->err_msg = f->ErrorMessage();
}
}
int count = --parent_counter_->parent_count;
if (count == 0) {
// thread safe to use parent_counter here
if (!parent_counter_->parent_failed) {
SetCompleted();
} else {
SetCompleted(parent_counter_->err_msg.c_str());
}
}
});
}
} else {
CAFFE_ENFORCE_EQ(futures.size(), (size_t)1);
auto future = futures.back();
future->SetCallback([this](const AsyncTaskFuture* f) {
if (!f->IsFailed()) {
SetCompleted();
} else {
SetCompleted(f->ErrorMessage().c_str());
}
});
}
}
bool AsyncTaskFuture::IsCompleted() const {
return completed_;
}
bool AsyncTaskFuture::IsFailed() const {
return failed_;
}
std::string AsyncTaskFuture::ErrorMessage() const {
return err_msg_;
}
void AsyncTaskFuture::Wait() const {
std::unique_lock<std::mutex> lock(mutex_);
while (!completed_) {
cv_completed_.wait(lock);
}
}
void AsyncTaskFuture::SetCallback(
std::function<void(const AsyncTaskFuture*)> callback) {
std::unique_lock<std::mutex> lock(mutex_);
callbacks_.push_back(callback);
if (completed_) {
callback(this);
}
}
void AsyncTaskFuture::SetCompleted(const char* err_msg) {
std::unique_lock<std::mutex> lock(mutex_);
CAFFE_ENFORCE(!completed_, "Calling SetCompleted on a completed future");
completed_ = true;
if (err_msg) {
failed_ = true;
err_msg_ = err_msg;
}
for (auto& callback : callbacks_) {
callback(this);
}
cv_completed_.notify_all();
}
// ResetState is called on a completed future,
// does not reset callbacks to keep task graph structure
void AsyncTaskFuture::ResetState() {
std::unique_lock<std::mutex> lock(mutex_);
if (parent_counter_) {
parent_counter_->Reset();
}
completed_ = false;
failed_ = false;
err_msg_ = "";
}
// NOLINTNEXTLINE(modernize-use-equals-default)
AsyncTaskFuture::~AsyncTaskFuture() {}
} // namespace caffe2