Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Passing in Size of Dynamic Dimensions to Inference Function #1025

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
88 changes: 46 additions & 42 deletions fx2ait/fx2ait/find_batch_size_dim.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,55 +21,59 @@ def find_batch_size_dim(
inputs: Any,
can_non_first_dim_be_dynamic: bool = True,
can_dim_value_one_be_dynamic: bool = True,
dynamic_size: int = -1,
# pyre-fixme Invalid type [31]
) -> []:
if isinstance(inputs, torch.Tensor) or len(inputs) <= 1:
return [0]
shapes = [i.shape for i in inputs]
frequency_map = {}
position_scores = {}
first_dims = set()
for shape in shapes:
if len(shape) < 2:
# By pass for rank-1 tensors. MRS model has rank-1 tensor carry no batch_size info
continue
# Dedup shape value for single tensor
first_dims.add(shape[0])
seen_dims = set()
valid_len = len(shape) if can_non_first_dim_be_dynamic else 1
for i in range(valid_len):
dim = shape[i]
if dim not in seen_dims:
frequency_map[dim] = frequency_map.get(dim, 0) + 1
position_scores[dim] = position_scores.get(dim, 0) + i
seen_dims.add(dim)

if len(first_dims) == 1:
# first dim is the same in every input: we use it as batch_size
batch_size = first_dims.pop()
elif frequency_map:
# first dims are different: we use the most frequent dim as batch_size
# if there is more than 1 most frequent dim, we choose the one with the
# lowest position score (i.e., the leftmost of the most frequent ones)
sorted_frequency = sorted(
frequency_map.items(),
key=lambda x: (-x[1], position_scores[x[0]]),
)
if len(sorted_frequency) > 1:
if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1:
# It's often that dim value one indicates a non-dynamic dimension.
# If the user says so, we pick the second most frequent value.
batch_size = sorted_frequency[1][0]
if dynamic_size > 0:
batch_size = dynamic_size
else:
shapes = [i.shape for i in inputs]
frequency_map = {}
position_scores = {}
first_dims = set()
for shape in shapes:
if len(shape) < 2:
# By pass for rank-1 tensors. MRS model has rank-1 tensor carry no batch_size info
continue
# Dedup shape value for single tensor
first_dims.add(shape[0])
seen_dims = set()
valid_len = len(shape) if can_non_first_dim_be_dynamic else 1
for i in range(valid_len):
dim = shape[i]
if dim not in seen_dims:
frequency_map[dim] = frequency_map.get(dim, 0) + 1
position_scores[dim] = position_scores.get(dim, 0) + i
seen_dims.add(dim)
if len(first_dims) == 1:
# first dim is the same in every input: we use it as batch_size
batch_size = first_dims.pop()
elif frequency_map:
# first dims are different: we use the most frequent dim as batch_size
# if there is more than 1 most frequent dim, we choose the one with the
# lowest position score (i.e., the leftmost of the most frequent ones)
sorted_frequency = sorted(
frequency_map.items(),
key=lambda x: (-x[1], position_scores[x[0]]),
)
if len(sorted_frequency) > 1:
if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1:
# It's often that dim value one indicates a non-dynamic dimension.
# If the user says so, we pick the second most frequent value.
batch_size = sorted_frequency[1][0]
else:
batch_size = sorted_frequency[0][0]
else:
batch_size = sorted_frequency[0][0]
if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1:
batch_size = -1
else:
batch_size = sorted_frequency[0][0]
else:
if not can_dim_value_one_be_dynamic and sorted_frequency[0][0] == 1:
batch_size = -1
else:
batch_size = sorted_frequency[0][0]
else:
# no dims to sort: no batch_size
batch_size = -1
# no dims to sort: no batch_size
batch_size = -1

bs_dim = []
for i in inputs:
Expand Down
1 change: 1 addition & 0 deletions fx2ait/fx2ait/lower/lower_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ class LowerSettings:
name: str = ""
dll_name: str = "ait_engine.so"
dynamic_profile_strategy: DynamicProfileStrategy = DynamicProfileStrategy.MAX
dynamic_size: int = -1
profile_devs: Any = None
# If None, infer the dtypes from the sample inputs.
precision: Optional[LowerPrecision] = LowerPrecision.FP16
Expand Down
6 changes: 5 additions & 1 deletion fx2ait/fx2ait/tensor_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,10 +477,14 @@ def find_batch_size_dim(
inputs: Any,
can_non_first_dim_be_dynamic: bool = True,
can_dim_value_one_be_dynamic: bool = True,
dynamic_size: int = -1,
# pyre-fixme Invalid type [31]
) -> []:
return find_batch_size_dim_impl(
inputs, can_non_first_dim_be_dynamic, can_dim_value_one_be_dynamic
inputs,
can_non_first_dim_be_dynamic,
can_dim_value_one_be_dynamic,
dynamic_size,
)

@classmethod
Expand Down
Loading