Kaggle知识点:使用BERT完成NLP任务

发表时间:2023-06-14 11:29作者:沃恩智慧

在日常生活中新闻具备有多的信息,在AINWIN互联网舆情企业风险事件的识别和预警 比赛中参赛选手需要根据新闻识别主体新闻类型

比赛官网(报名即可下载数据集):http://ailab.aiwin.org.cn/competitions/48

比赛给定了的1w条左右的训练数据,以及部分企业主体名称的汇总。接下来我们看如何一步一步完成本场比赛的。

比赛思路

对比赛要求进行分析后,可以将赛题任务具体划分为:

  • 任务1:企业主体抽取:抽取出新闻中主要的企业名称,并与完整的企业名称进行对应;(NER任务)
  • 任务2:新闻类型分类:根据新闻的内容将新闻的类型进行具体分类;(文本分类任务)

完成思路:

  • 将任务1和任务2,分开完成;
  • 将任务1和任务2,一起用Bert建模;

任务1:使用TFIFD完成

对文本进行分词:

import jieba

def strcut(s):
    seg_list = jieba.cut(s)
    return ' '.join(list(seg_list))

train_title = train_data['NEWS_TITLE'].apply(strcut)

TFIDF + 线性模型:


from sklearn.feature_extraction.text import TfidfVectorizer

tfidf = TfidfVectorizer(ngram_range=(1,1))
train_title_ttidf = tfidf.fit_transform(train_title)

验证集分类精度约89%。

任务1:使用BERT分类

进行token处理:

from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')
train_encoding = tokenizer(list(tr_x), truncation=True, padding=True, max_length=128)
val_encoding = tokenizer(list(val_x), truncation=True, padding=True, max_length=128)

读取模型并定义优化器:

import torch
from transformers import AutoModelForSequenceClassification, AdamW, get_linear_schedule_with_warmup
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model.to(device)

optim = AdamW(model.parameters(), lr=5e-5)
total_steps = len(train_loader) * 1

任务2:使用正则匹配

通过公司主体与数据集中字符串匹配:

for row in train_data.iloc[:100].iterrows():
    match1 = company_name[company_name['name'].apply(lambda x: x in row[1].NEWS_TITLE)]
    if match1.shape[0] > 0:
        match1.loc[:, 'name_len'] = match1['name'].apply(len)
        match1 = match1.sort_values(by='name_len')
        match1 = match1.iloc[-1]['name']
    else:
        match1 = ''
   
    match2 = company_name[company_name['name_short'].apply(lambda x: x in row[1].NEWS_TITLE)]
    if match2.shape[0] > 0 and match1 == '':
        match2.loc[:, 'name_len'] = match2['name_short'].apply(len)
        match2 = match2.sort_values(by='name_len')
        match2 = match2.iloc[-1]['name']
    else:
        match2 = ''

识别结果:

标题: 东阳光(600673.SH):控股股东一致行动人宜昌东阳光药业质押2500万股
主体标签: 宜昌东阳光药业股份有限公司
主体识别结果: 宜昌东阳光药业股份有限公司

标题: 千亿市值蒸发超九成,康美药业财务造假坑了谁?
主体标签: 康美药业股份有限公司
主体识别结果: 康美药业股份有限公司

任务2:BERT 序列标注

加载BERT序列标注模型:

import torch
from transformers import BertForTokenClassification, AdamW, get_linear_schedule_with_warmup
model = BertForTokenClassification.from_pretrained('bert-base-chinese', num_labels=7)

device = torch.device("cuda:1" if torch.cuda.is_available() else "cpu")
model.to(device)

识别结果:

标题:山东省政府办公厅就平邑县玉荣商贸有限公司石膏矿坍塌事故发出通报
机构: 山东省政府办公厅
机构: 平邑县玉荣商贸有限公司

标题:[新闻直播间]黑龙江:龙煤集团一煤矿发生火灾事故
位置: 黑龙江
机构: 龙煤集团

代码&数据

比赛报名地址:http://ailab.aiwin.org.cn/competitions/48

文章完整代码:https://github.com/datawhalechina/competition-baseline


分享到: