您的位置:首页 > 脚本大全 > > 正文

tensorflow自定义初始化(Tensorflow分类器项目自定义数据读入的实现)

更多 时间:2022-03-31 12:31:59 类别:脚本大全 浏览量:604

tensorflow自定义初始化

Tensorflow分类器项目自定义数据读入的实现

在照着tensorflow官网的demo敲了一遍分类器项目的代码后,运行倒是成功了,结果也不错。但是最终还是要训练自己的数据,所以尝试准备加载自定义的数据,然而demo中只是出现了fashion_mnist.load_data()并没有详细的读取过程,随后我又找了些资料,把读取的过程记录在这里。

首先提一下需要用到的模块:

  • ?
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • import os
  • import keras
  • import matplotlib.pyplot as plt
  • from pil import image
  • from keras.preprocessing.image import imagedatagenerator
  • from sklearn.model_selection import train_test_split
  • 图片分类器项目,首先确定你要处理的图片分辨率将是多少,这里的例子为30像素:

    img_size_x = 30
    img_size_y = 30

    其次确定你图片的方式目录:

  • ?
  • 1
  • 2
  • 3
  • 4
  • image_path = r'd:\projects\imageclassifier\data\set'
  • path = ".\data"
  • # 你也可以使用相对路径的方式
  • # image_path =os.path.join(path, "set")
  • 目录下的结构如下:

    tensorflow自定义初始化(Tensorflow分类器项目自定义数据读入的实现)

    相应的label.txt如下:

    动漫
    风景
    美女
    物语
    樱花

    接下来是接在labels.txt,如下:

  • ?
  • 1
  • 2
  • 3
  • label_name = "labels.txt"
  • label_path = os.path.join(path, label_name)
  • class_names = np.loadtxt(label_path, type(""))
  • 这里简便起见,直接利用了numpy的loadtxt函数直接加载。

    之后便是正式处理图片数据了,注释就写在里面了:

  • ?
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • 24
  • 25
  • 26
  • 27
  • 28
  • 29
  • 30
  • 31
  • 32
  • 33
  • 34
  • 35
  • 36
  • 37
  • 38
  • 39
  • 40
  • 41
  • 42
  • 43
  • 44
  • 45
  • 46
  • 47
  • 48
  • 49
  • 50
  • 51
  • 52
  • 53
  • 54
  • 55
  • 56
  • 57
  • 58
  • 59
  • 60
  • 61
  • 62
  • 63
  • re_load = false
  • re_build = false
  • # re_load = true
  • re_build = true
  •  
  • data_name = "data.npz"
  • data_path = os.path.join(path, data_name)
  • model_name = "model.h5"
  • model_path = os.path.join(path, model_name)
  •  
  • count = 0
  •  
  • # 这里判断是否存在序列化之后的数据,re_load是一个开关,是否强制重新处理,测试用,可以去除。
  • if not os.path.exists(data_path) or re_load:
  •   labels = []
  •   images = []
  •   print('handle images')
  •   # 由于label.txt是和图片防止目录的分类目录一一对应的,即每个子目录的目录名就是labels.txt里的一个label,所以这里可以通过读取class_names的每一项去拼接path后读取
  •   for index, name in enumerate(class_names):
  •     # 这里是拼接后的子目录path
  •     classpath = os.path.join(image_path, name)
  •     # 先判断一下是否是目录
  •     if not os.path.isdir(classpath):
  •       continue
  •     # limit是测试时候用的这里可以去除
  •     limit = 0
  •     for image_name in os.listdir(classpath):
  •       if limit >= max_size:
  •         break
  •       # 这里是拼接后的待处理的图片path
  •       imagepath = os.path.join(classpath, image_name)
  •       count = count + 1
  •       limit = limit + 1
  •       # 利用image打开图片
  •       img = image.open(imagepath)
  •       # 缩放到你最初确定要处理的图片分辨率大小
  •       img = img.resize((img_size_x, img_size_y))
  •       # 转为灰度图片,这里彩色通道会干扰结果,并且会加大计算量
  •       img = img.convert("l")
  •       # 转为numpy数组
  •       img = np.array(img)
  •       # 由(30,30)转为(1,30,30)(即`channels_first`),当然你也可以转换为(30,30,1)(即`channels_last`)但为了之后预览处理后的图片方便这里采用了(1,30,30)的格式存放
  •       img = np.reshape(img, (1, img_size_x, img_size_y))
  •       # 这里利用循环生成labels数据,其中存放的实际是class_names中对应元素的索引
  •       labels.append([index])
  •       # 添加到images中,最后统一处理
  •       images.append(img)
  •       # 循环中一些状态的输出,可以去除
  •       print("{} class: {} {} limit: {} {}"
  •          .format(count, index + 1, class_names[index], limit, imagepath))
  •   # 最后一次性将images和labels都转换成numpy数组
  •   npy_data = np.array(images)
  •   npy_labels = np.array(labels)
  •   # 处理数据只需要一次,所以我们选择在这里利用numpy自带的方法将处理之后的数据序列化存储
  •   np.savez(data_path, x=npy_data, y=npy_labels)
  •   print("save images by npz")
  • else:
  •   # 如果存在序列化号的数据,便直接读取,提高速度
  •   npy_data = np.load(data_path)["x"]
  •   npy_labels = np.load(data_path)["y"]
  •   print("load images by npz")
  • image_data = npy_data
  • labels_data = npy_labels
  • 到了这里原始数据的加工预处理便已经完成,只需要最后一步,就和demo中fashion_mnist.load_data()返回的结果一样了。代码如下:

  • ?
  • 1
  • 2
  • 3
  • # 最后一步就是将原始数据分成训练数据和测试数据
  • train_images, test_images, train_labels, test_labels = \
  •   train_test_split(image_data, labels_data, test_size=0.2, random_state=6)
  • 这里将相关信息打印的方法也附上:

  • ?
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • print("_________________________________________________________________")
  • print("%-28s %-s" % ("name", "shape"))
  • print("=================================================================")
  • print("%-28s %-s" % ("image data", image_data.shape))
  • print("%-28s %-s" % ("labels data", labels_data.shape))
  • print("=================================================================")
  •  
  • print('split train and test data,p=%')
  • print("_________________________________________________________________")
  • print("%-28s %-s" % ("name", "shape"))
  • print("=================================================================")
  • print("%-28s %-s" % ("train images", train_images.shape))
  • print("%-28s %-s" % ("test images", test_images.shape))
  • print("%-28s %-s" % ("train labels", train_labels.shape))
  • print("%-28s %-s" % ("test labels", test_labels.shape))
  • print("=================================================================")
  • 之后别忘了归一化哟:

  • ?
  • 1
  • 2
  • 3
  • print("normalize images")
  • train_images = train_images / 255.0
  • test_images = test_images / 255.0
  • 最后附上读取自定义数据的完整代码:

  • ?
  • 1
  • 2
  • 3
  • 4
  • 5
  • 6
  • 7
  • 8
  • 9
  • 10
  • 11
  • 12
  • 13
  • 14
  • 15
  • 16
  • 17
  • 18
  • 19
  • 20
  • 21
  • 22
  • 23
  • import os
  •  
  • import keras
  • import matplotlib.pyplot as plt
  • from pil import image
  • from keras.layers import *
  • from keras.models import *
  • from keras.optimizers import adam
  • from keras.preprocessing.image import imagedatagenerator
  • from sklearn.model_selection import train_test_split
  •  
  • os.environ['tf_cpp_min_log_level'] = '2'
  • # 支持中文
  • plt.rcparams['font.sans-serif'] = ['simhei'] # 用来正常显示中文标签
  • plt.rcparams['axes.unicode_minus'] = false # 用来正常显示负号
  • re_load = false
  • re_build = false
  • # re_load = true
  • re_build = true
  • epochs = 50
  • batch_size = 5
  • count = 0
  • max_size = 2000000000
  • 以上就是本文的全部内容,希望对大家的学习有所帮助,也希望大家多多支持开心学习网。

    原文链接:https://segmentfault.com/a/1190000018099185