K邻算法实现手写字体的识别,python
# -*- coding: utf-8 -*- "K邻算法实现" from sklearn.datasets import load_digits from sklearn.neighbors import KNeighborsClassifier from sklearn.model_selection import GridSearchCV, train_test_split import matplotlib.pyplot as plt from tensorflow import keras fashion_mnist = keras.datasets.fashion_mnist (train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data() from sklearn.preprocessing import StandardScaler def get_data(): data = load_digits() x, y = data.data, data.target print(x.shape) train_x, test_x, train_y, test_y = train_test_split(x, y, test_size= 0.3, random_state= 20) model = StandardScaler() train_x = model.fit_transform(train_x) test_x = model.transform(test_x) return train_x, test_x, train_y, test_y def model_fit(x, y): paras = { n_neighbors: [5, 6, 7, 8, 9, 10], p: [1, 2]} model = KNeighborsClassifier() gs = GridSearchCV(model, paras, verbose=2, cv= 5) gs.fit(x, y) print(最佳模型:, gs.best_params_, 准确率:,gs.best_score_) def train(train_x, test_x, train_y, test_y): model = KNeighborsClassifier(5, p= 1) model.fit(train_x, train_y) pre_y = model.predict(test_x) show_img(test_x, pre_y, test_y) print(model.score(test_x, test_y)) def show_img(test_x, pre_x, test_y): num_row = 5 num_col = 3 plt.figure(figsize = (num_row, num_col* 2)) plt.grid(False) for i in range(num_row * num_col): plt.subplot(num_row, num_col, i + 1) show_num_img(test_x[i], pre_x[i], test_y[i]) plt.tight_layout() plt.show() def show_num_img(img, pre, y): plt.xticks([]) plt.yticks([]) plt.imshow(img.reshape(8, 8), cmap=plt.cm.binary) color = green if pre != y: color = red plt.xlabel("{0}({1})".format(pre, y, color=color)) if __name__ == "__main__": train_x, test_x, train_y, test_y = get_data() #model_fit(x, y) train(train_x, test_x, train_y, test_y)
上一篇:
通过多线程提高代码的执行效率例子
下一篇:
脉冲触发、JK触发介绍——数电5.3节