jbilcke-hf's picture
jbilcke-hf HF Staff
upgrade finetrainers + gradio
ecd5028
raw
history blame
2.58 kB
from typing import List, Tuple, Union
import torch
from transformers import AutoTokenizer, GlmModel
from .base import ProcessorMixin
class CogView4GLMProcessor(ProcessorMixin):
r"""
Processor for the GLM family of models. This processor is used to encode text inputs and return the embeddings
and attention masks for the input text.
This processor is specific to CogView4 but can be used with any other model.
Args:
output_names (`List[str]`):
The names of the outputs that the processor should return. The first output is the embeddings of the input
text and the second output is the attention mask for the input text.
"""
def __init__(self, output_names: List[str]):
super().__init__()
self.output_names = output_names
assert len(self.output_names) == 1
def forward(
self,
tokenizer: AutoTokenizer,
text_encoder: GlmModel,
caption: Union[str, List[str]],
max_sequence_length: int,
) -> Tuple[torch.Tensor, torch.Tensor]:
r"""
Encode the input text and return the embeddings and attention mask for the input text.
Args:
tokenizer (`AutoTokenizer`):
The tokenizer used to tokenize the input text.
text_encoder (`GlmModel`):
The text encoder used to encode the input text.
caption (`Union[str, List[str]]`):
The input text to be encoded.
max_sequence_length (`int`):
The maximum sequence length of the input text.
"""
if isinstance(caption, str):
caption = [caption]
device = text_encoder.device
dtype = text_encoder.dtype
text_inputs = tokenizer(
caption,
padding="longest",
max_length=max_sequence_length,
truncation=True,
add_special_tokens=True,
return_tensors="pt",
)
text_input_ids = text_inputs.input_ids.to(device)
current_length = text_input_ids.size(1)
pad_length = 16 - current_length % 16
if pad_length > 0:
pad_ids = text_input_ids.new_full((text_input_ids.shape[0], pad_length), fill_value=tokenizer.pad_token_id)
text_input_ids = torch.cat([pad_ids, text_input_ids], dim=1)
prompt_embeds = text_encoder(text_input_ids, output_hidden_states=True).hidden_states[-2]
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
return {self.output_names[0]: prompt_embeds}