diff --git a/arm_compute/core/QuantizationInfo.h b/arm_compute/core/QuantizationInfo.h index c63777e80e7..e0d3e048161 100644 --- a/arm_compute/core/QuantizationInfo.h +++ b/arm_compute/core/QuantizationInfo.h @@ -1,5 +1,5 @@ /* - * Copyright (c) 2019-2025 Arm Limited. + * Copyright (c) 2019-2026 Arm Limited. * * SPDX-License-Identifier: MIT * @@ -203,6 +203,17 @@ class QuantizationInfo */ inline bool operator==(const QuantizationInfo &lhs, const QuantizationInfo &rhs) { + const bool lhs_is_uniform = lhs.scale().size() < 2 && lhs.offset().size() < 2; + const bool rhs_is_uniform = rhs.scale().size() < 2 && rhs.offset().size() < 2; + + if (lhs_is_uniform && rhs_is_uniform) + { + const auto lhs_qinfo = lhs.uniform(); + const auto rhs_qinfo = rhs.uniform(); + + return (lhs_qinfo.scale == rhs_qinfo.scale) && (lhs_qinfo.offset == rhs_qinfo.offset); + } + return (lhs.scale() == rhs.scale()) && (lhs.offset() == rhs.offset()); } diff --git a/src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp b/src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp index c3da624b6a3..330ec726a3e 100644 --- a/src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp +++ b/src/cpu/kernels/internal/CpuPool2dAssemblyWrapperKernel.cpp @@ -132,7 +132,7 @@ CpuPool2dAssemblyWrapperKernel::validate(const ITensorInfo *src, const ITensorIn const auto src_qinfo = src->quantization_info().uniform(); const auto dst_qinfo = dst->quantization_info().uniform(); - if (src_qinfo != dst_qinfo) + if (src->quantization_info() != dst->quantization_info()) { ARM_COMPUTE_RETURN_ERROR_ON_MSG(info.pool_type != PoolingType::MAX, "Assembly kernels only support differing src/dst quantization info for " diff --git a/tests/validation/UNIT/QuantizationInfo.cpp b/tests/validation/UNIT/QuantizationInfo.cpp new file mode 100644 index 00000000000..42f4ae858a0 --- /dev/null +++ b/tests/validation/UNIT/QuantizationInfo.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) 2026 Arm Limited. + * + * SPDX-License-Identifier: MIT + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to + * deal in the Software without restriction, including without limitation the + * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or + * sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE + * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE + * SOFTWARE. + */ +#include "arm_compute/core/QuantizationInfo.h" + +#include "tests/framework/Asserts.h" +#include "tests/framework/Macros.h" + +namespace arm_compute +{ +namespace test +{ +namespace validation +{ +TEST_SUITE(UNIT) +TEST_SUITE(QuantizationInfo) + +TEST_CASE(EquivalentUniformImplicitAndExplicitZeroOffset, framework::DatasetMode::ALL) +{ + const arm_compute::QuantizationInfo implicit_zero(1.0f); + const arm_compute::QuantizationInfo explicit_zero(1.0f, 0); + + ARM_COMPUTE_EXPECT(implicit_zero == explicit_zero, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(!(implicit_zero != explicit_zero), framework::LogLevel::ERRORS); +} + +TEST_CASE(DifferentUniformQInfoCompareDifferent, framework::DatasetMode::ALL) +{ + const arm_compute::QuantizationInfo unit_scale(1.0f); + const arm_compute::QuantizationInfo half_scale(0.5f, 0); + const arm_compute::QuantizationInfo offset_one(1.0f, 1); + const arm_compute::QuantizationInfo offset_zero(1.0f, 0); + + ARM_COMPUTE_EXPECT(unit_scale != half_scale, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(offset_one != offset_zero, framework::LogLevel::ERRORS); +} + +TEST_CASE(PerChannelQInfoUsesStrictVectorEquality, framework::DatasetMode::ALL) +{ + const arm_compute::QuantizationInfo matching_per_channel(std::vector{1.0f, 2.0f}); + const arm_compute::QuantizationInfo same_per_channel(std::vector{1.0f, 2.0f}); + const arm_compute::QuantizationInfo different_per_channel(std::vector{1.0f, 3.0f}); + + ARM_COMPUTE_EXPECT(matching_per_channel == same_per_channel, framework::LogLevel::ERRORS); + ARM_COMPUTE_EXPECT(matching_per_channel != different_per_channel, framework::LogLevel::ERRORS); +} + +TEST_CASE(PerChannelQInfoDoesNotMatchUniformQInfo, framework::DatasetMode::ALL) +{ + const arm_compute::QuantizationInfo per_channel(std::vector{1.0f, 2.0f}); + const arm_compute::QuantizationInfo uniform(1.0f); + + ARM_COMPUTE_EXPECT(per_channel != uniform, framework::LogLevel::ERRORS); +} + +TEST_SUITE_END() // QuantizationInfo +TEST_SUITE_END() // UNIT +} // namespace validation +} // namespace test +} // namespace arm_compute