You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
vgg = tf.keras.applications.VGG16(weights='imagenet', include_top=False, input_shape=input_size)
vgg.trainable = False
outputs = [vgg.get_layer(l).output for l in selected_layers]
model = tf.keras.Model(vgg.input, outputs)
I was trying to evaluate my network using perceptual loss. The issue was because my model output shape was (512,512,1). Hence when the VGG layer takes in the input with pre-trained weight "Imagenet" and it looks for a 3 channel image.
Hence before providing the "reconstructed image" to the function I made sure that it becomes a 3 channel. reconstruct_image = tf.keras.layers.Concatenate()([reconstruct_image, reconstruct_image, reconstruct_image])
I have modified the code slightly. Also, I had an OOM issue, for that, I have reduced my image size to (256,256). My modified code is given below
I am getting this error when I added custom loss
input shape(512,512,3)
model=models.unet_2d(input_size, filter_num, n_labels=1, stack_num_down=2, stack_num_up=2,
activation='ReLU', output_activation=None, batch_norm=True, pool='max', unpool='nearest',
backbone='VGG16', weights='imagenet', freeze_backbone=True, freeze_batch_norm=True, name='unet')
////////
selected_layers = ['block1_conv1', 'block2_conv2',"block3_conv3" ]
selected_layer_weights = [0.65, 0.3 , 0.05 ]
vgg = tf.keras.applications.VGG16(weights='imagenet', include_top=False, input_shape=input_size)
vgg.trainable = False
outputs = [vgg.get_layer(l).output for l in selected_layers]
model = tf.keras.Model(vgg.input, outputs)
@tf.function
def perceptual_loss(input_image , reconstruct_image):
h1_list = model(input_image)
h2_list = model(reconstruct_image)
///////
model.compile(loss=perceptual_loss, optimizer=keras.optimizers.SGD(lr=1e-2),metrics=[tf.keras.metrics.MeanAbsoluteError(),tf.keras.metrics.MeanSquaredError()])
pretrained weight='imagenet'
Backbone=VGG16
How to resolve this
The text was updated successfully, but these errors were encountered: