From 55ea0a524ad8c4e5f59127cdc7be81da24ef3771 Mon Sep 17 00:00:00 2001 From: wuzekai <3025054974@qq.com> Date: Tue, 8 Jul 2025 04:40:22 +0000 Subject: [PATCH] =?UTF-8?q?=E5=88=86=E5=89=B2=E6=95=B0=E6=8D=AE=E9=9B=86?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- gen_ocr_train_val_test.py | 151 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 151 insertions(+) create mode 100644 gen_ocr_train_val_test.py diff --git a/gen_ocr_train_val_test.py b/gen_ocr_train_val_test.py new file mode 100644 index 0000000..5415e7f --- /dev/null +++ b/gen_ocr_train_val_test.py @@ -0,0 +1,151 @@ +# coding:utf8 +import os +import shutil +import random +import argparse + + +# 删除划分的训练集、验证集、测试集文件夹,重新创建一个空的文件夹 +def isCreateOrDeleteFolder(path, flag): + flagPath = os.path.join(path, flag) + + if os.path.exists(flagPath): + shutil.rmtree(flagPath) + + os.makedirs(flagPath) + flagAbsPath = os.path.abspath(flagPath) + return flagAbsPath + + +def splitTrainVal(root, absTrainRootPath, absValRootPath, absTestRootPath, trainTxt, valTxt, testTxt, flag): + # 按照指定的比例划分训练集、验证集、测试集 + dataAbsPath = os.path.abspath(root) + + if flag == "det": + labelFilePath = os.path.join(dataAbsPath, args.detLabelFileName) + elif flag == "rec": + labelFilePath = os.path.join(dataAbsPath, args.recLabelFileName) + + labelFileRead = open(labelFilePath, "r", encoding="UTF-8") + labelFileContent = labelFileRead.readlines() + random.shuffle(labelFileContent) + labelRecordLen = len(labelFileContent) + + for index, labelRecordInfo in enumerate(labelFileContent): + imageRelativePath = labelRecordInfo.split('\t')[0] + imageLabel = labelRecordInfo.split('\t')[1] + imageName = os.path.basename(imageRelativePath) + + if flag == "det": + imagePath = os.path.join(dataAbsPath, imageName) + elif flag == "rec": + imagePath = os.path.join(dataAbsPath, "{}/{}".format(args.recImageDirName, imageName)) + + # 按预设的比例划分训练集、验证集、测试集 + trainValTestRatio = args.trainValTestRatio.split(":") + trainRatio = eval(trainValTestRatio[0]) / 10 + valRatio = trainRatio + eval(trainValTestRatio[1]) / 10 + curRatio = index / labelRecordLen + + if curRatio < trainRatio: + imageCopyPath = os.path.join(absTrainRootPath, imageName) + shutil.copy(imagePath, imageCopyPath) + trainTxt.write("{}\t{}".format(imageCopyPath, imageLabel)) + elif curRatio >= trainRatio and curRatio < valRatio: + imageCopyPath = os.path.join(absValRootPath, imageName) + shutil.copy(imagePath, imageCopyPath) + valTxt.write("{}\t{}".format(imageCopyPath, imageLabel)) + else: + imageCopyPath = os.path.join(absTestRootPath, imageName) + shutil.copy(imagePath, imageCopyPath) + testTxt.write("{}\t{}".format(imageCopyPath, imageLabel)) + + +# 删掉存在的文件 +def removeFile(path): + if os.path.exists(path): + os.remove(path) + + +def genDetRecTrainVal(args): + detAbsTrainRootPath = isCreateOrDeleteFolder(args.detRootPath, "train") + detAbsValRootPath = isCreateOrDeleteFolder(args.detRootPath, "val") + detAbsTestRootPath = isCreateOrDeleteFolder(args.detRootPath, "test") + recAbsTrainRootPath = isCreateOrDeleteFolder(args.recRootPath, "train") + recAbsValRootPath = isCreateOrDeleteFolder(args.recRootPath, "val") + recAbsTestRootPath = isCreateOrDeleteFolder(args.recRootPath, "test") + + removeFile(os.path.join(args.detRootPath, "train.txt")) + removeFile(os.path.join(args.detRootPath, "val.txt")) + removeFile(os.path.join(args.detRootPath, "test.txt")) + removeFile(os.path.join(args.recRootPath, "train.txt")) + removeFile(os.path.join(args.recRootPath, "val.txt")) + removeFile(os.path.join(args.recRootPath, "test.txt")) + + detTrainTxt = open(os.path.join(args.detRootPath, "train.txt"), "a", encoding="UTF-8") + detValTxt = open(os.path.join(args.detRootPath, "val.txt"), "a", encoding="UTF-8") + detTestTxt = open(os.path.join(args.detRootPath, "test.txt"), "a", encoding="UTF-8") + recTrainTxt = open(os.path.join(args.recRootPath, "train.txt"), "a", encoding="UTF-8") + recValTxt = open(os.path.join(args.recRootPath, "val.txt"), "a", encoding="UTF-8") + recTestTxt = open(os.path.join(args.recRootPath, "test.txt"), "a", encoding="UTF-8") + + splitTrainVal(args.datasetRootPath, detAbsTrainRootPath, detAbsValRootPath, detAbsTestRootPath, detTrainTxt, detValTxt, + detTestTxt, "det") + + for root, dirs, files in os.walk(args.datasetRootPath): + for dir in dirs: + if dir == 'crop_img': + splitTrainVal(root, recAbsTrainRootPath, recAbsValRootPath, recAbsTestRootPath, recTrainTxt, recValTxt, + recTestTxt, "rec") + else: + continue + break + + + +if __name__ == "__main__": + # 功能描述:分别划分检测和识别的训练集、验证集、测试集 + # 说明:可以根据自己的路径和需求调整参数,图像数据往往多人合作分批标注,每一批图像数据放在一个文件夹内用PPOCRLabel进行标注, + # 如此会有多个标注好的图像文件夹汇总并划分训练集、验证集、测试集的需求 + parser = argparse.ArgumentParser() + parser.add_argument( + "--trainValTestRatio", + type=str, + default="6:2:2", + help="ratio of trainset:valset:testset") + parser.add_argument( + "--datasetRootPath", + type=str, + default="../train_data/drivingData", + help="path to the dataset marked by ppocrlabel, E.g, dataset folder named 1,2,3..." + ) + parser.add_argument( + "--detRootPath", + type=str, + default="../train_data/det", + help="the path where the divided detection dataset is placed") + parser.add_argument( + "--recRootPath", + type=str, + default="../train_data/rec", + help="the path where the divided recognition dataset is placed" + ) + parser.add_argument( + "--detLabelFileName", + type=str, + default="Label.txt", + help="the name of the detection annotation file") + parser.add_argument( + "--recLabelFileName", + type=str, + default="rec_gt.txt", + help="the name of the recognition annotation file" + ) + parser.add_argument( + "--recImageDirName", + type=str, + default="crop_img", + help="the name of the folder where the cropped recognition dataset is located" + ) + args = parser.parse_args() + genDetRecTrainVal(args) -- 2.22.0