up 0 down

Я строй ˙U сеть для сегментации двоичного изображения изображения. Я использую tf.nn апи Tensorflow в. Мое входное изображение имеет размеры (256,256,3) и выходное бинарное изображение имеет размеры (256,256,1). Выход чистой модели U должно быть (1,256,256,1), но результаты вывода формы, чтобы быть (7,256,256,3) .For сверточные ядра я использую усеченный нормальный инициализатор Tensorflow с каждым типом данных, как float32. Могу ли я создать несколько слоев выходных где-то в коде

def get_filter(shape,na):
    w =tf.get_variable(name=na,shape=shape,dtype='float32',initializer=tf.truncated_normal_initializer(dtype='float32'))
    return w   
def unet(inp):
    #f1 = get_filter(shape=[3,3,3,16])
    lay_16_1 = tf.nn.conv2d(inp,filter=get_filter(shape=[3,3,3,16],na='w_1'),strides=[1,1,1,1],padding='SAME',name='conv_16_1')
    lay_16_2 = tf.nn.relu(lay_16_1,name='re_16_1')
    lay_16_3 = tf.layers.batch_normalization(lay_16_2,axis=-1,name='bn_16')
    lay_16_4 = tf.nn.conv2d(lay_16_3,filter=get_filter([3,3,16,16],na='w_2'),strides=[1,1,1,1],padding='SAME',name='conv_16_2')
    lay_16_5 = tf.nn.relu(lay_16_4,name='re_16_2')
    lay_p1 = tf.nn.max_pool(lay_16_5,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME',name='pool_1')

    lay_32_1 = tf.nn.conv2d(lay_p1,filter=get_filter([3,3,16,32],na='w_3'),strides=[1,1,1,1],padding='SAME',name='conv_32_1')
    lay_32_2 = tf.nn.relu(lay_32_1,name='re_32_1')
    lay_32_3 = tf.layers.batch_normalization(lay_32_2,axis=-1,name='bn_32')
    lay_32_4 = tf.nn.conv2d(lay_32_3,filter=get_filter([3,3,32,32],na='w_4'),strides=[1,1,1,1],padding='SAME',name='conv_32_2')
    lay_32_5 = tf.nn.relu(lay_32_4,name='re_32_2')
    lay_p2 = tf.nn.max_pool(lay_32_5,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME',name='pool_2')

    lay_64_1 = tf.nn.conv2d(lay_p2,filter=get_filter([3,3,32,64],na='w_5'),strides=[1,1,1,1],padding='SAME',name='conv_64_1')
    lay_64_2 = tf.nn.relu(lay_64_1,name='re_64_1')
    lay_64_3 = tf.layers.batch_normalization(lay_64_2,axis=-1,name='bn_64')
    lay_64_4 = tf.nn.conv2d(lay_64_3,filter=get_filter([3,3,64,64],na='w_6'),strides=[1,1,1,1],padding='SAME',name='conv_64_2')
    lay_64_5 = tf.nn.relu(lay_64_4,name='re_64_2')
    lay_p3 = tf.nn.max_pool(lay_64_5,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME',name='pool_3')

    lay_128_1 = tf.nn.conv2d(lay_p3,filter=get_filter([3,3,64,128],na='w_7'),strides=[1,1,1,1],padding='SAME',name='conv_128_1')
    lay_128_2 = tf.nn.relu(lay_128_1,name='re_128_1')
    lay_128_3 = tf.layers.batch_normalization(lay_128_2,axis=-1,name='bn_128')
    lay_128_4 = tf.nn.conv2d(lay_128_3,filter=get_filter([3,3,128,128],na='w_8'),strides=[1,1,1,1],padding='SAME',name='conv_128_2')
    lay_128_5 = tf.nn.relu(lay_128_4,name='re_128_2')
    lay_p4 = tf.nn.max_pool(lay_128_5,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME',name='pool_4')

    lay_256_1 = tf.nn.conv2d(lay_p4,filter=get_filter([3,3,128,256],na='w_9'),strides=[1,1,1,1],padding='SAME',name='conv_256_1')
    lay_256_2 = tf.nn.relu(lay_256_1,name='re_256_1')
    lay_256_3 = tf.layers.batch_normalization(lay_256_2,axis=-1,name='bn_256')
    lay_256_4 = tf.nn.conv2d(lay_256_3,filter=get_filter([3,3,256,256],na='w_10'),strides=[1,1,1,1],padding='SAME',name='conv_256_2')
    lay_256_5 = tf.nn.relu(lay_256_4,name='re_256_2')
    lay_p5 = tf.nn.max_pool(lay_256_5,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME',name='pool_5')

    lay_512_1 = tf.nn.conv2d(lay_p5,filter=get_filter([3,3,256,512],na='w_11'),strides=[1,1,1,1],padding='SAME',name='conv_512_1')
    lay_512_2 = tf.nn.relu(lay_512_1,name='re_512_1')
    lay_512_3 = tf.layers.batch_normalization(lay_512_2,axis=-1,name='bn_512')
    lay_512_4 = tf.nn.conv2d(lay_512_3,filter=get_filter([3,3,512,512],na='w_12'),strides=[1,1,1,1],padding='SAME',name='conv_512_2')
    lay_512_5 = tf.nn.relu(lay_512_4,name='re_512_2')
    lay_p6 = tf.nn.max_pool(lay_512_5,ksize=[1,2,2,1],strides=[1,2,2,1],padding='SAME',name='pool_6')

    lay_1024_1 = tf.nn.conv2d(lay_p6,filter=get_filter([3,3,512,1024],na='w_13'),strides=[1,1,1,1],padding='SAME',name='conv_1024_1')
    lay_1024_2 = tf.nn.relu(lay_1024_1,name='re_1024_1')
    lay_1024_3 = tf.layers.batch_normalization(lay_1024_2,axis=-1,name='bn_1024')
    lay_1024_4 = tf.nn.conv2d(lay_1024_3,filter=get_filter([3,3,1024,1024],na='w_14'),strides=[1,1,1,1],padding='SAME',name='conv_1024_2')
    lay_1024_5 = tf.nn.relu(lay_1024_4,name='re_1024_2')
    #lay_p7 = tf.nn.max_pool(lay_1024,ksize=[1,2,2,1],strides=[1,1,1,1],padding='SAME',name='pool_7')

    up_512 = tf.image.resize_images(images=lay_1024_5,size=[8,8],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    con_512_1 = tf.nn.conv2d(up_512,filter=get_filter([3,3,1024,512],na='w_15'),strides=[1,1,1,1],padding='SAME',name='mer_512_1')
    con_512_2 = tf.nn.relu(con_512_1,name='rel_512_1')
    mer_512 = tf.concat([lay_512_5,con_512_2],axis=0,name='mer_512_2')
    con_512_3 = tf.nn.conv2d(mer_512,filter=get_filter([3,3,512,512],na='w_16'),strides=[1,1,1,1],padding='SAME',name='mer_512_3')
    con_512_4 = tf.nn.relu(con_512_3,name='rel_512_2')
    con_512_5 = tf.layers.batch_normalization(con_512_4,axis=-1,name='mer_bn_512')
    con_512_6 = tf.nn.conv2d(con_512_5,filter=get_filter([3,3,512,512],na='w_17'),strides=[1,1,1,1],padding='SAME',name='mer_512_4')
    con_512_7 = tf.nn.relu(con_512_6,name='rel_512_3')

    up_256 = tf.image.resize_images(images=con_512_7,size=[16,16],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    con_256_1 = tf.nn.conv2d(up_256,filter=get_filter([3,3,512,256],na='w_18'),strides=[1,1,1,1],padding='SAME',name='mer_256_1')
    con_256_2 = tf.nn.relu(con_256_1,name='rel_256_1')
    mer_256 = tf.concat([lay_256_5,con_256_2],axis=0,name='mer_256_2')
    con_256_3 = tf.nn.conv2d(mer_256,filter=get_filter([3,3,256,256],na='w_19'),strides=[1,1,1,1],padding='SAME',name='mer_256_3')
    con_256_4 = tf.nn.relu(con_256_3,name='rel_256_2')
    con_256_5 = tf.layers.batch_normalization(con_256_4,axis=-1,name='mer_bn_256')
    con_256_6 = tf.nn.conv2d(con_256_5,filter=get_filter([3,3,256,256],na='w_20'),strides=[1,1,1,1],padding='SAME',name='mer_256_4')
    con_256_7 = tf.nn.relu(con_256_6,name='rel_256_3')

    up_128 = tf.image.resize_images(images=con_256_7,size=[32,32],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    con_128_1 = tf.nn.conv2d(up_128,filter=get_filter([3,3,256,128],na='w_21'),strides=[1,1,1,1],padding='SAME',name='mer_128_1')
    con_128_2 = tf.nn.relu(con_128_1,name='rel_128_1')
    mer_128 = tf.concat([lay_128_5,con_128_2],axis=0,name='mer_128_2')
    con_128_3 = tf.nn.conv2d(mer_128,filter=get_filter([3,3,128,128],na='w_22'),strides=[1,1,1,1],padding='SAME',name='mer_128_3')
    con_128_4 = tf.nn.relu(con_128_3,name='rel_128_2')
    con_128_5 = tf.layers.batch_normalization(con_128_4,axis=-1,name='mer_bn_128')
    con_128_6 = tf.nn.conv2d(con_128_5,filter=get_filter([3,3,128,128],na='w_23'),strides=[1,1,1,1],padding='SAME',name='mer_128_4')
    con_128_7 = tf.nn.relu(con_128_6,name='rel_128_3')

    up_64 = tf.image.resize_images(images=con_128_7,size=[64,64],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    con_64_1 = tf.nn.conv2d(up_64,filter=get_filter([3,3,128,64],na='w_24'),strides=[1,1,1,1],padding='SAME',name='mer_64_1')
    con_64_2 = tf.nn.relu(con_64_1,name='rel_64_1')
    mer_64 = tf.concat([lay_64_5,con_64_2],axis=0,name='mer_64_2')
    con_64_3 = tf.nn.conv2d(mer_64,filter=get_filter([3,3,64,64],na='w_25'),strides=[1,1,1,1],padding='SAME',name='mer_64_3')
    con_64_4 = tf.nn.relu(con_64_3,name='rel_64_2')
    con_64_5 = tf.layers.batch_normalization(con_64_4,axis=-1,name='mer_bn_64')
    con_64_6 = tf.nn.conv2d(con_64_5,filter=get_filter([3,3,64,64],na='w_26'),strides=[1,1,1,1],padding='SAME',name='mer_64_4')
    con_64_7 = tf.nn.relu(con_64_6,name='rel_64_3')

    up_32 = tf.image.resize_images(images=con_64_7,size=[128,128],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    con_32_1 = tf.nn.conv2d(up_32,filter=get_filter([3,3,64,32],na='w_27'),strides=[1,1,1,1],padding='SAME',name='mer_32_1')
    con_32_2 = tf.nn.relu(con_32_1,name='rel_32_1')
    mer_32 = tf.concat([lay_32_5,con_32_2],axis=0,name='mer_32_2')
    con_32_3 = tf.nn.conv2d(mer_32,filter=get_filter([3,3,32,32],na='w_28'),strides=[1,1,1,1],padding='SAME',name='mer_32_3')
    con_32_4 = tf.nn.relu(con_32_3,name='rel_32_2')
    con_32_5 = tf.layers.batch_normalization(con_32_4,axis=-1,name='mer_bn_32')
    con_32_6 = tf.nn.conv2d(con_32_5,filter=get_filter([3,3,32,32],na='w_29'),strides=[1,1,1,1],padding='SAME',name='mer_32_4')
    con_32_7 = tf.nn.relu(con_32_6,name='rel_32_3')

    up_16 = tf.image.resize_images(images=con_32_7,size=[256,256],method=tf.image.ResizeMethod.NEAREST_NEIGHBOR)
    con_16_1 = tf.nn.conv2d(up_16,filter=get_filter([3,3,32,16],na='w_30'),strides=[1,1,1,1],padding='SAME',name='mer_16_1')
    con_16_2 = tf.nn.relu(con_16_1,name='rel_16_1')
    mer_16 = tf.concat([lay_16_5,con_16_2],axis=0,name='mer_16_2')
    con_16_3 = tf.nn.conv2d(mer_16,filter=get_filter([3,3,16,16],na='w_31'),strides=[1,1,1,1],padding='SAME',name='mer_16_3')
    con_16_4 = tf.nn.relu(con_16_3,name='rel_16_2')
    con_16_5 = tf.layers.batch_normalization(con_16_4,axis=-1,name='mer_bn_16')
    con_16_6 = tf.nn.conv2d(con_16_5,filter=get_filter([3,3,16,16],na='w_32'),strides=[1,1,1,1],padding='SAME',name='mer_16_4')
    con_16_7 = tf.nn.relu(con_16_6,name='rel_16_3')

    fin_img = tf.nn.conv2d(con_16_7,filter=get_filter([1,1,16,1],na='w_33'),strides=[1,1,1,1],padding='SAME',name='final_image')
    #fin_img = tf.nn.sigmoid(fin_img)
    return fin_img