python 多分类任务中按照类别分层采样
内容导读
互联网集市收集整理的这篇技术教程文章主要介绍了python 多分类任务中按照类别分层采样,小编现在分享给大家,供广大互联网技能从业者学习和参考。文章包含9099字,纯文字阅读大概需要13分钟。
内容图文
![python 多分类任务中按照类别分层采样](/upload/InfoBanner/zyjiaocheng/782/91d0751361574fe6afbc908b26f5746d.jpg)
在机器学习多分类任务中有时候需要针对类别进行分层采样,比如说类别不均衡的数据,这时候随机采样会造成训练集、验证集、测试集中不同类别的数据比例不一样,这是会在一定程度上影响分类器的性能的,这时候就需要进行分层采样保证训练集、验证集、测试集中每一个类别的数据比例差不多持平。
下面python代码。
![python 多分类任务中按照类别分层采样 - 文章图片](/upload/getfiles/0001/2021/5/5/20210505014744303.jpg)
![python 多分类任务中按照类别分层采样 - 文章图片](/upload/getfiles/0001/2021/5/5/20210505014744426.jpg)
# 将数据按照类别进行分层划分 def save_file_stratified(filename, ssdfile_dir, categories): """ 将文件分流到3个文件中 filename: 原数据地址,一个csv文件 文件内容格式: 类别\t内容 """ f_train = open('../data/usefuldata-711depart/train.txt', 'w', encoding='utf-8') f_val = open('../data/usefuldata-711depart/val.txt', 'w', encoding='utf-8') f_test = open('../data/usefuldata-711depart/test.txt', 'w', encoding='utf-8') # f_class = open('../data/usefuldata-37depart/class.txt', 'w', encoding='utf-8') dict_ssdqw = {} for ssdfile in os.listdir(ssdfile_dir): ssdfile_name = os.path.join(ssdfile_dir, ssdfile) f = open(ssdfile_name, 'r', encoding='utf-8') content_qw = '' content = f.readline() # 以下部分,因为统计整个案件基本情况他有换行,所以将多行处理在一行里面 while content: content_qw += content content_qw = content_qw.replace('\n', '') content = f.readline() ssdfile_key = str(ssdfile).replace('.txt','') dict_ssdqw[ssdfile_key] = content_qw # doc_count代表每一类数据总共有多少个 doc_count_0 = 0 doc_count_1 = 0 doc_count_2 = 0 doc_count_3 = 0 doc_count_4 = 0 doc_count_5 = 0 doc_count_6 = 0 doc_count_7 = 0 doc_count_8 = 0 doc_count_9 = 0 doc_count_10 = 0 doc_count_11 = 0 doc_count_12 = 0 temp_file = open(filename, 'r', encoding='utf-8') line = temp_file.readline() while line: line_content = line.split(',') name = line_content[0] if name in dict_ssdqw: label = line_content[1] if label == categories[0]: doc_count_0 += 1 elif label == categories[1]: doc_count_1 += 1 elif label == categories[2]: doc_count_2 += 1 elif label == categories[3]: doc_count_3 += 1 elif label == categories[4]: doc_count_4 += 1 elif label == categories[5]: doc_count_5 += 1 elif label == categories[6]: doc_count_6 += 1 elif label == categories[7]: doc_count_7 += 1 elif label == categories[8]: doc_count_8 += 1 elif label == categories[9]: doc_count_9 += 1 elif label == categories[10]: doc_count_10 += 1 elif label == categories[11]: doc_count_11 += 1 elif label == categories[12]: doc_count_12 += 1 line = temp_file.readline() temp_file.close() # 总数量 doc_count = doc_count_0 + doc_count_1 + doc_count_2 + doc_count_3 +\ doc_count_4 + doc_count_5 + doc_count_6 + doc_count_7 +\ doc_count_8 + doc_count_9 + doc_count_10 + doc_count_11 + doc_count_12 class_set = set() tag_train_0 = doc_count_0 * 70 / 100 tag_train_1 = doc_count_1 * 70 / 100 tag_train_2 = doc_count_2 * 70 / 100 tag_train_3 = doc_count_3 * 70 / 100 tag_train_4 = doc_count_4 * 70 / 100 tag_train_5 = doc_count_5 * 70 / 100 tag_train_6 = doc_count_6 * 70 / 100 tag_train_7 = doc_count_7 * 70 / 100 tag_train_8 = doc_count_8 * 70 / 100 tag_train_9 = doc_count_9 * 70 / 100 tag_train_10 = doc_count_10 * 70 / 100 tag_train_11= doc_count_11 * 70 / 100 tag_train_12 = doc_count_12 * 70 / 100 tag_val_0 = doc_count_0 * 85 / 100 tag_val_1 = doc_count_1 * 85 / 100 tag_val_2 = doc_count_2 * 85 / 100 tag_val_3 = doc_count_3 * 85 / 100 tag_val_4 = doc_count_4 * 85 / 100 tag_val_5 = doc_count_5 * 85 / 100 tag_val_6 = doc_count_6 * 85 / 100 tag_val_7 = doc_count_7 * 85 / 100 tag_val_8 = doc_count_8 * 85 / 100 tag_val_9 = doc_count_9 * 85 / 100 tag_val_10 = doc_count_10 * 85 / 100 tag_val_11 = doc_count_11 * 85 / 100 tag_val_12 = doc_count_12 * 85 / 100 # tag_test = doc_count * 70 / 100 tag_0 = 0 tag_1 = 0 tag_2 = 0 tag_3 = 0 tag_4 = 0 tag_5 = 0 tag_6 = 0 tag_7 = 0 tag_8 = 0 tag_9 = 0 tag_10 = 0 tag_11 = 0 tag_12 = 0 # 有些文书行业标记是空!!我想看看有多少条? blank_tag = 0 # 标记一下,每个类别有多少个训练集、验证集、测试集? train_class_tag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] val_class_tag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] test_class_tag = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] # csvfile = open(filename, 'r', encoding='utf-8') txtfile = open(filename, 'r', encoding='utf-8') process_line = txtfile.readline() while process_line: line_content = process_line.split(',') name = line_content[0] if name in dict_ssdqw: content = dict_ssdqw[name] label = line_content[1] # if label != '' and label != '其他行业': if label != '': class_set.add(label) # 对每一类进行分层采样 if label == categories[0]: tag_0 += 1 if tag_0 < tag_train_0: f_train.write(label + '\t' + content + '\n') train_class_tag[0] += 1 elif tag_0 < tag_val_0: f_val.write(label + '\t' + content + '\n') val_class_tag[0] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[0] += 1 elif label == categories[1]: tag_1 += 1 if tag_1 < tag_train_1: f_train.write(label + '\t' + content + '\n') train_class_tag[1] += 1 elif tag_1 < tag_val_1: f_val.write(label + '\t' + content + '\n') val_class_tag[1] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[1] += 1 elif label == categories[2]: tag_2 += 1 if tag_2 < tag_train_2: f_train.write(label + '\t' + content + '\n') train_class_tag[2] += 1 elif tag_2 < tag_val_2: f_val.write(label + '\t' + content + '\n') val_class_tag[2] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[2] += 1 elif label == categories[3]: tag_3 += 1 if tag_3 < tag_train_3: f_train.write(label + '\t' + content + '\n') train_class_tag[3] += 1 elif tag_3 < tag_val_3: f_val.write(label + '\t' + content + '\n') val_class_tag[3] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[3] += 1 elif label == categories[4]: tag_4 += 1 if tag_4 < tag_train_4: f_train.write(label + '\t' + content + '\n') train_class_tag[4] += 1 elif tag_4 < tag_val_4: f_val.write(label + '\t' + content + '\n') val_class_tag[4] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[4] += 1 elif label == categories[5]: tag_5 += 1 if tag_5 < tag_train_5: f_train.write(label + '\t' + content + '\n') train_class_tag[5] += 1 elif tag_5 < tag_val_5: f_val.write(label + '\t' + content + '\n') val_class_tag[5] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[5] += 1 elif label == categories[6]: tag_6 += 1 if tag_6 < tag_train_6: f_train.write(label + '\t' + content + '\n') train_class_tag[6] += 1 elif tag_6 < tag_val_6: f_val.write(label + '\t' + content + '\n') val_class_tag[6] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[6] += 1 elif label == categories[7]: tag_7 += 1 if tag_7 < tag_train_7: f_train.write(label + '\t' + content + '\n') train_class_tag[7] += 1 elif tag_7 < tag_val_7: f_val.write(label + '\t' + content + '\n') val_class_tag[7] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[7] += 1 elif label == categories[8]: tag_8 += 1 if tag_8 < tag_train_8: f_train.write(label + '\t' + content + '\n') train_class_tag[8] += 1 elif tag_8 < tag_val_8: f_val.write(label + '\t' + content + '\n') val_class_tag[8] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[8] += 1 elif label == categories[9]: tag_9 += 1 if tag_9 < tag_train_9: f_train.write(label + '\t' + content + '\n') train_class_tag[9] += 1 elif tag_9 < tag_val_9: f_val.write(label + '\t' + content + '\n') val_class_tag[9] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[9] += 1 elif label == categories[10]: tag_10 += 1 if tag_10 < tag_train_10: f_train.write(label + '\t' + content + '\n') train_class_tag[10] += 1 elif tag_10 < tag_val_10: f_val.write(label + '\t' + content + '\n') val_class_tag[10] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[10] += 1 elif label == categories[11]: tag_11 += 1 if tag_11 < tag_train_11: f_train.write(label + '\t' + content + '\n') train_class_tag[11] += 1 elif tag_11 < tag_val_11: f_val.write(label + '\t' + content + '\n') val_class_tag[11] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[11] += 1 elif label == categories[12]: tag_12 += 1 if tag_12 < tag_train_12: f_train.write(label + '\t' + content + '\n') train_class_tag[12] += 1 elif tag_12 < tag_val_12: f_val.write(label + '\t' + content + '\n') val_class_tag[12] += 1 else: f_test.write(label + '\t' + content + '\n') test_class_tag[12] += 1 else: blank_tag += 1 process_line = txtfile.readline() txtfile.close() print("有" + str(blank_tag) + "个文书的行业标记为空!") print("train:") print(train_class_tag) train_tag_total =0 for i_total in train_class_tag: train_tag_total += i_total train_class_tag_distribute = [] for i in train_class_tag: train_class_tag_distribute.append((i / train_tag_total) * 100) print("分布:") print(train_class_tag_distribute) print("val:") print(val_class_tag) val_tag_total = 0 for i_total in val_class_tag: val_tag_total += i_total val_class_tag_distribute = [] for i in val_class_tag: val_class_tag_distribute.append((i / val_tag_total) * 100) print("分布:") print(val_class_tag_distribute) print("test:") print(test_class_tag) test_tag_total = 0 for i_total in test_class_tag: test_tag_total += i_total test_class_tag_distribute = [] for i in test_class_tag: test_class_tag_distribute.append((i / test_tag_total) * 100) print("分布:") print(test_class_tag_distribute) f_train.close() f_test.close() f_val.close() if __name__ == '__main__': categories = [ "class1", "class2", "class3", "class4", "class5", "class6", "class7", "class8", "class9", "class10", "class11", "class12", "class13" ] save_file_stratified('../data/qwdata/shuffle-try3/classified_table_ms.txt', '../data/qwdata/ms-ygscplusssdqw',categories)View Code
后面可以看到类别划分
这里要注意的一点是:这是我早期写的文章,需要注意的一点是,我们通常在训练集和验证集上做分层采样即可,测试集最好保持原样不要动。
内容总结
以上是互联网集市为您收集整理的python 多分类任务中按照类别分层采样全部内容,希望文章能够帮你解决python 多分类任务中按照类别分层采样所遇到的程序开发问题。 如果觉得互联网集市技术教程内容还不错,欢迎将互联网集市网站推荐给程序员好友。
内容备注
版权声明:本文内容由互联网用户自发贡献,该文观点与技术仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容, 请发送邮件至 gblab@vip.qq.com 举报,一经查实,本站将立刻删除。
内容手机端
扫描二维码推送至手机访问。