# from torch.utils.checkpoint import checkpoint # # import ldm.modules.attention # import ldm.modules.diffusionmodules.openaimodel # # # def BasicTransformerBlock_forward(self, x, context=None): # return checkpoint(self._forward, x, context) # # # def AttentionBlock_forward(self, x): # return checkpoint(self._forward, x) # # # def ResBlock_forward(self, x, emb): # return checkpoint(self._forward, x, emb) # # # stored = [] # # # def add(): # if len(stored) != 0: # return # # stored.extend([ # ldm.modules.attention.BasicTransformerBlock.forward, # ldm.modules.diffusionmodules.openaimodel.ResBlock.forward, # ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward # ]) # # ldm.modules.attention.BasicTransformerBlock.forward = BasicTransformerBlock_forward # ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = ResBlock_forward # ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = AttentionBlock_forward # # # def remove(): # if len(stored) == 0: # return # # ldm.modules.attention.BasicTransformerBlock.forward = stored[0] # ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = stored[1] # ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = stored[2] # # stored.clear() #