Skip to content

Commit

Permalink
Dev construct and infer op (#3120)
Browse files Browse the repository at this point in the history
* ConstructAndInferOp

* reformat oneflow_internal.h

Former-commit-id: 5a6a091
  • Loading branch information
lixinqi authored Jul 2, 2020
1 parent d09c6e7 commit e87f38c
Show file tree
Hide file tree
Showing 8 changed files with 162 additions and 32 deletions.
4 changes: 3 additions & 1 deletion oneflow/core/common/maybe.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,9 @@ class Maybe<
return str;
}

T GetDataAndSerializedErrorProto(std::string* error_str, const T& default_for_error) const {
template<typename Type = T>
Type GetDataAndSerializedErrorProto(std::string* error_str, const Type& default_for_error) const {
static_assert(std::is_same<T, Type>::value, "error type for argument 1");
if (IsOk()) {
google::protobuf::TextFormat::PrintToString(ErrorProto(), error_str);
return *Data_YouAreNotAllowedToCallThisFuncOutsideThisFile();
Expand Down
6 changes: 3 additions & 3 deletions oneflow/core/graph/op_graph.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -472,11 +472,11 @@ void OpGraph::InferOpNodeSbpSignature(OpNode* op_node, const SbpSignature& sbp_s
ibn2sbp_infer_hint.emplace(ibn,
SbpInferHint(parallel_desc, logical_blob_desc, sbp, batch_axis));
}
const auto& BatchAxis4Lbi = [&](const LogicalBlobId& lbi) -> Maybe<const OptInt64*> {
return op_node->BatchAxis4Lbi(lbi);
const auto& BatchAxis4BnInOp = [&](const std::string& bn_in_op) -> Maybe<const OptInt64*> {
return op_node->op().BatchAxis4BnInOp(bn_in_op);
};
CHECK_JUST(InferOpSbpSignature(op_node->mut_op(), sbp_sig_conf, op_node->parallel_desc(),
ibn2sbp_infer_hint, BatchAxis4Lbi));
ibn2sbp_infer_hint, BatchAxis4BnInOp));
op_node->InitLbi2SbpParallel();
}

Expand Down
9 changes: 4 additions & 5 deletions oneflow/core/job/job_build_and_infer_ctx.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -197,9 +197,8 @@ Maybe<void> JobBuildAndInferCtx::InferMirroredSignature(Operator* op,
Maybe<void> JobBuildAndInferCtx::InferOpOutSbpParallel(Operator* op,
const SbpSignature& sbp_sig_conf,
const ParallelDesc& parallel_desc) {
const auto& BatchAxis4Lbi = [&](const LogicalBlobId& lbi) -> Maybe<const OptInt64*> {
const auto& op = *JUST(Op4OpName(lbi.op_name()));
return op.BatchAxis4BnInOp(*JUST(op.obn4lbi(lbi)));
const auto& BatchAxis4BnInOp = [&](const std::string& bn_in_op) -> Maybe<const OptInt64*> {
return op->BatchAxis4BnInOp(bn_in_op);
};
HashMap<std::string, SbpInferHint> ibn2sbp_infer_hint;
for (const std::string& ibn : op->input_bns()) {
Expand All @@ -216,11 +215,11 @@ Maybe<void> JobBuildAndInferCtx::InferOpOutSbpParallel(Operator* op,
<< "when infer op_name: " << op->op_name() << " consumed op_name: " << lbi.op_name()
<< " blob_name: " << lbi.blob_name() << " not infer split axis";
const SbpParallel* sbp_parallel = &lbi2sbp_parallel_from_producer_view_.at(lbi);
const OptInt64* batch_axis = JUST(BatchAxis4Lbi(lbi));
const OptInt64* batch_axis = JUST(BatchAxis4BnInOp(ibn));
ibn2sbp_infer_hint.emplace(ibn, SbpInferHint(pd, logical_blob_desc, sbp_parallel, batch_axis));
}

JUST(InferOpSbpSignature(op, sbp_sig_conf, parallel_desc, ibn2sbp_infer_hint, BatchAxis4Lbi));
JUST(InferOpSbpSignature(op, sbp_sig_conf, parallel_desc, ibn2sbp_infer_hint, BatchAxis4BnInOp));

const auto& bn2sbp_parallel = JUST(op->sbp_signature())->bn_in_op2sbp_parallel();
for (const auto& obn : op->output_bns()) {
Expand Down
7 changes: 7 additions & 0 deletions oneflow/core/operator/op_attribute.proto
Original file line number Diff line number Diff line change
Expand Up @@ -27,3 +27,10 @@ message OpAttribute {
optional BlobDescSignature logical_blob_desc_signature = 106;
optional BatchAxisSignature batch_axis_signature = 107;
}

message UpstreamSignature {
optional SbpSignature sbp_signature = 1;
optional MirroredSignature mirrored_signature = 2;
optional BlobDescSignature logical_blob_desc_signature = 3;
optional BatchAxisSignature batch_axis_signature = 4;
}
122 changes: 119 additions & 3 deletions oneflow/core/operator/operator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,7 @@ Maybe<bool> ParseDisableBoxingFlag(const std::string& lbn_with_hint, bool* disab
Maybe<void> InferOpSbpSignature(
Operator* op, const SbpSignature& sbp_sig_conf, const ParallelDesc& parallel_desc,
const HashMap<std::string, SbpInferHint>& ibn2sbp_infer_hint,
std::function<Maybe<const OptInt64*>(const LogicalBlobId&)> BatchAxis4Lbi) {
std::function<Maybe<const OptInt64*>(const std::string&)> BatchAxis4BnInOp) {
auto SbpInferHint4Ibn = [&](const std::string& ibn) -> Maybe<const SbpInferHint*> {
auto it = ibn2sbp_infer_hint.find(ibn);
if (it == ibn2sbp_infer_hint.end()) {
Expand All @@ -644,14 +644,14 @@ Maybe<void> InferOpSbpSignature(
std::function<int32_t(const SbpSignature&)> CalcOrderValue4SbpSig;
auto OrderValue4HasBatchAxis = [&](const std::string& bn,
const SbpParallel& sbp_parallel) -> int32_t {
const auto& batch_axis = *CHECK_JUST(BatchAxis4Lbi(op->BnInOp2Lbi(bn)));
const auto& batch_axis = *CHECK_JUST(BatchAxis4BnInOp(bn));
return -1
* (batch_axis.has_value() && sbp_parallel.has_split_parallel()
&& sbp_parallel.split_parallel().axis() == batch_axis.value());
};
auto OrderValue4HasNoBatchAxis = [&](const std::string& ibn,
const SbpParallel& sbp_parallel) -> int32_t {
const auto& batch_axis = *CHECK_JUST(BatchAxis4Lbi(op->BnInOp2Lbi(ibn)));
const auto& batch_axis = *CHECK_JUST(BatchAxis4BnInOp(ibn));
return -2
* (batch_axis.has_value() == false
&& CHECK_JUST(SbpInferHint4Ibn(ibn))->sbp_parallel().has_split_parallel() == false
Expand Down Expand Up @@ -724,4 +724,120 @@ bool operator==(const OperatorConf& lhs, const OperatorConf& rhs) {
return PbMd().Equals(lhs, rhs);
}

namespace {

Maybe<void> InferOpOutBlobDescs(
Operator* op, const std::function<BlobDesc*(const std::string&)>& BlobDesc4BnInOp) {
ParallelContext parallel_ctx;
parallel_ctx.set_parallel_id(0);
parallel_ctx.set_parallel_num(1);
JUST(op->InferOutBlobDescsIf(BlobDesc4BnInOp, &parallel_ctx, CHECK_JUST(op->sbp_signature()),
[](OpContext*) {}));
return Maybe<void>::Ok();
}

Maybe<void> InferOpOutSbpParallel(
Operator* op, const UpstreamSignature& upstream_signature,
const std::function<const BlobDesc&(const std::string&)>& ConstBlobDesc4Ibn,
const SbpSignature& sbp_sig_conf, const ParallelDesc& parallel_desc) {
const auto& BatchAxis4BnInOp = [&](const std::string& bn_in_op) -> Maybe<const OptInt64*> {
return op->BatchAxis4BnInOp(bn_in_op);
};
const auto& SbpParallel4Ibn = [&](const std::string& ibn) -> const SbpParallel* {
const auto& map = upstream_signature.sbp_signature().bn_in_op2sbp_parallel();
return &map.at(ibn);
};
HashMap<std::string, SbpInferHint> ibn2sbp_infer_hint;
for (const std::string& ibn : op->input_bns()) {
const ParallelDesc* pd = &parallel_desc;
const BlobDesc* logical_blob_desc = &ConstBlobDesc4Ibn(ibn);
const SbpParallel* sbp_parallel = SbpParallel4Ibn(ibn);
const OptInt64* batch_axis = JUST(BatchAxis4BnInOp(ibn));
ibn2sbp_infer_hint.emplace(ibn, SbpInferHint(pd, logical_blob_desc, sbp_parallel, batch_axis));
}

JUST(InferOpSbpSignature(op, sbp_sig_conf, parallel_desc, ibn2sbp_infer_hint, BatchAxis4BnInOp));
return Maybe<void>::Ok();
}

Maybe<void> InferMirroredSignature(Operator* op, const UpstreamSignature& upstream_signature,
bool is_mirrored, const ParallelDesc& parallel_desc) {
HashMap<std::string, MirroredSigInferHint> ibn2mirrored_sig_infer_hint;
for (const std::string& ibn : op->input_bns()) {
const auto& map = upstream_signature.mirrored_signature().bn_in_op2opt_mirrored_parallel();
const auto& opt_mirrored_parallel = map.at(ibn);
ibn2mirrored_sig_infer_hint.emplace(
ibn, MirroredSigInferHint(&parallel_desc, opt_mirrored_parallel.has_mirrored_parallel()));
}
const auto& MirroredSigInferHint4Ibn =
[&](const std::string& ibn) -> Maybe<const MirroredSigInferHint*> {
const auto& iter = ibn2mirrored_sig_infer_hint.find(ibn);
CHECK_OR_RETURN(iter != ibn2mirrored_sig_infer_hint.end())
<< "input blob not found. ibn: " << ibn;
return &iter->second;
};
JUST(op->InferMirroredSignatureIf(MirroredSigInferHint4Ibn, is_mirrored, parallel_desc));
return Maybe<void>::Ok();
}

Maybe<void> CheckOpInputSignature(const Operator& op, const UpstreamSignature& upstream_signature) {
for (const auto& ibn : op.input_bns()) {
{
const auto& map = upstream_signature.logical_blob_desc_signature().bn_in_op2blob_desc();
CHECK_OR_RETURN(map.find(ibn) != map.end());
}
{
const auto& map = upstream_signature.sbp_signature().bn_in_op2sbp_parallel();
CHECK_OR_RETURN(map.find(ibn) != map.end());
}
{
const auto& map = upstream_signature.mirrored_signature().bn_in_op2opt_mirrored_parallel();
CHECK_OR_RETURN(map.find(ibn) != map.end());
}
{
const auto& map = upstream_signature.batch_axis_signature().bn_in_op2batch_axis();
CHECK_OR_RETURN(map.find(ibn) != map.end());
}
}
return Maybe<void>::Ok();
}

} // namespace

Maybe<Operator> ConstructAndInferOp(const OperatorConf& op_conf,
const UpstreamSignature& upstream_signature,
const ParallelConf& parallel_conf, bool is_mirrored,
const JobDesc& job_desc) {
const auto& op = ConstructOp(op_conf, &job_desc);
JUST(CheckOpInputSignature(*op, upstream_signature));
ParallelDesc parallel_desc(parallel_conf);
HashMap<std::string, std::unique_ptr<BlobDesc>> bn_in_op2blob_desc;
for (const auto& ibn : op->input_bns()) {
const auto& map = upstream_signature.logical_blob_desc_signature().bn_in_op2blob_desc();
bn_in_op2blob_desc[ibn].reset(new BlobDesc(map.at(ibn)));
}
const auto& ConstBlobDesc4Ibn = [&](const std::string& ibn) -> const BlobDesc& {
return *bn_in_op2blob_desc.at(ibn);
};
const auto& BatchAxis4Ibn = [&](const std::string& ibn) -> Maybe<const OptInt64*> {
const auto& map = upstream_signature.batch_axis_signature().bn_in_op2batch_axis();
const auto& iter = map.find(ibn);
CHECK_OR_RETURN(iter != map.end());
return &iter->second;
};
JUST(op->InferBatchAxisIf(ConstBlobDesc4Ibn, BatchAxis4Ibn));
JUST(InferMirroredSignature(op.get(), upstream_signature, is_mirrored, parallel_desc));
SbpSignature sbp_sig_conf;
JUST(InferOpOutSbpParallel(op.get(), upstream_signature, ConstBlobDesc4Ibn, sbp_sig_conf,
parallel_desc));
const auto& BlobDesc4BnInOp = [&](const std::string& bn_in_op) -> BlobDesc* {
if (!bn_in_op2blob_desc[bn_in_op]) {
bn_in_op2blob_desc[bn_in_op].reset(new BlobDesc(DataType::kInvalidDataType));
}
return bn_in_op2blob_desc[bn_in_op].get();
};
JUST(InferOpOutBlobDescs(op.get(), BlobDesc4BnInOp));
return op;
}

} // namespace oneflow
7 changes: 6 additions & 1 deletion oneflow/core/operator/operator.h
Original file line number Diff line number Diff line change
Expand Up @@ -410,7 +410,7 @@ Maybe<bool> ParseDisableBoxingFlag(const std::string& lbn_with_hint, bool* disab
Maybe<void> InferOpSbpSignature(
Operator* op, const SbpSignature& sbp_sig_conf, const ParallelDesc& parallel_desc,
const HashMap<std::string, SbpInferHint>& ibn2sbp_infer_hint,
std::function<Maybe<const OptInt64*>(const LogicalBlobId&)> BatchAxis4Lbi);
std::function<Maybe<const OptInt64*>(const std::string&)> BatchAxis4BnInOp);

std::string GetInputLbnInOpCustomizedConf(const PbMessage& msg,
const std::string& fd_name_may_have_idx);
Expand All @@ -419,6 +419,11 @@ void ReplaceInputLbnInOpCustomizedConf(PbMessage* msg, const std::string& fd_nam

bool operator==(const OperatorConf& lhs, const OperatorConf& rhs);

Maybe<Operator> ConstructAndInferOp(const OperatorConf& op_conf,
const UpstreamSignature& upstream_signature,
const ParallelConf& parallel_conf, bool is_mirrored,
const JobDesc& job_desc);

} // namespace oneflow

namespace std {
Expand Down
26 changes: 13 additions & 13 deletions oneflow/python/job_build_and_infer_if.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,8 +12,8 @@ void JobBuildAndInferCtx_Open(const std::string& job_name, std::string* error_st
}

std::string JobBuildAndInferCtx_GetCurrentJobName(std::string* error_str) {
return oneflow::JobBuildAndInferCtx_GetCurrentJobName().GetDataAndSerializedErrorProto(error_str,
"");
return oneflow::JobBuildAndInferCtx_GetCurrentJobName().GetDataAndSerializedErrorProto(
error_str, std::string(""));
}

void JobBuildAndInferCtx_Close(std::string* error_str) {
Expand Down Expand Up @@ -44,15 +44,15 @@ std::string CurJobBuildAndInferCtx_AddAndInferMirroredOp(
std::string* error_str) {
return oneflow::CurJobBuildAndInferCtx_AddAndInferMirroredOp(serialized_op_conf,
serialized_parallel_conf)
.GetDataAndSerializedErrorProto(error_str, "");
.GetDataAndSerializedErrorProto(error_str, std::string(""));
}

std::string CurJobBuildAndInferCtx_AddAndInferConsistentOp(
const std::string& serialized_op_conf, const std::string& serialized_parallel_conf,
std::string* error_str) {
return oneflow::CurJobBuildAndInferCtx_AddAndInferConsistentOp(serialized_op_conf,
serialized_parallel_conf)
.GetDataAndSerializedErrorProto(error_str, "");
.GetDataAndSerializedErrorProto(error_str, std::string(""));
}

void CurJobBuildAndInferCtx_AddLossLogicalBlobName(const std::string& lbn, std::string* error_str) {
Expand All @@ -70,7 +70,7 @@ std::string JobBuildAndInferCtx_GetSerializedIdListAsStaticShape(const std::stri
const std::string& lbn,
std::string* error_str) {
return oneflow::JobBuildAndInferCtx_GetSerializedIdListAsStaticShape(job_name, lbn)
.GetDataAndSerializedErrorProto(error_str, "");
.GetDataAndSerializedErrorProto(error_str, std::string(""));
}

long long JobBuildAndInferCtx_GetDataType(const std::string& job_name, const std::string& lbn,
Expand Down Expand Up @@ -100,20 +100,20 @@ bool JobBuildAndInferCtx_IsTensorList(const std::string& job_name, const std::st
std::string JobBuildAndInferCtx_GetBatchAxis(const std::string& job_name, const std::string& lbn,
std::string* error_str) {
return oneflow::JobBuildAndInferCtx_GetBatchAxis(job_name, lbn)
.GetDataAndSerializedErrorProto(error_str, "");
.GetDataAndSerializedErrorProto(error_str, std::string(""));
}

std::string JobBuildAndInferCtx_GetSplitAxisFromProducerView(const std::string& job_name,
const std::string& lbn,
std::string* error_str) {
return oneflow::JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn)
.GetDataAndSerializedErrorProto(error_str, "");
.GetDataAndSerializedErrorProto(error_str, std::string(""));
}

std::string JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView(
const std::string& job_name, const std::string& lbn, std::string* error_str) {
return oneflow::JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView(job_name, lbn)
.GetDataAndSerializedErrorProto(error_str, "");
.GetDataAndSerializedErrorProto(error_str, std::string(""));
}

bool JobBuildAndInferCtx_IsMirroredBlob(const std::string& job_name, const std::string& lbn,
Expand All @@ -132,13 +132,13 @@ std::string JobBuildAndInferCtx_MirroredBlobGetSerializedSubLbi(const std::strin
const std::string& lbn, int index,
std::string* error_str) {
return oneflow::JobBuildAndInferCtx_MirroredBlobGetSubLbi(job_name, lbn, index)
.GetDataAndSerializedErrorProto(error_str, "");
.GetDataAndSerializedErrorProto(error_str, std::string(""));
}

std::string JobBuildAndInferCtx_MirroredBlobGetSerializedIdListAsStaticShape(
const std::string& job_name, const std::string& lbn, std::string* error_str) {
return oneflow::JobBuildAndInferCtx_MirroredBlobGetSerializedIdListAsStaticShape(job_name, lbn)
.GetDataAndSerializedErrorProto(error_str, "");
.GetDataAndSerializedErrorProto(error_str, std::string(""));
}

long long JobBuildAndInferCtx_MirroredBlobGetDataType(const std::string& job_name,
Expand All @@ -164,18 +164,18 @@ std::string JobBuildAndInferCtx_MirroredBlobGetBatchAxis(const std::string& job_
const std::string& lbn,
std::string* error_str) {
return oneflow::JobBuildAndInferCtx_MirroredBlobGetBatchAxis(job_name, lbn)
.GetDataAndSerializedErrorProto(error_str, "");
.GetDataAndSerializedErrorProto(error_str, std::string(""));
}

std::string JobBuildAndInferCtx_MirroredBlobGetSplitAxisFromProducerView(
const std::string& job_name, const std::string& lbn, std::string* error_str) {
return oneflow::JobBuildAndInferCtx_MirroredBlobGetSplitAxisFromProducerView(job_name, lbn)
.GetDataAndSerializedErrorProto(error_str, "");
.GetDataAndSerializedErrorProto(error_str, std::string(""));
}

std::string JobBuildAndInferCtx_MirroredBlobGetSerializedParallelConfFromProducerView(
const std::string& job_name, const std::string& lbn, std::string* error_str) {
return oneflow::JobBuildAndInferCtx_MirroredBlobGetSerializedParallelConfFromProducerView(
job_name, lbn)
.GetDataAndSerializedErrorProto(error_str, "");
.GetDataAndSerializedErrorProto(error_str, std::string(""));
}
13 changes: 7 additions & 6 deletions oneflow/python/oneflow_internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ bool IsOpTypeNameCpuSupportOnly(const std::string& op_type_name, std::string* er
}

std::string CurrentResource(std::string* error_str) {
return oneflow::CurrentResource().GetDataAndSerializedErrorProto(error_str, "");
return oneflow::CurrentResource().GetDataAndSerializedErrorProto(error_str, std::string(""));
}

void EnableEagerExecution(bool enable_eager_execution) {
Expand Down Expand Up @@ -66,15 +66,16 @@ void StopGlobalSession(std::string* error_str) {
}

std::string GetSerializedInterUserJobInfo(std::string* error_str) {
return oneflow::GetSerializedInterUserJobInfo().GetDataAndSerializedErrorProto(error_str, "");
return oneflow::GetSerializedInterUserJobInfo().GetDataAndSerializedErrorProto(error_str,
std::string(""));
}

std::string GetSerializedJobSet(std::string* error_str) {
return oneflow::GetSerializedJobSet().GetDataAndSerializedErrorProto(error_str, "");
return oneflow::GetSerializedJobSet().GetDataAndSerializedErrorProto(error_str, std::string(""));
}

std::string GetFunctionConfigDef(std::string* error_str) {
return oneflow::GetFunctionConfigDef().GetDataAndSerializedErrorProto(error_str, "");
return oneflow::GetFunctionConfigDef().GetDataAndSerializedErrorProto(error_str, std::string(""));
}

void LaunchJob(const std::shared_ptr<oneflow::ForeignJobInstance>& cb, std::string* error_str) {
Expand All @@ -90,13 +91,13 @@ long DeviceType4DeviceTag(const std::string& device_tag, std::string* error_str)
std::string GetMachine2DeviceIdListOFRecordFromParallelConf(const std::string& parallel_conf,
std::string* error_str) {
return oneflow::GetSerializedMachineId2DeviceIdListOFRecord(parallel_conf)
.GetDataAndSerializedErrorProto(error_str, "");
.GetDataAndSerializedErrorProto(error_str, std::string(""));
}

std::string CheckAndCompleteUserOpConf(const std::string& serialized_op_conf,
std::string* error_str) {
return oneflow::CheckAndCompleteUserOpConf(serialized_op_conf)
.GetDataAndSerializedErrorProto(error_str, "");
.GetDataAndSerializedErrorProto(error_str, std::string(""));
}

long CurrentMachineId(std::string* error_str) {
Expand Down

0 comments on commit e87f38c

Please sign in to comment.