# 动手实现深度学习（13）池化层的实现

2022-09-12 19:05:05  阅读：78  来源： 互联网

#### 10.1 池化层的运算

github: Leezhen2014: https://github.com/Leezhen2014/python_deep_learning

##### 池化层的forward

Pool分为三类 mean-pool, max-pool和min-pool， 本章只讨论max-pool

##### 池化层的backward的运算

Max-pool的反传是将原来的单元扩大stride_h*stride_w，其余的地方填充0

#### 10.2 池化层的实现

```  1 class Pooling:
2     def __init__(self, pool_h, pool_w, stride=1, pad=0):
3         self.pool_h = pool_h
4         self.pool_w = pool_w
5         self.stride = stride
7
8         self.x = None
9         self.arg_max = None
10
11     def forward(self, x):
12         N, C, H, W = x.shape
13         out_h = int(1 + (H - self.pool_h) / self.stride)
14         out_w = int(1 + (W - self.pool_w) / self.stride)
15
16         col = im2col(x, self.pool_h, self.pool_w, self.stride, self.pad)
17         col = col.reshape(-1, self.pool_h * self.pool_w)
18
19         arg_max = np.argmax(col, axis=1)
20         out = np.max(col, axis=1)
21         out = out.reshape(N, out_h, out_w, C).transpose(0, 3, 1, 2)
22
23         self.x = x
24         self.arg_max = arg_max
25
26         return out
27
28     def backward(self, dout):
29         dout = dout.transpose(0, 2, 3, 1)
30
31         pool_size = self.pool_h * self.pool_w
32         dmax = np.zeros((dout.size, pool_size))
33         dmax[np.arange(self.arg_max.size), self.arg_max.flatten()] = dout.flatten()
34         dmax = dmax.reshape(dout.shape + (pool_size,))
35
36         dcol = dmax.reshape(dmax.shape[0] * dmax.shape[1] * dmax.shape[2], -1)
37         dx = col2im(dcol, self.x.shape, self.pool_h, self.pool_w, self.stride, self.pad)
38
39         return dx```

#### 10.3 pool单元测试

im2col以后的数据：

Maxpool以后的数据：

```  1 # -*- coding: utf-8 -*-
2 # @File  : test_im2col.py
3 # @Author: lizhen
4 # @Date  : 2020/2/14
5 # @Desc  : 测试im2col
6 import numpy as np
7
8 from src.common.util import im2col,col2im
9 from src.common.layers import Convolution,Pooling
10
11
12 if __name__ == '__main__':
13     raw_data = [3, 0, 4, 2,
14                 6, 5, 4, 3,
15                 3, 0, 2, 3,
16                 1, 0, 3, 1,
17
18                 1, 2, 0, 1,
19                 3, 0, 2, 4,
20                 1, 0, 3, 2,
21                 4, 3, 0, 1,
22
23                 4, 2, 0, 1,
24                 1, 2, 0, 4,
25                 3, 0, 4, 2,
26                 6, 2, 4, 5
27     ]
28
29     raw_filter=[
30         1,    1,    1,    1,    1,    1,
31         1,    1,    1,    1,    1,    1,
32         2,    2,    2,    2,    2,   2,
33         2,    2,    2,    2,    2,   2,
34
35     ]
36
37
38
39     input_data = np.array(raw_data)
40     filter_data = np.array(raw_filter)
41
42     x = input_data.reshape(1,3,4,4)# NCHW
43     W = filter_data.reshape(2,3,2,2) # NHWC
44     b = np.zeros(2)
45     # b = b.reshape((2,1))
47     # print(col1)
48
49     # print("input_data.shape=%s"%str(input_data.shape))
50     # print("W.shape=%s"%str(W.shape))
51     # print("b.shape=%s"%str(b.shape))
52     # conv = Convolution(W,b) # def __init__(self, W, b, stride=1, pad=0)
53     # out = conv.forward(x)
54     # print("bout.shape=%s"%str(out.shape))
55     # print(out)
56
57     print("===================")
58     pool=Pooling( pool_h=2, pool_w=2, stride=2, pad=0)
59     out = pool.forward(x)
60     print(out.shape)
61     print(out)```

