Skip to content

Commit

Permalink
community: Fix rank-llm import paths for new 0.20.3 version (#29154)
Browse files Browse the repository at this point in the history
# **PR title**: "community: Fix rank-llm import paths for new 0.20.3
version"
- The "community" package is being modified to handle updated import
paths for the new `rank-llm` version.

---

## Description
This PR updates the import paths for the `rank-llm` package to account
for changes introduced in version `0.20.3`. The changes ensure
compatibility with both pre- and post-revamp versions of `rank-llm`,
specifically version `0.12.8`. Conditional imports are introduced based
on the detected version of `rank-llm` to handle different path
structures for `VicunaReranker`, `ZephyrReranker`, and `SafeOpenai`.

## Issue
RankLLMRerank usage throws an error when used GPT (not only) when
rank-llm version is > 0.12.8 - #29156

## Dependencies
This change relies on the `packaging` and `pkg_resources` libraries to
handle version checks.

## Twitter handle
@tymzar
  • Loading branch information
tymzar authored Jan 13, 2025
1 parent 0e31153 commit 689592f
Showing 1 changed file with 23 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,14 @@

from copy import deepcopy
from enum import Enum
from importlib.metadata import version
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence

from langchain.retrievers.document_compressors.base import BaseDocumentCompressor
from langchain_core.callbacks.manager import Callbacks
from langchain_core.documents import Document
from langchain_core.utils import get_from_dict_or_env
from packaging.version import Version
from pydantic import ConfigDict, Field, PrivateAttr, model_validator

if TYPE_CHECKING:
Expand Down Expand Up @@ -49,6 +51,10 @@ def validate_environment(cls, values: Dict) -> Any:
if not values.get("client"):
client_name = values.get("model", "zephyr")

is_pre_rank_llm_revamp = Version(version=version("rank_llm")) <= Version(
"0.12.8"
)

try:
model_enum = ModelType(client_name.lower())
except ValueError:
Expand All @@ -58,15 +64,29 @@ def validate_environment(cls, values: Dict) -> Any:

try:
if model_enum == ModelType.VICUNA:
from rank_llm.rerank.vicuna_reranker import VicunaReranker
if is_pre_rank_llm_revamp:
from rank_llm.rerank.vicuna_reranker import VicunaReranker
else:
from rank_llm.rerank.listwise.vicuna_reranker import (
VicunaReranker,
)

values["client"] = VicunaReranker()
elif model_enum == ModelType.ZEPHYR:
from rank_llm.rerank.zephyr_reranker import ZephyrReranker
if is_pre_rank_llm_revamp:
from rank_llm.rerank.zephyr_reranker import ZephyrReranker
else:
from rank_llm.rerank.listwise.zephyr_reranker import (
ZephyrReranker,
)

values["client"] = ZephyrReranker()
elif model_enum == ModelType.GPT:
from rank_llm.rerank.rank_gpt import SafeOpenai
if is_pre_rank_llm_revamp:
from rank_llm.rerank.rank_gpt import SafeOpenai
else:
from rank_llm.rerank.listwise.rank_gpt import SafeOpenai

from rank_llm.rerank.reranker import Reranker

openai_api_key = get_from_dict_or_env(
Expand Down

0 comments on commit 689592f

Please sign in to comment.