diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/GroupedTopNRowNumberBuilder.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/GroupedTopNRowNumberBuilder.java index 3118f2bde3ffd..172adc82bc9fb 100644 --- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/GroupedTopNRowNumberBuilder.java +++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/GroupedTopNRowNumberBuilder.java @@ -49,6 +49,8 @@ public class GroupedTopNRowNumberBuilder implements GroupedTopNBuilder { private final GroupedTopNRowNumberAccumulator groupedTopNRowNumberAccumulator; private final TsBlockWithPositionComparator comparator; + private int effectiveGroupCount = 0; + public GroupedTopNRowNumberBuilder( List sourceTypes, TsBlockWithPositionComparator comparator, @@ -77,10 +79,18 @@ public GroupedTopNRowNumberBuilder( @Override public void addTsBlock(TsBlock tsBlock) { - int[] groupIds = groupByHash.getGroupIds(tsBlock.getColumns(groupByChannels)); - int groupCount = groupByHash.getGroupCount(); + int[] groupIds; + if (groupByChannels.length == 0) { + groupIds = new int[tsBlock.getPositionCount()]; + if (tsBlock.getPositionCount() > 0) { + effectiveGroupCount = 1; + } + } else { + groupIds = groupByHash.getGroupIds(tsBlock.getColumns(groupByChannels)); + effectiveGroupCount = groupByHash.getGroupCount(); + } - processTsBlock(tsBlock, groupCount, groupIds); + processTsBlock(tsBlock, effectiveGroupCount, groupIds); } @Override @@ -120,7 +130,7 @@ private void processTsBlock(TsBlock newTsBlock, int groupCount, int[] groupIds) private class ResultIterator extends AbstractIterator { private final TsBlockBuilder tsBlockBuilder; - private final int groupIdCount = groupByHash.getGroupCount(); + private final int groupIdCount = effectiveGroupCount; private int currentGroupId = -1; private final LongBigArray rowIdOutput = new LongBigArray(); private long currentGroupSize; diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/ValuesOperatorTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/ValuesOperatorTest.java new file mode 100644 index 0000000000000..cb5d9ea5c8710 --- /dev/null +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/ValuesOperatorTest.java @@ -0,0 +1,207 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.process; + +import org.apache.iotdb.commons.concurrent.IoTDBThreadPoolFactory; +import org.apache.iotdb.db.queryengine.common.FragmentInstanceId; +import org.apache.iotdb.db.queryengine.common.PlanFragmentId; +import org.apache.iotdb.db.queryengine.common.QueryId; +import org.apache.iotdb.db.queryengine.execution.driver.DriverContext; +import org.apache.iotdb.db.queryengine.execution.fragment.FragmentInstanceContext; +import org.apache.iotdb.db.queryengine.execution.fragment.FragmentInstanceStateMachine; +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNodeId; + +import com.google.common.collect.ImmutableList; +import org.apache.tsfile.block.column.ColumnBuilder; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.TsBlock; +import org.apache.tsfile.read.common.block.TsBlockBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.concurrent.ExecutorService; + +import static org.apache.iotdb.db.queryengine.execution.fragment.FragmentInstanceContext.createFragmentInstanceContext; +import static org.apache.iotdb.db.queryengine.execution.operator.source.relational.TableScanOperator.TIME_COLUMN_TEMPLATE; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class ValuesOperatorTest { + private static final ExecutorService instanceNotificationExecutor = + IoTDBThreadPoolFactory.newFixedThreadPool(1, "valuesOperator-test-instance-notification"); + + @Test + public void testEmptyValues() { + try (ValuesOperator operator = genValuesOperator(ImmutableList.of())) { + assertTrue(operator.isFinished()); + assertFalse(operator.hasNext()); + assertNull(operator.next()); + assertEquals(0, operator.calculateMaxPeekMemory()); + assertEquals(0, operator.calculateMaxReturnSize()); + assertEquals(0, operator.calculateRetainedSizeAfterCallingNext()); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testSingleTsBlock() { + int[] values = {10, 20, 30}; + TsBlock tsBlock = createIntTsBlock(values); + + try (ValuesOperator operator = genValuesOperator(Collections.singletonList(tsBlock))) { + assertFalse(operator.isFinished()); + assertTrue(operator.hasNext()); + + TsBlock result = operator.next(); + assertNotNull(result); + assertEquals(3, result.getPositionCount()); + for (int i = 0; i < values.length; i++) { + assertEquals(values[i], result.getColumn(0).getInt(i)); + } + + assertTrue(operator.isFinished()); + assertFalse(operator.hasNext()); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testMultipleTsBlocks() { + int[] values1 = {1, 2, 3}; + int[] values2 = {4, 5}; + int[] values3 = {6, 7, 8, 9}; + + TsBlock block1 = createIntTsBlock(values1); + TsBlock block2 = createIntTsBlock(values2); + TsBlock block3 = createIntTsBlock(values3); + + try (ValuesOperator operator = genValuesOperator(Arrays.asList(block1, block2, block3))) { + assertFalse(operator.isFinished()); + assertTrue(operator.hasNext()); + + // First block + TsBlock result1 = operator.next(); + assertNotNull(result1); + assertEquals(3, result1.getPositionCount()); + for (int i = 0; i < values1.length; i++) { + assertEquals(values1[i], result1.getColumn(0).getInt(i)); + } + + // Second block + assertFalse(operator.isFinished()); + assertTrue(operator.hasNext()); + TsBlock result2 = operator.next(); + assertNotNull(result2); + assertEquals(2, result2.getPositionCount()); + for (int i = 0; i < values2.length; i++) { + assertEquals(values2[i], result2.getColumn(0).getInt(i)); + } + + // Third block + assertFalse(operator.isFinished()); + assertTrue(operator.hasNext()); + TsBlock result3 = operator.next(); + assertNotNull(result3); + assertEquals(4, result3.getPositionCount()); + for (int i = 0; i < values3.length; i++) { + assertEquals(values3[i], result3.getColumn(0).getInt(i)); + } + + assertTrue(operator.isFinished()); + assertFalse(operator.hasNext()); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testRetainedSizeDecreases() { + int[] values1 = {1, 2, 3}; + int[] values2 = {4, 5, 6}; + + TsBlock block1 = createIntTsBlock(values1); + TsBlock block2 = createIntTsBlock(values2); + + try (ValuesOperator operator = genValuesOperator(Arrays.asList(block1, block2))) { + long initialRetained = operator.calculateRetainedSizeAfterCallingNext(); + + operator.next(); + long afterFirstRetained = operator.calculateRetainedSizeAfterCallingNext(); + assertTrue( + "Retained size should decrease after consuming a block", + afterFirstRetained < initialRetained); + + operator.next(); + long afterSecondRetained = operator.calculateRetainedSizeAfterCallingNext(); + assertEquals(0, afterSecondRetained); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + @Test + public void testIsBlockedReturnsNotBlocked() { + try (ValuesOperator operator = genValuesOperator(ImmutableList.of())) { + assertTrue(operator.isBlocked().isDone()); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + private TsBlock createIntTsBlock(int[] values) { + TsBlockBuilder builder = + new TsBlockBuilder(values.length, Collections.singletonList(TSDataType.INT32)); + ColumnBuilder columnBuilder = builder.getColumnBuilder(0); + for (int value : values) { + columnBuilder.writeInt(value); + } + builder.declarePositions(values.length); + return builder.build( + new RunLengthEncodedColumn(TIME_COLUMN_TEMPLATE, builder.getPositionCount())); + } + + private ValuesOperator genValuesOperator(java.util.List tsBlocks) { + QueryId queryId = new QueryId("stub_query"); + FragmentInstanceId instanceId = + new FragmentInstanceId(new PlanFragmentId(queryId, 0), "stub-instance"); + FragmentInstanceStateMachine stateMachine = + new FragmentInstanceStateMachine(instanceId, instanceNotificationExecutor); + FragmentInstanceContext fragmentInstanceContext = + createFragmentInstanceContext(instanceId, stateMachine); + DriverContext driverContext = new DriverContext(fragmentInstanceContext, 0); + PlanNodeId planNode = new PlanNodeId("1"); + driverContext.addOperatorContext(1, planNode, TreeLinearFillOperator.class.getSimpleName()); + + return new ValuesOperator(driverContext.getOperatorContexts().get(0), tsBlocks); + } +} diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/RowNumberOperatorTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/RowNumberOperatorTest.java new file mode 100644 index 0000000000000..cdf58eb7681c8 --- /dev/null +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/RowNumberOperatorTest.java @@ -0,0 +1,359 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.process.window; + +import org.apache.iotdb.commons.concurrent.IoTDBThreadPoolFactory; +import org.apache.iotdb.db.queryengine.common.FragmentInstanceId; +import org.apache.iotdb.db.queryengine.common.PlanFragmentId; +import org.apache.iotdb.db.queryengine.common.QueryId; +import org.apache.iotdb.db.queryengine.execution.driver.DriverContext; +import org.apache.iotdb.db.queryengine.execution.fragment.FragmentInstanceContext; +import org.apache.iotdb.db.queryengine.execution.fragment.FragmentInstanceStateMachine; +import org.apache.iotdb.db.queryengine.execution.operator.Operator; +import org.apache.iotdb.db.queryengine.execution.operator.OperatorContext; +import org.apache.iotdb.db.queryengine.execution.operator.process.TreeLinearFillOperator; +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNodeId; + +import com.google.common.util.concurrent.ListenableFuture; +import org.apache.tsfile.common.conf.TSFileConfig; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.TsBlock; +import org.apache.tsfile.read.common.block.TsBlockBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.concurrent.ExecutorService; + +import static org.apache.iotdb.db.queryengine.execution.fragment.FragmentInstanceContext.createFragmentInstanceContext; +import static org.apache.iotdb.db.queryengine.execution.operator.source.relational.TableScanOperator.TIME_COLUMN_TEMPLATE; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class RowNumberOperatorTest { + private static final ExecutorService instanceNotificationExecutor = + IoTDBThreadPoolFactory.newFixedThreadPool(1, "rowNumberOperator-test-instance-notification"); + + @Test + public void testRowNumberWithPartition() { + long[][] timeArray = {{1, 2, 3, 4, 5}}; + String[][] deviceArray = {{"d1", "d1", "d2", "d2", "d2"}}; + int[][] valueArray = {{10, 20, 30, 40, 50}}; + + long[] expectedTime = {1, 2, 3, 4, 5}; + String[] expectedDevice = {"d1", "d1", "d2", "d2", "d2"}; + int[] expectedValue = {10, 20, 30, 40, 50}; + long[] expectedRn = {1, 2, 1, 2, 3}; + + verifyRowNumberResults( + timeArray, + deviceArray, + valueArray, + Arrays.asList(1), + Optional.empty(), + expectedTime, + expectedDevice, + expectedValue, + expectedRn); + } + + @Test + public void testRowNumberWithoutPartition() { + long[][] timeArray = {{1, 2, 3, 4, 5}}; + String[][] deviceArray = {{"d1", "d1", "d2", "d2", "d2"}}; + int[][] valueArray = {{10, 20, 30, 40, 50}}; + + long[] expectedTime = {1, 2, 3, 4, 5}; + String[] expectedDevice = {"d1", "d1", "d2", "d2", "d2"}; + int[] expectedValue = {10, 20, 30, 40, 50}; + long[] expectedRn = {1, 2, 3, 4, 5}; + + verifyRowNumberResults( + timeArray, + deviceArray, + valueArray, + Collections.emptyList(), + Optional.empty(), + expectedTime, + expectedDevice, + expectedValue, + expectedRn); + } + + @Test + public void testRowNumberWithMaxRowsPerPartition() { + long[][] timeArray = {{1, 2, 3, 4, 5}}; + String[][] deviceArray = {{"d1", "d1", "d2", "d2", "d2"}}; + int[][] valueArray = {{10, 20, 30, 40, 50}}; + + // maxRowsPerPartition=2: d1 keeps 2, d2 keeps 2 (third row skipped) + long[] expectedTime = {1, 2, 3, 4}; + String[] expectedDevice = {"d1", "d1", "d2", "d2"}; + int[] expectedValue = {10, 20, 30, 40}; + long[] expectedRn = {1, 2, 1, 2}; + + verifyRowNumberResults( + timeArray, + deviceArray, + valueArray, + Arrays.asList(1), + Optional.of(2), + expectedTime, + expectedDevice, + expectedValue, + expectedRn); + } + + @Test + public void testRowNumberPartitionCrossMultiTsBlocks() { + long[][] timeArray = {{1, 2, 3}, {4, 5, 6, 7}}; + String[][] deviceArray = {{"d1", "d1", "d2"}, {"d2", "d2", "d3", "d3"}}; + int[][] valueArray = {{10, 20, 30}, {40, 50, 60, 70}}; + + long[] expectedTime = {1, 2, 3, 4, 5, 6, 7}; + String[] expectedDevice = {"d1", "d1", "d2", "d2", "d2", "d3", "d3"}; + int[] expectedValue = {10, 20, 30, 40, 50, 60, 70}; + long[] expectedRn = {1, 2, 1, 2, 3, 1, 2}; + + verifyRowNumberResults( + timeArray, + deviceArray, + valueArray, + Arrays.asList(1), + Optional.empty(), + expectedTime, + expectedDevice, + expectedValue, + expectedRn); + } + + @Test + public void testRowNumberWithEmptyInput() throws Exception { + long[][] timeArray = {}; + String[][] deviceArray = {}; + int[][] valueArray = {}; + + DriverContext driverContext = createDriverContext(); + Operator childOperator = new ChildOperator(timeArray, deviceArray, valueArray, driverContext); + + List inputDataTypes = + Arrays.asList(TSDataType.TIMESTAMP, TSDataType.TEXT, TSDataType.INT32); + List outputChannels = Arrays.asList(0, 1, 2); + + try (RowNumberOperator operator = + new RowNumberOperator( + driverContext.getOperatorContexts().get(0), + childOperator, + inputDataTypes, + outputChannels, + Collections.singletonList(1), + Optional.empty(), + 10)) { + assertTrue(operator.isFinished()); + assertFalse(operator.hasNext()); + } + } + + @Test + public void testRowNumberWithSingleRowPartitions() { + long[][] timeArray = {{1, 2, 3}}; + String[][] deviceArray = {{"d1", "d2", "d3"}}; + int[][] valueArray = {{10, 20, 30}}; + + long[] expectedTime = {1, 2, 3}; + String[] expectedDevice = {"d1", "d2", "d3"}; + int[] expectedValue = {10, 20, 30}; + long[] expectedRn = {1, 1, 1}; + + verifyRowNumberResults( + timeArray, + deviceArray, + valueArray, + Arrays.asList(1), + Optional.empty(), + expectedTime, + expectedDevice, + expectedValue, + expectedRn); + } + + private void verifyRowNumberResults( + long[][] timeArray, + String[][] deviceArray, + int[][] valueArray, + List partitionChannels, + Optional maxRowsPerPartition, + long[] expectedTime, + String[] expectedDevice, + int[] expectedValue, + long[] expectedRn) { + int count = 0; + try (RowNumberOperator operator = + genRowNumberOperator( + timeArray, deviceArray, valueArray, partitionChannels, maxRowsPerPartition)) { + ListenableFuture future = operator.isBlocked(); + future.get(); + while (!operator.isFinished() && operator.hasNext()) { + TsBlock tsBlock = operator.next(); + if (tsBlock != null && !tsBlock.isEmpty()) { + for (int i = 0; i < tsBlock.getPositionCount(); i++, count++) { + assertEquals(expectedTime[count], tsBlock.getColumn(0).getLong(i)); + assertEquals( + expectedDevice[count], + tsBlock.getColumn(1).getBinary(i).getStringValue(TSFileConfig.STRING_CHARSET)); + assertEquals(expectedValue[count], tsBlock.getColumn(2).getInt(i)); + assertEquals(expectedRn[count], tsBlock.getColumn(3).getLong(i)); + } + } + } + assertEquals(expectedTime.length, count); + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + private DriverContext createDriverContext() { + QueryId queryId = new QueryId("stub_query"); + FragmentInstanceId instanceId = + new FragmentInstanceId(new PlanFragmentId(queryId, 0), "stub-instance"); + FragmentInstanceStateMachine stateMachine = + new FragmentInstanceStateMachine(instanceId, instanceNotificationExecutor); + FragmentInstanceContext fragmentInstanceContext = + createFragmentInstanceContext(instanceId, stateMachine); + DriverContext driverContext = new DriverContext(fragmentInstanceContext, 0); + PlanNodeId planNode = new PlanNodeId("1"); + driverContext.addOperatorContext(1, planNode, TreeLinearFillOperator.class.getSimpleName()); + return driverContext; + } + + private RowNumberOperator genRowNumberOperator( + long[][] timeArray, + String[][] deviceArray, + int[][] valueArray, + List partitionChannels, + Optional maxRowsPerPartition) { + DriverContext driverContext = createDriverContext(); + + List inputDataTypes = + Arrays.asList(TSDataType.TIMESTAMP, TSDataType.TEXT, TSDataType.INT32); + List outputChannels = new ArrayList<>(); + for (int i = 0; i < inputDataTypes.size(); i++) { + outputChannels.add(i); + } + + Operator childOperator = new ChildOperator(timeArray, deviceArray, valueArray, driverContext); + return new RowNumberOperator( + driverContext.getOperatorContexts().get(0), + childOperator, + inputDataTypes, + outputChannels, + partitionChannels, + maxRowsPerPartition, + 10); + } + + static class ChildOperator implements Operator { + private int index; + private final long[][] timeArray; + private final String[][] deviceArray; + private final int[][] valueArray; + private final DriverContext driverContext; + + ChildOperator( + long[][] timeArray, + String[][] deviceArray, + int[][] valueArray, + DriverContext driverContext) { + this.timeArray = timeArray; + this.deviceArray = deviceArray; + this.valueArray = valueArray; + this.driverContext = driverContext; + this.index = 0; + } + + @Override + public OperatorContext getOperatorContext() { + return driverContext.getOperatorContexts().get(0); + } + + @Override + public TsBlock next() { + if (index >= timeArray.length) { + return null; + } + TsBlockBuilder builder = + new TsBlockBuilder( + timeArray[index].length, + Arrays.asList(TSDataType.TIMESTAMP, TSDataType.TEXT, TSDataType.INT32)); + for (int i = 0; i < timeArray[index].length; i++) { + builder.getColumnBuilder(0).writeLong(timeArray[index][i]); + builder + .getColumnBuilder(1) + .writeBinary(new Binary(deviceArray[index][i], TSFileConfig.STRING_CHARSET)); + builder.getColumnBuilder(2).writeInt(valueArray[index][i]); + } + builder.declarePositions(timeArray[index].length); + index++; + return builder.build( + new RunLengthEncodedColumn(TIME_COLUMN_TEMPLATE, builder.getPositionCount())); + } + + @Override + public boolean hasNext() { + return index < timeArray.length; + } + + @Override + public boolean isFinished() { + return index >= timeArray.length; + } + + @Override + public void close() {} + + @Override + public long calculateMaxPeekMemory() { + return 0; + } + + @Override + public long calculateMaxReturnSize() { + return 0; + } + + @Override + public long calculateRetainedSizeAfterCallingNext() { + return 0; + } + + @Override + public long ramBytesUsed() { + return 0; + } + } +} diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/TopKRankingOperatorTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/TopKRankingOperatorTest.java new file mode 100644 index 0000000000000..1bbf05f7679b9 --- /dev/null +++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/TopKRankingOperatorTest.java @@ -0,0 +1,455 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.apache.iotdb.db.queryengine.execution.operator.process.window; + +import org.apache.iotdb.commons.concurrent.IoTDBThreadPoolFactory; +import org.apache.iotdb.db.queryengine.common.FragmentInstanceId; +import org.apache.iotdb.db.queryengine.common.PlanFragmentId; +import org.apache.iotdb.db.queryengine.common.QueryId; +import org.apache.iotdb.db.queryengine.execution.driver.DriverContext; +import org.apache.iotdb.db.queryengine.execution.fragment.FragmentInstanceContext; +import org.apache.iotdb.db.queryengine.execution.fragment.FragmentInstanceStateMachine; +import org.apache.iotdb.db.queryengine.execution.operator.Operator; +import org.apache.iotdb.db.queryengine.execution.operator.OperatorContext; +import org.apache.iotdb.db.queryengine.execution.operator.process.TreeLinearFillOperator; +import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNodeId; +import org.apache.iotdb.db.queryengine.plan.relational.planner.SortOrder; +import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TopKRankingNode; + +import org.apache.tsfile.common.conf.TSFileConfig; +import org.apache.tsfile.enums.TSDataType; +import org.apache.tsfile.read.common.block.TsBlock; +import org.apache.tsfile.read.common.block.TsBlockBuilder; +import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn; +import org.apache.tsfile.utils.Binary; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.concurrent.ExecutorService; + +import static org.apache.iotdb.db.queryengine.execution.fragment.FragmentInstanceContext.createFragmentInstanceContext; +import static org.apache.iotdb.db.queryengine.execution.operator.source.relational.TableScanOperator.TIME_COLUMN_TEMPLATE; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; + +public class TopKRankingOperatorTest { + private static final ExecutorService instanceNotificationExecutor = + IoTDBThreadPoolFactory.newFixedThreadPool( + 1, "topKRankingOperator-test-instance-notification"); + + @Test + public void testTopKWithPartition() { + // Input: 4 rows for d1, 3 rows for d2 + // Sort by value (column 2) ascending, top 2 per partition + long[][] timeArray = {{1, 2, 3, 4, 5, 6, 7}}; + String[][] deviceArray = {{"d1", "d1", "d1", "d1", "d2", "d2", "d2"}}; + int[][] valueArray = {{5, 3, 1, 4, 6, 2, 1}}; + + // Expected: top 2 per partition sorted by value ASC + // d1: value=1(rn=1), value=3(rn=2) + // d2: value=1(rn=1), value=2(rn=2) + Map> expectedByDevice = new HashMap<>(); + expectedByDevice.put("d1", Arrays.asList(new int[] {1, 1}, new int[] {3, 2})); + expectedByDevice.put("d2", Arrays.asList(new int[] {1, 1}, new int[] {2, 2})); + + verifyTopKResultsByPartition( + timeArray, + deviceArray, + valueArray, + Collections.singletonList(1), + Collections.singletonList(TSDataType.TEXT), + Collections.singletonList(2), + Collections.singletonList(SortOrder.ASC_NULLS_LAST), + 2, + false, + expectedByDevice, + 4); + } + + @Test + public void testTopKWithPartitionDescending() { + long[][] timeArray = {{1, 2, 3, 4, 5, 6}}; + String[][] deviceArray = {{"d1", "d1", "d1", "d2", "d2", "d2"}}; + int[][] valueArray = {{5, 3, 1, 6, 2, 4}}; + + // top 2 per partition sorted by value DESC + // d1: value=5(rn=1), value=3(rn=2) + // d2: value=6(rn=1), value=4(rn=2) + Map> expectedByDevice = new HashMap<>(); + expectedByDevice.put("d1", Arrays.asList(new int[] {5, 1}, new int[] {3, 2})); + expectedByDevice.put("d2", Arrays.asList(new int[] {6, 1}, new int[] {4, 2})); + + verifyTopKResultsByPartition( + timeArray, + deviceArray, + valueArray, + Collections.singletonList(1), + Collections.singletonList(TSDataType.TEXT), + Collections.singletonList(2), + Collections.singletonList(SortOrder.DESC_NULLS_LAST), + 2, + false, + expectedByDevice, + 4); + } + + @Test + public void testTopKWithoutPartition() { + // No partition: all rows in one group + long[][] timeArray = {{1, 2, 3, 4, 5}}; + String[][] deviceArray = {{"d1", "d1", "d2", "d2", "d2"}}; + int[][] valueArray = {{5, 3, 1, 4, 2}}; + + // top 3 globally sorted by value ASC: value=1(rn=1), value=2(rn=2), value=3(rn=3) + int[][] expectedValueAndRn = {{1, 1}, {2, 2}, {3, 3}}; + + verifyTopKResultsGlobal( + timeArray, + deviceArray, + valueArray, + Collections.emptyList(), + Collections.emptyList(), + Collections.singletonList(2), + Collections.singletonList(SortOrder.ASC_NULLS_LAST), + 3, + false, + expectedValueAndRn, + 3); + } + + @Test + public void testTopKWithMultipleTsBlocks() { + long[][] timeArray = {{1, 2, 3}, {4, 5}, {6, 7}}; + String[][] deviceArray = {{"d1", "d1", "d1"}, {"d2", "d2"}, {"d2", "d2"}}; + int[][] valueArray = {{5, 3, 1}, {6, 2}, {4, 1}}; + + // top 2 per partition sorted by value ASC + // d1: value=1(rn=1), value=3(rn=2) + // d2: value=1(rn=1), value=2(rn=2) + Map> expectedByDevice = new HashMap<>(); + expectedByDevice.put("d1", Arrays.asList(new int[] {1, 1}, new int[] {3, 2})); + expectedByDevice.put("d2", Arrays.asList(new int[] {1, 1}, new int[] {2, 2})); + + verifyTopKResultsByPartition( + timeArray, + deviceArray, + valueArray, + Collections.singletonList(1), + Collections.singletonList(TSDataType.TEXT), + Collections.singletonList(2), + Collections.singletonList(SortOrder.ASC_NULLS_LAST), + 2, + false, + expectedByDevice, + 4); + } + + @Test + public void testTopKWithTopOne() { + long[][] timeArray = {{1, 2, 3, 4}}; + String[][] deviceArray = {{"d1", "d1", "d2", "d2"}}; + int[][] valueArray = {{5, 3, 6, 2}}; + + // top 1 per partition sorted by value ASC + // d1: value=3(rn=1) + // d2: value=2(rn=1) + Map> expectedByDevice = new HashMap<>(); + expectedByDevice.put("d1", Collections.singletonList(new int[] {3, 1})); + expectedByDevice.put("d2", Collections.singletonList(new int[] {2, 1})); + + verifyTopKResultsByPartition( + timeArray, + deviceArray, + valueArray, + Collections.singletonList(1), + Collections.singletonList(TSDataType.TEXT), + Collections.singletonList(2), + Collections.singletonList(SortOrder.ASC_NULLS_LAST), + 1, + false, + expectedByDevice, + 2); + } + + /** + * Verifies top-K results grouped by partition (device). The output order between partitions is + * not guaranteed, so we group results by device and verify each partition independently. + */ + private void verifyTopKResultsByPartition( + long[][] timeArray, + String[][] deviceArray, + int[][] valueArray, + List partitionChannels, + List partitionTypes, + List sortChannels, + List sortOrders, + int maxRowCountPerPartition, + boolean partial, + Map> expectedByDevice, + int expectedTotalCount) { + + Map> actualByDevice = new HashMap<>(); + int count = 0; + + try (TopKRankingOperator operator = + genTopKRankingOperator( + timeArray, + deviceArray, + valueArray, + partitionChannels, + partitionTypes, + sortChannels, + sortOrders, + maxRowCountPerPartition, + partial)) { + while (!operator.isFinished()) { + if (operator.hasNext()) { + TsBlock tsBlock = operator.next(); + if (tsBlock != null && !tsBlock.isEmpty()) { + int numColumns = tsBlock.getValueColumnCount(); + for (int i = 0; i < tsBlock.getPositionCount(); i++, count++) { + String device = + tsBlock.getColumn(1).getBinary(i).getStringValue(TSFileConfig.STRING_CHARSET); + int value = tsBlock.getColumn(2).getInt(i); + long rowNumber = tsBlock.getColumn(numColumns - 1).getLong(i); + actualByDevice + .computeIfAbsent(device, k -> new ArrayList<>()) + .add(new int[] {value, (int) rowNumber}); + } + } + } + } + assertEquals(expectedTotalCount, count); + + for (Map.Entry> entry : expectedByDevice.entrySet()) { + String device = entry.getKey(); + List expectedRows = entry.getValue(); + List actualRows = actualByDevice.get(device); + + assertTrue("Missing partition for device: " + device, actualRows != null); + assertEquals( + "Row count mismatch for device " + device, expectedRows.size(), actualRows.size()); + for (int i = 0; i < expectedRows.size(); i++) { + assertEquals( + "Value mismatch at row " + i + " for device " + device, + expectedRows.get(i)[0], + actualRows.get(i)[0]); + assertEquals( + "Row number mismatch at row " + i + " for device " + device, + expectedRows.get(i)[1], + actualRows.get(i)[1]); + } + } + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + private void verifyTopKResultsGlobal( + long[][] timeArray, + String[][] deviceArray, + int[][] valueArray, + List partitionChannels, + List partitionTypes, + List sortChannels, + List sortOrders, + int maxRowCountPerPartition, + boolean partial, + int[][] expectedValueAndRn, + int expectedTotalCount) { + + List results = new ArrayList<>(); + int count = 0; + + try (TopKRankingOperator operator = + genTopKRankingOperator( + timeArray, + deviceArray, + valueArray, + partitionChannels, + partitionTypes, + sortChannels, + sortOrders, + maxRowCountPerPartition, + partial)) { + while (!operator.isFinished()) { + if (operator.hasNext()) { + TsBlock tsBlock = operator.next(); + if (tsBlock != null && !tsBlock.isEmpty()) { + int numColumns = tsBlock.getValueColumnCount(); + for (int i = 0; i < tsBlock.getPositionCount(); i++, count++) { + int value = tsBlock.getColumn(2).getInt(i); + long rowNumber = tsBlock.getColumn(numColumns - 1).getLong(i); + results.add(new int[] {value, (int) rowNumber}); + } + } + } + } + assertEquals(expectedTotalCount, count); + for (int i = 0; i < expectedValueAndRn.length; i++) { + assertEquals("Value mismatch at row " + i, expectedValueAndRn[i][0], results.get(i)[0]); + assertEquals( + "Row number mismatch at row " + i, expectedValueAndRn[i][1], results.get(i)[1]); + } + } catch (Exception e) { + e.printStackTrace(); + fail(e.getMessage()); + } + } + + private DriverContext createDriverContext() { + QueryId queryId = new QueryId("stub_query"); + FragmentInstanceId instanceId = + new FragmentInstanceId(new PlanFragmentId(queryId, 0), "stub-instance"); + FragmentInstanceStateMachine stateMachine = + new FragmentInstanceStateMachine(instanceId, instanceNotificationExecutor); + FragmentInstanceContext fragmentInstanceContext = + createFragmentInstanceContext(instanceId, stateMachine); + DriverContext driverContext = new DriverContext(fragmentInstanceContext, 0); + PlanNodeId planNode = new PlanNodeId("1"); + driverContext.addOperatorContext(1, planNode, TreeLinearFillOperator.class.getSimpleName()); + return driverContext; + } + + private TopKRankingOperator genTopKRankingOperator( + long[][] timeArray, + String[][] deviceArray, + int[][] valueArray, + List partitionChannels, + List partitionTypes, + List sortChannels, + List sortOrders, + int maxRowCountPerPartition, + boolean partial) { + DriverContext driverContext = createDriverContext(); + + List inputDataTypes = + Arrays.asList(TSDataType.TIMESTAMP, TSDataType.TEXT, TSDataType.INT32); + List outputChannels = new ArrayList<>(); + for (int i = 0; i < inputDataTypes.size(); i++) { + outputChannels.add(i); + } + + Operator childOperator = new ChildOperator(timeArray, deviceArray, valueArray, driverContext); + return new TopKRankingOperator( + driverContext.getOperatorContexts().get(0), + childOperator, + TopKRankingNode.RankingType.ROW_NUMBER, + inputDataTypes, + outputChannels, + partitionChannels, + partitionTypes, + sortChannels, + sortOrders, + maxRowCountPerPartition, + partial, + Optional.empty(), + 10, + Optional.empty()); + } + + static class ChildOperator implements Operator { + private int index; + private final long[][] timeArray; + private final String[][] deviceArray; + private final int[][] valueArray; + private final DriverContext driverContext; + + ChildOperator( + long[][] timeArray, + String[][] deviceArray, + int[][] valueArray, + DriverContext driverContext) { + this.timeArray = timeArray; + this.deviceArray = deviceArray; + this.valueArray = valueArray; + this.driverContext = driverContext; + this.index = 0; + } + + @Override + public OperatorContext getOperatorContext() { + return driverContext.getOperatorContexts().get(0); + } + + @Override + public TsBlock next() { + if (index >= timeArray.length) { + return null; + } + TsBlockBuilder builder = + new TsBlockBuilder( + timeArray[index].length, + Arrays.asList(TSDataType.TIMESTAMP, TSDataType.TEXT, TSDataType.INT32)); + for (int i = 0; i < timeArray[index].length; i++) { + builder.getColumnBuilder(0).writeLong(timeArray[index][i]); + builder + .getColumnBuilder(1) + .writeBinary(new Binary(deviceArray[index][i], TSFileConfig.STRING_CHARSET)); + builder.getColumnBuilder(2).writeInt(valueArray[index][i]); + } + builder.declarePositions(timeArray[index].length); + index++; + return builder.build( + new RunLengthEncodedColumn(TIME_COLUMN_TEMPLATE, builder.getPositionCount())); + } + + @Override + public boolean hasNext() { + return index < timeArray.length; + } + + @Override + public boolean isFinished() { + return index >= timeArray.length; + } + + @Override + public void close() {} + + @Override + public long calculateMaxPeekMemory() { + return 0; + } + + @Override + public long calculateMaxReturnSize() { + return 0; + } + + @Override + public long calculateRetainedSizeAfterCallingNext() { + return 0; + } + + @Override + public long ramBytesUsed() { + return 0; + } + } +}