构建机器学习模型前,通常要检查数据,判断不用机器学习能不能轻松完成任务,或者需要的信息有没有包含在数据中。检查数据也是发现异常值和特殊值的好办法。
检查数据的最佳方法之一就是可视化,一种是绘制散点图,将一个特征作为x轴,另一个作为y轴,将每个数据点绘制为图上的点。为了解决3个或更多特征的数据集作图的问题,可以绘制散点图矩阵。
以鸢尾花数据集为例,首先将Numpy数组转换为pandas DataFrame。pandas有一个绘制散点图矩阵的函数,叫做scatter_matrix。
import mglearn
import matplotlib.pyplot as plt
iris_dataset=load_iris()
X_train,X_test,y_train,y_test=train_test_split(
iris_dataset['data'],iris_dataset['target'],random_state=0
)
iris_dataframe=pd.DataFrame(X_train,columns=iris_dataset.feature_names)
grr=pd.plotting.scatter_matrix(iris_dataframe,c=y_train,figsize=(15,15),marker='0',hist_kwds={'bins':20},s=60,alpha=.8,cmap=mglearn.cm3)
plt.show()