[docs]classSiliconFlowEmbeddings(Embeddings):"""Embedding class for SiliconFlow BGE-M3 API."""def__init__(self,api_key:str,api_base_url:str,model_name:str,org_id:str=None,**model_kwargs,):""" Initialize the SiliconFlow BGE-M3 embedding model. Args: api_key (str): The API key for SiliconFlow. api_base_url (str): The base URL for the API. model_name (str): The name of the model to use. org_id (str, optional): The organization ID. Defaults to None. **model_kwargs: Additional keyword arguments for the model. """self.api_key=api_keyself.api_base_url=api_base_urlself.model_name=model_nameself.org_id=org_idself.model_kwargs=model_kwargs# Initialize the Embed class from embed.pyself.embed_client=Embed(api_key=self.api_key,api_base_url=self.api_base_url,model_name=self.model_name,org_id=self.org_id,**self.model_kwargs,)
[docs]defembed_query(self,text:str)->List[float]:""" Generate an embedding for a single query text. Args: text (str): The query text to embed. Returns: List[float]: The embedding of the query text. """returnself.embed_documents(text)[0]
[docs]defembed_documents(self,texts:Union[List[str],str,Document,List[Document]])->List[List[float]]:""" Generate embeddings for a list of texts or documents. Args: texts (Union[List[str], str, Document, List[Document]]): The input texts or documents to embed. Returns: List[List[float]]: A list of embeddings for the input texts or documents. Raises: ValueError: If the input list contains mixed types of strings and Documents. TypeError: If the input is not a string, Document, list of strings, or list of Documents. """# Normalize input to a list of stringsifisinstance(texts,str):texts=[texts]elifisinstance(texts,Document):texts=[texts.page_content]elifisinstance(texts,list):ifall(isinstance(text,str)fortextintexts):pass# Already a list of stringselifall(isinstance(text,Document)fortextintexts):texts=[text.page_contentfortextintexts]else:raiseValueError("Input list must contain only strings or only Document instances.")else:raiseTypeError("Input must be a string, Document, list of strings, or list of Documents.")# Generate embeddings using the Embed clientreturnself.embed_client.embed(texts)
if__name__=="__main__":importargparsefromcicada.common.utilsimportload_configparser=argparse.ArgumentParser(description="Feedback Judge")parser.add_argument("--config",default="config.yaml",help="Path to the configuration YAML file")args=parser.parse_args()embed_config=load_config(args.config,"embed")embedding_model=SiliconFlowEmbeddings(embed_config["api_key"],embed_config.get("api_base_url"),embed_config.get("model_name","text-embedding-3-small"),embed_config.get("org_id"),**embed_config.get("model_kwargs",{}),)# Generate embeddings for a list of textstexts=["This is a test document.","Another test document."]embeddings=embedding_model.embed_documents(texts)print(embeddings)# Generate an embedding for a single queryquery_embedding=embedding_model.embed_query("This is a query.")print(query_embedding)