ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

TensorFlow2子类模型多输入多输出

2021-09-20 16:33:22  阅读:255  来源: 互联网

标签:TensorFlow2 子类 self initializer x2 tf x1 输入 size


在最近的一次项目中,因为需要模型具有多输入多输出,而且我的一个输出是一个包含张量的列表,所以无法使用函数式API或者容器去造模型,因为列表的添加操作不是一个层,而这两类的输出必须是层的结果,虽然可以用tf.keras.layers.Lambda将此操作变成层,但总归是牵强的,所以使用子类模型。

class Test(keras.Model):
    def __init__(self):
        super(Test, self).__init__()
        filters = 64
        initializer = tf.random_normal_initializer(0., 0.02)
        self.conv1 = Conv2D(filters, 4, 2, 'same', use_bias=False, 
                                              kernel_initializer=initializer)
        self.bn1 = BatchNormalization()
        self.conv2 = Conv2D(filters*2, 4, 2, 'same', use_bias=False, 
                                              kernel_initializer=initializer)
        self.bn2 = BatchNormalization()
        self.conv3 = Conv2D(filters*4, 4, 2, 'same', use_bias=False, 
                                              kernel_initializer=initializer)
        self.bn3 = BatchNormalization()

    def call(self, inputs):
        x1 = inputs[0]
        x2 = inputs[1]
        skips = []      # 存结果的列表
        x1_1 = tf.nn.relu(self.bn1(self.conv1(x1)))
        x2_1 = tf.nn.relu(self.bn1(self.conv1(x2)))
        skips.append(x1_1)

        x1_2 = tf.nn.relu(self.bn1(self.conv1(x1_1)))
        x2_2 = tf.nn.relu(self.bn1(self.conv1(x2_1)))
        skips.append(x1_2)

        x1_3 = tf.nn.relu(self.bn1(self.conv1(x1_2)))
        x2_3 = tf.nn.relu(self.bn1(self.conv1(x2_2)))
        skips.append(x1_3)

        return [skips, x2_3]


model = Test()
model.build(input_shape=[(batch_size, data_size), (batch_size, data_size)])
input1 = tf.random.normal([batch_size, data_size])
input2 = tf.random.normal([batch_size, data_size])
out_put1, out_put2 = model([input1, input2])

TF2用着真的是太难受了,网上的教程都比较泛,对一些细节的处理实例太难找了,找着了还大概率是tf.compat.v1。。。  做完这次我真的好好去看看Torch了。。。

另外,在TF2的图执行模式里,是无法使用for等循环的,但有专门的库函数tf.while_loop,反正我还不怎么会用,或者直接可以转eager模式就可以解决。

我的问题可能在一些大佬看来很低级,但确实给我造成了麻烦,我本以为教程上的东西就能解决一切问题的了,还是太弱。若要朋友想指正我的说法或者想要交流TF2里的坑,请私信我。

标签:TensorFlow2,子类,self,initializer,x2,tf,x1,输入,size
来源: https://blog.csdn.net/ReichQin/article/details/120392116

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有