Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[JAX FE] : Implement jax.lax.iota operation #28221

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

11happy
Copy link
Contributor

@11happy 11happy commented Dec 28, 2024

Overview:

Testing:

  • TODO: FIX, Currently I am facing some assertion issues while implementing the tests.

Screenshot from 2024-12-29 01-10-02

CC:

Signed-off-by: 11happy <[email protected]>
@11happy 11happy requested review from a team as code owners December 28, 2024 19:40
@github-actions github-actions bot added category: TF FE OpenVINO TensorFlow FrontEnd category: JAX FE OpenVINO JAX FrontEnd labels Dec 28, 2024
@sys-openvino-ci sys-openvino-ci added the ExternalPR External contributor label Dec 28, 2024
@11happy
Copy link
Contributor Author

11happy commented Dec 28, 2024

Hello @rkazants could you please help me with the tests.
Thank you

@rkazants rkazants self-assigned this Dec 30, 2024
@rkazants
Copy link
Member

build_jenkins

@rkazants
Copy link
Member

rkazants commented Jan 1, 2025

@11happy, please fix failing CI jobs

@rkazants rkazants added this to the 2025.0 milestone Jan 1, 2025
Signed-off-by: 11happy <[email protected]>
@rkazants
Copy link
Member

rkazants commented Jan 2, 2025

build_jenkins

@rkazants
Copy link
Member

rkazants commented Jan 3, 2025

Hi @11happy, please fix the build:

/__w/openvino/openvino/openvino/bin/aarch64/Release/libopenvino_util.a  -ldl  vcpkg_installed/arm64-android/lib/libpugixml.a  -latomic -lm && :
ld: error: undefined symbol: ov::element::Type ov::frontend::jax::NodeContext::const_named_param<ov::element::Type>(std::__ndk1::basic_string<char, std::__ndk1::char_traits<char>, std::__ndk1::allocator<char> > const&) const
>>> referenced by iota.cpp:19 (src/frontends/jax/src/op/iota.cpp:19)
>>>               src/frontends/jax/src/CMakeFiles/openvino_jax_frontend.dir/op/iota.cpp.o:(ov::frontend::jax::op::translate_iota(ov::frontend::jax::NodeContext const&))
clang++: error: linker command failed with exit code 1 (use -v to see invocation)

There is no instantiation for const_named_param with ov::element::Type type and you need to define it.

Best regards,
Roman

@11happy
Copy link
Contributor Author

11happy commented Jan 4, 2025

I have defined it, however I am sure if this is correct way.
Screenshot from 2025-01-04 17-58-59
I am assuming this creates an uninitialized constant & we can get type as get_element_type

@11happy
Copy link
Contributor Author

11happy commented Jan 4, 2025

also I wanted to ask do we have any pre-commit hooks or something like that , as most of the time I push some changes build fails due to some formatting error.

@11happy
Copy link
Contributor Author

11happy commented Jan 6, 2025

@rkazants humble ping!
Thank you

Comment on lines 175 to 180
template <>
ov::element::Type NodeContext::const_named_param<ov::element::Type>(const std::string& name) const {
auto c = get_constant_from_params(*this, name);
return c->get_element_type();
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not correct way how we should retrieve output type for iota operation. That is because it tries to create a constant that saves a value equal to type.
We need to have a different mechanism to parse and transfer such attribute values storing type.

OutputVector translate_iota(const NodeContext& context) {
num_inputs_check(context, 2, 2);
auto dtype = context.const_named_param<ov::element::Type>("dtype");
auto size = context.const_named_param<int64_t>("size");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

are we sure that size is attribute of iota operation and it can not be non-constant input?
Can you please try to experiment with it and create Jax graph with size input to iota operation?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure thing!, I will experiment it on my end & will update soon.
thanks


class TestIota(JaxLayerTest):
def _prepare_input(self):
return (self.input_type, self.input_shape)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is not correct because you generate input data with dtype value that will be as-if passing to neural network.
dtype is used only for model creation and not more.

Signed-off-by: 11happy <[email protected]>
@11happy
Copy link
Contributor Author

11happy commented Jan 20, 2025

@rkazants I think we don't need to define const_named_param as there is this convert_dtype I referenced from argmax implementation & its working . All tests are passing on my end.
Screenshot from 2025-01-20 20-57-42

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
category: JAX FE OpenVINO JAX FrontEnd category: TF FE OpenVINO TensorFlow FrontEnd ExternalPR External contributor
Projects
None yet
Development

Successfully merging this pull request may close these issues.

[Good First Issue][JAX FE]: Support jax.lax.iota operation for JAX
3 participants