ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

pytorch 和 tensorflow的 upsampling 互通代码

2022-09-16 22:30:55  阅读:297  来源: 互联网

标签:11 10 import torch pytorch shape input tensorflow upsampling


pytorch 实现上采样

点击查看代码
import numpy as np
import torch.nn.functional as F
import torch
from torch import nn


input = torch.arange(0, 12, dtype=torch.float32).view(2, 2, 3).transpose(1, 2)
# size 和 scale_factor只能二选一
sample_layer = nn.Upsample(scale_factor=2, mode='nearest')
print(input)
print(sample_layer(input).transpose(1, 2), sample_layer(input).transpose(1, 2).shape)


输出

点击查看代码

tensor([[[ 0.,  3.],
         [ 1.,  4.],
         [ 2.,  5.]],

        [[ 6.,  9.],
         [ 7., 10.],
         [ 8., 11.]]])
tensor([[[ 0.,  1.,  2.],
         [ 0.,  1.,  2.],
         [ 3.,  4.,  5.],
         [ 3.,  4.,  5.]],

        [[ 6.,  7.,  8.],
         [ 6.,  7.,  8.],
         [ 9., 10., 11.],
         [ 9., 10., 11.]]]) torch.Size([2, 4, 3])

Process finished with exit code 0


tensorflow 的实现

点击查看代码

import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import UpSampling1D

#


#   Arguments:
#     size: Integer. Upsampling factor.
#
#   Input shape:
#     3D tensor with shape: `(batch_size, steps, features)`.
#
#   Output shape:
#     3D tensor with shape: `(batch_size, upsampled_steps, features)`.
input_shape = (2, 2, 3)
x = np.arange(np.prod(input_shape)).reshape(input_shape)
print(x)
# [[[ 0  1  2]
#   [ 3  4  5]]
#  [[ 6  7  8]
#   [ 9 10 11]]]
y = tf.keras.layers.UpSampling1D(size=2)(x)
print(y)
# tf.Tensor(
#   [[[ 0  1  2]
#     [ 0  1  2]
#     [ 3  4  5]
#     [ 3  4  5]]
#    [[ 6  7  8]
#     [ 6  7  8]
#     [ 9 10 11]
#     [ 9 10 11]]], shape=(2, 4, 3), dtype=int64)


两者是完全等价的

标签:11,10,import,torch,pytorch,shape,input,tensorflow,upsampling
来源: https://www.cnblogs.com/boyknight/p/16701393.html

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有