Source code for common.rerank

import logging
import os
from abc import ABC
from typing import Dict, List

import httpx
import requests
import tenacity

from cicada.common.utils import colorstring

logger = logging.getLogger(__name__)


[docs] class Rerank(ABC): def __init__( self, api_key: str, api_base_url: str = "https://api.siliconflow.cn/v1", model_name: str = "BAAI/bge-reranker-v2-m3", **model_kwargs, ): """ Initialize the Rerank class. Args: api_key (str): API key for authentication. api_base_url (str, optional): Base URL for the rerank API. Defaults to "https://api.siliconflow.cn/v1". model_name (str, optional): Name of the rerank model. Defaults to "BAAI/bge-reranker-v2-m3". **model_kwargs: Additional model-specific parameters. """ self.api_key = api_key self.api_base_url = os.path.join(api_base_url, "rerank") self.model_name = model_name self.model_kwargs = model_kwargs
[docs] @tenacity.retry( stop=tenacity.stop_after_attempt(3) | tenacity.stop_after_delay(30), # Stop after 3 attempts or 30 seconds wait=tenacity.wait_random_exponential(multiplier=1, min=2, max=10), retry=tenacity.retry_if_exception_type( (httpx.ReadTimeout, httpx.ConnectTimeout) ), # Retry on API errors or network timeouts before_sleep=tenacity.before_sleep_log( logger, logging.WARNING ), # Log before retrying reraise=True, ) def rerank( self, query: str, documents: List[str], top_n: int = 4, return_documents: bool = False, ) -> List[Dict]: """ Rerank a list of documents based on a query. Args: query (str): The query to rerank documents against. documents (List[str]): List of documents to rerank. top_n (int, optional): Number of top documents to return. Defaults to 4. return_documents (bool, optional): Whether to return the full documents or just scores. Defaults to False. Returns: List[Dict]: List of reranked documents or scores. """ payload = { "model": self.model_name, "query": query, "documents": documents, "top_n": top_n, "return_documents": return_documents, **self.model_kwargs, } headers = { "Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json", } logger.debug(colorstring(f"Payload: {payload}", "blue")) logger.debug(colorstring(f"Headers: {headers}", "blue")) try: response = requests.post(self.api_base_url, json=payload, headers=headers) response.raise_for_status() return response.json()["results"] except requests.exceptions.RequestException as e: logger.error(colorstring(f"Failed to rerank documents: {e}", "red")) raise
if __name__ == "__main__": import argparse from cicada.common.utils import colorstring, load_config, setup_logging parser = argparse.ArgumentParser(description="Reranking Model") parser.add_argument( "--config", default="config.yaml", help="Path to the configuration YAML file" ) args = parser.parse_args() setup_logging() rerank_config = load_config(args.config, "rerank") rerank = Rerank( api_key=rerank_config["api_key"], api_base_url=rerank_config.get( "api_base_url", "https://api.siliconflow.cn/v1/" ), model_name=rerank_config.get("model_name", "BAAI/bge-reranker-v2-m3"), **rerank_config.get("model_kwargs", {}), ) query = "Apple" documents = ["苹果", "香蕉", "水果", "蔬菜"] reranked_results = rerank.rerank(query, documents, top_n=4, return_documents=False) logger.info(colorstring(f"Reranked results: {reranked_results}", "white"))