这个模型怎么做Reranking
#53
by
keranli
- opened
我理解reranking是需要Cross-Encoder,gte-qwen怎么做reranking呢?
我使用了第三方的rerankers库,效果还不错。
class RagServer:
def __init__(self, model: str, db_dir: str, embed_func: str, res_num=5, reranker_num=25):
vector_store = langchain_chroma.Chroma(persist_directory=db_dir,
embedding_function=langchain_ollama.OllamaEmbeddings(model=embed_func, num_gpu=0))
retriever = vector_store.as_retriever(search_type="mmr", search_kwargs={"k": reranker_num, "score_threshold": 0.3})
self.model = model
ranker = rerankers.Reranker('Alibaba-NLP/gte-multilingual-reranker-base', 'zh', model_type='cross-encoder', device='cuda:0', dtype=torch.float16,
model_kwargs={'cache_dir':'cache_dir', 'trust_remote_code': True, 'local_files_only': True},
tokenizer_kwargs={'cache_dir':'cache_dir', 'local_files_only': True})
compressor = ranker.as_langchain_compressor(res_num)
self.compression_retriever = ContextualCompressionRetriever(base_compressor=compressor, base_retriever=retriever)
然后使用self.compression_retriever.invoke(question)
调用
注意根据自己的显存和网络环境,调整num_gpu,device='cuda:0','local_files_only': True几个参数。