transpose即转置,举个简单的例子:

a2 = transpose(a1, perm=None, name='transpose')

a1是一个ndarray,重要的是perm:比如,我们的数据是[n, h, w, c]的格式,如果要变成[c, n, h, w],那就要设置perm = [3, 0, 1, 2],即perm[i]代表在新的数组(在这里就是a2)中,第i个维度,是原来的哪一个维度。如,perm[0] = 3意思就是新数组里面第0维就是原来的第3维(即c)。

在对网络可视化时候,不论是对feature map可视化还是filter可视化,最简单粗暴的方法可以用几个多重for循环把对应的块拼接到一个大的矩阵上可视化。但还有一种更加简单的方法,就是用transpose配合reshape来实现。把数据transpose[c, n, h, w]之后,加pad之后,reshape即可,参考:

def vis_square(data):
    """Take an array of shape (n, height, width) or (n, height, width, 3)
       and visualize each (height, width) thing in a grid of size approx. sqrt(n) by sqrt(n)"""
    
    # normalize data for display
    data = (data - data.min()) / (data.max() - data.min())
    
    # force the number of filters to be square
    n = int(np.ceil(np.sqrt(data.shape[0])))
    padding = (((0, n ** 2 - data.shape[0]),
               (0, 1), (0, 1))                 # add some space between filters
               + ((0, 0),) * (data.ndim - 3))  # don't pad the last dimension (if there is one)
    data = np.pad(data, padding, mode='constant', constant_values=1)  # pad with ones (white)
    
    # tile the filters into an image
    data = data.reshape((n, n) + data.shape[1:]).transpose((0, 2, 1, 3) + tuple(range(4, data.ndim + 1)))
    data = data.reshape((n * data.shape[1], n * data.shape[3]) + data.shape[4:])
    
    plt.imshow(data); plt.axis('off')