# import open_clip.tokenizer # import torch # # from modules import sd_hijack_clip, devices # from modules.shared import opts # # tokenizer = open_clip.tokenizer._tokenizer # # # class FrozenOpenCLIPEmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): # def __init__(self, wrapped, hijack): # super().__init__(wrapped, hijack) # # self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] # self.id_start = tokenizer.encoder[""] # self.id_end = tokenizer.encoder[""] # self.id_pad = 0 # # def tokenize(self, texts): # assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' # # tokenized = [tokenizer.encode(text) for text in texts] # # return tokenized # # def encode_with_transformers(self, tokens): # # set self.wrapped.layer_idx here according to opts.CLIP_stop_at_last_layers # z = self.wrapped.encode_with_transformer(tokens) # # return z # # def encode_embedding_init_text(self, init_text, nvpt): # ids = tokenizer.encode(init_text) # ids = torch.asarray([ids], device=devices.device, dtype=torch.int) # embedded = self.wrapped.model.token_embedding.wrapped(ids).squeeze(0) # # return embedded # # # class FrozenOpenCLIPEmbedder2WithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWordsBase): # def __init__(self, wrapped, hijack): # super().__init__(wrapped, hijack) # # self.comma_token = [v for k, v in tokenizer.encoder.items() if k == ','][0] # self.id_start = tokenizer.encoder[""] # self.id_end = tokenizer.encoder[""] # self.id_pad = 0 # # def tokenize(self, texts): # assert not opts.use_old_emphasis_implementation, 'Old emphasis implementation not supported for Open Clip' # # tokenized = [tokenizer.encode(text) for text in texts] # # return tokenized # # def encode_with_transformers(self, tokens): # d = self.wrapped.encode_with_transformer(tokens) # z = d[self.wrapped.layer] # # pooled = d.get("pooled") # if pooled is not None: # z.pooled = pooled # # return z # # def encode_embedding_init_text(self, init_text, nvpt): # ids = tokenizer.encode(init_text) # ids = torch.asarray([ids], device=devices.device, dtype=torch.int) # embedded = self.wrapped.model.token_embedding.wrapped(ids.to(self.wrapped.model.token_embedding.wrapped.weight.device)).squeeze(0) # # return embedded