Skip to content

Commit

Permalink
Minor simplification to CheckDims (where the tensor dim indexing is p…
Browse files Browse the repository at this point in the history
…erformed by the caller).

PiperOrigin-RevId: 570106474
  • Loading branch information
tensorflower-gardener committed Oct 2, 2023
1 parent b68b3b2 commit a008c44
Showing 1 changed file with 33 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ limitations under the License.
namespace xla {
namespace spmd {

using DimMap = StableHashMap</*instr. dim idx*/ int, /* mesh dim idx*/ int>;
using DimMap = StableHashMap</*tensor dim*/ int, /* mesh dim*/ int>;
using MeshDims = absl::Span<const int64_t>;

// Contains base functionality common to both DotHandler and ConvHandler.
Expand Down Expand Up @@ -97,16 +97,13 @@ class HandlerBase {
}
}

bool CheckDims(const HloInstruction* ins,
const tsl::protobuf::RepeatedField<int64_t>& instr_dims,
const DimMap& dim_map) const {
for (const auto& [instr_dim_idx, mesh_dim_idx] : dim_map) {
auto instr_dim = instr_dims.at(instr_dim_idx);
auto shape_dim = ins->shape().dimensions().at(instr_dim);
auto mesh_dim = device_mesh_.dim(mesh_dim_idx);
if (shape_dim < mesh_dim) return false;
bool CheckDims(const HloInstruction* ins, const DimMap& dim_map) const {
for (const auto& [tensor_dim, mesh_dim] : dim_map) {
auto shape_dim = ins->shape().dimensions().at(tensor_dim);
auto device_mesh_dim = device_mesh_.dim(mesh_dim);
if (shape_dim < device_mesh_dim) return false;
if (solver_option_.only_allow_divisible_intermediate &&
!IsDivisible(shape_dim, mesh_dim))
!IsDivisible(shape_dim, device_mesh_dim))
return false;
}
return true;
Expand Down Expand Up @@ -152,8 +149,8 @@ class DotHandler : public HandlerBase {
DCHECK_EQ(mesh_dims.size(), 2);
for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) {
for (int64_t j = 0; j < rhs_space_dims_.size(); ++j) {
if (!CheckDims(lhs_, lhs_space_dims_, {{i, mesh_dims[0]}}) ||
!CheckDims(rhs_, rhs_space_dims_, {{j, mesh_dims[1]}}))
if (!CheckDims(lhs_, {{lhs_space_dims_[i], mesh_dims[0]}}) ||
!CheckDims(rhs_, {{rhs_space_dims_[j], mesh_dims[1]}}))
continue;
std::string name = absl::StrFormat("SS = SR x RS @ {%s}",
absl::StrJoin(mesh_dims, ","));
Expand All @@ -177,8 +174,8 @@ class DotHandler : public HandlerBase {
DCHECK_EQ(mesh_dims.size(), 2);
for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) {
for (int64_t j = i + 1; j < lhs_space_dims_.size(); ++j) {
if (!CheckDims(lhs_, lhs_space_dims_,
{{i, mesh_dims[0]}, {j, mesh_dims[1]}}))
if (!CheckDims(lhs_, {{lhs_space_dims_[i], mesh_dims[0]},
{lhs_space_dims_[j], mesh_dims[1]}}))
continue;
std::string name = absl::StrFormat("SSR = SSR x RR @ {%s}",
absl::StrJoin(mesh_dims, ","));
Expand All @@ -199,8 +196,8 @@ class DotHandler : public HandlerBase {
DCHECK_EQ(mesh_dims.size(), 2);
for (int64_t i = 0; i < rhs_space_dims_.size(); ++i) {
for (int64_t j = i + 1; j < rhs_space_dims_.size(); ++j) {
if (!CheckDims(rhs_, rhs_space_dims_,
{{i, mesh_dims[0]}, {j, mesh_dims[1]}}))
if (!CheckDims(rhs_, {{rhs_space_dims_[i], mesh_dims[0]},
{rhs_space_dims_[j], mesh_dims[1]}}))
continue;
std::string name = absl::StrFormat("RSS = RR x RSS @ {%s}",
absl::StrJoin(mesh_dims, ","));
Expand Down Expand Up @@ -229,8 +226,8 @@ class DotHandler : public HandlerBase {
absl::StrJoin(mesh_dims, ","), mesh_dims[1]);
for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) {
for (int64_t j = 0; j < lhs_con_dims_.size(); ++j) {
if (!CheckDims(lhs_, lhs_space_dims_, {{i, mesh_dims[0]}}) ||
!CheckDims(lhs_, lhs_con_dims_, {{j, mesh_dims[1]}}))
if (!CheckDims(lhs_, {{lhs_space_dims_[i], mesh_dims[0]},
{lhs_con_dims_[j], mesh_dims[1]}}))
continue;

HloSharding output_spec = Tile(ins_->shape(), {space_base_dim_ + i},
Expand Down Expand Up @@ -259,8 +256,8 @@ class DotHandler : public HandlerBase {
absl::StrJoin(mesh_dims, ","), mesh_dims[0]);
for (int64_t i = 0; i < rhs_space_dims_.size(); ++i) {
for (int64_t j = 0; j < lhs_con_dims_.size(); ++j) {
if (!CheckDims(rhs_, rhs_space_dims_, {{i, mesh_dims[1]}}) ||
!CheckDims(lhs_, lhs_con_dims_, {{j, mesh_dims[0]}}))
if (!CheckDims(rhs_, {{rhs_space_dims_[i], mesh_dims[1]}}) ||
!CheckDims(lhs_, {{lhs_con_dims_[j], mesh_dims[0]}}))
continue;
HloSharding output_spec =
Tile(ins_->shape(),
Expand Down Expand Up @@ -288,7 +285,7 @@ class DotHandler : public HandlerBase {
[](int64_t size) { return size > 1; }) == 1) {
for (int64_t i = 0; i < lhs_batch_dims_.size(); ++i) {
for (int64_t j = 0; j < device_mesh_.num_dimensions(); ++j) {
if (!CheckDims(lhs_, lhs_batch_dims_, {{i, j}})) continue;
if (!CheckDims(lhs_, {{lhs_batch_dims_[i], j}})) continue;
std::string name = absl::StrFormat("Sb_%d = Sb x Sb @ {%d}", i, j);
HloSharding output_spec = Tile(ins_->shape(), {i}, {j}, device_mesh_);
HloSharding lhs_spec =
Expand All @@ -306,8 +303,8 @@ class DotHandler : public HandlerBase {
DCHECK_EQ(mesh_dims.size(), 2);
if (lhs_batch_dims_.size() == 2 && device_mesh_.dim(mesh_dims[0]) > 1 &&
device_mesh_.dim(mesh_dims[1]) > 1) {
if (!CheckDims(lhs_, lhs_batch_dims_,
{{0, mesh_dims[0]}, {1, mesh_dims[1]}}))
if (!CheckDims(lhs_, {{lhs_batch_dims_[0], mesh_dims[0]},
{lhs_batch_dims_[1], mesh_dims[1]}}))
return;
std::string name =
absl::StrFormat("Sb = Sb x Sb @ {%s}", absl::StrJoin(mesh_dims, ","));
Expand All @@ -331,8 +328,8 @@ class DotHandler : public HandlerBase {
absl::StrJoin(mesh_dims, ","));
for (int64_t i = 0; i < lhs_space_dims_.size(); ++i) {
for (int64_t j = 0; j < lhs_batch_dims_.size(); ++j) {
if (!CheckDims(lhs_, lhs_space_dims_, {{i, mesh_dims[0]}}) ||
!CheckDims(lhs_, lhs_batch_dims_, {{j, mesh_dims[1]}}))
if (!CheckDims(lhs_, {{lhs_space_dims_[i], mesh_dims[0]},
{lhs_batch_dims_[j], mesh_dims[1]}}))
continue;
HloSharding output_spec = Tile(
ins_->shape(), {j, space_base_dim_ + i}, mesh_dims, device_mesh_);
Expand All @@ -356,8 +353,8 @@ class DotHandler : public HandlerBase {
absl::StrJoin(mesh_dims, ","));
for (int64_t i = 0; i < rhs_space_dims_.size(); ++i) {
for (int64_t j = 0; j < lhs_batch_dims_.size(); ++j) {
if (!CheckDims(rhs_, rhs_space_dims_, {{i, mesh_dims[1]}}) ||
!CheckDims(lhs_, lhs_batch_dims_, {{j, mesh_dims[0]}}))
if (!CheckDims(rhs_, {{rhs_space_dims_[i], mesh_dims[1]}}) ||
!CheckDims(lhs_, {{lhs_batch_dims_[j], mesh_dims[0]}}))
continue;
HloSharding output_spec =
Tile(ins_->shape(),
Expand Down Expand Up @@ -385,8 +382,8 @@ class DotHandler : public HandlerBase {
absl::StrJoin(mesh_dims, ","), mesh_dims[1]);
for (int64_t i = 0; i < lhs_con_dims_.size(); ++i) {
for (int64_t j = 0; j < lhs_batch_dims_.size(); ++j) {
if (!CheckDims(lhs_, lhs_con_dims_, {{i, mesh_dims[1]}}) ||
!CheckDims(lhs_, lhs_batch_dims_, {{j, mesh_dims[0]}}))
if (!CheckDims(lhs_, {{lhs_con_dims_[i], mesh_dims[1]},
{lhs_batch_dims_[j], mesh_dims[0]}}))
continue;
HloSharding output_spec =
Tile(ins_->shape(), {j}, {mesh_dims[0]}, device_mesh_);
Expand Down Expand Up @@ -418,10 +415,10 @@ class DotHandler : public HandlerBase {
absl::StrJoin(mesh_dims, ","), absl::StrJoin(mesh_dims, ", "));
for (int64_t i = 0; i < lhs_con_dims_.size(); ++i) {
for (int64_t j = i + 1; j < lhs_con_dims_.size(); ++j) {
if (!CheckDims(lhs_, lhs_con_dims_,
{{i, mesh_dims[0]}, {j, mesh_dims[1]}}) ||
!CheckDims(rhs_, rhs_con_dims_,
{{i, mesh_dims[0]}, {j, mesh_dims[1]}}))
if (!CheckDims(lhs_, {{lhs_con_dims_[i], mesh_dims[0]},
{lhs_con_dims_[j], mesh_dims[1]}}) ||
!CheckDims(rhs_, {{rhs_con_dims_[i], mesh_dims[0]},
{rhs_con_dims_[j], mesh_dims[1]}}))
continue;
HloSharding output_spec = HloSharding::Replicate();
HloSharding lhs_spec =
Expand All @@ -447,7 +444,7 @@ class DotHandler : public HandlerBase {
std::string name = absl::StrFormat("RR = RS x SR @ {%d} (allreduce @ %d)",
mesh_dims[0], mesh_dims[0]);
for (int64_t i = 0; i < lhs_con_dims_.size(); ++i) {
if (!CheckDims(lhs_, lhs_con_dims_, {{i, mesh_dims[0]}})) continue;
if (!CheckDims(lhs_, {{lhs_con_dims_[i], mesh_dims[0]}})) continue;
HloSharding output_spec = HloSharding::Replicate();
HloSharding lhs_spec = Tile(lhs_->shape(), {lhs_con_dims_[i]},
{mesh_dims[0]}, device_mesh_);
Expand Down Expand Up @@ -524,8 +521,8 @@ class DotHandler : public HandlerBase {
[](int64_t size) { return size > 1; }) > 1) {
int mesh_dim = 0;
for (int64_t i = 0; i < lhs_batch_dims_.size(); ++i) {
if (!CheckDims(lhs_, lhs_batch_dims_, {{i, mesh_dim}}) ||
!CheckDims(rhs_, rhs_batch_dims_, {{i, mesh_dim}}))
if (!CheckDims(lhs_, {{lhs_batch_dims_[i], mesh_dim}}) ||
!CheckDims(rhs_, {{rhs_batch_dims_[i], mesh_dim}}))
continue;
std::string name =
absl::StrFormat("Sb_%d = Sb x Sb @ {%d} 1d", i, mesh_dim);
Expand Down

0 comments on commit a008c44

Please sign in to comment.