|
有关中文编码的知识详见:【中文编码】利用bert-base-chinese中的Tokenizer实现中文编码嵌入所有代码、数据集:下载仓库预训练中文Bert:bert-base-chinese镜像下载下载后文件夹中包含:1、bert_get_data.py 完成数据集与模型准备:importpandasaspdfromtorch.utils.dataimportDataset,DataLoaderfromtransformersimportBertTokenizerfromtorchimportnnfromtransformersimportBertModelbert_name='./bert-base-chinese'tokenizer=BertTokenizer.from_pretrained(bert_name)classMyDataset(Dataset):def__init__(self,df):#tokenizer分词后可以被自动汇聚self.texts=[tokenizer(text,padding='max_length',#填充到最大长度max_length=35, #经过数据分析,最大长度为35truncation=True,return_tensors="pt")fortextindf['text']]#Dataset会自动返回Tensorself.labels=[labelforlabelindf['label']]def__getitem__(self,idx):returnself.texts[idx],self.labels[idx]def__len__(self):returnlen(self.labels)classBertClassifier(nn.Module):def__init__(self):super(BertClassifier,self).__init__()self.bert=BertModel.from_pretrained(bert_name)self.dropout=nn.Dropout(0.5)self.linear=nn.Linear(768,10)self.relu=nn.ReLU()defforward(self,input_id,mask):_,pooled_output=self.bert(input_ids=input_id,attention_mask=mask,return_dict=False)dropout_output=self.dropout(pooled_output)linear_output=self.linear(dropout_output)final_layer=self.relu(linear_output)returnfinal_layerdefGenerateData(mode):train_data_path='./THUCNews/data/train.txt'dev_data_path='./THUCNews/data/dev.txt'test_data_path='./THUCNews/data/test.txt'train_df=pd.read_csv(train_data_path,sep='\t',header=None)dev_df=pd.read_csv(dev_data_path,sep='\t',header=None)test_df=pd.read_csv(test_data_path,sep='\t',header=None)new_columns=['text','label']train_df=train_df.rename(columns=dict(zip(train_df.columns,new_columns)))dev_df=dev_df.rename(columns=dict(zip(dev_df.columns,new_columns)))test_df=test_df.rename(columns=dict(zip(test_df.columns,new_columns)))train_dataset=MyDataset(train_df)dev_dataset=MyDataset(dev_df)test_dataset=MyDataset(test_df)ifmode=='train':returntrain_datasetelifmode=='val':returndev_datasetelifmode=='test':returntest_dataset1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465662、bert_train.py 实现模型训练:importtorchfromtorchimportnnfromtorch.optimimportAdamfromtqdmimporttqdmimportnumpyasnpimportpandasaspdimportrandomimportosfromtorch.utils.dataimportDataset,DataLoaderfrombert_get_dataimportBertClassifier,MyDataset,GenerateDatadefsetup_seed(seed):torch.manual_seed(seed)torch.cuda.manual_seed_all(seed)np.random.seed(seed)random.seed(seed)torch.backends.cudnn.deterministic=Truedefsave_model(save_name):ifnotos.path.exists(save_path)s.makedirs(save_path)torch.save(model.state_dict(),os.path.join(save_path,save_name))#训练超参数epoch=5batch_size=64lr=1e-5device=torch.device("cuda:0"iftorch.cuda.is_available()else"cpu")random_seed=20240121save_path='./bert_checkpoint'setup_seed(random_seed)#定义模型model=BertClassifier()#定义损失函数和优化器criterion=nn.CrossEntropyLoss()optimizer=Adam(model.parameters(),lr=lr)model=model.to(device)criterion=criterion.to(device)#构建数据集train_dataset=GenerateData(mode='train')dev_dataset=GenerateData(mode='val')train_loader=DataLoader(train_dataset,batch_size=batch_size,shuffle=True)dev_loader=DataLoader(dev_dataset,batch_size=batch_size)#训练best_dev_acc=0forepoch_numinrange(epoch):total_acc_train=0total_loss_train=0forinputs,labelsintqdm(train_loader):input_ids=inputs['input_ids'].squeeze(1).to(device)#torch.Size([32,35])masks=inputs['attention_mask'].to(device)#torch.Size([32,1,35])labels=labels.to(device)output=model(input_ids,masks)batch_loss=criterion(output,labels)batch_loss.backward()optimizer.step()optimizer.zero_grad()acc=(output.argmax(dim=1)==labels).sum().item()total_acc_train+=acctotal_loss_train+=batch_loss.item()#-----------验证模型-----------model.eval()total_acc_val=0total_loss_val=0withtorch.no_grad():forinputs,labelsindev_loader:input_ids=inputs['input_ids'].squeeze(1).to(device)#torch.Size([32,35])masks=inputs['attention_mask'].to(device)#torch.Size([32,1,35])labels=labels.to(device)output=model(input_ids,masks)batch_loss=criterion(output,labels)acc=(output.argmax(dim=1)==labels).sum().item()total_acc_val+=acctotal_loss_val+=batch_loss.item()print(f'''Epochs:{epoch_num+1}|TrainLoss:{total_loss_train/len(train_dataset):.3f}|TrainAccuracy:{total_acc_train/len(train_dataset):.3f}|ValLoss:{total_loss_val/len(dev_dataset):.3f}|ValAccuracy:{total_acc_val/len(dev_dataset):.3f}''')#保存最优的模型iftotal_acc_val/len(dev_dataset)>best_dev_acc:best_dev_acc=total_acc_val/len(dev_dataset)save_model('best.pt')model.train()#保存最后的模型save_model('last.pt')123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899 训练过程输出:3、bert_test.py 实现模型测试:importosimporttorchfrombert_get_dataimportBertClassifier,GenerateDatafromtorch.utils.dataimportDataLoaderdevice=torch.device("cuda:0"iftorch.cuda.is_available()else"cpu")save_path='./bert_checkpoint'model=BertClassifier()model.load_state_dict(torch.load(os.path.join(save_path,'best.pt')))model=model.to(device)model.eval()defevaluate(model,dataset):model.eval()test_loader=DataLoader(dataset,batch_size=128)total_acc_test=0withtorch.no_grad():fortest_input,test_labelintest_loader:input_id=test_input['input_ids'].squeeze(1).to(device)mask=test_input['attention_mask'].to(device)test_label=test_label.to(device)output=model(input_id,mask)acc=(output.argmax(dim=1)==test_label).sum().item()total_acc_test+=accprint(f'TestAccuracy:{total_acc_test/len(dataset):.3f}')test_dataset=GenerateData(mode="test")evaluate(model,test_dataset)12345678910111213141516171819202122232425262728 模型测试结果:4、bert_tuili.py 实现模型交互式推理:importosfromtransformersimportBertTokenizerimporttorchfrombert_get_dataimportBertClassifierbert_name='./bert-base-chinese'tokenizer=BertTokenizer.from_pretrained(bert_name)device=torch.device("cuda:0"iftorch.cuda.is_available()else"cpu")save_path='./bert_checkpoint'model=BertClassifier()model.load_state_dict(torch.load(os.path.join(save_path,'best.pt')))model=model.to(device)model.eval()real_labels=[]withopen('./THUCNews/data/class.txt','r')asf:forrowinf.readlines():real_labels.append(row.strip())whileTrue:text=input('请输入新闻:')bert_input=tokenizer(text,padding='max_length',max_length=35,truncation=True,return_tensors="pt")input_ids=bert_input['input_ids'].to(device)masks=bert_input['attention_mask'].unsqueeze(1).to(device)output=model(input_ids,masks)pred=output.argmax(dim=1)print(real_labels[pred])12345678910111213141516171819202122232425262728293031 交互测试展示:参考:微调BERT进行中文文本分类任务(Pytorch代码实现)
|
|