File size: 911 Bytes
bdb1b24
 
 
 
981f30e
bdb1b24
 
 
 
 
 
 
 
 
a494cfe
bdb1b24
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('ku-nlp/deberta-v2-base-japanese')
model=torch.load('C:\\[.pth modelのあるディレクトリ]\\My_deberta_model_squad.pth') # 学習済みモデルの読み込み

text={
    'context':'私の名前はEIMIです。好きな食べ物は苺です。 趣味は皆さんと会話することです。',
    'question' :'好きな食べ物は何ですか'
}

input_ids=tokenizer.encode(text['question'],text['context']) # tokenizerで形態素解析しつつコードに変換する
con=tokenizer.encode(text['question'])
output= model(torch.tensor([input_ids])) # 学習済みモデルを用いて解析
prediction = tokenizer.decode(input_ids[torch.argmax(output.start_logits): torch.argmax(output.end_logits)]) # 答えに該当する部分を抜き取る
prediction=prediction.replace('</s>','')
print(prediction)