这个模型怎么做Reranking

#53
by keranli - opened

我理解reranking是需要Cross-Encoder,gte-qwen怎么做reranking呢?

我使用了第三方的rerankers库,效果还不错。

class RagServer:
    def __init__(self, model: str, db_dir: str, embed_func: str, res_num=5, reranker_num=25):
        vector_store = langchain_chroma.Chroma(persist_directory=db_dir, 
                                               embedding_function=langchain_ollama.OllamaEmbeddings(model=embed_func, num_gpu=0))
        retriever = vector_store.as_retriever(search_type="mmr", search_kwargs={"k": reranker_num, "score_threshold": 0.3})
        self.model = model
        ranker = rerankers.Reranker('Alibaba-NLP/gte-multilingual-reranker-base', 'zh', model_type='cross-encoder', device='cuda:0', dtype=torch.float16, 
                                    model_kwargs={'cache_dir':'cache_dir', 'trust_remote_code': True, 'local_files_only': True},
                                    tokenizer_kwargs={'cache_dir':'cache_dir', 'local_files_only': True})
        compressor = ranker.as_langchain_compressor(res_num)
        self.compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)

然后使用self.compression_retriever.invoke(question)调用
注意根据自己的显存和网络环境,调整num_gpu,device='cuda:0','local_files_only': True几个参数。

Your need to confirm your account before you can post a new comment.

Sign up or log in to comment