Skip to content

Commit

Permalink
Merge pull request #485 from PsychoinformaticsLab/tf/hub/load
Browse files Browse the repository at this point in the history
Add TFHubAudioExtractor + SPICE Tests
  • Loading branch information
adelavega authored Nov 15, 2022
2 parents ecdf208 + d92c637 commit f0d7499
Show file tree
Hide file tree
Showing 3 changed files with 164 additions and 61 deletions.
3 changes: 2 additions & 1 deletion pliers/extractors/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@
from .misc import MetricExtractor
from .models import (TensorFlowKerasApplicationExtractor,
TFHubImageExtractor, TFHubTextExtractor,
TFHubExtractor)
TFHubExtractor, TFHubAudioExtractor)
from .text import (ComplexTextExtractor, DictionaryExtractor,
PredefinedDictionaryExtractor, LengthExtractor,
NumUniqueWordsExtractor, PartOfSpeechExtractor,
Expand Down Expand Up @@ -128,6 +128,7 @@
'TFHubExtractor',
'TFHubImageExtractor',
'TFHubTextExtractor',
'TFHubAudioExtractor',
'ComplexTextExtractor',
'DictionaryExtractor',
'PredefinedDictionaryExtractor',
Expand Down
201 changes: 142 additions & 59 deletions pliers/extractors/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from pliers.extractors.image import ImageExtractor
from pliers.extractors.base import Extractor, ExtractorResult
from pliers.filters.image import ImageResizingFilter
from pliers.stimuli import ImageStim, TextStim
from pliers.stimuli import ImageStim, TextStim, AudioStim
from pliers.stimuli.base import Stim
from pliers.support.exceptions import MissingDependencyError
from pliers.utils import (attempt_to_import, verify_dependencies,
Expand All @@ -26,34 +26,34 @@ class TFHubExtractor(Extractor):
Args:
url_or_path (str): url or path to TFHub model. You can
browse models at https://tfhub.dev/.
features (optional): list of labels (for classification)
or other feature names. The number of items must
match the number of features in the output. For example,
if a classification model with 1000 output classes is passed
features (optional): list of feature names matching output dimensions
For example, for a classification model with 1000 output classes
this must be a list containing 1000 items.
(e.g. EfficientNet B6,
see https://tfhub.dev/tensorflow/efficientnet/b6/classification/1),
this must be a list containing 1000 items. If a text encoder
outputting 768-dimensional encoding is passed (e.g. base BERT),
this must be a list containing 768 items. Each dimension in the
model output will be returned as a separate feature in the
ExtractorResult.
https://tfhub.dev/tensorflow/efficientnet/b6/classification/1),
Alternatively, the model output can be packed into a single
feature (i.e. a vector) by passing a single-element list
(e.g. ['encoding']) or a string. Along the lines of
the previous examples, if a single feature name is
passed here (e.g. if features=['encoding']) for a TFHub model
that outputs a 768-dimensional encoding, the extractor will
return only one feature named 'encoding', which contains the
encoding vector as a 1-d array wrapped in a list.
or a string. For example, for a model that outputs a
768-dimensional encoding, the value 'encoding' will result
in a 1-d array wrapped in a list named 'encoding'.
If no value is passed, the extractor will automatically
compute the number of features in the model output
and return an equal number of features in pliers, labeling
each feature with a generic prefix + its positional index
in the model output (feature_0, feature_1, ... ,feature_n).
Note that for saved models, the feature names are inferred
from the output signature, but can be over-ridden.
transform_out (optional): function to transform model
output for compatibility with extractor result
transform_inp (optional): function to transform Stim.data
for compatibility with model input format
output_key (str): key to desired in output in
dictionary. Set to None if the output is not a dictionary,
or to output all keys in dictionary.
keras_kwargs (dict): arguments to hub.KerasLayer call
'''

Expand All @@ -62,24 +62,36 @@ class TFHubExtractor(Extractor):

def __init__(self, url_or_path, features=None,
transform_out=None, transform_inp=None,
keras_kwargs=None):
output_key=None, keras_kwargs=None):
verify_dependencies(['tensorflow_hub'])
if keras_kwargs is None:
keras_kwargs = {}
self.keras_kwargs = keras_kwargs
self.output_key = output_key
self.model = hub.KerasLayer(url_or_path, **keras_kwargs)

self.url_or_path = url_or_path
self.features = features
self.transform_out = transform_out
self.transform_inp = transform_inp
super().__init__()

def get_feature_names(self, out):
# Manual feature names always take precedence
if self.features:
return listify(self.features)
# Infer feature names from output
else:
return ['feature_' + str(i)
for i in range(out.shape[-1])]
# If dict, use provided output key, or all keys
if isinstance(out, dict):
if self.output_key:
return [self.output_key]
else:
return list(out.keys())
# Worst case, use generic feature names
else:
return ['feature_' + str(i)
for i in range(out.shape[-1])]

def _preprocess(self, stim):
if self.transform_inp:
Expand All @@ -91,17 +103,56 @@ def _preprocess(self, stim):
return stim.data

def _postprocess(self, out):
# If key is provided, return only that key
if self.output_key:
try:
out = out[self.output_key]
except KeyError:
raise ValueError(f'{self.output_key} is not a valid key.'
'Check which keys are available in the output '
'at the model URL ({self.url_or_path})')
except (IndexError, TypeError):
raise ValueError(f'Model output is not a dictionary. '
'Try initialize the extractor with output_key=None.')

# If output is a dict and no output key, return all keys
if isinstance(out, dict):
out = np.vstack(list(out.values())).T
elif isinstance(out, tf.Tensor):
out = out.numpy()

# Always squeeze last dimension if it is 1
out = out.squeeze()

if self.transform_out:
out = self.transform_out(out)
return out.numpy().squeeze()
return out

def _get_timing(self, out, stim):
""" Returns the timing of the output.
Args:
out: output of the model
stim: input stimulus
Returns:
onsets: onsets of the output
durations: durations of the output
orders: order of the output
"""

return stim.onset, stim.duration, None

def _extract(self, stim):
inp = self._preprocess(stim)
out = self.model(inp)
out = self._postprocess(out)
features = self.get_feature_names(out)
out = self._postprocess(out)

onsets, durations, orders = self._get_timing(out, stim)

return ExtractorResult(listify(out), stim, self,
features=features)
onsets=onsets, durations=durations,
features=features, orders=orders)

def _to_df(self, result):
if len(result.features) == 1:
Expand All @@ -116,41 +167,84 @@ def _to_df(self, result):

class TFHubImageExtractor(TFHubExtractor):

''' TFHub Extractor class for image models
''' TFHub Extractor class for image models.
Note that some models may require specific input shapes.'
You can reshape inputs using filters, such as ImageResizingFilter.
ImageRescaleFilter.
Args:
url_or_path (str): url or path to TFHub model
features (optional): list of labels (for classification)
or other feature names. If not specified, returns
numbered features (feature_0, feature_1, ... ,feature_n)
keras_kwargs (dict): arguments to hub.KerasLayer call
input_dtype (optional): dtype of input data. Defaults to tf.float32
'''

_input_type = ImageStim
_log_attributes = ('url_or_path', 'features', 'keras_kwargs')

def __init__(self,
url_or_path,
features=None,
input_dtype=None,
keras_kwargs=None):
**kwargs):

self.input_dtype = input_dtype if input_dtype else tf.float32
if keras_kwargs is None:
keras_kwargs = {}
self.keras_kwargs = keras_kwargs

logging.warning('Some models may require specific input shapes.'
' Incompatible shapes may raise errors'
' at extraction. If needed, you can reshape'
' your input image using ImageResizingFilter, '
' and rescale using ImageRescalingFilter')
super().__init__(url_or_path, features, keras_kwargs=keras_kwargs)
super().__init__(url_or_path, **kwargs)

def _preprocess(self, stim):
x = tf.convert_to_tensor(stim.data, dtype=self.input_dtype)
x = tf.expand_dims(x, axis=0)
return x

class TFHubAudioExtractor(TFHubExtractor):

''' TFHub Extractor class for audio models.
Note that some models may require a specific sampling frequency.'
You can resample inputs using AudioResamplingFilter.
Args:
url_or_path (str): url or path to TFHub model
input_dtype (optional): dtype of input data. Defaults to tf.float32
'''

_input_type = AudioStim

def __init__(self,
url_or_path,
input_dtype=None,
**kwargs):

self.input_dtype = input_dtype if input_dtype else tf.float32

super().__init__(url_or_path, **kwargs)

def _preprocess(self, stim):
x = tf.convert_to_tensor(stim.data, dtype=self.input_dtype)
return x

def _get_timing(self, out, stim):
""" Returns the timing of the output.
Assumes model returns a fixed sampling frequency,
and deduces durations and onsets from the sampling frequency.
Args:
out: output of the model
stim: input stimulus
Returns:
onsets: onsets of the output
durations: durations of the output
orders: order of the output
"""

durations = [stim.duration / out.shape[0]] * out.shape[0]
onsets = np.arange(0, stim.duration, durations[0])
if stim.onset is not None:
onsets += stim.onset
onsets = onsets.tolist()
orders = range(0, len(onsets))

return onsets, durations, orders

class TFHubTextExtractor(TFHubExtractor):

Expand All @@ -162,7 +256,11 @@ class TFHubTextExtractor(TFHubExtractor):
The number of items must match the number of features
in the model output. For example, if a text encoder
outputting 768-dimensional encoding is passed
(e.g. base BERT), this must be a list containing 768 items.
output_key (str): key to desired embedding in output
dictionary (see documentation at
https://www.tensorflow.org/hub/common_saved_model_apis/text).
Set to None is the output is not a dictionary, or to
output all keys (e.g. base BERT), this must be a list containing 768 items.
Each dimension in the model output will be returned as a
separate feature in the ExtractorResult.
Alternatively, the model output can be packed into a single
Expand All @@ -176,7 +274,8 @@ class TFHubTextExtractor(TFHubExtractor):
output_key (str): key to desired embedding in output
dictionary (see documentation at
https://www.tensorflow.org/hub/common_saved_model_apis/text).
Set to None is the output is not a dictionary.
Set to None is the output is not a dictionary, or to
output all keys
preprocessor_url_or_path (str): if the model requires
preprocessing through another TFHub model, specifies the
url or path to the preprocessing module. Information on
Expand All @@ -196,7 +295,7 @@ def __init__(self,
preprocessor_kwargs=None,
keras_kwargs=None,
**kwargs):
super().__init__(url_or_path, features,
super().__init__(url_or_path, features, output_key=output_key,
keras_kwargs=keras_kwargs,
**kwargs)
self.output_key = output_key
Expand All @@ -217,22 +316,6 @@ def _preprocess(self, stim):
self.preprocessor_url_or_path, **self.preprocessor_kwargs)
x = preprocessor(x)
return x

def _postprocess(self, out):
if not self.output_key:
return out.numpy().squeeze()
else:
try:
return out[self.output_key].numpy().squeeze()
except KeyError:
raise ValueError(f'{self.output_key} is not a valid key.'
'Check which keys are available in the output '
'embedding dictionary in TFHub docs '
'(https://www.tensorflow.org/hub/common_saved_model_apis/text)'
f' or at the model URL ({self.url_or_path})')
except (IndexError, TypeError):
raise ValueError(f'Model output is not a dictionary. '
'Try initialize the extractor with output_key=None.')


class TensorFlowKerasApplicationExtractor(ImageExtractor):
Expand Down
21 changes: 20 additions & 1 deletion pliers/tests/extractors/test_model_extractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from pliers.extractors import (TensorFlowKerasApplicationExtractor,
TFHubExtractor,
TFHubImageExtractor,
TFHubAudioExtractor,
TFHubTextExtractor,
BertExtractor,
BertSequenceEncodingExtractor,
Expand Down Expand Up @@ -40,7 +41,7 @@
TOKENIZER_URL = 'https://tfhub.dev/tensorflow/bert_en_uncased_preprocess/2'
ELECTRA_URL = 'https://tfhub.dev/google/electra_small/2'
SPEECH_URL = 'https://tfhub.dev/google/speech_embedding/1'

SPICE_URL = 'https://tfhub.dev/google/spice/2'

pytestmark = pytest.mark.skipif(
environ.get('skip_high_memory', False) == 'true', reason='high memory')
Expand Down Expand Up @@ -447,3 +448,21 @@ def compute_expected_length(stim, ext):
with pytest.raises(ValueError) as err:
AudiosetLabelExtractor(top_n=10, labels=labels)
assert 'Top_n and labels are mutually exclusive' in str(err.value)

def test_spice_extractor():
audio_stim = AudioStim(join(AUDIO_DIR, 'homer.wav'))
audio_filter = AudioResamplingFilter(target_sr=16000)
audio_resampled = audio_filter.transform(audio_stim)

ext = TFHubAudioExtractor(SPICE_URL, keras_kwargs=dict(
signature='serving_default', signature_outputs_as_dict=True))
r_orig = ext.transform(audio_stim).to_df()
assert r_orig.shape == (74, 6)

r_orig.onset.min() == 0.0
r_orig.duration.min() == r_orig.duration.max() == 0.04594594594594594
r_orig.uncertainty[0] == 0.974131
r_orig.pitch[0] == 0.171392



0 comments on commit f0d7499

Please sign in to comment.