Skip to content

Commit

Permalink
community: Added new Utility runnables for NVIDIA Riva. (#15966)
Browse files Browse the repository at this point in the history
**Please tag this issue with `nvidia_genai`**

- **Description:** Added new Runnables for integration NVIDIA Riva into
LCEL chains for Automatic Speech Recognition (ASR) and Text To Speech
(TTS).
- **Issue:** N/A
- **Dependencies:** To use these runnables, the NVIDIA Riva client
libraries are required. It they are not installed, an error will be
raised instructing how to install them. The Runnables can be safely
imported without the riva client libraries.
- **Twitter handle:** N/A

All of the Riva Runnables are inside a single folder in the Utilities
module. In this folder are four files:
- common.py - Contains all code that is common to both TTS and ASR
- stream.py - Contains a class representing an audio stream that allows
the end user to put data into the stream like a queue.
- asr.py - Contains the RivaASR runnable
- tts.py - Contains the RivaTTS runnable

The following Python function is an example of creating a chain that
makes use of both of these Runnables:

```python
def create(
    config: Configuration,
    audio_encoding: RivaAudioEncoding,
    sample_rate: int,
    audio_channels: int = 1,
) -> Runnable[ASRInputType, TTSOutputType]:
    """Create a new instance of the chain."""
    _LOGGER.info("Instantiating the chain.")

    # create the riva asr client
    riva_asr = RivaASR(
        url=str(config.riva_asr.service.url),
        ssl_cert=config.riva_asr.service.ssl_cert,
        encoding=audio_encoding,
        audio_channel_count=audio_channels,
        sample_rate_hertz=sample_rate,
        profanity_filter=config.riva_asr.profanity_filter,
        enable_automatic_punctuation=config.riva_asr.enable_automatic_punctuation,
        language_code=config.riva_asr.language_code,
    )

    # create the prompt template
    prompt = PromptTemplate.from_template("{user_input}")

    # model = ChatOpenAI()
    model = ChatNVIDIA(model="mixtral_8x7b")  # type: ignore

    # create the riva tts client
    riva_tts = RivaTTS(
        url=str(config.riva_asr.service.url),
        ssl_cert=config.riva_asr.service.ssl_cert,
        output_directory=config.riva_tts.output_directory,
        language_code=config.riva_tts.language_code,
        voice_name=config.riva_tts.voice_name,
    )

    # construct and return the chain
    return {"user_input": riva_asr} | prompt | model | riva_tts  # type: ignore
```

The following code is an example of creating a new audio stream for
Riva:

```python
input_stream = AudioStream(maxsize=1000)
# Send bytes into the stream
for chunk in audio_chunks:
    await input_stream.aput(chunk)
input_stream.close()
```

The following code is an example of how to execute the chain with
RivaASR and RivaTTS

```python
output_stream = asyncio.Queue()
while not input_stream.complete:
    async for chunk in chain.astream(input_stream):
        output_stream.put(chunk)    
```

Everything should be async safe and thread safe. Audio data can be put
into the input stream while the chain is running without interruptions.

---------

Co-authored-by: Hayden Wolff <[email protected]>
Co-authored-by: Hayden Wolff <[email protected]>
Co-authored-by: Hayden Wolff <[email protected]>
Co-authored-by: Erick Friis <[email protected]>
  • Loading branch information
5 people authored Feb 6, 2024
1 parent 2d80155 commit f027696
Show file tree
Hide file tree
Showing 7 changed files with 3,272 additions and 2,174 deletions.
27 changes: 27 additions & 0 deletions libs/community/langchain_community/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,24 @@ def _import_nasa() -> Any:
return NasaAPIWrapper


def _import_nvidia_riva_asr() -> Any:
from langchain_community.utilities.nvidia_riva import RivaASR

return RivaASR


def _import_nvidia_riva_tts() -> Any:
from langchain_community.utilities.nvidia_riva import RivaTTS

return RivaTTS


def _import_nvidia_riva_stream() -> Any:
from langchain_community.utilities.nvidia_riva import AudioStream

return AudioStream


def __getattr__(name: str) -> Any:
if name == "AlphaVantageAPIWrapper":
return _import_alpha_vantage()
Expand Down Expand Up @@ -321,6 +339,12 @@ def __getattr__(name: str) -> Any:
return _import_metaphor_search()
elif name == "NasaAPIWrapper":
return _import_nasa()
elif name == "NVIDIARivaASR":
return _import_nvidia_riva_asr()
elif name == "NVIDIARivaStream":
return _import_nvidia_riva_stream()
elif name == "NVIDIARivaTTS":
return _import_nvidia_riva_tts()
elif name == "OpenWeatherMapAPIWrapper":
return _import_openweathermap()
elif name == "OutlineAPIWrapper":
Expand Down Expand Up @@ -388,6 +412,9 @@ def __getattr__(name: str) -> Any:
"MerriamWebsterAPIWrapper",
"MetaphorSearchAPIWrapper",
"NasaAPIWrapper",
"NVIDIARivaASR",
"NVIDIARivaStream",
"NVIDIARivaTTS",
"OpenWeatherMapAPIWrapper",
"OutlineAPIWrapper",
"Portkey",
Expand Down
Loading

0 comments on commit f027696

Please sign in to comment.