ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

神经网络dnn 多分类模型

2022-05-28 02:31:41  阅读:282  来源: 互联网

标签:模型 dnn list shape 神经网络 train tf col unit


import tensorflow.compat.v1 as tf
# from tensorflow.examples.tutorials.mnist import input_data
import os
import pandas as pd
import numpy as np
from tensorflow.python.keras.utils import to_categorical

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'


# 数据准备
df = pd.read_csv('./data/train_date_new.csv',sep=',',index_col=None,header=0)

# print(df)

# print(df.groupby(by="diabete").count())

X = df.iloc[:,1:].values.astype(np.float32)
Y = to_categorical(df.iloc[:,0].values).astype(np.float32)


print(X.shape,Y.shape)

train_split = int(df.shape[0]*0.8)
x_train,y_train,x_test,y_test = X[:train_split,:],Y[:train_split,:],X[train_split:,:],Y[train_split:,:]

ind,col = x_train.shape
y_ind,y_col = y_train.shape

# 全连接神经网络
def dense(x, w, b, keeppord):
    linear = tf.matmul(x, w) + b
    # activation = tf.nn.relu(linear)
    activation = tf.nn.sigmoid(linear)
    # activation = tf.nn.tanh(linear)
    # activation = tf.nn.softmax(linear)
    y = tf.nn.dropout(activation,keeppord)
    return y


def DNNModel(image, w, b, keeppord):
    global dense1
    for i in range(len(w)-1):
        if i==0:
            dense1 = dense(image, w[i], b[i],keeppord)
        else:
            dense1 = dense(dense1, w[i], b[i],keeppord)

    output = tf.matmul(dense1, w[-1]) + b[-1]
    return output

# 生成网络的权重
def gen_weights(unit_list):
    w = []
    b = []
    # 遍历层数
    for i in range(len(unit_list)-1):
        sub_w = tf.Variable(tf.random_normal(shape=[unit_list[i], unit_list[i+1]]))
        sub_b = tf.Variable(tf.random_normal(shape=[1,unit_list[i+1]]))
        w.append(sub_w)
        b.append(sub_b)
    return w, b


x = tf.placeholder(tf.float32, [None, col])
y = tf.placeholder(tf.float32, [None, y_col])
keepprob = tf.placeholder(tf.float32)

global_step = tf.Variable(0)

# unit_list = [784, 512, 256, 10]
unit_list = [col, 512,256, y_col] #  0.7543333

# unit_list = [col,1024,512,y_col]
duropt = 0.75

w, b = gen_weights(unit_list)
y_pre = DNNModel(x, w, b, keepprob)
loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=y_pre, labels=y))
tf.summary.scalar("loss", loss)                 # 收集标量

opt = tf.train.AdamOptimizer(0.01).minimize(loss, global_step=global_step)

predict = tf.equal(tf.argmax(y_pre, axis=1), tf.argmax(y, axis=1))       # 返回每行或者每列最大值的索引,判断是否相等
acc = tf.reduce_mean(tf.cast(predict, tf.float32))

tf.summary.scalar("acc", acc)                   # 收集标量
merged = tf.summary.merge_all()                 # 和并变量
saver = tf.train.Saver()                        # 保存和加载模型
init = tf.global_variables_initializer()        # 初始化全局变量



bach = 4
bach_0=bach-1
min_bach = int(ind/4)
print(bach_0,min_bach)


with tf.Session() as sess:
    sess.run(init)
    writer = tf.summary.FileWriter("./log/tensorboard", tf.get_default_graph())      # tensorboard 事件文件
    for i in range(10000):
        for j in range(bach):
            if j <= bach_0:
                x_train_bach, y_train_bach = x_train[(j * min_bach):(j + 1) * min_bach, :],\
                                             y_train[(j * min_bach):(j + 1) * min_bach,:]
            else:
                x_train_bach, y_train_bach = x_train[(j + 1) * min_bach:, :], y_train[(j + 1) * min_bach:, :]

            summary, _ = sess.run([merged, opt], feed_dict={x:x_train_bach, y:y_train_bach, keepprob: duropt})
            writer.add_summary(summary, i)              # 将每次迭代后的变量写入事件文件

        # 评估模型在验证集上的识别率
        if (i+1) % 1000 == 0:
            feeddict = {x: x_test, y: y_test, keepprob: 1.}      # 验证集
            valloss, accuracy = sess.run([loss, acc], feed_dict=feeddict)
            print(i, 'th batch val loss:', valloss, ', accuracy:', accuracy)

    saver.save(sess, './model/tfdnn.ckpt')        # 保存模型
    print('测试集准确度:', sess.run(acc, feed_dict={x:x_test, y:y_test, keepprob:1.}))

writer.close()

  

标签:模型,dnn,list,shape,神经网络,train,tf,col,unit
来源: https://www.cnblogs.com/wuzaipei/p/16319720.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有