K邻算法实现手写字体的识别,python

# -*- coding: utf-8 -*-

"K邻算法实现"
from sklearn.datasets import load_digits
from sklearn.neighbors import KNeighborsClassifier
from sklearn.model_selection import GridSearchCV, train_test_split
import matplotlib.pyplot as plt
from tensorflow import keras
fashion_mnist = keras.datasets.fashion_mnist
(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()

from sklearn.preprocessing import StandardScaler
def get_data():
    data = load_digits()
    x, y = data.data, data.target
    print(x.shape)
    train_x, test_x, train_y, test_y = train_test_split(x, y, test_size= 0.3, random_state= 20)
    model = StandardScaler()
    train_x = model.fit_transform(train_x)
    test_x = model.transform(test_x)
    return train_x, test_x, train_y, test_y

def model_fit(x, y):
    paras = {
          
   n_neighbors: [5, 6, 7, 8, 9, 10], p: [1, 2]}
    model = KNeighborsClassifier()
    gs = GridSearchCV(model, paras, verbose=2, cv= 5)
    gs.fit(x, y)
    print(最佳模型:, gs.best_params_, 准确率:,gs.best_score_)
    
def train(train_x, test_x, train_y, test_y):
    model = KNeighborsClassifier(5, p= 1)
    model.fit(train_x, train_y)
    pre_y = model.predict(test_x)
    show_img(test_x, pre_y, test_y)
    print(model.score(test_x, test_y))
def show_img(test_x, pre_x, test_y):
    num_row = 5
    num_col = 3
    
    plt.figure(figsize = (num_row, num_col* 2))
    plt.grid(False)
    for i in range(num_row * num_col):
        plt.subplot(num_row, num_col, i + 1)
        show_num_img(test_x[i], pre_x[i], test_y[i])
    plt.tight_layout()    
    plt.show()
def show_num_img(img, pre, y):
    plt.xticks([])
    plt.yticks([])
    plt.imshow(img.reshape(8, 8), cmap=plt.cm.binary)
    color = green
    if pre != y:
        color = red
    plt.xlabel("{0}({1})".format(pre, y, color=color))
if __name__ == "__main__":
    train_x, test_x, train_y, test_y = get_data()
    #model_fit(x, y)

    train(train_x, test_x, train_y, test_y)
经验分享 程序员 微信小程序 职场和发展