1. NLI(Natual Language Inference)실습¶
- 두 개의 문장(전제와 가설) 사이의 논리적 관계를 결정하는 자연어 처리 문제
!pip install transformers
Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.41.2) Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.15.4) Requirement already satisfied: huggingface-hub<1.0,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.23.4) Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.25.2) Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (24.1) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1) Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2024.5.15) Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.32.3) Requirement already satisfied: tokenizers<0.20,>=0.19 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.19.1) Requirement already satisfied: safetensors>=0.4.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.4.3) Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.4) Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.0->transformers) (2023.6.0) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.23.0->transformers) (4.12.2) Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.7) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.7) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2024.6.2)
# pipeline: 자연어 처리 작업을 간단한 코드로 여러 작업을 한번에 지원
# AutoTokenizer: 자동으로 적절한 토크나이저를 선택하여 모델의 토큰화를 지원
from transformers import pipeline, AutoTokenizer
classifier = pipeline(
'text-classification',
model = 'Huffon/klue-roberta-base-nli',
return_all_scores = True
)
/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: The secret `HF_TOKEN` does not exist in your Colab secrets. To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session. You will be able to reuse this secret in all of your notebooks. Please note that authentication is recommended but still optional to access public models or datasets. warnings.warn( /usr/local/lib/python3.10/dist-packages/transformers/pipelines/text_classification.py:104: UserWarning: `return_all_scores` is now deprecated, if want a similar functionality use `top_k=None` instead of `return_all_scores=True` or `top_k=1` instead of `return_all_scores=False`. warnings.warn(
tokenizer = AutoTokenizer.from_pretrained('Huffon/klue-roberta-base-nli')
tokenizer
BertTokenizerFast(name_or_path='Huffon/klue-roberta-base-nli', vocab_size=32000, model_max_length=512, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'unk_token': '[UNK]', 'sep_token': '[SEP]', 'pad_token': '[PAD]', 'cls_token': '[CLS]', 'mask_token': '[MASK]'}, clean_up_tokenization_spaces=True), added_tokens_decoder={ 0: AddedToken("[PAD]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 1: AddedToken("[UNK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 2: AddedToken("[CLS]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 3: AddedToken("[SEP]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 4: AddedToken("[MASK]", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), }
tokenizer.sep_token
'[SEP]'
# ENTAILMENT: 첫 번째 문장이 두 번째 문장을 내포하거나 함축한 경우
# NEUTRAL: 두 문장의 관계가 특별히 긍정적이거나 부정적이지 않을 때(모순 관계가 없을 때)
# CONTRADICTION: 명백한 모순 관계가 있을 때
classifier(f'나는 게임을 너무 좋아해 {tokenizer.sep_token} 나는 게임이 너무 싫어')
[[{'label': 'ENTAILMENT', 'score': 0.0003334380453452468}, {'label': 'NEUTRAL', 'score': 0.0004345145425759256}, {'label': 'CONTRADICTION', 'score': 0.9992320537567139}]]
classifier(f'여러 남성들이 축구를 즐기고 있어요 {tokenizer.sep_token} 어떤 남자들은 공을 차고 있어요')
[[{'label': 'ENTAILMENT', 'score': 0.8508433103561401}, {'label': 'NEUTRAL', 'score': 0.14859437942504883}, {'label': 'CONTRADICTION', 'score': 0.0005623317556455731}]]
2. 문장 요약, 번역, 텍스트 생성 실습¶
- BART(Sequential Bidirectional AutoRegressive Transformers)
- 허깅페이스에서 개발한 자연어 처리 모델
- 자연어 생성 및 이해 작업을 위한 사전 훈련된 언어 모델
- 트랜스포머 아키텍처를 기반으로 양방향 인코더-디코더 구조를 가지고 있음
- 문장 요약, 번역, 텍스트 생성에 뛰어난 성능을 보여주는 모델
import torch
from transformers import PreTrainedTokenizerFast
from transformers import BartForConditionalGeneration
# PreTrainedTokenizerFast: 허깅페이스에서 개발한 Rust기반의 고성능 토크나이저
tokenizer = PreTrainedTokenizerFast.from_pretrained('digit82/kobart-summarization')
tokenizer
PreTrainedTokenizerFast(name_or_path='digit82/kobart-summarization', vocab_size=30000, model_max_length=1000000000000000019884624838656, is_fast=True, padding_side='right', truncation_side='right', special_tokens={'bos_token': '<s>', 'eos_token': '</s>', 'unk_token': '<unk>', 'pad_token': '<pad>', 'mask_token': '<mask>'}, clean_up_tokenization_spaces=True), added_tokens_decoder={ 0: AddedToken("<s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 1: AddedToken("</s>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 2: AddedToken("<usr>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 3: AddedToken("<pad>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 4: AddedToken("<sys>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 5: AddedToken("<unk>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), 6: AddedToken("<mask>", rstrip=False, lstrip=False, single_word=False, normalized=False, special=True), ... }
model = BartForConditionalGeneration.from_pretrained('digit82/kobart-summarization')
You passed along `num_labels=3` with an incompatible id to label map: {'0': 'NEGATIVE', '1': 'POSITIVE'}. The number of labels wil be overwritten to 2.
model
BartForConditionalGeneration( (model): BartModel( (shared): Embedding(30000, 768, padding_idx=3) (encoder): BartEncoder( (embed_tokens): BartScaledWordEmbedding(30000, 768, padding_idx=3) (embed_positions): BartLearnedPositionalEmbedding(1028, 768) (layers): ModuleList( (0-5): 6 x BartEncoderLayer( (self_attn): BartSdpaAttention( (k_proj): Linear(in_features=768, out_features=768, bias=True) (v_proj): Linear(in_features=768, out_features=768, bias=True) (q_proj): Linear(in_features=768, out_features=768, bias=True) (out_proj): Linear(in_features=768, out_features=768, bias=True) ) (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (activation_fn): GELUActivation() (fc1): Linear(in_features=768, out_features=3072, bias=True) (fc2): Linear(in_features=3072, out_features=768, bias=True) (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) ) ) (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True) ) (decoder): BartDecoder( (embed_tokens): BartScaledWordEmbedding(30000, 768, padding_idx=3) (embed_positions): BartLearnedPositionalEmbedding(1028, 768) (layers): ModuleList( (0-5): 6 x BartDecoderLayer( (self_attn): BartSdpaAttention( (k_proj): Linear(in_features=768, out_features=768, bias=True) (v_proj): Linear(in_features=768, out_features=768, bias=True) (q_proj): Linear(in_features=768, out_features=768, bias=True) (out_proj): Linear(in_features=768, out_features=768, bias=True) ) (activation_fn): GELUActivation() (self_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (encoder_attn): BartSdpaAttention( (k_proj): Linear(in_features=768, out_features=768, bias=True) (v_proj): Linear(in_features=768, out_features=768, bias=True) (q_proj): Linear(in_features=768, out_features=768, bias=True) (out_proj): Linear(in_features=768, out_features=768, bias=True) ) (encoder_attn_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) (fc1): Linear(in_features=768, out_features=3072, bias=True) (fc2): Linear(in_features=3072, out_features=768, bias=True) (final_layer_norm): LayerNorm((768,), eps=1e-05, elementwise_affine=True) ) ) (layernorm_embedding): LayerNorm((768,), eps=1e-05, elementwise_affine=True) ) ) (lm_head): Linear(in_features=768, out_features=30000, bias=False) )
text = '''
기후 관련 신기록들이 나오고 있다. 지난 6월 5일 '세계기상기구'(WMO)는 향후 5년간 산업혁명 전 대비 지구표면 온도가 1.1℃에서 1.9℃까지 상승할 수 있다고 발표했다. 기후위기의 마지노선인 1.5℃ 상승이 앞으로 5년간 계속될 확률도 47%라고 한다. 1.5℃ 상승제한을 결정했던 2015년 유엔 파리기후협정 당시만 해도 2030년까지 1.5℃ 오를 가능성은 0%였다. 10년도 안 되어 사실상 기후 관련 국제사회의 결정이 무너지고 있다.
지난 6월 28일 발생해 시속 270km의 강풍을 동반, 최고등급인 5등급으로 발달한 허리케인 '베릴'(Beryl)이 중남미 카리브해의 나라들을 파괴했다. 7월 3일 영국 BBC방송은 베릴에 대해 지난 100년간의 허리케인 기록을 깼고, 현재 진행되는 기후변화의 위험성이 집중 조명된 사례로 보도했다. 지금까지 허리케인은 해수면 온도가 높아진 8월 말부터 발생했는데, 6월에 발생한 허리케인 베릴은 기후위기가 만든 최초 기록이라는 것이다.
'''
text = text.replace('\n', '')
raw_input_ids = tokenizer.encode(text)
print(raw_input_ids)
[21009, 14342, 14094, 17469, 14108, 17427, 15964, 14141, 15328, 16415, 14063, 15666, 9264, 11224, 17842, 18223, 286, 276, 278, 14831, 16090, 22030, 14982, 22635, 14038, 15171, 17581, 13497, 10586, 14488, 14973, 14035, 15275, 370, 266, 14030, 14035, 16157, 370, 266, 14129, 14919, 13594, 14032, 14857, 14615, 15615, 21009, 11973, 15525, 14087, 24828, 11268, 12037, 25423, 370, 266, 14919, 12034, 15023, 22030, 28851, 14181, 10432, 9866, 20665, 236, 14218, 19553, 25423, 370, 266, 14919, 18010, 12007, 14657, 14622, 17655, 17650, 14240, 16287, 13756, 21761, 14616, 10500, 15766, 14182, 250, 14515, 14129, 25423, 370, 266, 22727, 21452, 1700, 14859, 20029, 16889, 9866, 14105, 15380, 16500, 21009, 14342, 14725, 20516, 26193, 18413, 14429, 15964, 25560, 15328, 16838, 27116, 28827, 15433, 247, 16413, 12024, 14119, 23073, 17554, 243, 14614, 17503, 12037, 14144, 17503, 14027, 19104, 13590, 22226, 12934, 12037, 14063, 10849, 10482, 18223, 265, 14879, 320, 307, 15019, 14059, 9506, 10746, 14471, 10476, 11007, 13607, 12024, 15147, 14282, 18782, 15615, 15513, 16356, 15609, 15085, 15571, 15737, 12005, 14661, 10482, 11786, 14225, 14141, 14854, 15387, 12024, 22226, 12934, 12037, 20174, 1700, 9313, 14161, 14475, 21272, 21009, 10869, 16975, 15532, 14439, 15111, 22353, 9908, 15852, 10338, 15192, 15615, 16265, 22226, 12934, 14826, 23074, 10586, 14488, 14973, 26541, 15639, 28150, 14669, 20973, 15328, 11786, 17419, 22226, 12934, 12037, 14661, 10482, 12005, 21009, 11973, 14482, 15550, 15592, 14557, 14394, 16746]
input_ids = [tokenizer.bos_token_id] + raw_input_ids + [tokenizer.eos_token_id]
print(input_ids)
[0, 21009, 14342, 14094, 17469, 14108, 17427, 15964, 14141, 15328, 16415, 14063, 15666, 9264, 11224, 17842, 18223, 286, 276, 278, 14831, 16090, 22030, 14982, 22635, 14038, 15171, 17581, 13497, 10586, 14488, 14973, 14035, 15275, 370, 266, 14030, 14035, 16157, 370, 266, 14129, 14919, 13594, 14032, 14857, 14615, 15615, 21009, 11973, 15525, 14087, 24828, 11268, 12037, 25423, 370, 266, 14919, 12034, 15023, 22030, 28851, 14181, 10432, 9866, 20665, 236, 14218, 19553, 25423, 370, 266, 14919, 18010, 12007, 14657, 14622, 17655, 17650, 14240, 16287, 13756, 21761, 14616, 10500, 15766, 14182, 250, 14515, 14129, 25423, 370, 266, 22727, 21452, 1700, 14859, 20029, 16889, 9866, 14105, 15380, 16500, 21009, 14342, 14725, 20516, 26193, 18413, 14429, 15964, 25560, 15328, 16838, 27116, 28827, 15433, 247, 16413, 12024, 14119, 23073, 17554, 243, 14614, 17503, 12037, 14144, 17503, 14027, 19104, 13590, 22226, 12934, 12037, 14063, 10849, 10482, 18223, 265, 14879, 320, 307, 15019, 14059, 9506, 10746, 14471, 10476, 11007, 13607, 12024, 15147, 14282, 18782, 15615, 15513, 16356, 15609, 15085, 15571, 15737, 12005, 14661, 10482, 11786, 14225, 14141, 14854, 15387, 12024, 22226, 12934, 12037, 20174, 1700, 9313, 14161, 14475, 21272, 21009, 10869, 16975, 15532, 14439, 15111, 22353, 9908, 15852, 10338, 15192, 15615, 16265, 22226, 12934, 14826, 23074, 10586, 14488, 14973, 26541, 15639, 28150, 14669, 20973, 15328, 11786, 17419, 22226, 12934, 12037, 14661, 10482, 12005, 21009, 11973, 14482, 15550, 15592, 14557, 14394, 16746, 1]
summary_ids = model.generate(torch.tensor([input_ids]), num_beams=4, max_length=512, eos_token_id=1)
print(summary_ids)
tensor([[ 2, 14141, 15328, 16415, 14063, 15666, 9264, 11224, 17842, 18223, 286, 276, 278, 14831, 16090, 22030, 14982, 22635, 14038, 15171, 17581, 13497, 10586, 14488, 14973, 14035, 15275, 370, 266, 14030, 14035, 16157, 370, 266, 14129, 14919, 13594, 14032, 14857, 14615, 17432, 16090, 22030, 14982, 22635, 14038, 15171, 17581, 13497, 10586, 14488, 14973, 14035, 15275, 370, 266, 14030, 14035, 16157, 370, 266, 14129, 14919, 13594, 14032, 14857, 14615, 15615, 1]])
tokenizer.decode(summary_ids.squeeze().tolist(), skip_special_tokens=True)
"지난 6월 5일 '세계기상기구'(WMO)는 향후 5년간 산업혁명 전 대비 지구표면 온도가 1.1°C에서 1.9°C까지 상승할 수 있다고 발표했으며, 향후 5년간 산업혁명 전 대비 지구표면 온도가 1.1°C에서 1.9°C까지 상승할 수 있다고 발표했다."
3. KLUE(Korean Language Understanding Evaluation)¶
- 한국어 자연어 이해 평가 데이터셋
- 한국어 언어모델의 공정한 평가를 위한 목적으로 8개의 종류가 포함된 공개 데이터셋
- 뉴스 헤드라인 분류
- 문장 유사도 비교
- 자연어 추론
- 개체명 인식
- 관계 추출
- 형태소 및 의존 구문 분석
- 기계 독해 이해
- 대화 상태 추적
- 광범위한 주제와 다양한 스카일을 포괄하기 위해 다양한 출처에서 공개적으로 사용 가능한 한국어 말뭉치를 수집
- 약 62GB 크기의 최종 사전 학습 코퍼스를 구축
- MODU(국립국어원), CC-100-Kor(CC-Net을 사용한 다국어 웹 크롤링), 나무위키(웹 기반 백과사전), 뉴스스크롤(뉴스 집계 플랫폼), 청원(청와대 국민청원의 기사) 등
- ['ynat', 'sts', 'nli', 'ner', 're', 'dp', 'mrc', 'wos']
- ynat: 유튜브 비디오 댓글에서 자연스럽게 발생하는 대화 데이터를 이용한 태스크, 주어진 문장에 대해 답변하는 작업
- sts: 두 텍스트의 의미적 유사성을 평가하는 태스크
- nli: 전제와 가설이라는 두 문장 간의 논리적 관계를 판별하는 태스크(참, 거짓, 중립)
- ner: 문장에서 인명, 지명, 기관명 등 특정 개체명을 식별하고 분류하는 태스크
- re: 문장 또는 텍스트에서 개체들 간의 관계를 추출하는 태스크
- 예) "스티브잡스는 애플의 공동 창립자이다" -> 스티브잡스와 애플을 뽑아내고 그 관계를 알려줌
- dp: 문장 내 단어들 간의 문법적 관계를 파악하는 태스크. 각 단어가 어떻게 다른 단어들과 연결되어 있는지(종속성)를 구조적으로 분석
- mrc: 주어진 텍스트와 질문에 대해 답을 추출하거나 생성하는 태스크. 모델이 텍스트를 읽고 이해하여 질문에 답할 수 있도록 함
- wos: Web of Science 데이터 베이스에서 추출한 데이터를 활용하는 태스크. 특정 태스크보다는 데이터 출처를 나타냄
!pip install datasets
Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.20.0) Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from datasets) (3.15.4) Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from datasets) (1.25.2) Requirement already satisfied: pyarrow>=15.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (16.1.0) Requirement already satisfied: pyarrow-hotfix in /usr/local/lib/python3.10/dist-packages (from datasets) (0.6) Requirement already satisfied: dill<0.3.9,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.8) Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (2.0.3) Requirement already satisfied: requests>=2.32.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (2.32.3) Requirement already satisfied: tqdm>=4.66.3 in /usr/local/lib/python3.10/dist-packages (from datasets) (4.66.4) Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.4.1) Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.16) Requirement already satisfied: fsspec[http]<=2024.5.0,>=2023.1.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0) Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.9.5) Requirement already satisfied: huggingface-hub>=0.21.2 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.23.4) Requirement already satisfied: packaging in /usr/local/lib/python3.10/dist-packages (from datasets) (24.1) Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (6.0.1) Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1) Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.2.0) Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.1) Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.5) Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.4) Requirement already satisfied: async-timeout<5.0,>=4.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3) Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub>=0.21.2->datasets) (4.12.2) Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.3.2) Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (3.7) Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2.0.7) Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests>=2.32.2->datasets) (2024.6.2) Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2) Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.4) Requirement already satisfied: tzdata>=2022.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2024.1) Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)
import datasets
from datasets import load_dataset, load_metric
import random
import pandas as pd
import numpy as np
from transformers import AutoModelForSequenceClassification, TrainingArguments, Trainer
model_checkpoint = 'klue/roberta-base'
batch_size = 64
task='ynat'
klue_datasets = load_dataset('klue', task)
klue_datasets['train'][0]
{'guid': 'ynat-v1_train_00000', 'title': '유튜브 내달 2일까지 크리에이터 지원 공간 운영', 'label': 3, 'url': 'https://news.naver.com/main/read.nhn?mode=LS2D&mid=shm&sid1=105&sid2=227&oid=001&aid=0008508947', 'date': '2016.06.30. 오전 10:36'}
def show_random_elements(dataset, num_examples=10):
picks = []
for _ in range(num_examples):
pick = random.randint(0, len(dataset)-1)
while pick in picks:
pick = random.randint(0, len(dataset)-1)
picks.append(pick)
df = pd.DataFrame(dataset[picks])
return df
show_random_elements(klue_datasets['train'])
guid | title | label | url | date | |
---|---|---|---|---|---|
0 | ynat-v1_train_28475 | 현대차 판매·매출·수익 모두 후진…회복열쇠는 신차 | 1 | https://news.naver.com/main/read.nhn?mode=LS2D... | 2018.04.26. 오후 3:37 |
1 | ynat-v1_train_17064 | 채식과 육식의 황금비율은…신간 채소의 인문학 | 3 | https://news.naver.com/main/read.nhn?mode=LS2D... | 2017.06.04. 오전 11:54 |
2 | ynat-v1_train_41899 | 그래픽 코스피·코스닥 추이 | 1 | https://news.naver.com/main/read.nhn?mode=LS2D... | 2019.08.02. 오후 4:00 |
3 | ynat-v1_train_17000 | 현대로템 현대차투자증권에 200억원 채권 매도 | 1 | https://news.naver.com/main/read.nhn?mode=LS2D... | 2018.05.28. 오전 9:17 |
4 | ynat-v1_train_28740 | 오재일 홈런포 두 방 두산 어린이날 3연전 스윕…LG... | 5 | https://sports.news.naver.com/news.nhn?oid=001... | 2018.05.06 18:23 |
5 | ynat-v1_train_38880 | NBA 웨스트브룩 통산 100호 트리플더블…역대 4번째 | 5 | https://sports.news.naver.com/news.nhn?oid=001... | 2018.03.14 11:32 |
6 | ynat-v1_train_17335 | 北 김정은 인민군 오직 내가 가리키는 방향으로만 가야종합 | 6 | https://news.naver.com/main/read.nhn?mode=LS2D... | 2016.02.04. 오전 10:43 |
7 | ynat-v1_train_43776 | 北 김정일이 폐기한 공산주의 용어 다시 사용 | 6 | https://news.naver.com/main/read.nhn?mode=LS2D... | 2016.04.08. 오전 9:02 |
8 | ynat-v1_train_09672 | 찌워야 산다…류현진·강정호 몸무게·근육량 증가 | 5 | https://sports.news.naver.com/news.nhn?oid=001... | 2019.02.27 06:00 |
9 | ynat-v1_train_06297 | 러 중앙선관위 총선 최종 개표 결과 발표…여당 76% 득표 압승 | 4 | https://news.naver.com/main/read.nhn?mode=LS2D... | 2016.09.23. 오후 4:22 |
fake_preds = np.random.randint(0, 2, size=(64,))
fake_labels = np.random.randint(0, 2, size=(64,))
fake_preds, fake_labels
(array([1, 1, 0, 0, 0, 1, 1, 0, 0, 1, 0, 1, 1, 0, 1, 1, 1, 1, 0, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 1, 0, 1, 1, 1, 1, 1, 1, 0, 1, 1, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0]), array([0, 1, 0, 0, 1, 1, 1, 0, 0, 1, 1, 0, 1, 0, 0, 0, 1, 1, 1, 1, 0, 0, 1, 1, 0, 1, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 1, 0, 1]))
metric = load_metric('f1')
<ipython-input-26-b37dd255292e>:1: FutureWarning: load_metric is deprecated and will be removed in the next major version of datasets. Use 'evaluate.load' instead, from the new library 🤗 Evaluate: https://huggingface.co/docs/evaluate metric = load_metric('f1')
metric.compute(predictions=fake_preds, references=fake_labels)
{'f1': 0.5396825396825397}
tokenizer = AutoTokenizer.from_pretrained(model_checkpoint, use_fast=True)
tokenizer('"평생 처음" 주민들의 충격 증언... 대한민국 곳곳 이상징후')
{'input_ids': [0, 6, 5577, 3790, 6, 3972, 2031, 2079, 5326, 8105, 18, 18, 18, 4892, 5844, 3658, 2976, 2158, 2], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}
print(tokenizer.cls_token_id, tokenizer.eos_token_id)
0 2
print(f"문장1: {klue_datasets['train'][0]['title']}")
문장1: 유튜브 내달 2일까지 크리에이터 지원 공간 운영
def preprocess_function(examples):
return tokenizer(
examples['title'],
truncation=True, # 최대 길이를 초과할 경우 초과된 부분을 잘라냄
return_token_type_ids=False
)
preprocess_function(klue_datasets['train'][:5])
{'input_ids': [[0, 10637, 8474, 22, 2210, 2299, 2118, 28940, 3691, 4101, 3792, 2], [0, 24905, 1042, 4795, 19982, 2129, 121, 6904, 16311, 3, 14392, 2], [0, 4172, 3797, 3728, 2107, 2134, 3777, 904, 6022, 2332, 2113, 2259, 4523, 1380, 2259, 2062, 2], [0, 12417, 2155, 7840, 604, 2859, 3873, 11554, 2522, 1539, 2073, 8446, 6626, 18818, 575, 2], [0, 13203, 2179, 2366, 4197, 7551, 2096, 8542, 2088, 2353, 886, 1244, 4393, 2027, 22, 2207, 8189, 2]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}
encoded_datasets = klue_datasets.map(preprocess_function, batched=True)
num_labels = 7
# AutoModelForSequenceClassification
# 시퀀스 분류 작업(예: 감정분석, 텍스트 분류)
# 다양한 모델(BERT, RoBERTa 등) 지원
model = AutoModelForSequenceClassification.from_pretrained(model_checkpoint, num_labels=num_labels)
Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at klue/roberta-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
def compute_metrics(eval_pred):
predictions, labels = eval_pred
predictions = np.argmax(predictions, axis=1)
return metric.compute(predictions=predictions, references=labels, average='macro')
!pip install accelerate>=0.21.0
metric_name = 'f1'
args = TrainingArguments(
'test-tc',
# 평가전략: 에폭이 끝날 때마다 평가를 수행
evaluation_strategy='epoch',
# 모델 체크포인트를 저장. 매 에포크마다 저장
save_strategy='epoch',
learning_rate=2e-5,
per_device_train_batch_size=batch_size,
per_device_eval_batch_size=batch_size,
num_train_epochs=5,
# 가중치 감소의 설정
weight_decay=0.01,
# 학습이 끝난 후 가장 좋은 성능을 보인 모델을 불러올지 여부
load_best_model_at_end=True,
# 최적의 모델을 결정할 때 사용할 매트릭을 지정
metric_for_best_model=metric_name
)
/usr/local/lib/python3.10/dist-packages/transformers/training_args.py:1474: FutureWarning: `evaluation_strategy` is deprecated and will be removed in version 4.46 of 🤗 Transformers. Use `eval_strategy` instead warnings.warn(
trainer = Trainer(
model,
args,
train_dataset=encoded_datasets['train'],
eval_dataset=encoded_datasets['validation'],
tokenizer=tokenizer,
compute_metrics=compute_metrics
)
trainer.train()
[1185/3570 4:47:11 < 9:38:59, 0.07 it/s, Epoch 1.66/5]
Epoch | Training Loss | Validation Loss | F1 |
---|---|---|---|
1 | 0.501900 | 0.387981 | 0.865757 |
trainer.evaluate()
best_model_checkpoint = trainer.satate.best_model_checkpoint
best_model_checkpoint
classifier = pipeline(
'text-classification',
model='./test-tc/checkpoint-2142',
return_all_scores=True
)
'''
0: IT과학
1: 경제
2: 사회
3: 문화
4: 스포츠
5: 정치
6: 연예
'''
classifier(f'"평생 처음" 주민들의 충격 증언... 대한민국 곳곳 이상징후')
'코딩 > 자연어 처리' 카테고리의 다른 글
워드 임베딩 시각화 (1) | 2024.07.18 |
---|---|
자연어 처리를 위한 모델 학습 (0) | 2024.07.18 |
RNN 기초 (0) | 2024.07.18 |
cbow text classification (0) | 2024.07.18 |
워드 임베딩 (0) | 2024.07.18 |