Skip to content

keras.layers.Normalization crashes with multi-dim mean and variance #22065

@thanhlecongg

Description

@thanhlecongg

keras.layers.Normalization crash with valid multi-dim mean and variance
Please use this example to replicate the problem:

import keras
import numpy as np
in1 = keras.Input(shape=(2, 3, 4))
mean = np.array([[0, 1., 2., 3.],
                [4., 5., 6., 7.]]
           )
variance = np.array([[1., 1., 1., 1.],
                    [2., 2., 2., 2.]])


normLayer = keras.layers.Normalization(axis=(1,3), mean=mean, variance=variance)
out1 = normLayer(in1)
print(out1)

Error Message:

tensorflow.python.framework.errors_impl.InvalidArgumentError: {{function_node __wrapped__BroadcastTo_device_/job:localhost/replica:0/task:0/device:CPU:0}} Unable to broadcast tensor of shape [2,4] to tensor of shape [1,2,1,4] [Op:BroadcastTo]

I believe we can reshape the tensor of [2, 4] to tensor of shape [1,2,1,4] given 2x4 = 1x2x1x4. In keras 3.7.0, keras.layers.Normalization worked well with these inputs. However, from keras 3.8.0, this error happens. I found that this may be a regression introduced in gh-20626. I think we can just seperate two cases, one for float, i.e., the case of gh-20626 where we will use broadcast_to and one for matrices where we use reshape same as previous implementation. If this approach make senses to you, I can make a PR.

Metadata

Metadata

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions