# 动手实现深度学习（13）池化层的实现

2022-09-12 19:05:05  阅读：78  来源： 互联网

#### 10.1 池化层的运算

github: Leezhen2014: https://github.com/Leezhen2014/python_deep_learning

##### 池化层的forward

Pool分为三类 mean-pool, max-pool和min-pool， 本章只讨论max-pool

##### 池化层的backward的运算

Max-pool的反传是将原来的单元扩大stride_h*stride_w，其余的地方填充0

#### 10.2 池化层的实现

```  1 class Pooling:
2     def __init__(self, pool_h, pool_w, stride=1, pad=0):
3         self.pool_h = pool_h
4         self.pool_w = pool_w
5         self.stride = stride
7
8         self.x = None
9         self.arg_max = None
10
11     def forward(self, x):
12         N, C, H, W = x.shape
13         out_h = int(1 + (H - self.pool_h) / self.stride)
14         out_w = int(1 + (W - self.pool_w) / self.stride)
15
16         col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)
17         col = col.reshape(-1, self.pool_h * self.pool_w)
18
19         arg_max = np.argmax(col, axis=1)
20         out = np.max(col, axis=1)
21         out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)
22
23         self.x = x
24         self.arg_max = arg_max
25
26         return out
27
28     def backward(self, dout):
29         dout = dout.transpose(0, 2, 3, 1)
30
31         pool_size = self.pool_h * self.pool_w
32         dmax = np.zeros((dout.size, pool_size))
33         dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
34         dmax = dmax.reshape(dout.shape + (pool_size,))
35
36         dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
37         dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)
38
39         return dx```

#### 10.3 pool单元测试

im2col以后的数据：

Maxpool以后的数据：

```  1 # -*- coding: utf-8 -*-
2 # @File  : test_im2col.py
3 # @Author: lizhen
4 # @Date  : 2020/2/14
5 # @Desc  : 测试im2col
6 import numpy as np
7
8 from src.common.util import im2col,col2im
9 from src.common.layers import Convolution,Pooling
10
11
12 if __name__ == '__main__':
13     raw_data = [3, 0, 4, 2,
14                 6, 5, 4, 3,
15                 3, 0, 2, 3,
16                 1, 0, 3, 1,
17
18                 1, 2, 0, 1,
19                 3, 0, 2, 4,
20                 1, 0, 3, 2,
21                 4, 3, 0, 1,
22
23                 4, 2, 0, 1,
24                 1, 2, 0, 4,
25                 3, 0, 4, 2,
26                 6, 2, 4, 5
27     ]
28
29     raw_filter=[
30         1,    1,    1,    1,    1,    1,
31         1,    1,    1,    1,    1,    1,
32         2,    2,    2,    2,    2,   2,
33         2,    2,    2,    2,    2,   2,
34
35     ]
36
37
38
39     input_data = np.array(raw_data)
40     filter_data = np.array(raw_filter)
41
42     x = input_data.reshape(1,3,4,4)# NCHW
43     W = filter_data.reshape(2,3,2,2) # NHWC
44     b = np.zeros(2)
45     # b = b.reshape((2,1))
47     # print(col1)
48
49     # print("input_data.shape=%s"%str(input_data.shape))
50     # print("W.shape=%s"%str(W.shape))
51     # print("b.shape=%s"%str(b.shape))
52     # conv = Convolution(W,b) # def __init__(self, W, b, stride=1, pad=0)
53     # out = conv.forward(x)
54     # print("bout.shape=%s"%str(out.shape))
55     # print(out)
56
57     print("===================")
58     pool=Pooling( pool_h=2, pool_w=2, stride=2, pad=0)
59     out = pool.forward(x)
60     print(out.shape)
61     print(out)```

2. 关于本站的所有留言、评论、转载及引用，纯属内容发起人的个人观点，与本站观点和立场无关；
3. 关于本站的所有言论和文字，纯属内容发起人的个人观点，与本站观点和立场无关；
4. 本站文章均是网友提供，不完全保证技术分享内容的完整性、准确性、时效性、风险性和版权归属；如您发现该文章侵犯了您的权益，可联系我们第一时间进行删除；
5. 本站为非盈利性的个人网站，所有内容不会用来进行牟利，也不会利用任何形式的广告来间接获益，纯粹是为了广大技术爱好者提供技术内容和技术思想的分享性交流网站。