File size: 564 Bytes
055d938
 
 
bf718e9
 
 
 
d3e257b
 
 
bf718e9
d3e257b
bf718e9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import torch

model_name = "Salesforce/codet5-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSeq2SeqLM.from_pretrained(model_name)

def fix_code(code):
    prompt = f"fix Python: {code}"
    inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=512)
    with torch.no_grad():
        output = model.generate(**inputs, max_length=512, num_beams=4, early_stopping=True)
    return tokenizer.decode(output[0], skip_special_tokens=True)