数据挖掘 沪深股市预测

导入基本模块库

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from statsmodels.tsa.arima_model import ARMA
import warnings
from itertools import product
from datetime import datetime
warnings.filterwarnings(ignore)

加载数据

# 数据加载
df = pd.read_csv(./shanghai_1990-12-19_to_2019-2-28.csv)

将时间作为df的索引

df.Timestamp = pd.to_datetime(df.Timestamp)
df.index = df.Timestamp

效果如图

数据探索

print(df.head())

按照月,季度,年来统计

df_month = df.resample(M).mean()
df_Q = df.resample(Q-DEC).mean()
df_year = df.resample(A-DEC).mean()

按照天,月,季度,年来显示比特币的走势

fig = plt.figure(figsize=[15, 7])
plt.rcParams[font.sans-serif]=[SimHei] #用来正常显示中文标签
plt.suptitle(上证指数, fontsize=20)
plt.subplot(221)
plt.plot(df.Price, -, label=按天)
plt.legend()
plt.subplot(222)
plt.plot(df_month.Price, -, label=按月)
plt.legend()
plt.subplot(223)
plt.plot(df_Q.Price, -, label=按季度)
plt.legend()
plt.subplot(224)
plt.plot(df_year.Price, -, label=按年)
plt.legend()
plt.show()

ARMA模型训练

    设置参数范围
ps = range(0, 3)
qs = range(0, 3)
parameters = product(ps, qs)
parameters_list = list(parameters)
    寻找最优ARMA模型参数,即best_aic最小
results = []
best_aic = float("inf") # 正无穷
for param in parameters_list:
    try:
        model = ARMA(df_month.Price,order=(param[0], param[1])).fit()
    except ValueError:
        print(参数错误:, param)
        continue
    aic = model.aic
    if aic < best_aic:
        best_model = model
        best_aic = aic
        best_param = param
    results.append([param, model.aic])
    输出最优模型
result_table = pd.DataFrame(results)
result_table.columns = [parameters, aic]
print(最优模型: , best_model.summary())

指数预测

我们预测今年一年的上证指数走势,使用pd.date_range生成月字段,freq=MS’代表每月开始的日期。

df_month2 = df_month[[Price]]
date_list=pd.date_range(2019-3-31,2019-12-31, freq=M).tolist()
future = pd.DataFrame(index=date_list, columns= df_month.columns)
df_month2 = pd.concat([df_month2, future])
df_month2[forecast] = best_model.predict(start=0, end=350)

预测结果显示

plt.figure(figsize=(20,7))
df_month2.Price.plot(label=实际指数)
df_month2.forecast.plot(color=r, ls=--, label=预测指数)
plt.legend()
plt.title(指数(月))
plt.xlabel(时间)
plt.ylabel(指数)
plt.show()

预测结果图片显示 预测结果数值显示

经验分享 程序员 微信小程序 职场和发展