1. 순환 신경망(Recurrent Neural Network, RNN)¶
- 시계열 또는 자연어와 같은 Sequence 데이터를 모델링하는 데 강력한 신경망. 시계열 데이터나 시퀀스 데이터를 잘 처리
- 예) 주식 가격, 텍스트 데이터, 오디오 데이터
- Sequence: 단어의 문장. 연결되어 있는 정보
![No description has been provided for this image](https://velog.velcdn.com/images/softwarerbfl/post/d6c16b45-ef7d-4f3c-9a4a-efa54d2de587/image.png)
1-1. RNN 동작방식¶
- 은닉층의 노드에서 활성화 함수를 통해 나온 결과값을 다시 출력층 방향으로 보내면서 은닉층 노드의 다음 계산의 입력으로 보내는 것이 특징
- 셀(Cell): 은닉층에서 활서화 함수를 통해 나온 결과를 내보내는 역할을 하는 것. 이전의 값을 기억하려고 하는 일종의 메모리 역할을 수행
- 은닉 상태(hidden state): 셀이 출력층 방향 또는 다음 시점으로 자신에게 보내는 값
run = torch.nn.RNN(input_size, hidden_size)
outputs, state = rnn(input_data)
# state: hidden state
1-2. input size¶
- 단어가 입력되면 각 글자를 벡터의 형태로 변환하여 원 핫 인코딩 해주는 과정이 필요
- "hello"
- h = [1, 0, 0, 0]
- e = [0, 1, 0, 0]
- l = [0, 0, 1, 0]
- o = [0, 0, 0, 1]
- input_size = 4
- input_data의 세전째 차원으로 입력
1-3. hidden state size¶
- hidden state의 size는 output의 세번째 차원
- ouput size와 같음
- 셀에서 연산된 결과를 두 가지로 나눠 하나는 output으로 출력되고, 다른 하나는 hidden state로 다음 step에 그대로 저장하고 전해짐
1-4. Sequence Length¶
- 총 Sequence가 몇 개인지 나타냄
- "hello"를 입력으로 보내면 sequence length는 5
- 파이토치에서는 모델이 Sequence Length를 알아서 파악하기 때문에 파라미터로 전달해 줄 필요는 없음
1-5. Batch Size¶
- 여러 데이터를 묶어 하나의 batch로 만들어 학습을 진행
- h, e, l, o를 가지고 만들 수 있는 데이터중 배치사이즈로 정한 크기로 묶어 학습을 진행
- batch size를 모델에서 파악하고 output data, input data에서 첫번째 차원에 위치함
import torch
import numpy as np
from torch.nn import RNN
input_size = 4
hidden_size = 2
h = [1, 0, 0, 0]
e = [0, 1, 0, 0]
l = [0, 0, 1, 0]
o = [0, 0, 0, 1]
input_data_np = np.array([[h, e, l, l, o],
[e, o, l, l, l],
[l, l, e, e, l]], dtype=np.float32)
input_data = torch.Tensor(input_data_np)
rnn = RNN(input_size, hidden_size)
ouputs, state = rnn(input_data)
state
tensor([[[ 0.0057, -0.8686], [ 0.0137, -0.8827], [-0.2091, -0.8930], [-0.2091, -0.8930], [ 0.0177, -0.8881]]], grad_fn=<StackBackward0>)
test = 'hello! word'
string_set = list(set(test))
print(string_set)
['w', '!', 'e', 'r', 'd', 'h', 'o', ' ', 'l']
string_dic = {c: i for i, c in enumerate(string_set)}
print(string_dic)
{'w': 0, '!': 1, 'e': 2, 'r': 3, 'd': 4, 'h': 5, 'o': 6, ' ': 7, 'l': 8}
input_size = len(string_dic)
print(input_size)
hidden_size = len(string_dic)
print(hidden_size)
9 9
test_idx = [string_dic[c] for c in test]
print(test_idx)
[5, 2, 8, 8, 6, 1, 7, 0, 6, 3, 4]
x_data = [test_idx[:]]
print(x_data)
[[5, 2, 8, 8, 6, 1, 7, 0, 6, 3, 4]]
x_one_hot = [np.eye(input_size)[x] for x in x_data]
print(x_one_hot)
[array([[0., 0., 0., 0., 0., 1., 0., 0., 0.], [0., 0., 1., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 0., 1.], [0., 0., 0., 0., 0., 0., 0., 0., 1.], [0., 0., 0., 0., 0., 0., 1., 0., 0.], [0., 1., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 0., 1., 0.], [1., 0., 0., 0., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 0., 0., 1., 0., 0.], [0., 0., 0., 1., 0., 0., 0., 0., 0.], [0., 0., 0., 0., 1., 0., 0., 0., 0.]])]
y_data = [test_idx[:]]
print(y_data)
[[5, 2, 8, 8, 6, 1, 7, 0, 6, 3, 4]]
X = torch.FloatTensor(x_one_hot)
y = torch.LongTensor(y_data)
<ipython-input-30-25a499c695e1>:1: UserWarning: Creating a tensor from a list of numpy.ndarrays is extremely slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor. (Triggered internally at ../torch/csrc/utils/tensor_new.cpp:274.) X = torch.FloatTensor(x_one_hot)
rnn = RNN(input_size, hidden_size)
loss_fun = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(rnn.parameters(), lr=0.01)
# outputs, state = rnn(X)
# loss: x, predtion: x(idx), predction_str: x(str)
for i in range(100):
optimizer.zero_grad()
outputs, state = rnn(X)
# 배치사이즈, 시퀀스길이, 히든사이즈
# outputs: (1, 12, 12) -> (12, 12)
loss = loss_fun(outputs.view(-1, input_size), y.view(-1))
loss.backward()
optimizer.step()
result = outputs.data.numpy().argmax(axis=2)
result_str = ''.join([string_set[c] for c in np.squeeze(result)])
print(i, "loss: ", loss.item(), "prediction: ", result, "prediction str: ", result_str)
0 loss: 1.0980031490325928 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 1 loss: 1.0937079191207886 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 2 loss: 1.0894898176193237 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 3 loss: 1.0853471755981445 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 4 loss: 1.0812783241271973 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 5 loss: 1.0772817134857178 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 6 loss: 1.0733556747436523 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 7 loss: 1.0694990158081055 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 8 loss: 1.0657098293304443 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 9 loss: 1.061987042427063 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 10 loss: 1.0583289861679077 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 11 loss: 1.054734468460083 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 12 loss: 1.0512020587921143 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 13 loss: 1.0477302074432373 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 14 loss: 1.0443180799484253 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 15 loss: 1.040963888168335 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 16 loss: 1.0376667976379395 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 17 loss: 1.0344252586364746 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 18 loss: 1.031238317489624 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 19 loss: 1.0281047821044922 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 20 loss: 1.0250232219696045 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 21 loss: 1.021992802619934 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 22 loss: 1.0190120935440063 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 23 loss: 1.0160807371139526 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 24 loss: 1.0131969451904297 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 25 loss: 1.0103600025177002 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 26 loss: 1.0075691938400269 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 27 loss: 1.0048229694366455 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 28 loss: 1.0021209716796875 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 29 loss: 0.9994618892669678 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 30 loss: 0.996845006942749 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 31 loss: 0.9942693710327148 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 32 loss: 0.9917341470718384 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 33 loss: 0.9892386794090271 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 34 loss: 0.9867817163467407 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 35 loss: 0.9843629598617554 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 36 loss: 0.9819812774658203 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 37 loss: 0.979636013507843 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 38 loss: 0.9773265719413757 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 39 loss: 0.9750519394874573 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 40 loss: 0.9728116393089294 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 41 loss: 0.9706049561500549 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 42 loss: 0.968431293964386 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 43 loss: 0.9662896990776062 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 44 loss: 0.9641796350479126 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 45 loss: 0.9621007442474365 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 46 loss: 0.9600520730018616 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 47 loss: 0.9580333232879639 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 48 loss: 0.9560436010360718 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 49 loss: 0.9540826678276062 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 50 loss: 0.9521496295928955 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 51 loss: 0.9502441883087158 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 52 loss: 0.9483657479286194 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 53 loss: 0.9465137720108032 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 54 loss: 0.9446878433227539 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 55 loss: 0.9428872466087341 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 56 loss: 0.9411117434501648 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 57 loss: 0.9393606781959534 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 58 loss: 0.937633752822876 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 59 loss: 0.9359304904937744 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 60 loss: 0.9342501163482666 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 61 loss: 0.9325928092002869 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 62 loss: 0.9309577345848083 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 63 loss: 0.9293444156646729 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 64 loss: 0.9277528524398804 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 65 loss: 0.926182210445404 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 66 loss: 0.9246324300765991 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 67 loss: 0.9231029152870178 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 68 loss: 0.921593427658081 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 69 loss: 0.9201034903526306 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 70 loss: 0.9186330437660217 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 71 loss: 0.9171812534332275 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 72 loss: 0.9157482385635376 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 73 loss: 0.9143335819244385 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 74 loss: 0.912936806678772 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 75 loss: 0.9115576148033142 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 76 loss: 0.9101957678794861 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 77 loss: 0.9088510870933533 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 78 loss: 0.9075231552124023 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 79 loss: 0.9062116146087646 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 80 loss: 0.9049162268638611 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 81 loss: 0.9036367535591125 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 82 loss: 0.9023728966712952 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 83 loss: 0.9011244177818298 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 84 loss: 0.8998911380767822 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 85 loss: 0.8986726403236389 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 86 loss: 0.8974688053131104 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 87 loss: 0.8962791562080383 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 88 loss: 0.8951038718223572 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 89 loss: 0.8939424157142639 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 90 loss: 0.8927945494651794 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 91 loss: 0.8916603326797485 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 92 loss: 0.8905391693115234 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 93 loss: 0.8894311785697937 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 94 loss: 0.8883357644081116 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 95 loss: 0.8872531652450562 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 96 loss: 0.886182963848114 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 97 loss: 0.8851248025894165 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 98 loss: 0.8840788006782532 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word 99 loss: 0.8830446600914001 prediction: [[5 2 8 8 6 1 7 0 6 3 4]] prediction str: hello! word
'코딩 > 자연어 처리' 카테고리의 다른 글
자연어 처리를 위한 모델 학습 (0) | 2024.07.18 |
---|---|
PLM 실습 (0) | 2024.07.18 |
cbow text classification (0) | 2024.07.18 |
워드 임베딩 (0) | 2024.07.18 |
LSTM과 GRU (0) | 2024.07.18 |