🍨 本文为:[🔗365天深度学习训练营] 中的学习记录博客
🍖 原作者:[K同学啊 | 接辅导、项目定制]
要求:
- 本地读取并加载数据;
- 了解循环神经网络RNN的构建过程;
- 测试集accuracy达到87%;
一、 基础配置
- 语言环境:Python3.7
- 编译器选择:Pycharm
- 深度学习环境:TensorFlow2.4.1
- 数据集:私有数据集
二、 前期准备
1.设置GPU
import tensorflow as tf
gpus = tf.config.list_physical_devices("GPU")
if gpus:
tf.config.experimental.set_memory_growth(gpus[0], True) #设置GPU显存用量按需使用
tf.config.set_visible_devices([gpus[0]],"GPU")
# 打印显卡信息,确认GPU可用
print(gpus)
根据个人设备情况,选择使用GPU/CPU进行训练,若GPU可用则输出:
[PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
由于在设备上安装的CUDA版本与TensorFlow版本不一致,故这里直接安装了CPU版的TensorFlow,无上述输出。
2. 导入数据
本项目所采用的数据集未收录于公开数据中,故需要自己在文件目录中导入相应数据集合,并设置对应文件目录,以供后续学习过程中使用。
运行下述代码:
import pandas as pd
df = pd.read_csv("./data/heart.csv")
print(df)
得到如下输出:
age sex cp trestbps chol fbs ... exang oldpeak slope ca thal target
0 63 1 3 145 233 1 ... 0 2.3 0 0 1 1
1 37 1 2 130 250 0 ... 0 3.5 0 0 2 1
2 41 0 1 130 204 0 ... 0 1.4 2 0 2 1
3 56 1 1 120 236 0 ... 0 0.8 2 0 2 1
4 57 0 0 120 354 0 ... 1 0.6 2 0 2 1
.. ... ... .. ... ... ... ... ... ... ... .. ... ...
298 57 0 0 140 241 0 ... 1 0.2 1 0 3 0
299 45 1 3 110 264 0 ... 0 1.2 1 0 3 0
300 68 1 0 144 193 1 ... 0 3.4 1 2 3 0
301 57 1 0 130 131 0 ... 1 1.2 1 1 3 0
302 57 0 1 130 236 0 ... 0 0.0 1 1 2 0
[303 rows x 14 columns]
3.检查数据
由于本次所采用的数据集为表格类型,因此,我们需要观察数据中是否有空值、异常值等错误数据的存在,这里我们先观察是否有空值:
dfnull = df.isnull().sum()
print(dfnull)
可以得到如下输出:
age 0
sex 0
cp 0
trestbps 0
chol 0
fbs 0
restecg 0
thalach 0
exang 0
oldpeak 0
slope 0
ca 0
thal 0
target 0
dtype: int64
可见,数据集中无空值存在。
三、数据预处理
1.划分数据集
测试集与验证集的关系:
- 验证集并没有参与训练过程中的梯度下降过程,即没有参与模型的参数更新过程;
- 广义上讲,验证集参与了“人工调参”的过程,我们根据每个epoch的训练结果,选择调整对应的超参数
- 可以认为,验证集参与了训练,但并没有使模型overfit验证集;
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
X = df.iloc[:,:-1]
y = df.iloc[:,-1]
X_train,X_test,y_train,y_test = train_test_split(X,y,test_size= 0.1 , random_state= 1)
print(X_train.shape,y_train.shape)
得到如下输出:
(272, 13) (272,)
其中:
- X = df.iloc[:,:-1]: 从DataFrame df 中选择所有行和除最后一列外的所有列作为特征矩阵 X;
- y = df.iloc[:,-1]: 从DataFrame df 中选择所有行和最后一列作为目标变量 y;
- X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=1): 将特征矩阵 X 和目标变量 y 划分为训练集和测试集。其中,test_size=0.1 表示测试集占总数据集的比例为10%,random_state=1 是随机种子,保证每次运行代码时划分的结果相同
2.标准化
sc = StandardScaler()
X_train = sc.fit_transform(X_train)
X_test = sc.transform(X_test)
X_train = X_train.reshape(X_train.shape[0],X_train.shape[1],1)
X_test = X_test .reshape(X_test .shape[0],X_test .shape[1],1)
其中:
- sc = StandardScaler(): 实例化了StandardScaler类,用于对特征进行标准化处理;
- X_train = sc.fit_transform(X_train): 使用fit_transform方法对训练集的特征矩阵进行标准化处理,即将每个特征的数值缩放到均值为0、方差为1的范围内;
- X_test = sc.transform(X_test): 使用transform方法对测试集的特征矩阵进行相同的标准化处理,使用的均值和方差参数来自于训练集;
- X_train = X_train.reshape(X_train.shape[0], X_train.shape[1], 1): 调整训练集特征矩阵的形状,将其变为三维数组,通常用于适配某些深度学习模型(比如卷积神经网络)的输入要求;
- X_test = X_test.reshape(X_test.shape[0], X_test.shape[1], 1): 同样地,调整测试集特征矩阵的形状,使其符合相同的三维格式要求;
四、构建网络
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense,LSTM,SimpleRNN
model = Sequential()
model.add(SimpleRNN(200,input_shape= (13,1),activation='relu'))
model.add(Dense(100,activation='relu'))
model.add(Dense(1,activation='sigmoid'))
model.summary()
可以得到如下输出:
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
simple_rnn (SimpleRNN) (None, 200) 40400
_________________________________________________________________
dense (Dense) (None, 100) 20100
_________________________________________________________________
dense_1 (Dense) (None, 1) 101
=================================================================
Total params: 60,601
Trainable params: 60,601
Non-trainable params: 0
_________________________________________________________________
我们调用了tf.keras.layers.SimpleRNN这个原型函数:
tf.keras.layers.SimpleRNN(units, activation='tanh', use_bias=True, kernel_initializer='glorot_uniform', recurrent_initializer='orthogonal', bias_initializer='zeros', return_sequences=False)
其中:
- units: 整数,表示输出空间的维度(神经元的数量);
- activation: 激活函数的名称,可选,默认为 'tanh'。可以是内置的激活函数,也可以是自定义的激活函数;
-
use_bias: 布尔值,表示是否使用偏置项,默认为 True;
-
kernel_initializer: 权重矩阵的初始化方法,默认为 'glorot_uniform',也称为 Xavier 初始化;
-
recurrent_initializer: 循环权重矩阵的初始化方法,默认为 'orthogonal',用于循环连接的权重;
-
bias_initializer: 偏置项的初始化方法,默认为 'zeros';
-
return_sequences: 布尔值,表示在输出序列中是否返回完整的序列,默认为 False。如果为 True,则返回整个输出序列;如果为 False,则只返回最后一个时间步的输出;
五、 编译模型
通过下列示例代码:
opt = tf.keras.optimizers.Adam(learning_rate=1e-5)
model.compile(loss='binary_crossentropy',
optimizer=opt,
metrics="accuracy")
六、训练模型
通过下列示例代码:
epochs = 100
history = model.fit(X_train,y_train,
epochs = epochs,
batch_size = 128,
validation_data=(X_test,y_test),
verbose = 1)
运行得到如下输出:
Epoch 1/100
3/3 [==============================] - 1s 119ms/step - loss: 0.6851 - accuracy: 0.5728 - val_loss: 0.6811 - val_accuracy: 0.6452
Epoch 2/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6831 - accuracy: 0.5796 - val_loss: 0.6795 - val_accuracy: 0.6452
Epoch 3/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6829 - accuracy: 0.5784 - val_loss: 0.6779 - val_accuracy: 0.6452
Epoch 4/100
3/3 [==============================] - 0s 8ms/step - loss: 0.6815 - accuracy: 0.5812 - val_loss: 0.6762 - val_accuracy: 0.6452
Epoch 5/100
3/3 [==============================] - 0s 11ms/step - loss: 0.6796 - accuracy: 0.5947 - val_loss: 0.6745 - val_accuracy: 0.6452
Epoch 6/100
3/3 [==============================] - 0s 8ms/step - loss: 0.6797 - accuracy: 0.6031 - val_loss: 0.6729 - val_accuracy: 0.6774
Epoch 7/100
3/3 [==============================] - 0s 7ms/step - loss: 0.6774 - accuracy: 0.6253 - val_loss: 0.6713 - val_accuracy: 0.7097
Epoch 8/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6760 - accuracy: 0.6394 - val_loss: 0.6697 - val_accuracy: 0.7097
Epoch 9/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6760 - accuracy: 0.6239 - val_loss: 0.6682 - val_accuracy: 0.7097
Epoch 10/100
3/3 [==============================] - 0s 7ms/step - loss: 0.6756 - accuracy: 0.6380 - val_loss: 0.6668 - val_accuracy: 0.7097
Epoch 11/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6743 - accuracy: 0.6564 - val_loss: 0.6655 - val_accuracy: 0.7097
Epoch 12/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6738 - accuracy: 0.6524 - val_loss: 0.6640 - val_accuracy: 0.7419
Epoch 13/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6714 - accuracy: 0.6631 - val_loss: 0.6625 - val_accuracy: 0.7419
Epoch 14/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6715 - accuracy: 0.6524 - val_loss: 0.6610 - val_accuracy: 0.7742
Epoch 15/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6694 - accuracy: 0.6850 - val_loss: 0.6595 - val_accuracy: 0.7419
Epoch 16/100
3/3 [==============================] - 0s 8ms/step - loss: 0.6683 - accuracy: 0.6811 - val_loss: 0.6580 - val_accuracy: 0.7419
Epoch 17/100
3/3 [==============================] - 0s 8ms/step - loss: 0.6688 - accuracy: 0.6598 - val_loss: 0.6565 - val_accuracy: 0.7419
Epoch 18/100
3/3 [==============================] - 0s 8ms/step - loss: 0.6677 - accuracy: 0.6741 - val_loss: 0.6550 - val_accuracy: 0.7419
Epoch 19/100
3/3 [==============================] - 0s 7ms/step - loss: 0.6665 - accuracy: 0.7028 - val_loss: 0.6535 - val_accuracy: 0.7419
Epoch 20/100
3/3 [==============================] - 0s 11ms/step - loss: 0.6665 - accuracy: 0.6949 - val_loss: 0.6520 - val_accuracy: 0.7419
Epoch 21/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6645 - accuracy: 0.7027 - val_loss: 0.6505 - val_accuracy: 0.7742
Epoch 22/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6632 - accuracy: 0.7143 - val_loss: 0.6489 - val_accuracy: 0.7742
Epoch 23/100
3/3 [==============================] - 0s 11ms/step - loss: 0.6619 - accuracy: 0.7084 - val_loss: 0.6474 - val_accuracy: 0.8065
Epoch 24/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6624 - accuracy: 0.6997 - val_loss: 0.6459 - val_accuracy: 0.8065
Epoch 25/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6618 - accuracy: 0.7083 - val_loss: 0.6445 - val_accuracy: 0.8387
Epoch 26/100
3/3 [==============================] - 0s 8ms/step - loss: 0.6601 - accuracy: 0.7169 - val_loss: 0.6431 - val_accuracy: 0.8387
Epoch 27/100
3/3 [==============================] - 0s 8ms/step - loss: 0.6601 - accuracy: 0.7236 - val_loss: 0.6417 - val_accuracy: 0.8387
Epoch 28/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6577 - accuracy: 0.7283 - val_loss: 0.6402 - val_accuracy: 0.8387
Epoch 29/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6567 - accuracy: 0.7282 - val_loss: 0.6388 - val_accuracy: 0.8387
Epoch 30/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6585 - accuracy: 0.7394 - val_loss: 0.6374 - val_accuracy: 0.8387
Epoch 31/100
3/3 [==============================] - 0s 11ms/step - loss: 0.6551 - accuracy: 0.7501 - val_loss: 0.6359 - val_accuracy: 0.8387
Epoch 32/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6549 - accuracy: 0.7462 - val_loss: 0.6345 - val_accuracy: 0.8710
Epoch 33/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6523 - accuracy: 0.7597 - val_loss: 0.6331 - val_accuracy: 0.8710
Epoch 34/100
3/3 [==============================] - 0s 8ms/step - loss: 0.6523 - accuracy: 0.7671 - val_loss: 0.6317 - val_accuracy: 0.9032
Epoch 35/100
3/3 [==============================] - 0s 8ms/step - loss: 0.6512 - accuracy: 0.7749 - val_loss: 0.6303 - val_accuracy: 0.9355
Epoch 36/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6507 - accuracy: 0.7748 - val_loss: 0.6289 - val_accuracy: 0.9355
Epoch 37/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6512 - accuracy: 0.7641 - val_loss: 0.6275 - val_accuracy: 0.9032
Epoch 38/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6495 - accuracy: 0.7641 - val_loss: 0.6261 - val_accuracy: 0.9032
Epoch 39/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6499 - accuracy: 0.7659 - val_loss: 0.6247 - val_accuracy: 0.9032
Epoch 40/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6479 - accuracy: 0.7785 - val_loss: 0.6233 - val_accuracy: 0.9032
Epoch 41/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6483 - accuracy: 0.7640 - val_loss: 0.6218 - val_accuracy: 0.9032
Epoch 42/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6445 - accuracy: 0.7874 - val_loss: 0.6205 - val_accuracy: 0.9032
Epoch 43/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6447 - accuracy: 0.7747 - val_loss: 0.6192 - val_accuracy: 0.9032
Epoch 44/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6459 - accuracy: 0.7718 - val_loss: 0.6179 - val_accuracy: 0.9032
Epoch 45/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6435 - accuracy: 0.7765 - val_loss: 0.6165 - val_accuracy: 0.9032
Epoch 46/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6445 - accuracy: 0.7745 - val_loss: 0.6152 - val_accuracy: 0.9032
Epoch 47/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6443 - accuracy: 0.7695 - val_loss: 0.6138 - val_accuracy: 0.9032
Epoch 48/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6418 - accuracy: 0.7802 - val_loss: 0.6124 - val_accuracy: 0.9032
Epoch 49/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6410 - accuracy: 0.7783 - val_loss: 0.6110 - val_accuracy: 0.9032
Epoch 50/100
3/3 [==============================] - 0s 11ms/step - loss: 0.6410 - accuracy: 0.7763 - val_loss: 0.6096 - val_accuracy: 0.9032
Epoch 51/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6392 - accuracy: 0.7831 - val_loss: 0.6082 - val_accuracy: 0.9032
Epoch 52/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6383 - accuracy: 0.7812 - val_loss: 0.6068 - val_accuracy: 0.9032
Epoch 53/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6387 - accuracy: 0.7657 - val_loss: 0.6053 - val_accuracy: 0.9032
Epoch 54/100
3/3 [==============================] - 0s 11ms/step - loss: 0.6326 - accuracy: 0.8007 - val_loss: 0.6037 - val_accuracy: 0.9032
Epoch 55/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6375 - accuracy: 0.7694 - val_loss: 0.6023 - val_accuracy: 0.9032
Epoch 56/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6389 - accuracy: 0.7800 - val_loss: 0.6008 - val_accuracy: 0.9032
Epoch 57/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6328 - accuracy: 0.7994 - val_loss: 0.5993 - val_accuracy: 0.9032
Epoch 58/100
3/3 [==============================] - 0s 11ms/step - loss: 0.6310 - accuracy: 0.7994 - val_loss: 0.5978 - val_accuracy: 0.9032
Epoch 59/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6297 - accuracy: 0.8034 - val_loss: 0.5962 - val_accuracy: 0.9032
Epoch 60/100
3/3 [==============================] - 0s 11ms/step - loss: 0.6304 - accuracy: 0.7945 - val_loss: 0.5946 - val_accuracy: 0.9032
Epoch 61/100
3/3 [==============================] - 0s 11ms/step - loss: 0.6256 - accuracy: 0.8023 - val_loss: 0.5930 - val_accuracy: 0.9032
Epoch 62/100
3/3 [==============================] - 0s 13ms/step - loss: 0.6312 - accuracy: 0.7883 - val_loss: 0.5915 - val_accuracy: 0.9032
Epoch 63/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6247 - accuracy: 0.8147 - val_loss: 0.5899 - val_accuracy: 0.9032
Epoch 64/100
3/3 [==============================] - 0s 11ms/step - loss: 0.6240 - accuracy: 0.8059 - val_loss: 0.5884 - val_accuracy: 0.9032
Epoch 65/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6280 - accuracy: 0.7981 - val_loss: 0.5869 - val_accuracy: 0.9032
Epoch 66/100
3/3 [==============================] - 0s 11ms/step - loss: 0.6249 - accuracy: 0.8069 - val_loss: 0.5854 - val_accuracy: 0.9032
Epoch 67/100
3/3 [==============================] - 0s 11ms/step - loss: 0.6244 - accuracy: 0.7932 - val_loss: 0.5838 - val_accuracy: 0.9032
Epoch 68/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6209 - accuracy: 0.8137 - val_loss: 0.5822 - val_accuracy: 0.9032
Epoch 69/100
3/3 [==============================] - 0s 11ms/step - loss: 0.6193 - accuracy: 0.8059 - val_loss: 0.5806 - val_accuracy: 0.9032
Epoch 70/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6199 - accuracy: 0.8020 - val_loss: 0.5790 - val_accuracy: 0.9032
Epoch 71/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6185 - accuracy: 0.8097 - val_loss: 0.5774 - val_accuracy: 0.9032
Epoch 72/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6165 - accuracy: 0.8126 - val_loss: 0.5759 - val_accuracy: 0.9032
Epoch 73/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6126 - accuracy: 0.8253 - val_loss: 0.5744 - val_accuracy: 0.9032
Epoch 74/100
3/3 [==============================] - 0s 11ms/step - loss: 0.6141 - accuracy: 0.8214 - val_loss: 0.5729 - val_accuracy: 0.9032
Epoch 75/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6150 - accuracy: 0.8077 - val_loss: 0.5713 - val_accuracy: 0.9032
Epoch 76/100
3/3 [==============================] - 0s 11ms/step - loss: 0.6177 - accuracy: 0.7980 - val_loss: 0.5698 - val_accuracy: 0.9032
Epoch 77/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6156 - accuracy: 0.8048 - val_loss: 0.5682 - val_accuracy: 0.9032
Epoch 78/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6132 - accuracy: 0.8068 - val_loss: 0.5666 - val_accuracy: 0.9032
Epoch 79/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6120 - accuracy: 0.8146 - val_loss: 0.5650 - val_accuracy: 0.9032
Epoch 80/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6077 - accuracy: 0.8136 - val_loss: 0.5635 - val_accuracy: 0.9032
Epoch 81/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6121 - accuracy: 0.8009 - val_loss: 0.5620 - val_accuracy: 0.9032
Epoch 82/100
3/3 [==============================] - 0s 11ms/step - loss: 0.6085 - accuracy: 0.8087 - val_loss: 0.5604 - val_accuracy: 0.9032
Epoch 83/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6060 - accuracy: 0.8136 - val_loss: 0.5587 - val_accuracy: 0.9032
Epoch 84/100
3/3 [==============================] - 0s 12ms/step - loss: 0.6073 - accuracy: 0.8010 - val_loss: 0.5571 - val_accuracy: 0.9032
Epoch 85/100
3/3 [==============================] - 0s 9ms/step - loss: 0.6042 - accuracy: 0.8088 - val_loss: 0.5553 - val_accuracy: 0.9032
Epoch 86/100
3/3 [==============================] - 0s 13ms/step - loss: 0.6068 - accuracy: 0.7950 - val_loss: 0.5535 - val_accuracy: 0.9032
Epoch 87/100
3/3 [==============================] - 0s 10ms/step - loss: 0.6001 - accuracy: 0.8136 - val_loss: 0.5516 - val_accuracy: 0.9032
Epoch 88/100
3/3 [==============================] - 0s 10ms/step - loss: 0.5991 - accuracy: 0.8048 - val_loss: 0.5498 - val_accuracy: 0.9032
Epoch 89/100
3/3 [==============================] - 0s 11ms/step - loss: 0.5977 - accuracy: 0.8107 - val_loss: 0.5481 - val_accuracy: 0.9032
Epoch 90/100
3/3 [==============================] - 0s 10ms/step - loss: 0.5990 - accuracy: 0.7970 - val_loss: 0.5465 - val_accuracy: 0.9032
Epoch 91/100
3/3 [==============================] - 0s 9ms/step - loss: 0.5981 - accuracy: 0.8165 - val_loss: 0.5448 - val_accuracy: 0.9032
Epoch 92/100
3/3 [==============================] - 0s 10ms/step - loss: 0.5960 - accuracy: 0.8213 - val_loss: 0.5431 - val_accuracy: 0.9032
Epoch 93/100
3/3 [==============================] - 0s 10ms/step - loss: 0.5971 - accuracy: 0.8126 - val_loss: 0.5413 - val_accuracy: 0.9032
Epoch 94/100
3/3 [==============================] - 0s 11ms/step - loss: 0.5944 - accuracy: 0.8115 - val_loss: 0.5395 - val_accuracy: 0.9032
Epoch 95/100
3/3 [==============================] - 0s 10ms/step - loss: 0.5964 - accuracy: 0.8066 - val_loss: 0.5376 - val_accuracy: 0.9032
Epoch 96/100
3/3 [==============================] - 0s 9ms/step - loss: 0.5940 - accuracy: 0.8165 - val_loss: 0.5357 - val_accuracy: 0.9032
Epoch 97/100
3/3 [==============================] - 0s 9ms/step - loss: 0.5911 - accuracy: 0.8087 - val_loss: 0.5338 - val_accuracy: 0.9032
Epoch 98/100
3/3 [==============================] - 0s 8ms/step - loss: 0.5889 - accuracy: 0.8125 - val_loss: 0.5319 - val_accuracy: 0.9032
Epoch 99/100
3/3 [==============================] - 0s 9ms/step - loss: 0.5901 - accuracy: 0.8096 - val_loss: 0.5299 - val_accuracy: 0.9032
Epoch 100/100
3/3 [==============================] - 0s 8ms/step - loss: 0.5876 - accuracy: 0.8105 - val_loss: 0.5279 - val_accuracy: 0.9032
Process finished with exit code 0
模型训练结果为:val_accuracy = 90.32%
六、 模型评估
1.Loss与Accuracy图
import matplotlib.pyplot as plt
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']
loss = history.history['loss']
val_loss = history.history['val_loss']
epochs_range = range(epochs)
plt.figure(figsize=(14, 4))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy-Adam')
plt.plot(epochs_range, val_acc, label='Validation Accuracy-Adam')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss-Adam')
plt.plot(epochs_range, val_loss, label='Validation Loss-Adam')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
得到的可视化结果:
七、个人理解
本项目为通过RNN来实现心脏病的预测,需要根据给定的CSV文件来实现该目标。由于CSV为表格类文件,故可能存在数据缺失的情况,这与之前的图片类数据有明显的差异,因此在进入网络模型训练前需要针对表格中的数据做一定的处理,由于初次接触表格数据,故只完成了数据是否为0的检查,之后的学习过程中将完善异常值处理等其他数据操作。