-
Notifications
You must be signed in to change notification settings - Fork 19.7k
Open
Labels
Description
keras.layers.Normalization crash with valid multi-dim mean and variance
Please use this example to replicate the problem:
import keras
import numpy as np
in1 = keras.Input(shape=(2, 3, 4))
mean = np.array([[0, 1., 2., 3.],
[4., 5., 6., 7.]]
)
variance = np.array([[1., 1., 1., 1.],
[2., 2., 2., 2.]])
normLayer = keras.layers.Normalization(axis=(1,3), mean=mean, variance=variance)
out1 = normLayer(in1)
print(out1)
Error Message:
tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node __wrapped__BroadcastTo_device_/job:localhost/replica:0/task:0/device:CPU:0}} Unable to broadcast tensor of shape [2,4] to tensor of shape [1,2,1,4] [Op:BroadcastTo]
I believe we can reshape the tensor of [2, 4] to tensor of shape [1,2,1,4] given 2x4 = 1x2x1x4. In keras 3.7.0, keras.layers.Normalization worked well with these inputs. However, from keras 3.8.0, this error happens. I found that this may be a regression introduced in gh-20626. I think we can just seperate two cases, one for float, i.e., the case of gh-20626 where we will use broadcast_to and one for matrices where we use reshape same as previous implementation. If this approach make senses to you, I can make a PR.