Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from transformers import DistilBertTokenizer, DistilBertModel | |
from torch.nn.functional import cosine_similarity | |
class FastLightweightTextEncoder: | |
def __init__(self, model_name='distilbert-base-uncased', cache_dir='/path/to/your/cache'): | |
self.tokenizer = DistilBertTokenizer.from_pretrained(model_name) | |
self.text_encoder = DistilBertModel.from_pretrained(model_name).eval().cuda() | |
def encode_texts(self, prompts): | |
# Batch processing the prompts to get their embeddings | |
inputs = self.tokenizer(prompts, return_tensors="pt", padding=True, truncation=True) | |
input_ids = inputs['input_ids'].cuda() | |
attention_mask = inputs['attention_mask'].cuda() | |
with torch.no_grad(): | |
embeddings = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state | |
# Normalize embeddings to get consistent vector representations | |
embeddings = torch.nn.functional.normalize(embeddings, p=2, dim=-1) | |
# Print shape of embeddings | |
# print(f"Embeddings shape: {embeddings.shape}") | |
return embeddings | |
def calculate_differences(self, embeddings): | |
# Calculate differences between consecutive embeddings | |
differences = [] | |
for i in range(1, embeddings.size(0)): | |
diff = embeddings[i] - embeddings[i - 1] | |
print('diff shape', diff.shape) | |
differences.append(diff.unsqueeze(0)) # Add batch dimension | |
print('differences shape', differences.shape) | |
# Concatenate differences along the batch dimension (f-1) | |
concatenated_differences = torch.cat(differences, dim=0) # Shape: (f-1, sequence_length, hidden_size) | |
return concatenated_differences | |
# Example usage | |
if __name__ == '__main__': | |
prompts = [ | |
"A smiling dog. Focal length: 24mm.", | |
"A smiling dog. Focal length: 25mm.", | |
"A smiling dog. Focal length: 26mm.", | |
"A smiling dog. Focal length: 30mm.", | |
"A smiling dog. Focal length: 36mm.", | |
] | |
# Initialize the FastLightweightTextEncoder | |
text_encoder = FastLightweightTextEncoder(cache_dir='/home/yuan418/lab/users/Yu/modules/') | |
# Encode the prompts | |
embeddings = text_encoder.encode_texts(prompts) | |
print('a') | |
print('embeddings', embeddings) | |
print('embeddings shape', embeddings.shape) | |
# Calculate and concatenate differences | |
concatenated_diffs = text_encoder.calculate_differences(embeddings) | |
print("Concatenated differences shape:", concatenated_diffs.shape) | |