ICode9

精准搜索请尝试: 精确搜索
首页 > 其他分享> 文章详细

conv_general_dilated实现

2021-02-14 11:30:55  阅读:276  来源: 互联网

标签:dilated dim conv rhs general maps lhs spec dilation


 

参考:

https://github.com/DarrenZhang01/TensorFlow_GSoC/blob/2fd360c8b1b8c9106044034f6a8b5c2734db9c3d/tf_jax_stax/tf_conv_general.py

结果应该还不对:


"""
This file contains a general conv operation for TensorFlow.
Zhibo Zhang, 2020.06.06
"""
import torch.nn as nn
import numpy as np
import tensorflow as tf



# # Translate dimension numbers from number representations into string
# #   representations
# #
# # For example, in a 1D case, "NWC" means "# batches", "spatial dimensions" and
# #   "# Channels (features)", so the according number representations based on order
# #   "batch dimension, feature dimension, spatial dimension" is tuple (0, 2, 1);
# #
# # The reverse translation is simple. Set the default mapping as follows:
# #   1D: "NCW" <=> (lhs_spec[0], lhs_spec[1], lhs_spec[2])
# #   2D: "NCHW" <=> (lhs_spec[0], lhs_spec[1], lhs_spec[2], lhs_spec[3])
# #   3D: "NCDHW" <=> (lhs_spec[0], lhs_spec[1], lhs_spec[2], lhs_spec[3], lhs_spec[4])
# # In order to get the translated string version from the number version, sort
# # strings and numbers at the same time based on numbers (0 to len(lhs_spec) - 1)
# def conv_dim_translator(lhs, lhs_spec, dim):
#   """ Translate ConvDimensionNumbers into string representations. """
#   str_maps = {1: ['N', 'C', 'W'], 2: ['N', 'C', 'H', 'W'],
#               3: ['N', 'C', 'D', 'H', 'W']}
#   str_list = str_maps[dim]
#   sorted_str = sort_together([list(lhs_spec), str_list])[1]
#   spatial_dim_maps = {1: 'W', 2: "HW", 3: "DHW"}
#   if not ((sorted_str[0] == 'N' and sorted_str[1] == 'C') or\
#           (sorted_str[0] == 'N' and sorted_str[-1] == 'C')):
#     # raise TypeError("The batch number should always be at the first dimension"
#     #                 "and the channel/feature number should always be at the"
#     #                 "second dimension or the very last dimension, but got ",
#     #                 sorted_str)
#     # The current data format is not valid, need to swap certain axis of the
#     # inputs to make it valid.
#     lhs = np.moveaxis(lhs, (sorted_str.index('N'), sorted_str.index('C')), (0, dim + 1))
#     return 'N' + spatial_dim_maps[dim] + 'C', lhs
#   output_str = "NC" + spatial_dim_maps[dim] \
#                 if sorted_str[0] == 'N' and sorted_str[1] == 'C' else \
#                 'N' + spatial_dim_maps[dim] + 'C'
#   return output_str, lhs


# For example,
#  in the 3D case, if lhs_dilation = 2, then convert it to [2, 2, 2]
#                  if lhs_dilation = (2, 2, 2), convert it also to [2, 2, 2]
def _conv_general_param_type_converter(window_strides, lhs_dilation, rhs_dilation,dim):
  """ Convert the inputs strides, lhs_dilation, rhs_dilation to the standard
  TF conv inputs."""
  strides = [window_strides] * dim if isinstance(window_strides, int) else \
            list(window_strides)
  if lhs_dilation:
    lhs_dilation = [lhs_dilation] * dim if isinstance(lhs_dilation, int) else \
                    list(lhs_dilation)
  if rhs_dilation:
    rhs_dilation = [rhs_dilation] * dim if isinstance(rhs_dilation, int) else \
                    list(rhs_dilation)
  return (strides, lhs_dilation, rhs_dilation)


# TODO: Support feature_group_count, batch_group_count and precision, and
#       allow lhs_dilation and rhs_dilation to happen at the same time.
def conv_general_dilated(lhs, rhs, window_strides, padding, lhs_dilation=None,
                         rhs_dilation=None, dimension_numbers=None,
                         feature_group_count=1, batch_group_count=1, precision=None):
  """ A general conv function that integrates normal conv, deconvolution,
  dilated convolution, etc."""
  # raise TypeError("lhs shape: {}, rhs shape: {}".format(lhs.shape, rhs.shape))
  dim = None
  lhs_spec, rhs_spec, out_spec = dimension_numbers
  if lhs_spec != out_spec:
    raise TypeError("Current implementation requires the `data_format` of the"
                    "inputs and outputs to be the same.")
  if len(lhs_spec) >= 6:
    raise TypeError("Current implmentation does not support 4 or higher"
                    "dimensional convolution, but got: ", len(lhs_spec) - 2)
  dim = len(lhs_spec) - 2
  if lhs_dilation and rhs_dilation:
    if lhs_dilation == (1,) * dim and rhs_dilation == (1,) * dim:
      lhs_dilation, rhs_dilation = None, None
    else:
      raise TypeError("Current implementation does not support that deconvolution"
                    "and dilation to be performed at the same time, but got"
                    " lhs_dilation: {}, rhs_dilation: {}".format(lhs_dilation,
                    rhs_dilation))
  print("the dim is: {}".format(dim))
  if padding not in ["SAME", "VALID"]:
    raise TypeError("Current implementation requires the padding parameter"
                    "to be either 'VALID' or 'SAME', but got: ", padding)
  # Convert params from int/Sequence[int] to list of ints.
  strides, lhs_dilation, rhs_dilation = _conv_general_param_type_converter(window_strides, lhs_dilation, rhs_dilation,dim)
  # Preprocess the shapes
  dim_maps = {}
  if isinstance(lhs_spec, str):
    dim_maps['I'] = list(rhs_spec).index('I')
    dim_maps['O'] = list(rhs_spec).index('O')
    dim_maps['N'] = list(lhs_spec).index('N')
    dim_maps['C'] = list(lhs_spec).index('C')
  else:
    dim_maps['I'] = rhs_spec[1]
    dim_maps['O'] = rhs_spec[0]
    dim_maps['N'] = lhs_spec[0]
    dim_maps['C'] = lhs_spec[1]
  # data_format, lhs = conv_dim_translator(lhs, lhs_spec, dim)
  lhs = np.moveaxis(lhs, (dim_maps['N'], dim_maps['C']), (0, dim + 1))
  # Adjust the filters, put the dimension 'I' and 'O' at last.
  rhs = np.moveaxis(rhs, (dim_maps['O'], dim_maps['I']), (dim + 1, dim))
  spatial_dim_maps = {1: 'W', 2: "HW", 3: "DHW"}
  data_format = 'N' + spatial_dim_maps[dim] + 'C'
  tf_nn_APIs = {1: [nn.conv1d, nn.conv1d_transpose],
                2: [nn.conv2d, nn.conv2d_transpose],
                3: [nn.conv3d, nn.conv3d_transpose]}

  output = None
  if rhs_dilation or (lhs_dilation is None and rhs_dilation is None):
    output = tf_nn_APIs[dim][0](lhs, rhs, strides, padding, data_format, rhs_dilation)
  else:
    output = tf_nn_APIs[dim][1](lhs, rhs, strides, padding, data_format, lhs_dilation)
  output = np.moveaxis(output, (0, dim + 1), (dim_maps['N'], dim_maps['C']))
  return output

 

标签:dilated,dim,conv,rhs,general,maps,lhs,spec,dilation
来源: https://blog.csdn.net/jacke121/article/details/113806678

本站声明: 1. iCode9 技术分享网(下文简称本站)提供的所有内容,仅供技术学习、探讨和分享;
2. 关于本站的所有留言、评论、转载及引用,纯属内容发起人的个人观点,与本站观点和立场无关;
3. 关于本站的所有言论和文字,纯属内容发起人的个人观点,与本站观点和立场无关;
4. 本站文章均是网友提供,不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属;如您发现该文章侵犯了您的权益,可联系我们第一时间进行删除;
5. 本站为非盈利性的个人网站,所有内容不会用来进行牟利,也不会利用任何形式的广告来间接获益,纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。

专注分享技术,共同学习,共同进步。侵权联系[81616952@qq.com]

Copyright (C)ICode9.com, All Rights Reserved.

ICode9版权所有