import torch as th

def get_score(input_embs, label_ids, model_control, t=None):
    label_ids2 = label_ids.clone()
    label_ids2[:, :65] = -100
    # print(label_ids2[:, 65:])
    # print(final.shape, tgt_embs.shape)
    # input_embs = th.cat([final, tgt_embs], dim=1)
    model_out = model_control(input_embs=input_embs,
                              labels=label_ids2, t=t)
    print(model_out.loss, 'final end')
    loss_fn = th.nn.CrossEntropyLoss(reduction='none')
    shifted_logits = model_out.logits[:, :-1].contiguous()
    shifted_labels = label_ids2[:, 1:].contiguous()
    loss = loss_fn(shifted_logits.view(-1, shifted_logits.size(-1)), shifted_labels.view(-1)).reshape(
        shifted_labels.shape)
    return loss.sum(dim=-1).tolist()


def langevin_fn3(debug_lst, model_control, model3, label_ids, step_size, sample, mean, sigma,
                 alpha, t, prev_sample):  # current best.
    if t[0].item() < 10:
        K = 0
    else:
        K = 3
    # K = 3

    if t[0].item() > 0:
        tt = t[0].item() - 1
    else:
        tt = 200
    label_ids = label_ids.cuda()
    tgt_embs = model3(label_ids[:, sample.size(1):])

    label_ids2 = label_ids.clone()
    label_ids2[:, :65] = -100
    input_embs_param = th.nn.Parameter(sample)
    if False:
        input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
        debug_lst.append(get_score(input_embs, label_ids2, model_control, t=tt))
    with th.enable_grad():
        for i in range(K):
            optimizer = th.optim.Adagrad([input_embs_param], lr=step_size)
            optimizer.zero_grad()
            input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
            model_out = model_control(input_embs=input_embs,
                                      labels=label_ids2, t=tt)

            coef = 0.01
            # coef=1.
            if sigma.mean() == 0:
                logp_term = coef * ((mean - input_embs_param) ** 2 / 1.).mean(dim=0).sum()
            else:
                logp_term = coef * ((mean - input_embs_param) ** 2 / sigma).mean(dim=0).sum()
            # print(model_out.loss, f'start_{i}', logp_term.item(), t[0].item(), sigma.mean().item())
            loss = model_out.loss + logp_term
            loss.backward()
            optimizer.step()
            epsilon = th.randn_like(input_embs_param.data)
            input_embs_param = th.nn.Parameter((input_embs_param.data + 0.0 * sigma.mean().item() * epsilon).detach())
            # input_embs_param = th.nn.Parameter((input_embs_param.data +
            #                                    np.sqrt(2*sigma.mean().item()) * epsilon).detach())

    # input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
    # model_out = model_control(input_embs=input_embs,
    #                           labels=label_ids2,
    #                           t=tt)
    # print(model_out.loss, 'end')

    return input_embs_param.data

def langevin_fn4(debug_lst, model_control, model3, label_ids, step_size, sample, mean, sigma,
                 alpha, t, prev_sample): # current best.
    if t[0].item() < 10:
        K = 0
    else:
        K = 3

    if t[0].item() >0:
        tt =t[0].item() - 1
    else:
        tt = 200
    label_ids = label_ids.cuda()
    input_embs_param = th.nn.Parameter(sample)
    if False:
        input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
        debug_lst.append(get_score(input_embs, label_ids2, model_control, t=tt))
    with th.enable_grad():
        for i in range(K):
            optimizer = th.optim.Adagrad([input_embs_param], lr=step_size)
            optimizer.zero_grad()
            # print(input_embs_param.shape, label_ids.shape)
            model_out = model_control(input_embs=input_embs_param, pos_ids=label_ids, t=tt)

            coef = 0.0001 # prev default.
            # coef = 0.001
            # coef = 0.0005


            # coef=1.
            if sigma.mean() == 0:
                logp_term = coef * ((mean - input_embs_param) ** 2 / 1.).mean(dim=0).sum()
            else:
                logp_term = coef * ((mean - input_embs_param)**2 / sigma).mean(dim=0).sum()
            print(model_out.loss, f'start_{i}', logp_term.item(),
                  t[0].item(), sigma.mean().item())
            loss = model_out.loss + logp_term
            loss.backward()
            optimizer.step()
            epsilon = th.randn_like(input_embs_param.data)
            input_embs_param = th.nn.Parameter((input_embs_param.data + 0.0*sigma.mean().item() * epsilon).detach())
            # input_embs_param = th.nn.Parameter((input_embs_param.data +
            #                                    np.sqrt(2*sigma.mean().item()) * epsilon).detach())

    model_out = model_control(input_embs=input_embs_param, pos_ids=label_ids, t=tt)
    print(model_out.loss, 'end')

    return input_embs_param.data

def langevin_fn_length(coeff, diffusion, partial_mask, diff_model, tgt_embs, step_size, sample, mean, sigma,
                 alpha, t, prev_sample): # current best.
    if t[0].item() < 10:
        K = 0
    else:
        K = 3

    if t[0].item() >0:
        tt =t[0].item() - 1
    else:
        tt = 200
    input_embs_param = th.nn.Parameter(sample)
    if False:
        input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
        debug_lst.append(get_score(input_embs, label_ids2, model_control, t=tt))
    with th.enable_grad():
        for i in range(K):
            optimizer = th.optim.Adagrad([input_embs_param], lr=step_size)
            optimizer.zero_grad()
            print(t.shape)
            # print(input_embs_param.shape, label_ids.shape)
            out = diffusion.p_mean_variance(
                diff_model,
                input_embs_param,
                t,
                clip_denoised=False,
                denoised_fn=None,
                model_kwargs={},
            )

            # model_out = model_control(input_embs=input_embs_param, pos_ids=label_ids, t=tt)
            coef = coeff
            # coef = 0.0001 # prev default.
            # coef = 0.001
            # coef = 0.0005


            # coef=1.
            if sigma.mean() == 0:
                logp_term = coef * ((mean - input_embs_param) ** 2 / 1.).mean(dim=0).sum()
                infill_loss = (out['pred_xstart'][~partial_mask] - tgt_embs[~partial_mask]) ** 2
                infill_loss = infill_loss.mean(dim=0).sum()
            else:
                logp_term = coef * ((mean - input_embs_param)**2 / sigma).mean(dim=0).sum()
                # print(out['pred_xstart'].shape, tgt_embs.shape)
                # print(partial_mask[0])
                infill_loss = ((out['pred_xstart'][~partial_mask] - tgt_embs[~partial_mask]) ** 2).view(tgt_embs.size(0), -1, tgt_embs.size(-1) )
                # print(infill_loss.shape, ((mean - input_embs_param)**2).shape )
                infill_loss = (infill_loss/sigma.mean()).mean(dim=0).sum()
            print(infill_loss, f'start_{i}', logp_term.item(),
                  t[0].item(), sigma.mean().item())
            loss = logp_term + infill_loss
            loss.backward()
            optimizer.step()
            epsilon = th.randn_like(input_embs_param.data)
            input_embs_param = th.nn.Parameter((input_embs_param.data + 0.0*sigma.mean().item() * epsilon).detach())
            # input_embs_param = th.nn.Parameter((input_embs_param.data +
            #                                    np.sqrt(2*sigma.mean().item()) * epsilon).detach())

    # model_out = model_control(input_embs=input_embs_param, pos_ids=label_ids, t=tt)
    # print(model_out.loss, 'end')

    return input_embs_param.data

def langevin_fn_tree(coeff, model_control, model3, label_ids, step_size, sample, mean, sigma,
                 alpha, t, prev_sample): # current best.
    if t[0].item() < 10:
        K = 0
    else:
        K = 3

    if t[0].item() >0:
        tt =t[0].item() - 1
    else:
        tt = 200
    label_ids = label_ids.cuda()
    input_embs_param = th.nn.Parameter(sample)

    with th.enable_grad():
        for i in range(K):
            optimizer = th.optim.Adagrad([input_embs_param], lr=step_size)
            optimizer.zero_grad()
            # print(input_embs_param.shape, label_ids.shape)
            model_out = model_control(input_embs=input_embs_param, parse_chart=label_ids, t=tt)

            # coef = 0.0001
            # coef = 0.001
            # coef = 0.01

            # coef = 0.1 # good for partial.
            # coef=0.001 # also good for full (more fluent).
            # coef=0.0001

            # coef=0.0005 # good for full.
            coef = coeff

            # coef = 0.5


            # coef=1.
            if sigma.mean() == 0:
                logp_term = coef * ((mean - input_embs_param) ** 2 / 1.).mean(dim=0).sum()
            else:
                logp_term = coef * ((mean - input_embs_param)**2 / sigma).mean(dim=0).sum()
            # print(model_out.loss, f'start_{i}', logp_term.item(),
            #       t[0].item(), sigma.mean().item())
            loss = model_out.loss + logp_term
            loss.backward()
            optimizer.step()
            epsilon = th.randn_like(input_embs_param.data)
            input_embs_param = th.nn.Parameter((input_embs_param.data + 0.0*sigma.mean().item() * epsilon).detach())
            # input_embs_param = th.nn.Parameter((input_embs_param.data +
            #                                    np.sqrt(2*sigma.mean().item()) * epsilon).detach())

    # COMMENT OUT 
    # model_out = model_control(input_embs=input_embs_param, parse_chart=label_ids, t=tt)
    # print(model_out.loss, 'end')

    return input_embs_param.data

def langevin_fn1(debug_lst, model_control, model3, label_ids, step_size, sample, mean, sigma,
                 alpha, t, prev_sample):  # current best.
    if t[0].item() < 10:
        K = 0
    else:
        K = 1
    # K = 3

    if t[0].item() > 0:
        tt = t[0].item() - 1
    else:
        tt = 200
    label_ids = label_ids.cuda()
    tgt_embs = model3(label_ids[:, sample.size(1):])

    label_ids2 = label_ids.clone()
    label_ids2[:, :65] = -100
    input_embs_param = th.nn.Parameter(sample)
    if True:
        input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
        debug_lst.append(get_score(input_embs, label_ids2, model_control, t=tt))
    with th.enable_grad():
        for i in range(K):
            optimizer = th.optim.Adagrad([input_embs_param], lr=step_size)
            optimizer.zero_grad()
            input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
            model_out = model_control(input_embs=input_embs,
                                      labels=label_ids2, t=tt)

            # coef = 0.0
            # if sigma.mean() == 0:
            #     logp_term = coef * ((mean - input_embs_param) ** 2 / 1.).mean(dim=0).sum()
            # else:
            #     logp_term = coef * ((mean - input_embs_param) ** 2 / sigma).mean(dim=0).sum()
            print(model_out.loss, f'start_{i}', t[0].item(), sigma.mean().item())
            coef = 3.
            loss = model_out.loss # + logp_term
            loss.backward()
            # print(input_embs_param.grad.shape, )
            input_embs_param.data = input_embs_param.data - coef * sigma.mean().item() * input_embs_param.grad
            # optimizer.step()
            # epsilon = th.randn_like(input_embs_param.data)
            # input_embs_param = th.nn.Parameter((input_embs_param.data + 0.0 * sigma.mean().item() * epsilon).detach())
            # input_embs_param = th.nn.Parameter((input_embs_param.data +
            #                                    np.sqrt(2*sigma.mean().item()) * epsilon).detach())

    input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
    model_out = model_control(input_embs=input_embs,
                              labels=label_ids2,
                              t=tt)
    print(model_out.loss, 'end')
    # if True:
    #     debug_lst.append(get_score(input_embs, label_ids2, model_control, t=tt))

    return input_embs_param.data


def langevin_fn3_compose(debug_lst, model_control, model3, label_ids_lst, step_size, sample, mean, sigma,
                 alpha, t, prev_sample):  # current best.
    if t[0].item() < 10:
        K = 0
    else:
        K = 3
    # K = 3

    if t[0].item() > 0:
        tt = t[0].item() - 1
    else:
        tt = 200

    tgt_embs_lst = [model3(label_ids[:, sample.size(1):]) for label_ids in label_ids_lst]

    label_ids2_lst = []
    for label_ids in label_ids_lst:
        label_ids2 = label_ids.clone()
        label_ids2[:, :65] = -100
        label_ids2_lst.append(label_ids2)

    input_embs_param = th.nn.Parameter(sample)
    if True:
        part_score = []
        for (tgt_embs,label_ids2) in zip(tgt_embs_lst, label_ids2_lst):
            input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
            score_ = get_score(input_embs, label_ids2, model_control, t=tt)
            part_score.append(score_)
        debug_lst.append(part_score)
    with th.enable_grad():
        for i in range(K):
            optimizer = th.optim.Adagrad([input_embs_param], lr=step_size)
            optimizer.zero_grad()
            cum_loss = 0
            for (tgt_embs, label_ids2) in zip(tgt_embs_lst, label_ids2_lst):
                input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
                model_out = model_control(input_embs=input_embs,
                                          labels=label_ids2, t=tt)
                cum_loss += model_out.loss

            coef = 0.01
            if sigma.mean() == 0:
                logp_term = coef * ((mean - input_embs_param) ** 2 / 1.).mean(dim=0).sum()
            else:
                logp_term = coef * ((mean - input_embs_param) ** 2 / sigma).mean(dim=0).sum()
            print(cum_loss, f'start_{i}', logp_term.item(), t[0].item(), sigma.mean().item())
            loss = cum_loss + logp_term
            loss.backward()
            optimizer.step()
            epsilon = th.randn_like(input_embs_param.data)
            input_embs_param = th.nn.Parameter((input_embs_param.data + 0.0 * sigma.mean().item() * epsilon).detach())

    part_score = []
    for (tgt_embs, label_ids2) in zip(tgt_embs_lst, label_ids2_lst):
        input_embs = th.cat([input_embs_param, tgt_embs], dim=1)
        score_ = get_score(input_embs, label_ids2, model_control, t=tt)
        part_score.append(score_)

    return input_embs_param.data