import { app } from "../../../scripts/app.js";

// Adds a bunch of context menu entries for quickly adding common steps

function addMenuHandler(nodeType, cb) {
	const getOpts = nodeType.prototype.getExtraMenuOptions;
	nodeType.prototype.getExtraMenuOptions = function () {
		const r = getOpts.apply(this, arguments);
		cb.apply(this, arguments);
		return r;
	};
}

function getOrAddVAELoader(node) {
	let vaeNode = app.graph._nodes.find((n) => n.type === "VAELoader");
	if (!vaeNode) {
		vaeNode = addNode("VAELoader", node);
	}
	return vaeNode;
}

function addNode(name, nextTo, options) {
	options = { select: true, shiftY: 0, before: false, ...(options || {}) };
	const node = LiteGraph.createNode(name);
	app.graph.add(node);
	node.pos = [
		options.before ? nextTo.pos[0] - node.size[0] - 30 : nextTo.pos[0] + nextTo.size[0] + 30,
		nextTo.pos[1] + options.shiftY,
	];
	if (options.select) {
		app.canvas.selectNode(node, false);
	}
	return node;
}

app.registerExtension({
	name: "pysssss.QuickNodes",
	async beforeRegisterNodeDef(nodeType, nodeData, app) {
		if (nodeData.input && nodeData.input.required) {
			const keys = Object.keys(nodeData.input.required);
			for (let i = 0; i < keys.length; i++) {
				if (nodeData.input.required[keys[i]][0] === "VAE") {
					addMenuHandler(nodeType, function (_, options) {
						options.unshift({
							content: "Use VAE",
							callback: () => {
								getOrAddVAELoader(this).connect(0, this, i);
							},
						});
					});
					break;
				}
			}
		}

		if (nodeData.name === "KSampler") {
			addMenuHandler(nodeType, function (_, options) {
				options.unshift(
					{
						content: "Add Blank Input",
						callback: () => {
							const imageNode = addNode("EmptyLatentImage", this, { before: true });
							imageNode.connect(0, this, 3);
						},
					},
					{
						content: "Add Hi-res Fix",
						callback: () => {
							const upscaleNode = addNode("LatentUpscale", this);
							this.connect(0, upscaleNode, 0);

							const sampleNode = addNode("KSampler", upscaleNode);

							for (let i = 0; i < 3; i++) {
								const l = this.getInputLink(i);
								if (l) {
									app.graph.getNodeById(l.origin_id).connect(l.origin_slot, sampleNode, i);
								}
							}

							upscaleNode.connect(0, sampleNode, 3);
						},
					},
					{
						content: "Add 2nd Pass",
						callback: () => {
							const upscaleNode = addNode("LatentUpscale", this);
							this.connect(0, upscaleNode, 0);

							const ckptNode = addNode("CheckpointLoaderSimple", this);
							const sampleNode = addNode("KSampler", ckptNode);

							const positiveLink = this.getInputLink(1);
							const negativeLink = this.getInputLink(2);
							const positiveNode = positiveLink
								? app.graph.add(app.graph.getNodeById(positiveLink.origin_id).clone())
								: addNode("CLIPTextEncode");
							const negativeNode = negativeLink
								? app.graph.add(app.graph.getNodeById(negativeLink.origin_id).clone())
								: addNode("CLIPTextEncode");

							ckptNode.connect(0, sampleNode, 0);
							ckptNode.connect(1, positiveNode, 0);
							ckptNode.connect(1, negativeNode, 0);
							positiveNode.connect(0, sampleNode, 1);
							negativeNode.connect(0, sampleNode, 2);
							upscaleNode.connect(0, sampleNode, 3);
						},
					},
					{
						content: "Add Save Image",
						callback: () => {
							const decodeNode = addNode("VAEDecode", this);
							this.connect(0, decodeNode, 0);

							getOrAddVAELoader(decodeNode).connect(0, decodeNode, 1);

							const saveNode = addNode("SaveImage", decodeNode);
							decodeNode.connect(0, saveNode, 0);
						},
					}
				);
			});
		}

		if (nodeData.name === "CheckpointLoaderSimple") {
			addMenuHandler(nodeType, function (_, options) {
				options.unshift({
					content: "Add Clip Skip",
					callback: () => {
						const clipSkipNode = addNode("CLIPSetLastLayer", this);
						const clipLinks = this.outputs[1].links ? this.outputs[1].links.map((l) => ({ ...graph.links[l] })) : [];

						this.disconnectOutput(1);
						this.connect(1, clipSkipNode, 0);

						for (const clipLink of clipLinks) {
							clipSkipNode.connect(0, clipLink.target_id, clipLink.target_slot);
						}
					},
				});
			});
		}

		if (
			nodeData.name === "CheckpointLoaderSimple" ||
			nodeData.name === "CheckpointLoader" ||
			nodeData.name === "CheckpointLoader|pysssss" ||
			nodeData.name === "LoraLoader" ||
			nodeData.name === "LoraLoader|pysssss"
		) {
			addMenuHandler(nodeType, function (_, options) {
				function addLora(type) {
					const loraNode = addNode(type, this);

					const modelLinks = this.outputs[0].links ? this.outputs[0].links.map((l) => ({ ...graph.links[l] })) : [];
					const clipLinks = this.outputs[1].links ? this.outputs[1].links.map((l) => ({ ...graph.links[l] })) : [];

					this.disconnectOutput(0);
					this.disconnectOutput(1);

					this.connect(0, loraNode, 0);
					this.connect(1, loraNode, 1);

					for (const modelLink of modelLinks) {
						loraNode.connect(0, modelLink.target_id, modelLink.target_slot);
					}

					for (const clipLink of clipLinks) {
						loraNode.connect(1, clipLink.target_id, clipLink.target_slot);
					}
				}
				options.unshift(
					{
						content: "Add LoRA",
						callback: () => addLora.call(this, "LoraLoader"),
					},
					{
						content: "Add 🐍 LoRA",
						callback: () => addLora.call(this, "LoraLoader|pysssss"),
					},
					{
						content: "Add Prompts",
						callback: () => {
							const positiveNode = addNode("CLIPTextEncode", this);
							const negativeNode = addNode("CLIPTextEncode", this, { shiftY: positiveNode.size[1] + 30 });

							this.connect(1, positiveNode, 0);
							this.connect(1, negativeNode, 0);
						},
					}
				);
			});
		}
	},
});