|
|
|
|
公众号矩阵

教你使用TensorFlow2判断细胞图像是否感染

在本教程中,我们将使用 TensorFlow (Keras API) 实现一个用于二进制分类任务的深度学习模型,该任务包括将细胞的图像标记为感染或未感染疟疾。

作者:小sen来源:Python之王|2021-06-11 05:37

在本教程中,我们将使用 TensorFlow (Keras API) 实现一个用于二进制分类任务的深度学习模型,该任务包括将细胞的图像标记为感染或未感染疟疾。

数据集来源:https://www.kaggle.com/iarunava/cell-images-for-detecting-malaria

数据集包含2个文件夹

  • 感染::13780张图片
  • 未感染:13780张图片

总共27558张图片。

此数据集取自NIH官方网站:https://ceb.nlm.nih.gov/repositories/malaria-datasets/

环境:kaggle,天池实验室或者gogole colab都可以。

导入相关模块

  1. import cv2 
  2. import tensorflow as tf 
  3. from tensorflow.keras.models import Sequential  
  4. from tensorflow.keras.layers import Dense, Conv2D, MaxPool2D, Flatten, Activation 
  5. from sklearn.model_selection import train_test_split 
  6. import numpy as np 
  7. import matplotlib.pyplot as plt 
  8. import glob 
  9. import os 

对于图片数据存在形状不一样的情况,因此需要使用 OpenCV 进行图像预处理。

将图片变成 numpy 数组(数字格式)的形式转换为灰度,并将其调整为一个(70x70)形状。

  1. img_dir="../input/cell-images-for-detecting-malaria/cell_images"   
  2. img_size=70 
  3. def load_img_data(path): 
  4.     # 打乱数据 
  5.     image_files = glob.glob(os.path.join(path, "Parasitized/*.png")) + \ 
  6.                   glob.glob(os.path.join(path, "Uninfected/*.png")) 
  7.     X, y = [], [] 
  8.     for image_file in image_files: 
  9.         # 命名标签  0 for uninfected and 1 for infected 
  10.         label = 0 if "Uninfected" in image_file else 1 
  11.         # load the image in gray scale 变成灰度图片 
  12.         img_arr = cv2.imread(image_file, cv2.IMREAD_GRAYSCALE) 
  13.         # resize the image to (70x70)  调整图片大小 
  14.         img_resized = cv2.resize(img_arr, (img_size, img_size)) 
  15.         X.append(img_resized) 
  16.         y.append(label) 
  17.     return X, y 
  18. X, y = load_img_data(img_dir) 

查看X的shape。

  1. print(X.shape) 

X的shape为(27558, 70, 70, 1),27558表示图片的数据,70*70表示图片的长和宽像素。

另外,为了帮助网络更快收敛,我们应该进行数据归一化。在sklearn 中有一些缩放方法,例如:

在这里我们将除以255,因为像素可以达到的最大值是255,这将导致应用缩放后像素范围在 0 和 1 之间。

  1. X, y = load_img_data(img_dir) 
  2. # reshape to (n_samples, 70, 70, 1) (to fit the NN) 
  3. X = np.array(X).reshape(-1, img_size, img_size, 1) 
  4. #从[0,255]到[0,1]缩放像素 帮助神经网络更快地训练 
  5. X = X / 255 
  6.  
  7. # shuffle & split the dataset 
  8. X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, stratify=y) 
  9. print("Total training samples:", X_train.shape) 
  10. print("Total validation samples:", X_test.shape[0]) 

使用sklearn的train_test_split()方法将数据集划分为训练集和测试集,我们使用总数据的 10% 稍后对其进行验证。

在建立的模型中,我们将添加 3 个卷积层,然后Flatten是由层组成的全连接Dense层。

  1. model = Sequential() 
  2. model.add(Conv2D(64, (3, 3), input_shape=X_train.shape[1:])) 
  3. model.add(Activation("relu")) 
  4. model.add(MaxPool2D(pool_size=(2, 2))) 
  5.  
  6. model.add(Conv2D(64, (3, 3))) 
  7. model.add(Activation("relu")) 
  8. model.add(MaxPool2D(pool_size=(2, 2))) 
  9.  
  10. model.add(Conv2D(64, (3, 3))) 
  11. model.add(Activation("relu")) 
  12. model.add(MaxPool2D(pool_size=(2, 2))) 
  13.  
  14. model.add(Flatten()) 
  15.  
  16. model.add(Dense(64)) 
  17. model.add(Activation("relu")) 
  18.  
  19. model.add(Dense(64)) 
  20. model.add(Activation("relu")) 
  21.  
  22. model.add(Dense(1)) 
  23. model.add(Activation("sigmoid")) 
  24.  
  25. model.compile(loss="binary_crossentropy", optimizer="adam", metrics=["accuracy"]) 
  26. print(model.summary()) 

 

由于输出是二进制的(感染或未感染),我们使用Sigmoid 函数作为输出层的激活函数。

  1. # train the model with 10 epochs, 64 batch size 
  2. model.fit(X_train, np.array(y_train), batch_size=64, epochs=10, validation_split=0.2) 

在训练数据集及其验证拆分上实现了94%的准确率。

现在使用evaluate() 来评估测试数据集上的模型

  1. loss, accuracy = model.evaluate(X_test, np.array(y_test), verbose=0) 
  2. print(f"Testing on {len(X_test)} images, the results are\n Accuracy: {accuracy} | Loss: {loss}"

输出如下

  1. Testing on 2756 images, the results are 
  2. Accuracy: 0.9404934644699097 | Loss: 0.1666732281446457 

该模型在测试数据中也表现OK,准确率达到94%

最后,我们将通过保存我们的模型来结束所有这个过程。

  1. model.save("model.h5"

【编辑推荐】

  1. 鸿蒙官方战略合作共建——HarmonyOS技术社区
  2. 鸿蒙,就算套壳安卓又能怎样呢?
  3. Kubernetes为什么要弃用Docker?
  4. 从“PPT系统”走向现实:HarmonyOS,你真香了吗?
  5. 这6款Python IDE&代码编辑器,你都用过吗?
  6. Kubernetes实践之优雅终止
【责任编辑:姜华 TEL:(010)68476606】

点赞 0
分享:
大家都在看
猜你喜欢

订阅专栏+更多

带你轻松入门 RabbitMQ

带你轻松入门 RabbitMQ

轻松入门RabbitMQ
共4章 | loong576

12人订阅学习

数据湖与数据仓库的分析实践攻略

数据湖与数据仓库的分析实践攻略

助力现代化数据管理:数据湖与数据仓库的分析实践攻略
共3章 | 创世达人

9人订阅学习

云原生架构实践

云原生架构实践

新技术引领移动互联网进入急速赛道
共3章 | KaliArch

40人订阅学习

订阅51CTO邮刊

点击这里查看样刊

订阅51CTO邮刊

51CTO服务号

51CTO官微