From e87f38c0897ff75b91cb2fdb68230c07eaf5c4b3 Mon Sep 17 00:00:00 2001 From: Li Xinqi Date: Thu, 2 Jul 2020 14:03:50 +0800 Subject: [PATCH] Dev construct and infer op (#3120) * ConstructAndInferOp * reformat oneflow_internal.h Former-commit-id: 5a6a0912dfaaaafdf297dac6ad659384acfcb629 --- oneflow/core/common/maybe.h | 4 +- oneflow/core/graph/op_graph.cpp | 6 +- oneflow/core/job/job_build_and_infer_ctx.cpp | 9 +- oneflow/core/operator/op_attribute.proto | 7 ++ oneflow/core/operator/operator.cpp | 122 ++++++++++++++++++- oneflow/core/operator/operator.h | 7 +- oneflow/python/job_build_and_infer_if.h | 26 ++-- oneflow/python/oneflow_internal.h | 13 +- 8 files changed, 162 insertions(+), 32 deletions(-) diff --git a/oneflow/core/common/maybe.h b/oneflow/core/common/maybe.h index b1dffa24a81..e45c5a6cd95 100644 --- a/oneflow/core/common/maybe.h +++ b/oneflow/core/common/maybe.h @@ -40,7 +40,9 @@ class Maybe< return str; } - T GetDataAndSerializedErrorProto(std::string* error_str, const T& default_for_error) const { + template + Type GetDataAndSerializedErrorProto(std::string* error_str, const Type& default_for_error) const { + static_assert(std::is_same::value, "error type for argument 1"); if (IsOk()) { google::protobuf::TextFormat::PrintToString(ErrorProto(), error_str); return *Data_YouAreNotAllowedToCallThisFuncOutsideThisFile(); diff --git a/oneflow/core/graph/op_graph.cpp b/oneflow/core/graph/op_graph.cpp index e75b0ff10b0..96af784c8d3 100644 --- a/oneflow/core/graph/op_graph.cpp +++ b/oneflow/core/graph/op_graph.cpp @@ -472,11 +472,11 @@ void OpGraph::InferOpNodeSbpSignature(OpNode* op_node, const SbpSignature& sbp_s ibn2sbp_infer_hint.emplace(ibn, SbpInferHint(parallel_desc, logical_blob_desc, sbp, batch_axis)); } - const auto& BatchAxis4Lbi = [&](const LogicalBlobId& lbi) -> Maybe { - return op_node->BatchAxis4Lbi(lbi); + const auto& BatchAxis4BnInOp = [&](const std::string& bn_in_op) -> Maybe { + return op_node->op().BatchAxis4BnInOp(bn_in_op); }; CHECK_JUST(InferOpSbpSignature(op_node->mut_op(), sbp_sig_conf, op_node->parallel_desc(), - ibn2sbp_infer_hint, BatchAxis4Lbi)); + ibn2sbp_infer_hint, BatchAxis4BnInOp)); op_node->InitLbi2SbpParallel(); } diff --git a/oneflow/core/job/job_build_and_infer_ctx.cpp b/oneflow/core/job/job_build_and_infer_ctx.cpp index 2d46dd701f2..2a3e7bfcf3e 100644 --- a/oneflow/core/job/job_build_and_infer_ctx.cpp +++ b/oneflow/core/job/job_build_and_infer_ctx.cpp @@ -197,9 +197,8 @@ Maybe JobBuildAndInferCtx::InferMirroredSignature(Operator* op, Maybe JobBuildAndInferCtx::InferOpOutSbpParallel(Operator* op, const SbpSignature& sbp_sig_conf, const ParallelDesc& parallel_desc) { - const auto& BatchAxis4Lbi = [&](const LogicalBlobId& lbi) -> Maybe { - const auto& op = *JUST(Op4OpName(lbi.op_name())); - return op.BatchAxis4BnInOp(*JUST(op.obn4lbi(lbi))); + const auto& BatchAxis4BnInOp = [&](const std::string& bn_in_op) -> Maybe { + return op->BatchAxis4BnInOp(bn_in_op); }; HashMap ibn2sbp_infer_hint; for (const std::string& ibn : op->input_bns()) { @@ -216,11 +215,11 @@ Maybe JobBuildAndInferCtx::InferOpOutSbpParallel(Operator* op, << "when infer op_name: " << op->op_name() << " consumed op_name: " << lbi.op_name() << " blob_name: " << lbi.blob_name() << " not infer split axis"; const SbpParallel* sbp_parallel = &lbi2sbp_parallel_from_producer_view_.at(lbi); - const OptInt64* batch_axis = JUST(BatchAxis4Lbi(lbi)); + const OptInt64* batch_axis = JUST(BatchAxis4BnInOp(ibn)); ibn2sbp_infer_hint.emplace(ibn, SbpInferHint(pd, logical_blob_desc, sbp_parallel, batch_axis)); } - JUST(InferOpSbpSignature(op, sbp_sig_conf, parallel_desc, ibn2sbp_infer_hint, BatchAxis4Lbi)); + JUST(InferOpSbpSignature(op, sbp_sig_conf, parallel_desc, ibn2sbp_infer_hint, BatchAxis4BnInOp)); const auto& bn2sbp_parallel = JUST(op->sbp_signature())->bn_in_op2sbp_parallel(); for (const auto& obn : op->output_bns()) { diff --git a/oneflow/core/operator/op_attribute.proto b/oneflow/core/operator/op_attribute.proto index 29fe6bd2cf9..008282546b5 100644 --- a/oneflow/core/operator/op_attribute.proto +++ b/oneflow/core/operator/op_attribute.proto @@ -27,3 +27,10 @@ message OpAttribute { optional BlobDescSignature logical_blob_desc_signature = 106; optional BatchAxisSignature batch_axis_signature = 107; } + +message UpstreamSignature { + optional SbpSignature sbp_signature = 1; + optional MirroredSignature mirrored_signature = 2; + optional BlobDescSignature logical_blob_desc_signature = 3; + optional BatchAxisSignature batch_axis_signature = 4; +} diff --git a/oneflow/core/operator/operator.cpp b/oneflow/core/operator/operator.cpp index bac7a18dd36..283353211d6 100644 --- a/oneflow/core/operator/operator.cpp +++ b/oneflow/core/operator/operator.cpp @@ -632,7 +632,7 @@ Maybe ParseDisableBoxingFlag(const std::string& lbn_with_hint, bool* disab Maybe InferOpSbpSignature( Operator* op, const SbpSignature& sbp_sig_conf, const ParallelDesc& parallel_desc, const HashMap& ibn2sbp_infer_hint, - std::function(const LogicalBlobId&)> BatchAxis4Lbi) { + std::function(const std::string&)> BatchAxis4BnInOp) { auto SbpInferHint4Ibn = [&](const std::string& ibn) -> Maybe { auto it = ibn2sbp_infer_hint.find(ibn); if (it == ibn2sbp_infer_hint.end()) { @@ -644,14 +644,14 @@ Maybe InferOpSbpSignature( std::function CalcOrderValue4SbpSig; auto OrderValue4HasBatchAxis = [&](const std::string& bn, const SbpParallel& sbp_parallel) -> int32_t { - const auto& batch_axis = *CHECK_JUST(BatchAxis4Lbi(op->BnInOp2Lbi(bn))); + const auto& batch_axis = *CHECK_JUST(BatchAxis4BnInOp(bn)); return -1 * (batch_axis.has_value() && sbp_parallel.has_split_parallel() && sbp_parallel.split_parallel().axis() == batch_axis.value()); }; auto OrderValue4HasNoBatchAxis = [&](const std::string& ibn, const SbpParallel& sbp_parallel) -> int32_t { - const auto& batch_axis = *CHECK_JUST(BatchAxis4Lbi(op->BnInOp2Lbi(ibn))); + const auto& batch_axis = *CHECK_JUST(BatchAxis4BnInOp(ibn)); return -2 * (batch_axis.has_value() == false && CHECK_JUST(SbpInferHint4Ibn(ibn))->sbp_parallel().has_split_parallel() == false @@ -724,4 +724,120 @@ bool operator==(const OperatorConf& lhs, const OperatorConf& rhs) { return PbMd().Equals(lhs, rhs); } +namespace { + +Maybe InferOpOutBlobDescs( + Operator* op, const std::function& BlobDesc4BnInOp) { + ParallelContext parallel_ctx; + parallel_ctx.set_parallel_id(0); + parallel_ctx.set_parallel_num(1); + JUST(op->InferOutBlobDescsIf(BlobDesc4BnInOp, ¶llel_ctx, CHECK_JUST(op->sbp_signature()), + [](OpContext*) {})); + return Maybe::Ok(); +} + +Maybe InferOpOutSbpParallel( + Operator* op, const UpstreamSignature& upstream_signature, + const std::function& ConstBlobDesc4Ibn, + const SbpSignature& sbp_sig_conf, const ParallelDesc& parallel_desc) { + const auto& BatchAxis4BnInOp = [&](const std::string& bn_in_op) -> Maybe { + return op->BatchAxis4BnInOp(bn_in_op); + }; + const auto& SbpParallel4Ibn = [&](const std::string& ibn) -> const SbpParallel* { + const auto& map = upstream_signature.sbp_signature().bn_in_op2sbp_parallel(); + return &map.at(ibn); + }; + HashMap ibn2sbp_infer_hint; + for (const std::string& ibn : op->input_bns()) { + const ParallelDesc* pd = ¶llel_desc; + const BlobDesc* logical_blob_desc = &ConstBlobDesc4Ibn(ibn); + const SbpParallel* sbp_parallel = SbpParallel4Ibn(ibn); + const OptInt64* batch_axis = JUST(BatchAxis4BnInOp(ibn)); + ibn2sbp_infer_hint.emplace(ibn, SbpInferHint(pd, logical_blob_desc, sbp_parallel, batch_axis)); + } + + JUST(InferOpSbpSignature(op, sbp_sig_conf, parallel_desc, ibn2sbp_infer_hint, BatchAxis4BnInOp)); + return Maybe::Ok(); +} + +Maybe InferMirroredSignature(Operator* op, const UpstreamSignature& upstream_signature, + bool is_mirrored, const ParallelDesc& parallel_desc) { + HashMap ibn2mirrored_sig_infer_hint; + for (const std::string& ibn : op->input_bns()) { + const auto& map = upstream_signature.mirrored_signature().bn_in_op2opt_mirrored_parallel(); + const auto& opt_mirrored_parallel = map.at(ibn); + ibn2mirrored_sig_infer_hint.emplace( + ibn, MirroredSigInferHint(¶llel_desc, opt_mirrored_parallel.has_mirrored_parallel())); + } + const auto& MirroredSigInferHint4Ibn = + [&](const std::string& ibn) -> Maybe { + const auto& iter = ibn2mirrored_sig_infer_hint.find(ibn); + CHECK_OR_RETURN(iter != ibn2mirrored_sig_infer_hint.end()) + << "input blob not found. ibn: " << ibn; + return &iter->second; + }; + JUST(op->InferMirroredSignatureIf(MirroredSigInferHint4Ibn, is_mirrored, parallel_desc)); + return Maybe::Ok(); +} + +Maybe CheckOpInputSignature(const Operator& op, const UpstreamSignature& upstream_signature) { + for (const auto& ibn : op.input_bns()) { + { + const auto& map = upstream_signature.logical_blob_desc_signature().bn_in_op2blob_desc(); + CHECK_OR_RETURN(map.find(ibn) != map.end()); + } + { + const auto& map = upstream_signature.sbp_signature().bn_in_op2sbp_parallel(); + CHECK_OR_RETURN(map.find(ibn) != map.end()); + } + { + const auto& map = upstream_signature.mirrored_signature().bn_in_op2opt_mirrored_parallel(); + CHECK_OR_RETURN(map.find(ibn) != map.end()); + } + { + const auto& map = upstream_signature.batch_axis_signature().bn_in_op2batch_axis(); + CHECK_OR_RETURN(map.find(ibn) != map.end()); + } + } + return Maybe::Ok(); +} + +} // namespace + +Maybe ConstructAndInferOp(const OperatorConf& op_conf, + const UpstreamSignature& upstream_signature, + const ParallelConf& parallel_conf, bool is_mirrored, + const JobDesc& job_desc) { + const auto& op = ConstructOp(op_conf, &job_desc); + JUST(CheckOpInputSignature(*op, upstream_signature)); + ParallelDesc parallel_desc(parallel_conf); + HashMap> bn_in_op2blob_desc; + for (const auto& ibn : op->input_bns()) { + const auto& map = upstream_signature.logical_blob_desc_signature().bn_in_op2blob_desc(); + bn_in_op2blob_desc[ibn].reset(new BlobDesc(map.at(ibn))); + } + const auto& ConstBlobDesc4Ibn = [&](const std::string& ibn) -> const BlobDesc& { + return *bn_in_op2blob_desc.at(ibn); + }; + const auto& BatchAxis4Ibn = [&](const std::string& ibn) -> Maybe { + const auto& map = upstream_signature.batch_axis_signature().bn_in_op2batch_axis(); + const auto& iter = map.find(ibn); + CHECK_OR_RETURN(iter != map.end()); + return &iter->second; + }; + JUST(op->InferBatchAxisIf(ConstBlobDesc4Ibn, BatchAxis4Ibn)); + JUST(InferMirroredSignature(op.get(), upstream_signature, is_mirrored, parallel_desc)); + SbpSignature sbp_sig_conf; + JUST(InferOpOutSbpParallel(op.get(), upstream_signature, ConstBlobDesc4Ibn, sbp_sig_conf, + parallel_desc)); + const auto& BlobDesc4BnInOp = [&](const std::string& bn_in_op) -> BlobDesc* { + if (!bn_in_op2blob_desc[bn_in_op]) { + bn_in_op2blob_desc[bn_in_op].reset(new BlobDesc(DataType::kInvalidDataType)); + } + return bn_in_op2blob_desc[bn_in_op].get(); + }; + JUST(InferOpOutBlobDescs(op.get(), BlobDesc4BnInOp)); + return op; +} + } // namespace oneflow diff --git a/oneflow/core/operator/operator.h b/oneflow/core/operator/operator.h index 075b4388e7b..db5b095c78a 100644 --- a/oneflow/core/operator/operator.h +++ b/oneflow/core/operator/operator.h @@ -410,7 +410,7 @@ Maybe ParseDisableBoxingFlag(const std::string& lbn_with_hint, bool* disab Maybe InferOpSbpSignature( Operator* op, const SbpSignature& sbp_sig_conf, const ParallelDesc& parallel_desc, const HashMap& ibn2sbp_infer_hint, - std::function(const LogicalBlobId&)> BatchAxis4Lbi); + std::function(const std::string&)> BatchAxis4BnInOp); std::string GetInputLbnInOpCustomizedConf(const PbMessage& msg, const std::string& fd_name_may_have_idx); @@ -419,6 +419,11 @@ void ReplaceInputLbnInOpCustomizedConf(PbMessage* msg, const std::string& fd_nam bool operator==(const OperatorConf& lhs, const OperatorConf& rhs); +Maybe ConstructAndInferOp(const OperatorConf& op_conf, + const UpstreamSignature& upstream_signature, + const ParallelConf& parallel_conf, bool is_mirrored, + const JobDesc& job_desc); + } // namespace oneflow namespace std { diff --git a/oneflow/python/job_build_and_infer_if.h b/oneflow/python/job_build_and_infer_if.h index e7b10dd389e..e1adddb6bea 100644 --- a/oneflow/python/job_build_and_infer_if.h +++ b/oneflow/python/job_build_and_infer_if.h @@ -12,8 +12,8 @@ void JobBuildAndInferCtx_Open(const std::string& job_name, std::string* error_st } std::string JobBuildAndInferCtx_GetCurrentJobName(std::string* error_str) { - return oneflow::JobBuildAndInferCtx_GetCurrentJobName().GetDataAndSerializedErrorProto(error_str, - ""); + return oneflow::JobBuildAndInferCtx_GetCurrentJobName().GetDataAndSerializedErrorProto( + error_str, std::string("")); } void JobBuildAndInferCtx_Close(std::string* error_str) { @@ -44,7 +44,7 @@ std::string CurJobBuildAndInferCtx_AddAndInferMirroredOp( std::string* error_str) { return oneflow::CurJobBuildAndInferCtx_AddAndInferMirroredOp(serialized_op_conf, serialized_parallel_conf) - .GetDataAndSerializedErrorProto(error_str, ""); + .GetDataAndSerializedErrorProto(error_str, std::string("")); } std::string CurJobBuildAndInferCtx_AddAndInferConsistentOp( @@ -52,7 +52,7 @@ std::string CurJobBuildAndInferCtx_AddAndInferConsistentOp( std::string* error_str) { return oneflow::CurJobBuildAndInferCtx_AddAndInferConsistentOp(serialized_op_conf, serialized_parallel_conf) - .GetDataAndSerializedErrorProto(error_str, ""); + .GetDataAndSerializedErrorProto(error_str, std::string("")); } void CurJobBuildAndInferCtx_AddLossLogicalBlobName(const std::string& lbn, std::string* error_str) { @@ -70,7 +70,7 @@ std::string JobBuildAndInferCtx_GetSerializedIdListAsStaticShape(const std::stri const std::string& lbn, std::string* error_str) { return oneflow::JobBuildAndInferCtx_GetSerializedIdListAsStaticShape(job_name, lbn) - .GetDataAndSerializedErrorProto(error_str, ""); + .GetDataAndSerializedErrorProto(error_str, std::string("")); } long long JobBuildAndInferCtx_GetDataType(const std::string& job_name, const std::string& lbn, @@ -100,20 +100,20 @@ bool JobBuildAndInferCtx_IsTensorList(const std::string& job_name, const std::st std::string JobBuildAndInferCtx_GetBatchAxis(const std::string& job_name, const std::string& lbn, std::string* error_str) { return oneflow::JobBuildAndInferCtx_GetBatchAxis(job_name, lbn) - .GetDataAndSerializedErrorProto(error_str, ""); + .GetDataAndSerializedErrorProto(error_str, std::string("")); } std::string JobBuildAndInferCtx_GetSplitAxisFromProducerView(const std::string& job_name, const std::string& lbn, std::string* error_str) { return oneflow::JobBuildAndInferCtx_GetSplitAxisFromProducerView(job_name, lbn) - .GetDataAndSerializedErrorProto(error_str, ""); + .GetDataAndSerializedErrorProto(error_str, std::string("")); } std::string JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView( const std::string& job_name, const std::string& lbn, std::string* error_str) { return oneflow::JobBuildAndInferCtx_GetSerializedParallelConfFromProducerView(job_name, lbn) - .GetDataAndSerializedErrorProto(error_str, ""); + .GetDataAndSerializedErrorProto(error_str, std::string("")); } bool JobBuildAndInferCtx_IsMirroredBlob(const std::string& job_name, const std::string& lbn, @@ -132,13 +132,13 @@ std::string JobBuildAndInferCtx_MirroredBlobGetSerializedSubLbi(const std::strin const std::string& lbn, int index, std::string* error_str) { return oneflow::JobBuildAndInferCtx_MirroredBlobGetSubLbi(job_name, lbn, index) - .GetDataAndSerializedErrorProto(error_str, ""); + .GetDataAndSerializedErrorProto(error_str, std::string("")); } std::string JobBuildAndInferCtx_MirroredBlobGetSerializedIdListAsStaticShape( const std::string& job_name, const std::string& lbn, std::string* error_str) { return oneflow::JobBuildAndInferCtx_MirroredBlobGetSerializedIdListAsStaticShape(job_name, lbn) - .GetDataAndSerializedErrorProto(error_str, ""); + .GetDataAndSerializedErrorProto(error_str, std::string("")); } long long JobBuildAndInferCtx_MirroredBlobGetDataType(const std::string& job_name, @@ -164,18 +164,18 @@ std::string JobBuildAndInferCtx_MirroredBlobGetBatchAxis(const std::string& job_ const std::string& lbn, std::string* error_str) { return oneflow::JobBuildAndInferCtx_MirroredBlobGetBatchAxis(job_name, lbn) - .GetDataAndSerializedErrorProto(error_str, ""); + .GetDataAndSerializedErrorProto(error_str, std::string("")); } std::string JobBuildAndInferCtx_MirroredBlobGetSplitAxisFromProducerView( const std::string& job_name, const std::string& lbn, std::string* error_str) { return oneflow::JobBuildAndInferCtx_MirroredBlobGetSplitAxisFromProducerView(job_name, lbn) - .GetDataAndSerializedErrorProto(error_str, ""); + .GetDataAndSerializedErrorProto(error_str, std::string("")); } std::string JobBuildAndInferCtx_MirroredBlobGetSerializedParallelConfFromProducerView( const std::string& job_name, const std::string& lbn, std::string* error_str) { return oneflow::JobBuildAndInferCtx_MirroredBlobGetSerializedParallelConfFromProducerView( job_name, lbn) - .GetDataAndSerializedErrorProto(error_str, ""); + .GetDataAndSerializedErrorProto(error_str, std::string("")); } diff --git a/oneflow/python/oneflow_internal.h b/oneflow/python/oneflow_internal.h index d1d63044a94..3ad449764eb 100644 --- a/oneflow/python/oneflow_internal.h +++ b/oneflow/python/oneflow_internal.h @@ -17,7 +17,7 @@ bool IsOpTypeNameCpuSupportOnly(const std::string& op_type_name, std::string* er } std::string CurrentResource(std::string* error_str) { - return oneflow::CurrentResource().GetDataAndSerializedErrorProto(error_str, ""); + return oneflow::CurrentResource().GetDataAndSerializedErrorProto(error_str, std::string("")); } void EnableEagerExecution(bool enable_eager_execution) { @@ -66,15 +66,16 @@ void StopGlobalSession(std::string* error_str) { } std::string GetSerializedInterUserJobInfo(std::string* error_str) { - return oneflow::GetSerializedInterUserJobInfo().GetDataAndSerializedErrorProto(error_str, ""); + return oneflow::GetSerializedInterUserJobInfo().GetDataAndSerializedErrorProto(error_str, + std::string("")); } std::string GetSerializedJobSet(std::string* error_str) { - return oneflow::GetSerializedJobSet().GetDataAndSerializedErrorProto(error_str, ""); + return oneflow::GetSerializedJobSet().GetDataAndSerializedErrorProto(error_str, std::string("")); } std::string GetFunctionConfigDef(std::string* error_str) { - return oneflow::GetFunctionConfigDef().GetDataAndSerializedErrorProto(error_str, ""); + return oneflow::GetFunctionConfigDef().GetDataAndSerializedErrorProto(error_str, std::string("")); } void LaunchJob(const std::shared_ptr& cb, std::string* error_str) { @@ -90,13 +91,13 @@ long DeviceType4DeviceTag(const std::string& device_tag, std::string* error_str) std::string GetMachine2DeviceIdListOFRecordFromParallelConf(const std::string& parallel_conf, std::string* error_str) { return oneflow::GetSerializedMachineId2DeviceIdListOFRecord(parallel_conf) - .GetDataAndSerializedErrorProto(error_str, ""); + .GetDataAndSerializedErrorProto(error_str, std::string("")); } std::string CheckAndCompleteUserOpConf(const std::string& serialized_op_conf, std::string* error_str) { return oneflow::CheckAndCompleteUserOpConf(serialized_op_conf) - .GetDataAndSerializedErrorProto(error_str, ""); + .GetDataAndSerializedErrorProto(error_str, std::string("")); } long CurrentMachineId(std::string* error_str) {