Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
BiteTheDDDDt committed Jan 26, 2025
1 parent 34ed2e2 commit 35aa379
Show file tree
Hide file tree
Showing 4 changed files with 79 additions and 30 deletions.
45 changes: 26 additions & 19 deletions be/src/vec/functions/array/function_array_flatten.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,36 +44,43 @@ class FunctionArrayFlatten : public IFunction {
size_t get_number_of_arguments() const override { return 1; }

DataTypePtr get_return_type_impl(const DataTypes& arguments) const override {
DataTypePtr arg_0 = arguments[0];
DCHECK(is_array(arg_0));
return remove_nullable(assert_cast<const DataTypeArray*>(arg_0.get())->get_nested_type());
DataTypePtr arg = arguments[0];
while (is_array(arg)) {
arg = remove_nullable(assert_cast<const DataTypeArray*>(arg.get())->get_nested_type());
}
return std::make_shared<DataTypeArray>(make_nullable(arg));
}

Status execute_impl(FunctionContext* context, Block& block, const ColumnNumbers& arguments,
uint32_t result, size_t input_rows_count) const override {
auto src_column =
block.get_by_position(arguments[0]).column->convert_to_full_column_if_const();
const auto& src_column_array =
assert_cast<const ColumnArray&>(*remove_nullable(src_column));
const auto& nested_src_column_array =
assert_cast<const ColumnArray&>(*remove_nullable(src_column_array.get_data_ptr()));
auto* src_column_array_ptr =
assert_cast<ColumnArray*>(remove_nullable(src_column)->assume_mutable().get());
ColumnArray* nested_src_column_array_ptr = src_column_array_ptr;

DataTypePtr src_column_type = block.get_by_position(arguments[0]).type;
auto nested_type = assert_cast<const DataTypeArray&>(*src_column_type).get_nested_type();
auto result_column_offsets =
assert_cast<ColumnArray::ColumnOffsets&>(src_column_array_ptr->get_offsets_column())
.clone();
auto* offsets = assert_cast<ColumnArray::ColumnOffsets*>(result_column_offsets.get())
->get_data()
.data();

auto result_column_offsets = ColumnArray::ColumnOffsets::create(input_rows_count);
auto* offsets = result_column_offsets->get_data().data();
while (src_column_array_ptr->get_data_ptr()->is_column_array()) {
nested_src_column_array_ptr = assert_cast<ColumnArray*>(
remove_nullable(src_column_array_ptr->get_data_ptr())->assume_mutable().get());

for (size_t i = 0; i < input_rows_count; ++i) {
offsets[i] =
nested_src_column_array.get_offsets()[src_column_array.get_offsets()[i] - 1];
for (size_t i = 0; i < input_rows_count; ++i) {
offsets[i] = nested_src_column_array_ptr->get_offsets()[offsets[i] - 1];
}
src_column_array_ptr = nested_src_column_array_ptr;
}

block.replace_by_position(result,
ColumnArray::create(assert_cast<const ColumnNullable&>(
nested_src_column_array.get_data())
.clone(),
std::move(result_column_offsets)));
block.replace_by_position(
result, ColumnArray::create(assert_cast<const ColumnNullable&>(
nested_src_column_array_ptr->get_data())
.clone(),
std::move(result_column_offsets)));
return Status::OK();
}
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,12 @@

import org.apache.doris.catalog.FunctionSignature;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.functions.CustomSignature;
import org.apache.doris.nereids.trees.expressions.functions.ExplicitlyCastableSignature;
import org.apache.doris.nereids.trees.expressions.functions.PropagateNullable;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.ArrayType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.coercion.AnyDataType;
import org.apache.doris.nereids.types.coercion.FollowToAnyDataType;

Expand All @@ -35,11 +37,7 @@
* ScalarFunction 'array_flatten'
*/
public class ArrayFlatten extends ScalarFunction
implements ExplicitlyCastableSignature, PropagateNullable {

public static final List<FunctionSignature> SIGNATURES = ImmutableList.of(
FunctionSignature.ret(ArrayType.of(new FollowToAnyDataType(0)))
.args(ArrayType.of(ArrayType.of(new AnyDataType(0)))));
implements CustomSignature, PropagateNullable {

/**
* constructor with 1 arguments.
Expand All @@ -48,6 +46,15 @@ public ArrayFlatten(Expression arg) {
super("array_flatten", arg);
}

@Override
public FunctionSignature customSignature() {
DataType dataType = getArgument(0).getDataType();
while (dataType instanceof ArrayType) {
dataType = ((ArrayType)dataType).getItemType();
}
return FunctionSignature.ret(ArrayType.of(dataType)).args(getArgument(0).getDataType());
}

/**
* withChildren.
*/
Expand All @@ -61,10 +68,4 @@ public ArrayFlatten withChildren(List<Expression> children) {
public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) {
return visitor.visitArrayFlatten(this, context);
}

@Override
public List<FunctionSignature> getSignatures() {
return SIGNATURES;
}

}
30 changes: 30 additions & 0 deletions regression-test/data/nereids_function_p0/scalar_function/Array.out
Original file line number Diff line number Diff line change
Expand Up @@ -16923,3 +16923,33 @@ false false
-- !sql --
false false

-- !sql --
[1, 2, 3, 4, 5]

-- !sql --
[]

-- !sql --
[1]

-- !sql --
[1, 2, 3]

-- !sql --
[1, 2, 3, 4, 5]

-- !sql --
[null, null]

-- !sql --
[1, 2, 3, 4, 5]

-- !sql --
[1, 2, 3, 4, 5]

-- !sql --
[1, 2, 3, 4, 5, 6, 7, 8, 9]

-- !sql --
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]

Original file line number Diff line number Diff line change
Expand Up @@ -1424,4 +1424,15 @@ suite("nereids_scalar_fn_Array") {
// map_contains_value
qt_sql """ select map_contains_value(map(1,1), 257), map_contains_value(map(1,2), 258);"""

qt_sql """select array_flatten([[1,2,3],[4,5]]);"""
qt_sql """select array_flatten([[],[]]);"""
qt_sql """select array_flatten([[1],[]]);"""
qt_sql """select array_flatten([[1,2,3],null]);"""
qt_sql """select array_flatten([[1,2,3],null,[4,5]]);"""
qt_sql """select array_flatten([null,null]);"""
qt_sql """select array_flatten([[1,2,3,4,5]]);"""
qt_sql """select array_flatten([[[1,2,3,4,5]]]);;"""
qt_sql """select array_flatten([ [[1,2,3,4,5]],[[6,7],[8,9]] ]);"""
qt_sql """select array_flatten([[[[[[1,2,3,4,5],[6,7],[8,9],[10,11],[12]]]]]]);"""

}

0 comments on commit 35aa379

Please sign in to comment.