-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Extract tf.function decorator keyword arguments #144
base: main
Are you sure you want to change the base?
Changes from 10 commits
39c7fb5
f21e423
747d5c5
8f8c0b1
4ffa27a
21a8004
89e4dd0
b6f4e9c
b9c8107
d6b110c
26fd8f1
9a4cb62
d0d0875
015965a
6102351
a62af55
f7a9373
c9ee308
0ab5892
db5147f
b9bd9be
9db5396
ec84857
74b4bc4
e12fb16
7aef65b
9f49fea
279089b
ef5d5bb
64bd2b6
f53e997
1df7219
1b2ac89
c4184d4
d2ae4c9
b48960d
1eb7274
9195db7
eba9736
dd82611
df97664
54d78be
10cb16b
303a117
fc90f50
ecffbc5
3577a39
0c72e7d
94733e6
41bfe8d
9003639
6fd42b2
1682088
02ef49b
ece2522
93f4e9c
af61fe5
247e1d6
9b764d9
bc3ff06
cb23e2f
79d0ff3
2c93c3f
e690e4f
bc534a4
9163368
09d52cb
d845f0a
6a13a5e
4c79a5a
720bba7
5d63094
0c345f0
44b8b7b
c5b576b
f2238cf
45bcfd2
a8ddc42
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -29,6 +29,7 @@ | |
import org.python.pydev.parser.visitors.TypeInfo; | ||
|
||
import edu.cuny.citytech.refactoring.common.core.RefactorableProgramEntity; | ||
import edu.cuny.hunter.hybridize.core.analysis.TensorSpec.Dtype; | ||
|
||
/** | ||
* A representation of a Python function. | ||
|
@@ -38,10 +39,14 @@ | |
*/ | ||
public class Function extends RefactorableProgramEntity { | ||
|
||
public enum TfAutographExperimentalFeature { | ||
ALL, AUTO_CONTROL_DEPS, ASSERT_STATEMENTS, BUILTIN_FUNCTIONS, EQUALITY_OPERATORS, LISTS, NAME_SCOPES | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What do each of these mean? What is the URL where we can find out more info? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added an URL in the javadoc |
||
} | ||
|
||
/** | ||
* Parameters that may be passed to a tf.fuction decorator. Parameter descriptions found at: | ||
* https://tensorflow.org/versions/r2.9/api_docs/python/tf/function Note: We are also parsing the deprecated parameters specified in | ||
* the documentation. Users can still use these deprecated parameters. Therefore we need to be able to account for them. Please refer to | ||
* https://tensorflow.org/versions/r2.9/api_docs/python/tf/function Note: We are also parsing the deprecated parameters specified in the | ||
* documentation. Users can still use these deprecated parameters. Therefore we need to be able to account for them. Please refer to | ||
* https://github.com/ponder-lab/Hybridize-Functions-Refactoring/wiki/tf.function-parameter's-version-information to see more | ||
* information about the tf.function parameters according to the versions. | ||
*/ | ||
|
@@ -81,13 +86,14 @@ public class HybridizationParameters { | |
|
||
/** | ||
* Value of this {@link Function}'s {@link decoratorsType} parameter experimental_autograph_options. The values could be an optional | ||
* tuple or value of tf.autograph.experimental.Feature values or None. | ||
* tuple or value of tf.autograph.experimental.Feature values (e.g. | ||
* <code>tf.autograph.experimental.Feature.EQUALITY_OPERATORS</code>) or None. | ||
*/ | ||
private String experimentalAutographOptionsParam; | ||
private java.util.List<TfAutographExperimentalFeature> experimentalAutographOptionsParam; | ||
|
||
/** | ||
* Value of this {@link Function}'s {@link decoratorsType} parameter experimental_implements. The value could be None or a name of a | ||
* "known" function this implements. | ||
* "known" function this implements (e.g. <code>embedded_matmul</code>). | ||
*/ | ||
private String experimentalImplementsParam; | ||
|
||
|
@@ -101,7 +107,7 @@ public class HybridizationParameters { | |
* Value of this {@link Function}'s {@link decoratorsType} parameter input_signature. The value could be None, or a possibly nested | ||
* sequence of tf.TensorSpec objects specifying the shapes and dtypes of the Tensors that will be supplied to this function. | ||
*/ | ||
private ArrayList<TensorSpec> inputSignatureParam; | ||
private java.util.List<TensorSpec> inputSignatureParam; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
/** | ||
* Value of this {@link Function}'s {@link decoratorsType} parameter jit_compile. The values could be None, False or True. null | ||
|
@@ -151,10 +157,7 @@ public HybridizationParameters(IProgressMonitor monitor) throws BadLocationExcep | |
// Example of value: Name of function or None | ||
if (keyword.value instanceof Name) { | ||
Name value = (Name) keyword.value; | ||
if (value.id == "None") // Checking only literals | ||
// Default value | ||
this.funcParam = null; | ||
else | ||
if (value.id != "None") // Checking only literals | ||
throw new IllegalArgumentException("Unable to process " + FUNC + " argument."); | ||
} else { | ||
throw new IllegalArgumentException("Unable to process " + FUNC + " argument."); | ||
|
@@ -164,10 +167,7 @@ public HybridizationParameters(IProgressMonitor monitor) throws BadLocationExcep | |
// Example of value: None | ||
if (keyword.value instanceof Name) { | ||
Name value = (Name) keyword.value; | ||
if (value.id == "None") // Checking only literals | ||
// Default value | ||
this.inputSignatureParam = null; | ||
else | ||
if (value.id != "None") // Checking only literals | ||
throw new IllegalArgumentException("Unable to process " + INPUT_SIGNATURE + " argument."); | ||
// Example: (tf.TensorSpec(shape=[None], dtype=tf.float32),) | ||
} else if (keyword.value instanceof Tuple) { | ||
|
@@ -189,12 +189,9 @@ public HybridizationParameters(IProgressMonitor monitor) throws BadLocationExcep | |
// Example of value: True, False | ||
if (keyword.value instanceof Name) { | ||
Name value = (Name) keyword.value; | ||
if (value.id == "True")// Checking only literals | ||
// Default value | ||
this.autoGraphParam = true; | ||
else if (value.id == "False") | ||
if (value.id == "False") | ||
this.autoGraphParam = false; | ||
else | ||
else if (value.id != "True") | ||
throw new IllegalArgumentException("Unable to process " + AUTOGRAPH + " argument."); | ||
} else { | ||
throw new IllegalArgumentException("Unable to process " + AUTOGRAPH + " argument."); | ||
|
@@ -211,10 +208,7 @@ else if (value.id == "False") | |
this.jitCompileParam = true; | ||
else if (value.id == "False") | ||
this.jitCompileParam = false; | ||
else if (value.id == "None") | ||
// Default value | ||
this.jitCompileParam = null; | ||
else | ||
else if (value.id != "None") | ||
throw new IllegalArgumentException( | ||
"Unable to process " + JIT_COMPILE + "/" + EXPERIMENTAL_COMPILE + " argument."); | ||
} else { | ||
|
@@ -231,10 +225,7 @@ else if (value.id == "None") | |
Name value = (Name) keyword.value; | ||
if (value.id == "True") // Checking only literals | ||
this.reduceRetracingParam = true; | ||
else if (value.id == "False") // Checking only literals | ||
// Default value | ||
this.reduceRetracingParam = false; | ||
else | ||
else if (value.id != "False") // Checking only literals | ||
throw new IllegalArgumentException( | ||
"Unable to process " + REDUCE_RETRACING + "/" + EXPERIMENTAL_RELAX_SHAPES + " argument."); | ||
} else { | ||
|
@@ -250,45 +241,38 @@ else if (value.id == "False") // Checking only literals | |
// Example of value: None | ||
} else if (keyword.value instanceof Name) { | ||
Name value = (Name) keyword.value; | ||
if (value.id == "None") // Checking only literals | ||
// Default value | ||
this.experimentalImplementsParam = null; | ||
else | ||
if (value.id != "None") // Checking only literals | ||
throw new IllegalArgumentException("Unable to process " + EXPERIMENTAL_IMPLEMENTS + " argument."); | ||
} else { | ||
throw new IllegalArgumentException("Unable to process " + EXPERIMENTAL_IMPLEMENTS + " argument."); | ||
} | ||
} else if (name.id.equals(EXPERIMENTAL_AUTOGRAPH_OPTIONS)) { | ||
java.util.List<TfAutographExperimentalFeature> autographExperimental = new ArrayList<>(); | ||
// Found parameter experimental_autograph_options | ||
// Example of value: tf.autograph.experimental.Feature.EQUALITY_OPERATORS | ||
if (keyword.value instanceof Attribute) { | ||
Attribute keywordAttribute = (Attribute) keyword.value; | ||
this.experimentalAutographOptionsParam = processAttributeForAutographOptions(keywordAttribute); | ||
autographExperimental.add(processAttributeForAutographOptions(keywordAttribute)); | ||
this.experimentalAutographOptionsParam = autographExperimental; | ||
// Example of value: (tf.autograph.experimental.Feature.EQUALITY_OPERATORS, | ||
// tf.autograph.experimental.Feature.BUILTIN_FUNCTIONS) | ||
} else if (keyword.value instanceof Tuple) { | ||
Tuple keywordTuple = (Tuple) keyword.value; | ||
exprType[] keywordExpr = keywordTuple.elts; | ||
String finalTuple = ""; | ||
int count = 0; | ||
for (exprType expr : keywordExpr) { | ||
if (expr instanceof Attribute) { | ||
Attribute keywordAttribute = (Attribute) expr; | ||
if (count == 0) | ||
finalTuple += processAttributeForAutographOptions(keywordAttribute); | ||
else | ||
finalTuple += ", " + processAttributeForAutographOptions(keywordAttribute); | ||
autographExperimental.add(processAttributeForAutographOptions(keywordAttribute)); | ||
} else { | ||
throw new IllegalArgumentException( | ||
"Unable to process " + EXPERIMENTAL_AUTOGRAPH_OPTIONS + " arguments"); | ||
} | ||
count++; | ||
} | ||
this.experimentalAutographOptionsParam = "(" + finalTuple + ")"; | ||
this.experimentalAutographOptionsParam = autographExperimental; | ||
// Example of value: None | ||
} else if (keyword.value instanceof Name) { | ||
Name value = (Name) keyword.value; | ||
if (value.id == "None") // Checking only literals | ||
// Default value | ||
this.experimentalAutographOptionsParam = null; | ||
else | ||
if (value.id != "None") // Checking only literals | ||
throw new IllegalArgumentException( | ||
"Unable to process " + EXPERIMENTAL_AUTOGRAPH_OPTIONS + " argument."); | ||
} else { | ||
|
@@ -300,14 +284,11 @@ else if (value.id == "False") // Checking only literals | |
// Example of value: True, False, None | ||
if (keyword.value instanceof Name) { | ||
Name value = (Name) keyword.value; | ||
if (value.id == "None") // Checking only literals | ||
// Default value | ||
this.experimentaFollowTypeHintsParam = null; | ||
else if (value.id == "True") // Checking only literals | ||
if (value.id == "True") // Checking only literals | ||
this.experimentaFollowTypeHintsParam = true; | ||
else if (value.id == "False") // Checking only literals | ||
this.experimentaFollowTypeHintsParam = false; | ||
else | ||
else if (value.id != "None") | ||
throw new IllegalArgumentException( | ||
"Unable to process " + EXPERIMENTAL_FOLLOW_TYPE_HINTS + " argument."); | ||
} else { | ||
|
@@ -325,27 +306,20 @@ else if (value.id == "False") // Checking only literals | |
* | ||
* @return String of TensorSpec shape. | ||
*/ | ||
private String processTupleOrListForShape(exprType[] exprTupleOrList) { | ||
int count = 0; | ||
String tempString = ""; | ||
private java.util.List<Integer> processTupleOrListForShape(exprType[] exprTupleOrList) { | ||
java.util.List<Integer> shape = new ArrayList<>(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We already have another import for List (org.python.pydev.parser.jython.ast.List). Therefore, I had to specify which List it was by doing java.util.List. It is a list of dimensions (https://www.tensorflow.org/versions/r2.9/api_docs/python/tf/TensorShape) which can be integers or |
||
|
||
for (exprType expr : exprTupleOrList) { | ||
if (expr instanceof Num) { | ||
if (count == 0) | ||
tempString += ((Num) expr).num; | ||
else | ||
tempString += ", " + ((Num) expr).num; | ||
count++; | ||
shape.add(Integer.parseInt(((Num) expr).num)); | ||
} else if (expr instanceof Name) { | ||
if (((Name) expr).id == "None") // Checking only literals | ||
tempString = ((Name) expr).id; | ||
else | ||
if (((Name) expr).id != "None") // Checking only literals | ||
throw new IllegalArgumentException("Unable to process " + INPUT_SIGNATURE + " argument."); | ||
} else | ||
throw new IllegalArgumentException("Unable to process " + INPUT_SIGNATURE + " argument."); | ||
} | ||
|
||
return tempString; | ||
return shape; | ||
|
||
} | ||
|
||
|
@@ -354,21 +328,94 @@ private String processTupleOrListForShape(exprType[] exprTupleOrList) { | |
* | ||
* @return String of autograph options that contains various attributes. | ||
*/ | ||
private String processAttributeForAutographOptions(Attribute keywordAttribute) { | ||
StringBuilder argument = new StringBuilder(); | ||
private TfAutographExperimentalFeature processAttributeForAutographOptions(Attribute keywordAttribute) { | ||
String attributeEnum; | ||
Attribute tempAttr = keywordAttribute; | ||
|
||
while (tempAttr.value instanceof Attribute) { | ||
if (tempAttr.value instanceof Attribute) { | ||
NameTok valueAttribute = (NameTok) tempAttr.attr; | ||
argument.insert(0, valueAttribute.id); | ||
argument.insert(0, "."); | ||
if (tempAttr.value instanceof Attribute) | ||
tempAttr = (Attribute) tempAttr.value; | ||
else | ||
throw new IllegalArgumentException("Unable to process " + EXPERIMENTAL_AUTOGRAPH_OPTIONS + " argument."); | ||
} | ||
attributeEnum = valueAttribute.id; | ||
} else | ||
throw new IllegalArgumentException("Unable to process " + EXPERIMENTAL_AUTOGRAPH_OPTIONS + " argument."); | ||
|
||
LOG.info(attributeEnum); | ||
|
||
if (attributeEnum.equals("ALL")) | ||
return TfAutographExperimentalFeature.ALL; | ||
else if (attributeEnum.equals("AUTO_CONTROL_DEPS")) | ||
return TfAutographExperimentalFeature.AUTO_CONTROL_DEPS; | ||
else if (attributeEnum.equals("ASSERT_STATEMENTS")) | ||
return TfAutographExperimentalFeature.ASSERT_STATEMENTS; | ||
else if (attributeEnum.equals("BUILTIN_FUNCTIONS")) | ||
return TfAutographExperimentalFeature.BUILTIN_FUNCTIONS; | ||
else if (attributeEnum.equals("EQUALITY_OPERATORS")) | ||
return TfAutographExperimentalFeature.EQUALITY_OPERATORS; | ||
else if (attributeEnum.equals("LISTS")) | ||
return TfAutographExperimentalFeature.LISTS; | ||
else if (attributeEnum.equals("NAME_SCOPES")) | ||
return TfAutographExperimentalFeature.NAME_SCOPES; | ||
else | ||
return null; | ||
} | ||
|
||
/** | ||
* Classifies the dtype of TensorSpec to return a dtype of the TensorSpec autograph options. | ||
* | ||
* @return Dtype of TensorSpec. | ||
*/ | ||
private Dtype determineDtypeForAutographOptions(String typeString) { | ||
|
||
if (typeString.equals("bfloat16")) | ||
return Dtype.bfloat16; | ||
else if (typeString.equals("bool")) | ||
return Dtype.bool; | ||
else if (typeString.equals("complex128")) | ||
return Dtype.complex128; | ||
else if (typeString.equals("complex64")) | ||
return Dtype.complex64; | ||
else if (typeString.equals("float16")) | ||
return Dtype.float16; | ||
else if (typeString.equals("float32")) | ||
return Dtype.float32; | ||
else if (typeString.equals("float64")) | ||
return Dtype.float64; | ||
else if (typeString.equals("half")) | ||
return Dtype.half; | ||
else if (typeString.equals("int16")) | ||
return Dtype.int16; | ||
else if (typeString.equals("int32")) | ||
return Dtype.int32; | ||
else if (typeString.equals("int64")) | ||
return Dtype.int64; | ||
else if (typeString.equals("int8")) | ||
return Dtype.int8; | ||
else if (typeString.equals("qint16")) | ||
return Dtype.qint16; | ||
else if (typeString.equals("qint32")) | ||
return Dtype.qint32; | ||
else if (typeString.equals("qint8")) | ||
return Dtype.qint8; | ||
else if (typeString.equals("quint16")) | ||
return Dtype.quint16; | ||
else if (typeString.equals("quint8")) | ||
return Dtype.quint8; | ||
else if (typeString.equals("resource")) | ||
return Dtype.resource; | ||
else if (typeString.equals("string")) | ||
return Dtype.string; | ||
else if (typeString.equals("uint16")) | ||
return Dtype.uint16; | ||
else if (typeString.equals("uint32")) | ||
return Dtype.uint32; | ||
else if (typeString.equals("uint64")) | ||
return Dtype.uint64; | ||
else if (typeString.equals("uint8")) | ||
return Dtype.uint8; | ||
else if (typeString.equals("variant")) | ||
return Dtype.variant; | ||
else | ||
return null; | ||
|
||
return ((Name) tempAttr.value).id + "." + ((NameTok) tempAttr.attr).id + argument.toString(); | ||
} | ||
|
||
/** | ||
|
@@ -394,7 +441,8 @@ else if (tensorArg instanceof List) | |
tensor.setShape(processTupleOrListForShape(((List) tensorArg).elts)); | ||
else if (tensorArg instanceof Attribute) { | ||
Attribute attrValue = (Attribute) tensorArg; | ||
tensor.setDType(((Name) attrValue.value).id + "." + ((NameTok) attrValue.attr).id); | ||
Dtype dtype = determineDtypeForAutographOptions(((NameTok) attrValue.attr).id); | ||
tensor.setDType(dtype); | ||
} else | ||
throw new IllegalArgumentException("Unable to process " + INPUT_SIGNATURE + " argument."); | ||
} | ||
|
@@ -407,7 +455,8 @@ else if (keyword.value instanceof List) | |
tensor.setShape(processTupleOrListForShape(((List) keyword.value).elts)); | ||
else if (keyword.value instanceof Attribute) { | ||
Attribute attrValue = (Attribute) keyword.value; | ||
tensor.setDType(((Name) attrValue.value).id + "." + ((NameTok) attrValue.attr).id); | ||
Dtype dtype = determineDtypeForAutographOptions(((NameTok) attrValue.attr).id); | ||
tensor.setDType(dtype); | ||
} else { | ||
throw new IllegalArgumentException("Unable to process " + INPUT_SIGNATURE + " argument."); | ||
} | ||
|
@@ -518,7 +567,7 @@ public boolean getAutoGraphArg() { | |
* | ||
* @return String of this {@link decoratorType} parameter experimental_autograph_options. | ||
*/ | ||
public String getExperimentalAutographOptArg() { | ||
public java.util.List<TfAutographExperimentalFeature> getExperimentalAutographOptArg() { | ||
return this.experimentalAutographOptionsParam; | ||
} | ||
|
||
|
@@ -554,7 +603,7 @@ public String getFuncArg() { | |
* | ||
* @return ArrayList of TensorSpecs of this {@link decoratorType} parameter input_signature. | ||
*/ | ||
public ArrayList<TensorSpec> getInputSignatureArg() { | ||
public java.util.List<TensorSpec> getInputSignatureArg() { | ||
return this.inputSignatureParam; | ||
} | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Need a javadoc comment here explaining what this is and a URL for more info.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added.