ユーザ用ツール

サイト用ツール


max_pooling

文書の過去の版を表示しています。


マックスプーリング

import numpy as np
 
def max_pooling(array, kernel_size, stride):
    row_size = array.shape[0] - kernel_size + 1
    col_size = array.shape[1] - kernel_size + 1
 
    output_array = np.zeros((int(row_size/stride), int(col_size/stride)))
 
    for row_index in range(0, row_size, stride):
        for col_index in range(0, col_size, stride):
            # 起点からカーネルフィルタに入るarrayの要素を順次比較し、最大値を得る
            max_value = 0
            for r_index in range(0, kernel_size):
                for c_index in range(0, kernel_size):
                    if array[row_index+r_index][col_index+c_index] > max_value:
                        max_value = array[row_index+r_index][col_index+c_index]
            output_array[int(row_index/stride)][int(col_index/stride)] = max_value
 
    return output_array
 
# 入力データの作成
array = np.zeros((6, 6))
array[0][0] = 10
array[0][1] = 9
array[1][1] = 8
array[2][0] = 6
array[2][2] = 12
array[4][4] = 17
 
kernel_size = 2
stride = 2
 
output = max_pooling(array, kernel_size, stride)
 
print(output)
max_pooling.1738076774.txt.gz · 最終更新: 2025/01/28 15:06 by 118.158.174.226

Donate Powered by PHP Valid HTML5 Valid CSS Driven by DokuWiki