Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extract tf.function decorator keyword arguments #144

Open
wants to merge 78 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 10 commits
Commits
Show all changes
78 commits
Select commit Hold shift + click to select a range
39c7fb5
update
tatianacv Jan 31, 2023
f21e423
progress
tatianacv Jan 31, 2023
747d5c5
Progress
tatianacv Jan 31, 2023
8f8c0b1
Progress
tatianacv Jan 31, 2023
4ffa27a
Progress
tatianacv Jan 31, 2023
21a8004
Progress
tatianacv Jan 31, 2023
89e4dd0
Progress
tatianacv Feb 1, 2023
b6f4e9c
Progress
tatianacv Feb 1, 2023
b9c8107
adding more testing
tatianacv Feb 1, 2023
d6b110c
Fixing trailing whitespace of tests
tatianacv Feb 1, 2023
26fd8f1
Fix to test
tatianacv Feb 2, 2023
9a4cb62
adding other params
tatianacv Feb 2, 2023
d0d0875
Progress
tatianacv Feb 2, 2023
015965a
Modifying test resource
tatianacv Feb 2, 2023
6102351
Modifying test resource
tatianacv Feb 2, 2023
a62af55
Renaming variable
tatianacv Feb 5, 2023
f7a9373
Formatting
tatianacv Feb 5, 2023
c9ee308
Adding exceptions
tatianacv Feb 6, 2023
0ab5892
Add more info to exception
tatianacv Feb 6, 2023
db5147f
Adding new line to test resource
tatianacv Feb 6, 2023
b9bd9be
Restructuring exceptions
tatianacv Feb 6, 2023
9db5396
Adding new line to test resource
tatianacv Feb 6, 2023
ec84857
Fixing variable name
tatianacv Feb 6, 2023
74b4bc4
Progress
tatianacv Feb 7, 2023
e12fb16
Removing unnecessary else
tatianacv Feb 7, 2023
7aef65b
Progress
tatianacv Feb 8, 2023
9f49fea
Update
tatianacv Feb 8, 2023
279089b
Progress
tatianacv Feb 8, 2023
ef5d5bb
Progress
tatianacv Feb 8, 2023
64bd2b6
progress
tatianacv Feb 8, 2023
f53e997
Progress
tatianacv Feb 8, 2023
1df7219
Cleanup
tatianacv Feb 8, 2023
1b2ac89
making sure we are only checking literals
tatianacv Feb 10, 2023
c4184d4
Reorganization
tatianacv Feb 10, 2023
d2ae4c9
Renaming
tatianacv Feb 10, 2023
b48960d
Adding documentation
tatianacv Feb 10, 2023
1eb7274
Adding a test where there should be an exception
tatianacv Feb 10, 2023
9195db7
Formatting
tatianacv Feb 10, 2023
eba9736
Test
tatianacv Feb 10, 2023
dd82611
Revert "Test"
tatianacv Feb 10, 2023
df97664
Update
tatianacv Feb 10, 2023
54d78be
Adding new tests
tatianacv Feb 10, 2023
10cb16b
Removing unnecesary files, and adding new line
tatianacv Feb 10, 2023
303a117
Adding newline
tatianacv Feb 10, 2023
fc90f50
Update
tatianacv Feb 10, 2023
ecffbc5
Adding more tests
tatianacv Feb 14, 2023
3577a39
Trailing whitespace fix
tatianacv Feb 14, 2023
0c72e7d
Adding comments
tatianacv Feb 14, 2023
94733e6
Make sure we are dealing with TensorSpec (input_signature)
tatianacv Feb 14, 2023
41bfe8d
Revert "Make sure we are dealing with TensorSpec (input_signature)"
tatianacv Feb 15, 2023
9003639
Revert "Revert "Make sure we are dealing with TensorSpec (input_signa…
tatianacv Feb 24, 2023
6fd42b2
Fix build
tatianacv Feb 24, 2023
1682088
Adding comments
tatianacv Feb 27, 2023
02ef49b
Merge branch 'main' into 136-extract-tffunction-decorator-arguments
tatianacv Mar 4, 2023
ece2522
Update
tatianacv Mar 6, 2023
93f4e9c
Restructuring
tatianacv Mar 9, 2023
af61fe5
Restructure
tatianacv Mar 9, 2023
247e1d6
Restructure
tatianacv Mar 9, 2023
9b764d9
Remove redundancy
tatianacv Mar 9, 2023
bc3ff06
Restructuring
tatianacv Mar 10, 2023
cb23e2f
Adding more information
tatianacv Mar 10, 2023
79d0ff3
update
tatianacv Mar 13, 2023
2c93c3f
Merge branch 'main' into 136-extract-tffunction-decorator-arguments
tatianacv Mar 16, 2023
e690e4f
Merge branch 'main' into 136-extract-tffunction-decorator-arguments
khatchad Mar 21, 2023
bc534a4
Merge branch 'main' into 136-extract-tffunction-decorator-arguments
tatianacv Mar 21, 2023
9163368
Fixing asserts
tatianacv Mar 22, 2023
09d52cb
Merge branch 'main' into 136-extract-tffunction-decorator-arguments
tatianacv Mar 22, 2023
d845f0a
Progress
tatianacv Mar 23, 2023
6a13a5e
progress
tatianacv Mar 23, 2023
4c79a5a
Progress
tatianacv Mar 23, 2023
720bba7
Merge branch '136-extract-tffunction-decorator-arguments' of https://…
tatianacv Mar 23, 2023
5d63094
Changing dtype
tatianacv Mar 23, 2023
0c345f0
fixing comments
tatianacv Mar 24, 2023
44b8b7b
Merge branch 'main' into 136-extract-tffunction-decorator-arguments
khatchad Mar 28, 2023
c5b576b
Merge branch 'main' into 136-extract-tffunction-decorator-arguments
tatianacv Mar 31, 2023
f2238cf
Remove unnecessary comments, update comments, change tensorspecs shap…
tatianacv Mar 31, 2023
45bcfd2
Adding another test
tatianacv Mar 31, 2023
a8ddc42
remove file
tatianacv Mar 31, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import org.python.pydev.parser.visitors.TypeInfo;

import edu.cuny.citytech.refactoring.common.core.RefactorableProgramEntity;
import edu.cuny.hunter.hybridize.core.analysis.TensorSpec.Dtype;

/**
* A representation of a Python function.
Expand All @@ -38,10 +39,14 @@
*/
public class Function extends RefactorableProgramEntity {

public enum TfAutographExperimentalFeature {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need a javadoc comment here explaining what this is and a URL for more info.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added.

ALL, AUTO_CONTROL_DEPS, ASSERT_STATEMENTS, BUILTIN_FUNCTIONS, EQUALITY_OPERATORS, LISTS, NAME_SCOPES
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What do each of these mean? What is the URL where we can find out more info?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added an URL in the javadoc

}

/**
* Parameters that may be passed to a tf.fuction decorator. Parameter descriptions found at:
* https://tensorflow.org/versions/r2.9/api_docs/python/tf/function Note: We are also parsing the deprecated parameters specified in
* the documentation. Users can still use these deprecated parameters. Therefore we need to be able to account for them. Please refer to
* https://tensorflow.org/versions/r2.9/api_docs/python/tf/function Note: We are also parsing the deprecated parameters specified in the
* documentation. Users can still use these deprecated parameters. Therefore we need to be able to account for them. Please refer to
* https://github.com/ponder-lab/Hybridize-Functions-Refactoring/wiki/tf.function-parameter's-version-information to see more
* information about the tf.function parameters according to the versions.
*/
Expand Down Expand Up @@ -81,13 +86,14 @@ public class HybridizationParameters {

/**
* Value of this {@link Function}'s {@link decoratorsType} parameter experimental_autograph_options. The values could be an optional
* tuple or value of tf.autograph.experimental.Feature values or None.
* tuple or value of tf.autograph.experimental.Feature values (e.g.
* <code>tf.autograph.experimental.Feature.EQUALITY_OPERATORS</code>) or None.
*/
private String experimentalAutographOptionsParam;
private java.util.List<TfAutographExperimentalFeature> experimentalAutographOptionsParam;

/**
* Value of this {@link Function}'s {@link decoratorsType} parameter experimental_implements. The value could be None or a name of a
* "known" function this implements.
* "known" function this implements (e.g. <code>embedded_matmul</code>).
*/
private String experimentalImplementsParam;

Expand All @@ -101,7 +107,7 @@ public class HybridizationParameters {
* Value of this {@link Function}'s {@link decoratorsType} parameter input_signature. The value could be None, or a possibly nested
* sequence of tf.TensorSpec objects specifying the shapes and dtypes of the Tensors that will be supplied to this function.
*/
private ArrayList<TensorSpec> inputSignatureParam;
private java.util.List<TensorSpec> inputSignatureParam;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. Why not just List?
  2. Why use List? Is ordering important? Can you have duplicates?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. We already have another import for List (org.python.pydev.parser.jython.ast.List). Therefore, I had to specify which List it was by doing java.util.List<TensorSpec>.
  2. Order is important, and we can have duplicates.


/**
* Value of this {@link Function}'s {@link decoratorsType} parameter jit_compile. The values could be None, False or True. null
Expand Down Expand Up @@ -151,10 +157,7 @@ public HybridizationParameters(IProgressMonitor monitor) throws BadLocationExcep
// Example of value: Name of function or None
if (keyword.value instanceof Name) {
Name value = (Name) keyword.value;
if (value.id == "None") // Checking only literals
// Default value
this.funcParam = null;
else
if (value.id != "None") // Checking only literals
throw new IllegalArgumentException("Unable to process " + FUNC + " argument.");
} else {
throw new IllegalArgumentException("Unable to process " + FUNC + " argument.");
Expand All @@ -164,10 +167,7 @@ public HybridizationParameters(IProgressMonitor monitor) throws BadLocationExcep
// Example of value: None
if (keyword.value instanceof Name) {
Name value = (Name) keyword.value;
if (value.id == "None") // Checking only literals
// Default value
this.inputSignatureParam = null;
else
if (value.id != "None") // Checking only literals
throw new IllegalArgumentException("Unable to process " + INPUT_SIGNATURE + " argument.");
// Example: (tf.TensorSpec(shape=[None], dtype=tf.float32),)
} else if (keyword.value instanceof Tuple) {
Expand All @@ -189,12 +189,9 @@ public HybridizationParameters(IProgressMonitor monitor) throws BadLocationExcep
// Example of value: True, False
if (keyword.value instanceof Name) {
Name value = (Name) keyword.value;
if (value.id == "True")// Checking only literals
// Default value
this.autoGraphParam = true;
else if (value.id == "False")
if (value.id == "False")
this.autoGraphParam = false;
else
else if (value.id != "True")
throw new IllegalArgumentException("Unable to process " + AUTOGRAPH + " argument.");
} else {
throw new IllegalArgumentException("Unable to process " + AUTOGRAPH + " argument.");
Expand All @@ -211,10 +208,7 @@ else if (value.id == "False")
this.jitCompileParam = true;
else if (value.id == "False")
this.jitCompileParam = false;
else if (value.id == "None")
// Default value
this.jitCompileParam = null;
else
else if (value.id != "None")
throw new IllegalArgumentException(
"Unable to process " + JIT_COMPILE + "/" + EXPERIMENTAL_COMPILE + " argument.");
} else {
Expand All @@ -231,10 +225,7 @@ else if (value.id == "None")
Name value = (Name) keyword.value;
if (value.id == "True") // Checking only literals
this.reduceRetracingParam = true;
else if (value.id == "False") // Checking only literals
// Default value
this.reduceRetracingParam = false;
else
else if (value.id != "False") // Checking only literals
throw new IllegalArgumentException(
"Unable to process " + REDUCE_RETRACING + "/" + EXPERIMENTAL_RELAX_SHAPES + " argument.");
} else {
Expand All @@ -250,45 +241,38 @@ else if (value.id == "False") // Checking only literals
// Example of value: None
} else if (keyword.value instanceof Name) {
Name value = (Name) keyword.value;
if (value.id == "None") // Checking only literals
// Default value
this.experimentalImplementsParam = null;
else
if (value.id != "None") // Checking only literals
throw new IllegalArgumentException("Unable to process " + EXPERIMENTAL_IMPLEMENTS + " argument.");
} else {
throw new IllegalArgumentException("Unable to process " + EXPERIMENTAL_IMPLEMENTS + " argument.");
}
} else if (name.id.equals(EXPERIMENTAL_AUTOGRAPH_OPTIONS)) {
java.util.List<TfAutographExperimentalFeature> autographExperimental = new ArrayList<>();
// Found parameter experimental_autograph_options
// Example of value: tf.autograph.experimental.Feature.EQUALITY_OPERATORS
if (keyword.value instanceof Attribute) {
Attribute keywordAttribute = (Attribute) keyword.value;
this.experimentalAutographOptionsParam = processAttributeForAutographOptions(keywordAttribute);
autographExperimental.add(processAttributeForAutographOptions(keywordAttribute));
this.experimentalAutographOptionsParam = autographExperimental;
// Example of value: (tf.autograph.experimental.Feature.EQUALITY_OPERATORS,
// tf.autograph.experimental.Feature.BUILTIN_FUNCTIONS)
} else if (keyword.value instanceof Tuple) {
Tuple keywordTuple = (Tuple) keyword.value;
exprType[] keywordExpr = keywordTuple.elts;
String finalTuple = "";
int count = 0;
for (exprType expr : keywordExpr) {
if (expr instanceof Attribute) {
Attribute keywordAttribute = (Attribute) expr;
if (count == 0)
finalTuple += processAttributeForAutographOptions(keywordAttribute);
else
finalTuple += ", " + processAttributeForAutographOptions(keywordAttribute);
autographExperimental.add(processAttributeForAutographOptions(keywordAttribute));
} else {
throw new IllegalArgumentException(
"Unable to process " + EXPERIMENTAL_AUTOGRAPH_OPTIONS + " arguments");
}
count++;
}
this.experimentalAutographOptionsParam = "(" + finalTuple + ")";
this.experimentalAutographOptionsParam = autographExperimental;
// Example of value: None
} else if (keyword.value instanceof Name) {
Name value = (Name) keyword.value;
if (value.id == "None") // Checking only literals
// Default value
this.experimentalAutographOptionsParam = null;
else
if (value.id != "None") // Checking only literals
throw new IllegalArgumentException(
"Unable to process " + EXPERIMENTAL_AUTOGRAPH_OPTIONS + " argument.");
} else {
Expand All @@ -300,14 +284,11 @@ else if (value.id == "False") // Checking only literals
// Example of value: True, False, None
if (keyword.value instanceof Name) {
Name value = (Name) keyword.value;
if (value.id == "None") // Checking only literals
// Default value
this.experimentaFollowTypeHintsParam = null;
else if (value.id == "True") // Checking only literals
if (value.id == "True") // Checking only literals
this.experimentaFollowTypeHintsParam = true;
else if (value.id == "False") // Checking only literals
this.experimentaFollowTypeHintsParam = false;
else
else if (value.id != "None")
throw new IllegalArgumentException(
"Unable to process " + EXPERIMENTAL_FOLLOW_TYPE_HINTS + " argument.");
} else {
Expand All @@ -325,27 +306,20 @@ else if (value.id == "False") // Checking only literals
*
* @return String of TensorSpec shape.
*/
private String processTupleOrListForShape(exprType[] exprTupleOrList) {
int count = 0;
String tempString = "";
private java.util.List<Integer> processTupleOrListForShape(exprType[] exprTupleOrList) {
java.util.List<Integer> shape = new ArrayList<>();
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  • Javadoc says it's returning a String. Is that still true?
  • Why not just List?
  • I don't know how a list of integers represents tensor shape. Why a list? Doesn't a tensor shape just have two integers?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We already have another import for List (org.python.pydev.parser.jython.ast.List). Therefore, I had to specify which List it was by doing java.util.List. It is a list of dimensions (https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/TensorShape) which can be integers or None.


for (exprType expr : exprTupleOrList) {
if (expr instanceof Num) {
if (count == 0)
tempString += ((Num) expr).num;
else
tempString += ", " + ((Num) expr).num;
count++;
shape.add(Integer.parseInt(((Num) expr).num));
} else if (expr instanceof Name) {
if (((Name) expr).id == "None") // Checking only literals
tempString = ((Name) expr).id;
else
if (((Name) expr).id != "None") // Checking only literals
throw new IllegalArgumentException("Unable to process " + INPUT_SIGNATURE + " argument.");
} else
throw new IllegalArgumentException("Unable to process " + INPUT_SIGNATURE + " argument.");
}

return tempString;
return shape;

}

Expand All @@ -354,21 +328,94 @@ private String processTupleOrListForShape(exprType[] exprTupleOrList) {
*
* @return String of autograph options that contains various attributes.
*/
private String processAttributeForAutographOptions(Attribute keywordAttribute) {
StringBuilder argument = new StringBuilder();
private TfAutographExperimentalFeature processAttributeForAutographOptions(Attribute keywordAttribute) {
String attributeEnum;
Attribute tempAttr = keywordAttribute;

while (tempAttr.value instanceof Attribute) {
if (tempAttr.value instanceof Attribute) {
NameTok valueAttribute = (NameTok) tempAttr.attr;
argument.insert(0, valueAttribute.id);
argument.insert(0, ".");
if (tempAttr.value instanceof Attribute)
tempAttr = (Attribute) tempAttr.value;
else
throw new IllegalArgumentException("Unable to process " + EXPERIMENTAL_AUTOGRAPH_OPTIONS + " argument.");
}
attributeEnum = valueAttribute.id;
} else
throw new IllegalArgumentException("Unable to process " + EXPERIMENTAL_AUTOGRAPH_OPTIONS + " argument.");

LOG.info(attributeEnum);

if (attributeEnum.equals("ALL"))
return TfAutographExperimentalFeature.ALL;
else if (attributeEnum.equals("AUTO_CONTROL_DEPS"))
return TfAutographExperimentalFeature.AUTO_CONTROL_DEPS;
else if (attributeEnum.equals("ASSERT_STATEMENTS"))
return TfAutographExperimentalFeature.ASSERT_STATEMENTS;
else if (attributeEnum.equals("BUILTIN_FUNCTIONS"))
return TfAutographExperimentalFeature.BUILTIN_FUNCTIONS;
else if (attributeEnum.equals("EQUALITY_OPERATORS"))
return TfAutographExperimentalFeature.EQUALITY_OPERATORS;
else if (attributeEnum.equals("LISTS"))
return TfAutographExperimentalFeature.LISTS;
else if (attributeEnum.equals("NAME_SCOPES"))
return TfAutographExperimentalFeature.NAME_SCOPES;
else
return null;
}

/**
* Classifies the dtype of TensorSpec to return a dtype of the TensorSpec autograph options.
*
* @return Dtype of TensorSpec.
*/
private Dtype determineDtypeForAutographOptions(String typeString) {

if (typeString.equals("bfloat16"))
return Dtype.bfloat16;
else if (typeString.equals("bool"))
return Dtype.bool;
else if (typeString.equals("complex128"))
return Dtype.complex128;
else if (typeString.equals("complex64"))
return Dtype.complex64;
else if (typeString.equals("float16"))
return Dtype.float16;
else if (typeString.equals("float32"))
return Dtype.float32;
else if (typeString.equals("float64"))
return Dtype.float64;
else if (typeString.equals("half"))
return Dtype.half;
else if (typeString.equals("int16"))
return Dtype.int16;
else if (typeString.equals("int32"))
return Dtype.int32;
else if (typeString.equals("int64"))
return Dtype.int64;
else if (typeString.equals("int8"))
return Dtype.int8;
else if (typeString.equals("qint16"))
return Dtype.qint16;
else if (typeString.equals("qint32"))
return Dtype.qint32;
else if (typeString.equals("qint8"))
return Dtype.qint8;
else if (typeString.equals("quint16"))
return Dtype.quint16;
else if (typeString.equals("quint8"))
return Dtype.quint8;
else if (typeString.equals("resource"))
return Dtype.resource;
else if (typeString.equals("string"))
return Dtype.string;
else if (typeString.equals("uint16"))
return Dtype.uint16;
else if (typeString.equals("uint32"))
return Dtype.uint32;
else if (typeString.equals("uint64"))
return Dtype.uint64;
else if (typeString.equals("uint8"))
return Dtype.uint8;
else if (typeString.equals("variant"))
return Dtype.variant;
else
return null;

return ((Name) tempAttr.value).id + "." + ((NameTok) tempAttr.attr).id + argument.toString();
}

/**
Expand All @@ -394,7 +441,8 @@ else if (tensorArg instanceof List)
tensor.setShape(processTupleOrListForShape(((List) tensorArg).elts));
else if (tensorArg instanceof Attribute) {
Attribute attrValue = (Attribute) tensorArg;
tensor.setDType(((Name) attrValue.value).id + "." + ((NameTok) attrValue.attr).id);
Dtype dtype = determineDtypeForAutographOptions(((NameTok) attrValue.attr).id);
tensor.setDType(dtype);
} else
throw new IllegalArgumentException("Unable to process " + INPUT_SIGNATURE + " argument.");
}
Expand All @@ -407,7 +455,8 @@ else if (keyword.value instanceof List)
tensor.setShape(processTupleOrListForShape(((List) keyword.value).elts));
else if (keyword.value instanceof Attribute) {
Attribute attrValue = (Attribute) keyword.value;
tensor.setDType(((Name) attrValue.value).id + "." + ((NameTok) attrValue.attr).id);
Dtype dtype = determineDtypeForAutographOptions(((NameTok) attrValue.attr).id);
tensor.setDType(dtype);
} else {
throw new IllegalArgumentException("Unable to process " + INPUT_SIGNATURE + " argument.");
}
Expand Down Expand Up @@ -518,7 +567,7 @@ public boolean getAutoGraphArg() {
*
* @return String of this {@link decoratorType} parameter experimental_autograph_options.
*/
public String getExperimentalAutographOptArg() {
public java.util.List<TfAutographExperimentalFeature> getExperimentalAutographOptArg() {
return this.experimentalAutographOptionsParam;
}

Expand Down Expand Up @@ -554,7 +603,7 @@ public String getFuncArg() {
*
* @return ArrayList of TensorSpecs of this {@link decoratorType} parameter input_signature.
*/
public ArrayList<TensorSpec> getInputSignatureArg() {
public java.util.List<TensorSpec> getInputSignatureArg() {
return this.inputSignatureParam;
}

Expand Down
Loading