-
Notifications
You must be signed in to change notification settings - Fork 2.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX FE] : Implement jax.lax.iota operation #28221
base: master
Are you sure you want to change the base?
Conversation
Signed-off-by: 11happy <[email protected]>
Hello @rkazants could you please help me with the tests. |
Signed-off-by: 11happy <[email protected]>
build_jenkins |
@11happy, please fix failing CI jobs |
Signed-off-by: 11happy <[email protected]>
build_jenkins |
Hi @11happy, please fix the build:
There is no instantiation for Best regards, |
Signed-off-by: 11happy <[email protected]>
also I wanted to ask do we have any pre-commit hooks or something like that , as most of the time I push some changes build fails due to some formatting error. |
@rkazants humble ping! |
template <> | ||
ov::element::Type NodeContext::const_named_param<ov::element::Type>(const std::string& name) const { | ||
auto c = get_constant_from_params(*this, name); | ||
return c->get_element_type(); | ||
} | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think this is not correct way how we should retrieve output type for iota
operation. That is because it tries to create a constant that saves a value equal to type.
We need to have a different mechanism to parse and transfer such attribute values storing type.
OutputVector translate_iota(const NodeContext& context) { | ||
num_inputs_check(context, 2, 2); | ||
auto dtype = context.const_named_param<ov::element::Type>("dtype"); | ||
auto size = context.const_named_param<int64_t>("size"); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
are we sure that size
is attribute of iota
operation and it can not be non-constant input?
Can you please try to experiment with it and create Jax graph with size input to iota operation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
sure thing!, I will experiment it on my end & will update soon.
thanks
|
||
class TestIota(JaxLayerTest): | ||
def _prepare_input(self): | ||
return (self.input_type, self.input_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is not correct because you generate input data with dtype value that will be as-if passing to neural network.
dtype
is used only for model creation and not more.
Signed-off-by: 11happy <[email protected]>
@rkazants I think we don't need to define |
Overview:
Testing:
CC: