diff --git a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBWindowFunction3IT.java b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBWindowFunction3IT.java
index d461f3a11fee6..6985f5b527e68 100644
--- a/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBWindowFunction3IT.java
+++ b/integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBWindowFunction3IT.java
@@ -51,6 +51,17 @@ public class IoTDBWindowFunction3IT {
"insert into demo values (2021-01-01T09:10:00, 'd1', 1)",
"insert into demo values (2021-01-01T09:08:00, 'd2', 2)",
"insert into demo values (2021-01-01T09:15:00, 'd2', 4)",
+ "create table multi_tag (region string tag, plant string tag, temp double field)",
+ "insert into multi_tag values (2021-01-01T08:00:00, 'east', 'A', 10)",
+ "insert into multi_tag values (2021-01-01T09:00:00, 'east', 'A', 20)",
+ "insert into multi_tag values (2021-01-01T10:00:00, 'east', 'A', 15)",
+ "insert into multi_tag values (2021-01-01T11:00:00, 'east', 'A', 25)",
+ "insert into multi_tag values (2021-01-01T08:30:00, 'east', 'B', 30)",
+ "insert into multi_tag values (2021-01-01T09:30:00, 'east', 'B', 35)",
+ "insert into multi_tag values (2021-01-01T10:30:00, 'east', 'B', 32)",
+ "insert into multi_tag values (2021-01-01T07:00:00, 'west', 'C', 50)",
+ "insert into multi_tag values (2021-01-01T08:00:00, 'west', 'C', 55)",
+ "insert into multi_tag values (2021-01-01T09:00:00, 'west', 'C', 52)",
"FLUSH",
"CLEAR ATTRIBUTE CACHE",
};
@@ -121,8 +132,8 @@ public void testPushDownFilterIntoWindow() {
String[] expectedHeader = new String[] {"time", "device", "value", "rn"};
String[] retArray =
new String[] {
- "2021-01-01T09:10:00.000Z,d1,1.0,1,",
"2021-01-01T09:05:00.000Z,d1,3.0,2,",
+ "2021-01-01T09:10:00.000Z,d1,1.0,1,",
"2021-01-01T09:08:00.000Z,d2,2.0,1,",
"2021-01-01T09:15:00.000Z,d2,4.0,2,",
};
@@ -166,6 +177,84 @@ public void testReplaceWindowWithRowNumber() {
DATABASE_NAME);
}
+ @Test
+ public void testPushDownFilterIntoWindowWithRank() {
+ // Data: d1 values {3,5,3,1}, d2 values {2,4}
+ // rank(PARTITION BY device ORDER BY value):
+ // d1: 1.0→rank=1, 3.0→rank=2, 3.0→rank=2, 5.0→rank=4
+ // d2: 2.0→rank=1, 4.0→rank=2
+ // WHERE rk <= 2: keeps d1 rows with rank≤2 (3 rows due to tie), d2 all (2 rows)
+ String[] expectedHeader = new String[] {"time", "device", "value", "rk"};
+ String[] retArray =
+ new String[] {
+ "2021-01-01T09:05:00.000Z,d1,3.0,2,",
+ "2021-01-01T09:09:00.000Z,d1,3.0,2,",
+ "2021-01-01T09:10:00.000Z,d1,1.0,1,",
+ "2021-01-01T09:08:00.000Z,d2,2.0,1,",
+ "2021-01-01T09:15:00.000Z,d2,4.0,2,",
+ };
+ tableResultSetEqualTest(
+ "SELECT * FROM (SELECT *, rank() OVER (PARTITION BY device ORDER BY value) as rk FROM demo) WHERE rk <= 2 ORDER BY device, time",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ @Test
+ public void testPushDownLimitIntoWindowWithRank() {
+ // TopKRanking(RANK, topN=2) keeps rank≤2 per partition, then LIMIT 2 on final result
+ // d1 rank≤2: 1.0(r=1), 3.0(r=2), 3.0(r=2) → 3 rows sorted by time: 09:05,09:09,09:10
+ // d2 rank≤2: 2.0(r=1), 4.0(r=2) → 2 rows
+ // ORDER BY device, time LIMIT 2 → first 2 from d1
+ String[] expectedHeader = new String[] {"time", "device", "value", "rk"};
+ String[] retArray =
+ new String[] {
+ "2021-01-01T09:05:00.000Z,d1,3.0,2,", "2021-01-01T09:07:00.000Z,d1,5.0,4,",
+ };
+ tableResultSetEqualTest(
+ "SELECT * FROM (SELECT *, rank() OVER (PARTITION BY device ORDER BY value) as rk FROM demo) ORDER BY device, time LIMIT 2",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ @Test
+ public void testRankBasic() {
+ // Verifies rank computation: ties get the same rank, gaps after ties
+ String[] expectedHeader = new String[] {"time", "device", "value", "rk"};
+ String[] retArray =
+ new String[] {
+ "2021-01-01T09:05:00.000Z,d1,3.0,2,",
+ "2021-01-01T09:07:00.000Z,d1,5.0,4,",
+ "2021-01-01T09:09:00.000Z,d1,3.0,2,",
+ "2021-01-01T09:10:00.000Z,d1,1.0,1,",
+ "2021-01-01T09:08:00.000Z,d2,2.0,1,",
+ "2021-01-01T09:15:00.000Z,d2,4.0,2,",
+ };
+ tableResultSetEqualTest(
+ "SELECT *, rank() OVER (PARTITION BY device ORDER BY value) as rk FROM demo ORDER BY device, time",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ @Test
+ public void testRankWithFilterEquals() {
+ // WHERE rk = 2 keeps only rows with rank exactly 2 (both d1 rows with value=3)
+ String[] expectedHeader = new String[] {"time", "device", "value", "rk"};
+ String[] retArray =
+ new String[] {
+ "2021-01-01T09:05:00.000Z,d1,3.0,2,",
+ "2021-01-01T09:09:00.000Z,d1,3.0,2,",
+ "2021-01-01T09:15:00.000Z,d2,4.0,2,",
+ };
+ tableResultSetEqualTest(
+ "SELECT * FROM (SELECT *, rank() OVER (PARTITION BY device ORDER BY value) as rk FROM demo) WHERE rk = 2 ORDER BY device, time",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
@Test
public void testRemoveRedundantWindow() {
String[] expectedHeader = new String[] {"time", "device", "value", "rn"};
@@ -176,4 +265,116 @@ public void testRemoveRedundantWindow() {
retArray,
DATABASE_NAME);
}
+
+ @Test
+ public void testTopKRankingOrderByTimeAsc() {
+ // PARTITION BY all tags + ORDER BY time ASC triggers limit push-down to DeviceTableScanNode
+ // and streaming RowNumberOperator optimization.
+ String[] expectedHeader = new String[] {"time", "device", "value", "rn"};
+ String[] retArray =
+ new String[] {
+ "2021-01-01T09:05:00.000Z,d1,3.0,1,",
+ "2021-01-01T09:07:00.000Z,d1,5.0,2,",
+ "2021-01-01T09:08:00.000Z,d2,2.0,1,",
+ "2021-01-01T09:15:00.000Z,d2,4.0,2,",
+ };
+ tableResultSetEqualTest(
+ "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY device ORDER BY time ASC) as rn FROM demo) WHERE rn <= 2 ORDER BY device, time",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ @Test
+ public void testTopKRankingOrderByTimeDesc() {
+ // ORDER BY time DESC: returns newest K rows per device
+ String[] expectedHeader = new String[] {"time", "device", "value", "rn"};
+ String[] retArray =
+ new String[] {
+ "2021-01-01T09:09:00.000Z,d1,3.0,2,",
+ "2021-01-01T09:10:00.000Z,d1,1.0,1,",
+ "2021-01-01T09:08:00.000Z,d2,2.0,2,",
+ "2021-01-01T09:15:00.000Z,d2,4.0,1,",
+ };
+ tableResultSetEqualTest(
+ "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY device ORDER BY time DESC) as rn FROM demo) WHERE rn <= 2 ORDER BY device, time",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ @Test
+ public void testTopKRankingOrderByTimeLimit1() {
+ // rn <= 1: get exactly the oldest row per device
+ String[] expectedHeader = new String[] {"time", "device", "value", "rn"};
+ String[] retArray =
+ new String[] {
+ "2021-01-01T09:05:00.000Z,d1,3.0,1,", "2021-01-01T09:08:00.000Z,d2,2.0,1,",
+ };
+ tableResultSetEqualTest(
+ "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY device ORDER BY time ASC) as rn FROM demo) WHERE rn <= 1 ORDER BY device",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ @Test
+ public void testTopKRankingOrderByTimeMultiTag() {
+ // Multi-tag table: PARTITION BY region, plant (all tags) ORDER BY time
+ String[] expectedHeader = new String[] {"time", "region", "plant", "temp", "rn"};
+ String[] retArray =
+ new String[] {
+ "2021-01-01T08:00:00.000Z,east,A,10.0,1,",
+ "2021-01-01T09:00:00.000Z,east,A,20.0,2,",
+ "2021-01-01T08:30:00.000Z,east,B,30.0,1,",
+ "2021-01-01T09:30:00.000Z,east,B,35.0,2,",
+ "2021-01-01T07:00:00.000Z,west,C,50.0,1,",
+ "2021-01-01T08:00:00.000Z,west,C,55.0,2,",
+ };
+ tableResultSetEqualTest(
+ "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY region, plant ORDER BY time ASC) as rn FROM multi_tag) WHERE rn <= 2 ORDER BY region, plant, time",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ @Test
+ public void testTopKRankingOrderByTimeMultiTagDesc() {
+ // Multi-tag table: ORDER BY time DESC returns newest rows per device
+ String[] expectedHeader = new String[] {"time", "region", "plant", "temp", "rn"};
+ String[] retArray =
+ new String[] {
+ "2021-01-01T10:00:00.000Z,east,A,15.0,2,",
+ "2021-01-01T11:00:00.000Z,east,A,25.0,1,",
+ "2021-01-01T09:30:00.000Z,east,B,35.0,2,",
+ "2021-01-01T10:30:00.000Z,east,B,32.0,1,",
+ "2021-01-01T08:00:00.000Z,west,C,55.0,2,",
+ "2021-01-01T09:00:00.000Z,west,C,52.0,1,",
+ };
+ tableResultSetEqualTest(
+ "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY region, plant ORDER BY time DESC) as rn FROM multi_tag) WHERE rn <= 2 ORDER BY region, plant, time",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
+
+ @Test
+ public void testTopKRankingOrderByTimeLimitExceedsRows() {
+ // rn <= 10 but d2 only has 2 rows - should return all available rows
+ String[] expectedHeader = new String[] {"time", "device", "value", "rn"};
+ String[] retArray =
+ new String[] {
+ "2021-01-01T09:05:00.000Z,d1,3.0,1,",
+ "2021-01-01T09:07:00.000Z,d1,5.0,2,",
+ "2021-01-01T09:09:00.000Z,d1,3.0,3,",
+ "2021-01-01T09:10:00.000Z,d1,1.0,4,",
+ "2021-01-01T09:08:00.000Z,d2,2.0,1,",
+ "2021-01-01T09:15:00.000Z,d2,4.0,2,",
+ };
+ tableResultSetEqualTest(
+ "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY device ORDER BY time ASC) as rn FROM demo) WHERE rn <= 10 ORDER BY device, time",
+ expectedHeader,
+ retArray,
+ DATABASE_NAME);
+ }
}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/GroupedTopNRankAccumulator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/GroupedTopNRankAccumulator.java
new file mode 100644
index 0000000000000..2007b2bb1bdf0
--- /dev/null
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/GroupedTopNRankAccumulator.java
@@ -0,0 +1,754 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.execution.operator;
+
+import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray;
+import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArrayFIFOQueue;
+import org.apache.iotdb.db.utils.HeapTraversal;
+
+import org.apache.tsfile.read.common.block.TsBlock;
+import org.apache.tsfile.utils.RamUsageEstimator;
+
+import java.util.function.LongConsumer;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static com.google.common.base.Preconditions.checkState;
+import static com.google.common.base.Verify.verify;
+import static java.util.Objects.requireNonNull;
+
+public class GroupedTopNRankAccumulator {
+ private static final long INSTANCE_SIZE =
+ RamUsageEstimator.shallowSizeOfInstance(GroupedTopNRankAccumulator.class);
+ private static final long UNKNOWN_INDEX = -1;
+ private static final long NULL_GROUP_ID = -1;
+
+ private final GroupIdToHeapBuffer groupIdToHeapBuffer = new GroupIdToHeapBuffer();
+ private final HeapNodeBuffer heapNodeBuffer = new HeapNodeBuffer();
+ private final PeerGroupBuffer peerGroupBuffer = new PeerGroupBuffer();
+ private final HeapTraversal heapTraversal = new HeapTraversal();
+
+ // Map from (Group ID, Row Value) to Heap Node Index where the value is stored
+ private final TopNPeerGroupLookup peerGroupLookup;
+
+ private final RowIdComparisonHashStrategy strategy;
+ private final int topN;
+ private final LongConsumer rowIdEvictionListener;
+
+ public GroupedTopNRankAccumulator(
+ RowIdComparisonHashStrategy strategy, int topN, LongConsumer rowIdEvictionListener) {
+ this.strategy = requireNonNull(strategy, "strategy is null");
+ this.peerGroupLookup = new TopNPeerGroupLookup(10_000, strategy, NULL_GROUP_ID, UNKNOWN_INDEX);
+ checkArgument(topN > 0, "topN must be greater than zero");
+ this.topN = topN;
+ this.rowIdEvictionListener =
+ requireNonNull(rowIdEvictionListener, "rowIdEvictionListener is null");
+ }
+
+ public long sizeOf() {
+ return INSTANCE_SIZE
+ + groupIdToHeapBuffer.sizeOf()
+ + heapNodeBuffer.sizeOf()
+ + peerGroupBuffer.sizeOf()
+ + heapTraversal.sizeOf()
+ + peerGroupLookup.sizeOf();
+ }
+
+ public int findFirstPositionToAdd(
+ TsBlock newTsBlock,
+ int groupCount,
+ int[] groupIds,
+ TsBlockWithPositionComparator comparator,
+ RowReferenceTsBlockManager tsBlockManager) {
+ int currentGroups = groupIdToHeapBuffer.getTotalGroups();
+ groupIdToHeapBuffer.allocateGroupIfNeeded(groupCount);
+
+ for (int position = 0; position < newTsBlock.getPositionCount(); position++) {
+ int groupId = groupIds[position];
+ if (groupId >= currentGroups || groupIdToHeapBuffer.getHeapValueCount(groupId) < topN) {
+ return position;
+ }
+ long heapRootNodeIndex = groupIdToHeapBuffer.getHeapRootNodeIndex(groupId);
+ if (heapRootNodeIndex == UNKNOWN_INDEX) {
+ return position;
+ }
+ long rightTsBlockRowId = peekRootRowIdByHeapNodeIndex(heapRootNodeIndex);
+ TsBlock rightTsBlock = tsBlockManager.getTsBlock(rightTsBlockRowId);
+ int rightPosition = tsBlockManager.getPosition(rightTsBlockRowId);
+ // If the current position is equal to or less than the current heap root index, then we may
+ // need to insert it
+ if (comparator.compareTo(newTsBlock, position, rightTsBlock, rightPosition) <= 0) {
+ return position;
+ }
+ }
+ return -1;
+ }
+
+ /**
+ * Add the specified row to this accumulator.
+ *
+ *
This may trigger row eviction callbacks if other rows have to be evicted to make space.
+ *
+ * @return true if this row was incorporated, false otherwise
+ */
+ public boolean add(int groupId, RowReference rowReference) {
+ // Insert to any existing peer groups first (heap nodes contain distinct values)
+ long peerHeapNodeIndex = peerGroupLookup.get(groupId, rowReference);
+ if (peerHeapNodeIndex != UNKNOWN_INDEX) {
+ directPeerGroupInsert(groupId, peerHeapNodeIndex, rowReference.allocateRowId());
+ if (calculateRootRank(groupId, groupIdToHeapBuffer.getHeapRootNodeIndex(groupId)) > topN) {
+ heapPop(groupId, rowIdEvictionListener);
+ }
+ // Return true because heapPop is guaranteed not to evict the newly inserted row (by
+ // definition of rank)
+ return true;
+ }
+
+ groupIdToHeapBuffer.allocateGroupIfNeeded(groupId);
+ if (groupIdToHeapBuffer.getHeapValueCount(groupId) < topN) {
+ // Always safe to insert if total number of values is still less than topN
+ long newPeerGroupIndex =
+ peerGroupBuffer.allocateNewNode(rowReference.allocateRowId(), UNKNOWN_INDEX);
+ heapInsert(groupId, newPeerGroupIndex, 1);
+ return true;
+ }
+ long heapRootNodeIndex = groupIdToHeapBuffer.getHeapRootNodeIndex(groupId);
+ if (rowReference.compareTo(strategy, peekRootRowIdByHeapNodeIndex(heapRootNodeIndex)) < 0) {
+ // Given that total number of values >= topN, we can only consider values that are less than
+ // the root (otherwise topN would be violated)
+ long newPeerGroupIndex =
+ peerGroupBuffer.allocateNewNode(rowReference.allocateRowId(), UNKNOWN_INDEX);
+ // Rank will increase by +1 after insertion, so only need to pop if root rank is already ==
+ // topN.
+ if (calculateRootRank(groupId, heapRootNodeIndex) < topN) {
+ heapInsert(groupId, newPeerGroupIndex, 1);
+ } else {
+ heapPopAndInsert(groupId, newPeerGroupIndex, 1, rowIdEvictionListener);
+ }
+ return true;
+ }
+ // Row cannot be accepted because the total number of values >= topN, and the row is greater
+ // than the root (meaning it's rank would be at least topN+1).
+ return false;
+ }
+
+ /**
+ * Drain the contents of this accumulator to the provided output row ID and ranking buffer.
+ *
+ *
Rows will be presented in increasing rank order. Draining will not trigger any row eviction
+ * callbacks. After this method completion, the Accumulator will contain zero rows for the
+ * specified groupId.
+ *
+ * @return number of rows deposited to the output buffers
+ */
+ public long drainTo(int groupId, LongBigArray rowIdOutput, LongBigArray rankingOutput) {
+ long valueCount = groupIdToHeapBuffer.getHeapValueCount(groupId);
+ rowIdOutput.ensureCapacity(valueCount);
+ rankingOutput.ensureCapacity(valueCount);
+
+ // Heap is inverted to output order, so insert back to front
+ long insertionIndex = valueCount - 1;
+ while (insertionIndex >= 0) {
+ long heapRootNodeIndex = groupIdToHeapBuffer.getHeapRootNodeIndex(groupId);
+ verify(heapRootNodeIndex != UNKNOWN_INDEX);
+
+ long peerGroupIndex = heapNodeBuffer.getPeerGroupIndex(heapRootNodeIndex);
+ verify(peerGroupIndex != UNKNOWN_INDEX, "Peer group should have at least one value");
+
+ long rank = calculateRootRank(groupId, heapRootNodeIndex);
+ do {
+ rowIdOutput.set(insertionIndex, peerGroupBuffer.getRowId(peerGroupIndex));
+ rankingOutput.set(insertionIndex, rank);
+ insertionIndex--;
+ peerGroupIndex = peerGroupBuffer.getNextPeerIndex(peerGroupIndex);
+ } while (peerGroupIndex != UNKNOWN_INDEX);
+
+ heapPop(groupId, null);
+ }
+ return valueCount;
+ }
+
+ /**
+ * Drain the contents of this accumulator to the provided output row ID.
+ *
+ *
Rows will be presented in increasing rank order. Draining will not trigger any row eviction
+ * callbacks. After this method completion, the Accumulator will contain zero rows for the
+ * specified groupId.
+ *
+ * @return number of rows deposited to the output buffer
+ */
+ public long drainTo(int groupId, LongBigArray rowIdOutput) {
+ long valueCount = groupIdToHeapBuffer.getHeapValueCount(groupId);
+ rowIdOutput.ensureCapacity(valueCount);
+
+ // Heap is inverted to output order, so insert back to front
+ long insertionIndex = valueCount - 1;
+ while (insertionIndex >= 0) {
+ long heapRootNodeIndex = groupIdToHeapBuffer.getHeapRootNodeIndex(groupId);
+ verify(heapRootNodeIndex != UNKNOWN_INDEX);
+
+ long peerGroupIndex = heapNodeBuffer.getPeerGroupIndex(heapRootNodeIndex);
+ verify(peerGroupIndex != UNKNOWN_INDEX, "Peer group should have at least one value");
+
+ do {
+ rowIdOutput.set(insertionIndex, peerGroupBuffer.getRowId(peerGroupIndex));
+ insertionIndex--;
+ peerGroupIndex = peerGroupBuffer.getNextPeerIndex(peerGroupIndex);
+ } while (peerGroupIndex != UNKNOWN_INDEX);
+
+ heapPop(groupId, null);
+ }
+ return valueCount;
+ }
+
+ private long calculateRootRank(int groupId, long heapRootIndex) {
+ long heapValueCount = groupIdToHeapBuffer.getHeapValueCount(groupId);
+ checkArgument(heapRootIndex != UNKNOWN_INDEX, "Group does not have a root");
+ long rootPeerGroupCount = heapNodeBuffer.getPeerGroupCount(heapRootIndex);
+ return heapValueCount - rootPeerGroupCount + 1;
+ }
+
+ private void directPeerGroupInsert(int groupId, long heapNodeIndex, long rowId) {
+ long existingPeerGroupIndex = heapNodeBuffer.getPeerGroupIndex(heapNodeIndex);
+ long newPeerGroupIndex = peerGroupBuffer.allocateNewNode(rowId, existingPeerGroupIndex);
+ heapNodeBuffer.setPeerGroupIndex(heapNodeIndex, newPeerGroupIndex);
+ heapNodeBuffer.incrementPeerGroupCount(heapNodeIndex);
+ groupIdToHeapBuffer.incrementHeapValueCount(groupId);
+ }
+
+ private long peekRootRowIdByHeapNodeIndex(long heapRootNodeIndex) {
+ checkArgument(heapRootNodeIndex != UNKNOWN_INDEX, "Group has nothing to peek");
+ return peerGroupBuffer.getRowId(heapNodeBuffer.getPeerGroupIndex(heapRootNodeIndex));
+ }
+
+ private long getChildIndex(long heapNodeIndex, HeapTraversal.Child child) {
+ return child == HeapTraversal.Child.LEFT
+ ? heapNodeBuffer.getLeftChildHeapIndex(heapNodeIndex)
+ : heapNodeBuffer.getRightChildHeapIndex(heapNodeIndex);
+ }
+
+ private void setChildIndex(long heapNodeIndex, HeapTraversal.Child child, long newChildIndex) {
+ if (child == HeapTraversal.Child.LEFT) {
+ heapNodeBuffer.setLeftChildHeapIndex(heapNodeIndex, newChildIndex);
+ } else {
+ heapNodeBuffer.setRightChildHeapIndex(heapNodeIndex, newChildIndex);
+ }
+ }
+
+ /**
+ * Pop the root node off the group ID's max heap.
+ *
+ * @param contextEvictionListener optional callback for the root node that gets popped off
+ */
+ private void heapPop(int groupId, LongConsumer contextEvictionListener) {
+ long heapRootNodeIndex = groupIdToHeapBuffer.getHeapRootNodeIndex(groupId);
+ checkArgument(heapRootNodeIndex != UNKNOWN_INDEX, "Group ID has an empty heap");
+
+ long lastHeapNodeIndex = heapDetachLastInsertionLeaf(groupId);
+ long lastPeerGroupIndex = heapNodeBuffer.getPeerGroupIndex(lastHeapNodeIndex);
+ long lastPeerGroupCount = heapNodeBuffer.getPeerGroupCount(lastHeapNodeIndex);
+
+ if (lastHeapNodeIndex == heapRootNodeIndex) {
+ // The root is the last node remaining
+ dropHeapNodePeerGroup(groupId, lastHeapNodeIndex, contextEvictionListener);
+ } else {
+ // Pop the root and insert the last peer group back into the heap to ensure a balanced tree
+ heapPopAndInsert(groupId, lastPeerGroupIndex, lastPeerGroupCount, contextEvictionListener);
+ }
+
+ // peerGroupLookup entry will be updated by definition of inserting the last peer group into a
+ // new node
+ heapNodeBuffer.deallocate(lastHeapNodeIndex);
+ }
+
+ /**
+ * Detaches (but does not deallocate) the leaf in the bottom right-most position in the heap.
+ *
+ *
Given the fixed insertion order, the bottom right-most leaf will correspond to the last leaf
+ * node inserted into the balanced heap.
+ *
+ * @return leaf node index that was detached from the heap
+ */
+ private long heapDetachLastInsertionLeaf(int groupId) {
+ long heapRootNodeIndex = groupIdToHeapBuffer.getHeapRootNodeIndex(groupId);
+ long heapSize = groupIdToHeapBuffer.getHeapSize(groupId);
+
+ long previousNodeIndex = UNKNOWN_INDEX;
+ HeapTraversal.Child childPosition = null;
+ long currentNodeIndex = heapRootNodeIndex;
+
+ heapTraversal.resetWithPathTo(heapSize);
+ while (!heapTraversal.isTarget()) {
+ previousNodeIndex = currentNodeIndex;
+ childPosition = heapTraversal.nextChild();
+ currentNodeIndex = getChildIndex(currentNodeIndex, childPosition);
+ verify(currentNodeIndex != UNKNOWN_INDEX, "Target node must exist");
+ }
+
+ // Detach the last insertion leaf node, but do not deallocate yet
+ if (previousNodeIndex == UNKNOWN_INDEX) {
+ // Last insertion leaf was the root node
+ groupIdToHeapBuffer.setHeapRootNodeIndex(groupId, UNKNOWN_INDEX);
+ groupIdToHeapBuffer.setHeapValueCount(groupId, 0);
+ groupIdToHeapBuffer.setHeapSize(groupId, 0);
+ } else {
+ setChildIndex(previousNodeIndex, childPosition, UNKNOWN_INDEX);
+ groupIdToHeapBuffer.addHeapValueCount(
+ groupId, -heapNodeBuffer.getPeerGroupCount(currentNodeIndex));
+ groupIdToHeapBuffer.addHeapSize(groupId, -1);
+ }
+
+ return currentNodeIndex;
+ }
+
+ /**
+ * Inserts a new row into the heap for the specified group ID.
+ *
+ *
The technique involves traversing the heap from the root to a new bottom left-priority leaf
+ * position, potentially swapping heap nodes along the way to find the proper insertion position
+ * for the new row. Insertions always fill the left child before the right, and fill up an entire
+ * heap level before moving to the next level.
+ */
+ private void heapInsert(int groupId, long newPeerGroupIndex, long newPeerGroupCount) {
+ long newCanonicalRowId = peerGroupBuffer.getRowId(newPeerGroupIndex);
+
+ long heapRootNodeIndex = groupIdToHeapBuffer.getHeapRootNodeIndex(groupId);
+ if (heapRootNodeIndex == UNKNOWN_INDEX) {
+ // Heap is currently empty, so this will be the first node
+ heapRootNodeIndex = heapNodeBuffer.allocateNewNode(newPeerGroupIndex, newPeerGroupCount);
+ verify(peerGroupLookup.put(groupId, newCanonicalRowId, heapRootNodeIndex) == UNKNOWN_INDEX);
+ groupIdToHeapBuffer.setHeapRootNodeIndex(groupId, heapRootNodeIndex);
+ groupIdToHeapBuffer.setHeapValueCount(groupId, newPeerGroupCount);
+ groupIdToHeapBuffer.setHeapSize(groupId, 1);
+ return;
+ }
+
+ long previousHeapNodeIndex = UNKNOWN_INDEX;
+ HeapTraversal.Child childPosition = null;
+ long currentHeapNodeIndex = heapRootNodeIndex;
+ boolean swapped = false;
+
+ groupIdToHeapBuffer.addHeapValueCount(groupId, newPeerGroupCount);
+ groupIdToHeapBuffer.incrementHeapSize(groupId);
+ heapTraversal.resetWithPathTo(groupIdToHeapBuffer.getHeapSize(groupId));
+ while (!heapTraversal.isTarget()) {
+ long peerGroupIndex = heapNodeBuffer.getPeerGroupIndex(currentHeapNodeIndex);
+ long currentCanonicalRowId = peerGroupBuffer.getRowId(peerGroupIndex);
+ // We can short-circuit the check if a parent has already been swapped because the new row to
+ // insert must
+ // be greater than all of it's children.
+ if (swapped || strategy.compare(newCanonicalRowId, currentCanonicalRowId) > 0) {
+ long peerGroupCount = heapNodeBuffer.getPeerGroupCount(currentHeapNodeIndex);
+
+ // Swap the peer groups
+ heapNodeBuffer.setPeerGroupIndex(currentHeapNodeIndex, newPeerGroupIndex);
+ heapNodeBuffer.setPeerGroupCount(currentHeapNodeIndex, newPeerGroupCount);
+ peerGroupLookup.put(groupId, newCanonicalRowId, currentHeapNodeIndex);
+
+ newPeerGroupIndex = peerGroupIndex;
+ newPeerGroupCount = peerGroupCount;
+ newCanonicalRowId = currentCanonicalRowId;
+ swapped = true;
+ }
+
+ previousHeapNodeIndex = currentHeapNodeIndex;
+ childPosition = heapTraversal.nextChild();
+ currentHeapNodeIndex = getChildIndex(currentHeapNodeIndex, childPosition);
+ }
+
+ verify(
+ previousHeapNodeIndex != UNKNOWN_INDEX && childPosition != null,
+ "heap must have at least one node before starting traversal");
+ verify(currentHeapNodeIndex == UNKNOWN_INDEX, "New child shouldn't exist yet");
+
+ long newHeapNodeIndex = heapNodeBuffer.allocateNewNode(newPeerGroupIndex, newPeerGroupCount);
+ peerGroupLookup.put(groupId, newCanonicalRowId, newHeapNodeIndex);
+
+ // Link the new child to the parent
+ setChildIndex(previousHeapNodeIndex, childPosition, newHeapNodeIndex);
+ }
+
+ /**
+ * Pop the root off the group ID's max heap and insert the new peer group.
+ *
+ *
These two operations are more efficient if performed together. The technique involves
+ * swapping the new row into the root position, and applying a heap down bubbling operation to
+ * heap-ify.
+ *
+ * @param contextEvictionListener optional callback for the root node that gets popped off
+ */
+ private void heapPopAndInsert(
+ int groupId,
+ long newPeerGroupIndex,
+ long newPeerGroupCount,
+ LongConsumer contextEvictionListener) {
+ long heapRootNodeIndex = groupIdToHeapBuffer.getHeapRootNodeIndex(groupId);
+ checkState(heapRootNodeIndex != UNKNOWN_INDEX, "popAndInsert() requires at least a root node");
+
+ // Clear contents of the root node to create a vacancy for the new peer group
+ groupIdToHeapBuffer.addHeapValueCount(
+ groupId, newPeerGroupCount - heapNodeBuffer.getPeerGroupCount(heapRootNodeIndex));
+ dropHeapNodePeerGroup(groupId, heapRootNodeIndex, contextEvictionListener);
+
+ long newCanonicalRowId = peerGroupBuffer.getRowId(newPeerGroupIndex);
+
+ long currentNodeIndex = heapRootNodeIndex;
+ while (true) {
+ long maxChildNodeIndex = heapNodeBuffer.getLeftChildHeapIndex(currentNodeIndex);
+ if (maxChildNodeIndex == UNKNOWN_INDEX) {
+ // Left is always inserted before right, so a missing left child means there can't be a
+ // right child,
+ // which means this must already be a leaf position.
+ break;
+ }
+ long maxChildPeerGroupIndex = heapNodeBuffer.getPeerGroupIndex(maxChildNodeIndex);
+ long maxChildCanonicalRowId = peerGroupBuffer.getRowId(maxChildPeerGroupIndex);
+
+ long rightChildNodeIndex = heapNodeBuffer.getRightChildHeapIndex(currentNodeIndex);
+ if (rightChildNodeIndex != UNKNOWN_INDEX) {
+ long rightChildPeerGroupIndex = heapNodeBuffer.getPeerGroupIndex(rightChildNodeIndex);
+ long rightChildCanonicalRowId = peerGroupBuffer.getRowId(rightChildPeerGroupIndex);
+ if (strategy.compare(rightChildCanonicalRowId, maxChildCanonicalRowId) > 0) {
+ maxChildNodeIndex = rightChildNodeIndex;
+ maxChildPeerGroupIndex = rightChildPeerGroupIndex;
+ maxChildCanonicalRowId = rightChildCanonicalRowId;
+ }
+ }
+
+ if (strategy.compare(newCanonicalRowId, maxChildCanonicalRowId) >= 0) {
+ // New row is greater than or equal to both children, so the heap invariant is satisfied by
+ // inserting the
+ // new row at this position
+ break;
+ }
+
+ // Swap the max child row value into the current node
+ heapNodeBuffer.setPeerGroupIndex(currentNodeIndex, maxChildPeerGroupIndex);
+ heapNodeBuffer.setPeerGroupCount(
+ currentNodeIndex, heapNodeBuffer.getPeerGroupCount(maxChildNodeIndex));
+ peerGroupLookup.put(groupId, maxChildCanonicalRowId, currentNodeIndex);
+
+ // Max child now has an unfilled vacancy, so continue processing with that as the current node
+ currentNodeIndex = maxChildNodeIndex;
+ }
+
+ heapNodeBuffer.setPeerGroupIndex(currentNodeIndex, newPeerGroupIndex);
+ heapNodeBuffer.setPeerGroupCount(currentNodeIndex, newPeerGroupCount);
+ peerGroupLookup.put(groupId, newCanonicalRowId, currentNodeIndex);
+ }
+
+ /**
+ * Deallocates all peer group associations for this heap node, leaving a structural husk with no
+ * contents. Assumes that any required group level metric changes are handled externally.
+ */
+ private void dropHeapNodePeerGroup(
+ int groupId, long heapNodeIndex, LongConsumer contextEvictionListener) {
+ long peerGroupIndex = heapNodeBuffer.getPeerGroupIndex(heapNodeIndex);
+ checkState(peerGroupIndex != UNKNOWN_INDEX, "Heap node must have at least one peer group");
+
+ long rowId = peerGroupBuffer.getRowId(peerGroupIndex);
+ long nextPeerIndex = peerGroupBuffer.getNextPeerIndex(peerGroupIndex);
+ peerGroupBuffer.deallocate(peerGroupIndex);
+ verify(peerGroupLookup.remove(groupId, rowId) == heapNodeIndex);
+
+ if (contextEvictionListener != null) {
+ contextEvictionListener.accept(rowId);
+ }
+
+ peerGroupIndex = nextPeerIndex;
+
+ while (peerGroupIndex != UNKNOWN_INDEX) {
+ rowId = peerGroupBuffer.getRowId(peerGroupIndex);
+ nextPeerIndex = peerGroupBuffer.getNextPeerIndex(peerGroupIndex);
+ peerGroupBuffer.deallocate(peerGroupIndex);
+
+ if (contextEvictionListener != null) {
+ contextEvictionListener.accept(rowId);
+ }
+
+ peerGroupIndex = nextPeerIndex;
+ }
+ }
+
+ /**
+ * Buffer abstracting a mapping from group ID to a heap. The group ID provides the index for all
+ * operations.
+ */
+ private static final class GroupIdToHeapBuffer {
+ private static final long INSTANCE_SIZE =
+ RamUsageEstimator.shallowSizeOfInstance(GroupIdToHeapBuffer.class);
+ private static final int METRICS_POSITIONS_PER_ENTRY = 2;
+ private static final int METRICS_HEAP_SIZE_OFFSET = 1;
+
+ /*
+ * Memory layout:
+ * [LONG] heapNodeIndex1,
+ * [LONG] heapNodeIndex2,
+ * ...
+ */
+ // Since we have a single element per group, this array is effectively indexed on group ID
+ private final LongBigArray heapIndexBuffer = new LongBigArray(UNKNOWN_INDEX);
+
+ /*
+ * Memory layout:
+ * [LONG] valueCount1, [LONG] heapSize1,
+ * [LONG] valueCount2, [LONG] heapSize2,
+ * ...
+ */
+ private final LongBigArray metricsBuffer = new LongBigArray(0);
+
+ private int totalGroups;
+
+ public void allocateGroupIfNeeded(int groupId) {
+ if (totalGroups > groupId) {
+ return;
+ }
+ // Group IDs generated by GroupByHash are always generated consecutively starting from 0, so
+ // observing a
+ // group ID N means groups [0, N] inclusive must exist.
+ totalGroups = groupId + 1;
+ heapIndexBuffer.ensureCapacity(totalGroups);
+ metricsBuffer.ensureCapacity((long) totalGroups * METRICS_POSITIONS_PER_ENTRY);
+ }
+
+ public int getTotalGroups() {
+ return totalGroups;
+ }
+
+ public long getHeapRootNodeIndex(int groupId) {
+ return heapIndexBuffer.get(groupId);
+ }
+
+ public void setHeapRootNodeIndex(int groupId, long heapNodeIndex) {
+ heapIndexBuffer.set(groupId, heapNodeIndex);
+ }
+
+ public long getHeapValueCount(int groupId) {
+ return metricsBuffer.get((long) groupId * METRICS_POSITIONS_PER_ENTRY);
+ }
+
+ public void setHeapValueCount(int groupId, long count) {
+ metricsBuffer.set((long) groupId * METRICS_POSITIONS_PER_ENTRY, count);
+ }
+
+ public void addHeapValueCount(int groupId, long delta) {
+ metricsBuffer.add((long) groupId * METRICS_POSITIONS_PER_ENTRY, delta);
+ }
+
+ public void incrementHeapValueCount(int groupId) {
+ metricsBuffer.increment((long) groupId * METRICS_POSITIONS_PER_ENTRY);
+ }
+
+ public long getHeapSize(int groupId) {
+ return metricsBuffer.get(
+ (long) groupId * METRICS_POSITIONS_PER_ENTRY + METRICS_HEAP_SIZE_OFFSET);
+ }
+
+ public void setHeapSize(int groupId, long size) {
+ metricsBuffer.set(
+ (long) groupId * METRICS_POSITIONS_PER_ENTRY + METRICS_HEAP_SIZE_OFFSET, size);
+ }
+
+ public void addHeapSize(int groupId, long delta) {
+ metricsBuffer.add(
+ (long) groupId * METRICS_POSITIONS_PER_ENTRY + METRICS_HEAP_SIZE_OFFSET, delta);
+ }
+
+ public void incrementHeapSize(int groupId) {
+ metricsBuffer.increment(
+ (long) groupId * METRICS_POSITIONS_PER_ENTRY + METRICS_HEAP_SIZE_OFFSET);
+ }
+
+ public long sizeOf() {
+ return INSTANCE_SIZE + heapIndexBuffer.sizeOf() + metricsBuffer.sizeOf();
+ }
+ }
+
+ /**
+ * Buffer abstracting storage of nodes in the heap. Nodes are referenced by their node index for
+ * operations.
+ */
+ private static final class HeapNodeBuffer {
+ private static final long INSTANCE_SIZE =
+ RamUsageEstimator.shallowSizeOfInstance(HeapNodeBuffer.class);
+ private static final int POSITIONS_PER_ENTRY = 4;
+ private static final int PEER_GROUP_COUNT_OFFSET = 1;
+ private static final int LEFT_CHILD_HEAP_INDEX_OFFSET = 2;
+ private static final int RIGHT_CHILD_HEAP_INDEX_OFFSET = 3;
+
+ /*
+ * Memory layout:
+ * [LONG] peerGroupIndex1, [LONG] peerGroupCount1, [LONG] leftChildNodeIndex1, [LONG] rightChildNodeIndex1,
+ * [LONG] peerGroupIndex2, [LONG] peerGroupCount2, [LONG] leftChildNodeIndex2, [LONG] rightChildNodeIndex2,
+ * ...
+ */
+ private final LongBigArray buffer = new LongBigArray();
+
+ private final LongBigArrayFIFOQueue emptySlots = new LongBigArrayFIFOQueue();
+
+ private long capacity;
+
+ /**
+ * Allocates storage for a new heap node.
+ *
+ * @return index referencing the node
+ */
+ public long allocateNewNode(long peerGroupIndex, long peerGroupCount) {
+ long newHeapIndex;
+ if (!emptySlots.isEmpty()) {
+ newHeapIndex = emptySlots.dequeueLong();
+ } else {
+ newHeapIndex = capacity;
+ capacity++;
+ buffer.ensureCapacity(capacity * POSITIONS_PER_ENTRY);
+ }
+
+ setPeerGroupIndex(newHeapIndex, peerGroupIndex);
+ setPeerGroupCount(newHeapIndex, peerGroupCount);
+ setLeftChildHeapIndex(newHeapIndex, UNKNOWN_INDEX);
+ setRightChildHeapIndex(newHeapIndex, UNKNOWN_INDEX);
+
+ return newHeapIndex;
+ }
+
+ public void deallocate(long index) {
+ emptySlots.enqueue(index);
+ }
+
+ public long getActiveNodeCount() {
+ return capacity - emptySlots.longSize();
+ }
+
+ public long getPeerGroupIndex(long index) {
+ return buffer.get(index * POSITIONS_PER_ENTRY);
+ }
+
+ public void setPeerGroupIndex(long index, long peerGroupIndex) {
+ buffer.set(index * POSITIONS_PER_ENTRY, peerGroupIndex);
+ }
+
+ public long getPeerGroupCount(long index) {
+ return buffer.get(index * POSITIONS_PER_ENTRY + PEER_GROUP_COUNT_OFFSET);
+ }
+
+ public void setPeerGroupCount(long index, long peerGroupCount) {
+ buffer.set(index * POSITIONS_PER_ENTRY + PEER_GROUP_COUNT_OFFSET, peerGroupCount);
+ }
+
+ public void incrementPeerGroupCount(long index) {
+ buffer.increment(index * POSITIONS_PER_ENTRY + PEER_GROUP_COUNT_OFFSET);
+ }
+
+ public void addPeerGroupCount(long index, long delta) {
+ buffer.add(index * POSITIONS_PER_ENTRY + PEER_GROUP_COUNT_OFFSET, delta);
+ }
+
+ public long getLeftChildHeapIndex(long index) {
+ return buffer.get(index * POSITIONS_PER_ENTRY + LEFT_CHILD_HEAP_INDEX_OFFSET);
+ }
+
+ public void setLeftChildHeapIndex(long index, long childHeapIndex) {
+ buffer.set(index * POSITIONS_PER_ENTRY + LEFT_CHILD_HEAP_INDEX_OFFSET, childHeapIndex);
+ }
+
+ public long getRightChildHeapIndex(long index) {
+ return buffer.get(index * POSITIONS_PER_ENTRY + RIGHT_CHILD_HEAP_INDEX_OFFSET);
+ }
+
+ public void setRightChildHeapIndex(long index, long childHeapIndex) {
+ buffer.set(index * POSITIONS_PER_ENTRY + RIGHT_CHILD_HEAP_INDEX_OFFSET, childHeapIndex);
+ }
+
+ public long sizeOf() {
+ return INSTANCE_SIZE + buffer.sizeOf() + emptySlots.sizeOf();
+ }
+ }
+
+ /**
+ * Buffer abstracting storage of peer groups as linked chains of matching values. Peer groups are
+ * referenced by their node index for operations.
+ */
+ private static final class PeerGroupBuffer {
+ private static final long INSTANCE_SIZE =
+ RamUsageEstimator.shallowSizeOfInstance(PeerGroupBuffer.class);
+ private static final int POSITIONS_PER_ENTRY = 2;
+ private static final int NEXT_PEER_INDEX_OFFSET = 1;
+
+ /*
+ * Memory layout:
+ * [LONG] rowId1, [LONG] nextPeerIndex1,
+ * [LONG] rowId2, [LONG] nextPeerIndex2,
+ * ...
+ */
+ private final LongBigArray buffer = new LongBigArray();
+
+ private final LongBigArrayFIFOQueue emptySlots = new LongBigArrayFIFOQueue();
+
+ private long capacity;
+
+ /**
+ * Allocates storage for a new peer group node.
+ *
+ * @return index referencing the node
+ */
+ public long allocateNewNode(long rowId, long nextPeerIndex) {
+ long newPeerIndex;
+ if (!emptySlots.isEmpty()) {
+ newPeerIndex = emptySlots.dequeueLong();
+ } else {
+ newPeerIndex = capacity;
+ capacity++;
+ buffer.ensureCapacity(capacity * POSITIONS_PER_ENTRY);
+ }
+
+ setRowId(newPeerIndex, rowId);
+ setNextPeerIndex(newPeerIndex, nextPeerIndex);
+
+ return newPeerIndex;
+ }
+
+ public void deallocate(long index) {
+ emptySlots.enqueue(index);
+ }
+
+ public long getActiveNodeCount() {
+ return capacity - emptySlots.longSize();
+ }
+
+ public long getRowId(long index) {
+ return buffer.get(index * POSITIONS_PER_ENTRY);
+ }
+
+ public void setRowId(long index, long rowId) {
+ buffer.set(index * POSITIONS_PER_ENTRY, rowId);
+ }
+
+ public long getNextPeerIndex(long index) {
+ return buffer.get(index * POSITIONS_PER_ENTRY + NEXT_PEER_INDEX_OFFSET);
+ }
+
+ public void setNextPeerIndex(long index, long nextPeerIndex) {
+ buffer.set(index * POSITIONS_PER_ENTRY + NEXT_PEER_INDEX_OFFSET, nextPeerIndex);
+ }
+
+ public long sizeOf() {
+ return INSTANCE_SIZE + buffer.sizeOf() + emptySlots.sizeOf();
+ }
+ }
+}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/GroupedTopNRankBuilder.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/GroupedTopNRankBuilder.java
new file mode 100644
index 0000000000000..dcfc8dbcee14b
--- /dev/null
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/GroupedTopNRankBuilder.java
@@ -0,0 +1,219 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.execution.operator;
+
+import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray;
+import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.hash.GroupByHash;
+
+import com.google.common.collect.AbstractIterator;
+import com.google.common.collect.ImmutableList;
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.block.column.ColumnBuilder;
+import org.apache.tsfile.enums.TSDataType;
+import org.apache.tsfile.read.common.block.TsBlock;
+import org.apache.tsfile.read.common.block.TsBlockBuilder;
+import org.apache.tsfile.read.common.block.column.RunLengthEncodedColumn;
+import org.apache.tsfile.utils.RamUsageEstimator;
+
+import java.util.Iterator;
+import java.util.List;
+
+import static org.apache.iotdb.db.queryengine.execution.operator.source.relational.AbstractTableScanOperator.TIME_COLUMN_TEMPLATE;
+
+/**
+ * Finds the top N rows by rank value for each group. Unlike row_number which assigns unique
+ * sequential numbers, rank assigns the same number to rows with equal sort key values (peers).
+ */
+public class GroupedTopNRankBuilder implements GroupedTopNBuilder {
+ private static final long INSTANCE_SIZE =
+ RamUsageEstimator.shallowSizeOfInstance(GroupedTopNRankBuilder.class);
+
+ private final List sourceTypes;
+ private final boolean produceRanking;
+ private final int[] groupByChannels;
+ private final GroupByHash groupByHash;
+ private final TsBlockWithPositionComparator comparator;
+ private final RowReferenceTsBlockManager tsBlockManager = new RowReferenceTsBlockManager();
+ private final GroupedTopNRankAccumulator groupedTopNRankAccumulator;
+
+ private int effectiveGroupCount = 0;
+
+ public GroupedTopNRankBuilder(
+ List sourceTypes,
+ TsBlockWithPositionComparator comparator,
+ TsBlockWithPositionEqualsAndHash equalsAndHash,
+ int topN,
+ boolean produceRanking,
+ int[] groupByChannels,
+ GroupByHash groupByHash) {
+ this.sourceTypes = sourceTypes;
+ this.produceRanking = produceRanking;
+ this.groupByChannels = groupByChannels;
+ this.groupByHash = groupByHash;
+ this.comparator = comparator;
+
+ this.groupedTopNRankAccumulator =
+ new GroupedTopNRankAccumulator(
+ new RowIdComparisonHashStrategy() {
+ @Override
+ public int compare(long leftRowId, long rightRowId) {
+ TsBlock leftTsBlock = tsBlockManager.getTsBlock(leftRowId);
+ int leftPosition = tsBlockManager.getPosition(leftRowId);
+ TsBlock rightTsBlock = tsBlockManager.getTsBlock(rightRowId);
+ int rightPosition = tsBlockManager.getPosition(rightRowId);
+ return comparator.compareTo(leftTsBlock, leftPosition, rightTsBlock, rightPosition);
+ }
+
+ @Override
+ public boolean equals(long leftRowId, long rightRowId) {
+ TsBlock leftTsBlock = tsBlockManager.getTsBlock(leftRowId);
+ int leftPosition = tsBlockManager.getPosition(leftRowId);
+ TsBlock rightTsBlock = tsBlockManager.getTsBlock(rightRowId);
+ int rightPosition = tsBlockManager.getPosition(rightRowId);
+ return equalsAndHash.equals(leftTsBlock, leftPosition, rightTsBlock, rightPosition);
+ }
+
+ @Override
+ public long hashCode(long rowId) {
+ TsBlock tsBlock = tsBlockManager.getTsBlock(rowId);
+ int position = tsBlockManager.getPosition(rowId);
+ return equalsAndHash.hashCode(tsBlock, position);
+ }
+ },
+ topN,
+ tsBlockManager::dereference);
+ }
+
+ @Override
+ public void addTsBlock(TsBlock tsBlock) {
+ int[] groupIds;
+ if (groupByChannels.length == 0) {
+ groupIds = new int[tsBlock.getPositionCount()];
+ if (tsBlock.getPositionCount() > 0) {
+ effectiveGroupCount = 1;
+ }
+ } else {
+ groupIds = groupByHash.getGroupIds(tsBlock.getColumns(groupByChannels));
+ effectiveGroupCount = groupByHash.getGroupCount();
+ }
+
+ processTsBlock(tsBlock, effectiveGroupCount, groupIds);
+ }
+
+ @Override
+ public Iterator getResult() {
+ return new ResultIterator();
+ }
+
+ @Override
+ public long getEstimatedSizeInBytes() {
+ return INSTANCE_SIZE
+ + groupByHash.getEstimatedSize()
+ + tsBlockManager.sizeOf()
+ + groupedTopNRankAccumulator.sizeOf();
+ }
+
+ private void processTsBlock(TsBlock newTsBlock, int groupCount, int[] groupIds) {
+ int firstPositionToAdd =
+ groupedTopNRankAccumulator.findFirstPositionToAdd(
+ newTsBlock, groupCount, groupIds, comparator, tsBlockManager);
+ if (firstPositionToAdd < 0) {
+ return;
+ }
+
+ try (RowReferenceTsBlockManager.LoadCursor loadCursor =
+ tsBlockManager.add(newTsBlock, firstPositionToAdd)) {
+ for (int position = firstPositionToAdd;
+ position < newTsBlock.getPositionCount();
+ position++) {
+ int groupId = groupIds[position];
+ loadCursor.advance();
+ groupedTopNRankAccumulator.add(groupId, loadCursor);
+ }
+ }
+
+ tsBlockManager.compactIfNeeded();
+ }
+
+ private class ResultIterator extends AbstractIterator {
+ private final TsBlockBuilder tsBlockBuilder;
+ private final int groupIdCount = effectiveGroupCount;
+ private int currentGroupId = -1;
+ private final LongBigArray rowIdOutput = new LongBigArray();
+ private final LongBigArray rankingOutput = new LongBigArray();
+ private long currentGroupSize;
+ private int currentIndexInGroup;
+
+ ResultIterator() {
+ ImmutableList.Builder sourceTypesBuilders =
+ ImmutableList.builder().addAll(sourceTypes);
+ if (produceRanking) {
+ sourceTypesBuilders.add(TSDataType.INT64);
+ }
+ tsBlockBuilder = new TsBlockBuilder(sourceTypesBuilders.build());
+ }
+
+ @Override
+ protected TsBlock computeNext() {
+ tsBlockBuilder.reset();
+ while (!tsBlockBuilder.isFull()) {
+ while (currentIndexInGroup >= currentGroupSize) {
+ if (currentGroupId + 1 >= groupIdCount) {
+ if (tsBlockBuilder.isEmpty()) {
+ return endOfData();
+ }
+ return tsBlockBuilder.build(
+ new RunLengthEncodedColumn(
+ TIME_COLUMN_TEMPLATE, tsBlockBuilder.getPositionCount()));
+ }
+ currentGroupId++;
+ currentGroupSize =
+ produceRanking
+ ? groupedTopNRankAccumulator.drainTo(currentGroupId, rowIdOutput, rankingOutput)
+ : groupedTopNRankAccumulator.drainTo(currentGroupId, rowIdOutput);
+ currentIndexInGroup = 0;
+ }
+
+ long rowId = rowIdOutput.get(currentIndexInGroup);
+ TsBlock tsBlock = tsBlockManager.getTsBlock(rowId);
+ int position = tsBlockManager.getPosition(rowId);
+ for (int i = 0; i < sourceTypes.size(); i++) {
+ ColumnBuilder builder = tsBlockBuilder.getColumnBuilder(i);
+ Column column = tsBlock.getColumn(i);
+ builder.write(column, position);
+ }
+ if (produceRanking) {
+ ColumnBuilder builder = tsBlockBuilder.getColumnBuilder(sourceTypes.size());
+ builder.writeLong(rankingOutput.get(currentIndexInGroup));
+ }
+ tsBlockBuilder.declarePosition();
+ currentIndexInGroup++;
+
+ tsBlockManager.dereference(rowId);
+ }
+
+ if (tsBlockBuilder.isEmpty()) {
+ return endOfData();
+ }
+ return tsBlockBuilder.build(
+ new RunLengthEncodedColumn(TIME_COLUMN_TEMPLATE, tsBlockBuilder.getPositionCount()));
+ }
+ }
+}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/RowIdComparisonHashStrategy.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/RowIdComparisonHashStrategy.java
new file mode 100644
index 0000000000000..acb63ce57f87d
--- /dev/null
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/RowIdComparisonHashStrategy.java
@@ -0,0 +1,27 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.execution.operator;
+
+public interface RowIdComparisonHashStrategy extends RowIdComparisonStrategy, RowIdHashStrategy {
+ @Override
+ default boolean equals(long leftRowId, long rightRowId) {
+ return compare(leftRowId, rightRowId) == 0;
+ }
+}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/SimpleTsBlockWithPositionEqualsAndHash.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/SimpleTsBlockWithPositionEqualsAndHash.java
new file mode 100644
index 0000000000000..5b9aa6c3b4c75
--- /dev/null
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/SimpleTsBlockWithPositionEqualsAndHash.java
@@ -0,0 +1,129 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.execution.operator;
+
+import org.apache.tsfile.block.column.Column;
+import org.apache.tsfile.enums.TSDataType;
+import org.apache.tsfile.read.common.block.TsBlock;
+
+import java.util.ArrayList;
+import java.util.List;
+
+/**
+ * Computes equality and hash based on specified column channels of a TsBlock. Used for peer group
+ * detection in RANK window functions.
+ */
+public class SimpleTsBlockWithPositionEqualsAndHash implements TsBlockWithPositionEqualsAndHash {
+ private final List channels;
+ private final List types;
+
+ public SimpleTsBlockWithPositionEqualsAndHash(List allTypes, List channels) {
+ this.channels = channels;
+ this.types = new ArrayList<>(channels.size());
+ for (int channel : channels) {
+ types.add(allTypes.get(channel));
+ }
+ }
+
+ @Override
+ public boolean equals(TsBlock left, int leftPosition, TsBlock right, int rightPosition) {
+ for (int i = 0; i < channels.size(); i++) {
+ int channel = channels.get(i);
+ Column leftColumn = left.getColumn(channel);
+ Column rightColumn = right.getColumn(channel);
+
+ boolean leftNull = leftColumn.isNull(leftPosition);
+ boolean rightNull = rightColumn.isNull(rightPosition);
+ if (leftNull != rightNull) {
+ return false;
+ }
+ if (leftNull) {
+ continue;
+ }
+
+ if (!valueEquals(leftColumn, leftPosition, rightColumn, rightPosition, types.get(i))) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ @Override
+ public long hashCode(TsBlock block, int position) {
+ long hash = 0;
+ for (int i = 0; i < channels.size(); i++) {
+ Column column = block.getColumn(channels.get(i));
+ hash = hash * 31 + valueHash(column, position, types.get(i));
+ }
+ return hash;
+ }
+
+ private static boolean valueEquals(
+ Column left, int leftPos, Column right, int rightPos, TSDataType type) {
+ switch (type) {
+ case INT32:
+ case DATE:
+ return left.getInt(leftPos) == right.getInt(rightPos);
+ case INT64:
+ case TIMESTAMP:
+ return left.getLong(leftPos) == right.getLong(rightPos);
+ case FLOAT:
+ return Float.compare(left.getFloat(leftPos), right.getFloat(rightPos)) == 0;
+ case DOUBLE:
+ return Double.compare(left.getDouble(leftPos), right.getDouble(rightPos)) == 0;
+ case BOOLEAN:
+ return left.getBoolean(leftPos) == right.getBoolean(rightPos);
+ case TEXT:
+ case BLOB:
+ case STRING:
+ return left.getBinary(leftPos).equals(right.getBinary(rightPos));
+ default:
+ throw new IllegalArgumentException("Unsupported type: " + type);
+ }
+ }
+
+ private static long valueHash(Column column, int position, TSDataType type) {
+ if (column.isNull(position)) {
+ return 0;
+ }
+ switch (type) {
+ case INT32:
+ case DATE:
+ return column.getInt(position);
+ case INT64:
+ case TIMESTAMP:
+ long v = column.getLong(position);
+ return v ^ (v >>> 32);
+ case FLOAT:
+ return Float.floatToIntBits(column.getFloat(position));
+ case DOUBLE:
+ long dv = Double.doubleToLongBits(column.getDouble(position));
+ return dv ^ (dv >>> 32);
+ case BOOLEAN:
+ return column.getBoolean(position) ? 1231 : 1237;
+ case TEXT:
+ case BLOB:
+ case STRING:
+ return column.getBinary(position).hashCode();
+ default:
+ throw new IllegalArgumentException("Unsupported type: " + type);
+ }
+ }
+}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/TopNPeerGroupLookup.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/TopNPeerGroupLookup.java
new file mode 100644
index 0000000000000..60cd49a613b9e
--- /dev/null
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/TopNPeerGroupLookup.java
@@ -0,0 +1,401 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.execution.operator;
+
+import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.array.LongBigArray;
+
+import org.apache.tsfile.utils.RamUsageEstimator;
+
+import static com.google.common.base.Preconditions.checkArgument;
+import static java.util.Objects.requireNonNull;
+
+/** Optimized hash table for streaming Top N peer group lookup operations. */
+// Note: this code was forked from fastutil (http://fastutil.di.unimi.it/)
+// Long2LongOpenCustomHashMap.
+// Copyright (C) 2002-2019 Sebastiano Vigna
+public class TopNPeerGroupLookup {
+ private static final long INSTANCE_SIZE =
+ RamUsageEstimator.shallowSizeOfInstance(TopNPeerGroupLookup.class);
+
+ /** The buffer containing key and value data. */
+ private Buffer buffer;
+
+ /** The mask for wrapping a position counter. */
+ private long mask;
+
+ /** The hash strategy. */
+ private final RowIdHashStrategy strategy;
+
+ /** The current allocated table size. */
+ private long tableSize;
+
+ /** Threshold after which we rehash. */
+ private long maxFill;
+
+ /** The acceptable load factor. */
+ private final float fillFactor;
+
+ /** Number of entries in the set. */
+ private long entryCount;
+
+ /**
+ * The value denoting unmapped group IDs. Since group IDs need to co-exist at all times with row
+ * IDs, we only need to use one of the two IDs to indicate that a slot is unused. Group IDs have
+ * been arbitrarily selected for that purpose.
+ */
+ private final long unmappedGroupId;
+
+ /** The default return value for {@code get()}, {@code put()} and {@code remove()}. */
+ private final long defaultReturnValue;
+
+ /**
+ * Standard hash table parameters are expected. {@code unmappedGroupId} specifies the internal
+ * marker value for unmapped group IDs.
+ */
+ public TopNPeerGroupLookup(
+ long expected,
+ float fillFactor,
+ RowIdHashStrategy strategy,
+ long unmappedGroupId,
+ long defaultReturnValue) {
+ checkArgument(expected >= 0, "The expected number of elements must be nonnegative");
+ checkArgument(
+ fillFactor > 0 && fillFactor <= 1,
+ "Load factor must be greater than 0 and smaller than or equal to 1");
+ this.fillFactor = fillFactor;
+ this.strategy = requireNonNull(strategy, "strategy is null");
+ this.unmappedGroupId = unmappedGroupId;
+ this.defaultReturnValue = defaultReturnValue;
+
+ tableSize = bigArraySize(expected, fillFactor);
+ mask = tableSize - 1;
+ maxFill = maxFill(tableSize, fillFactor);
+ buffer = new Buffer(tableSize, unmappedGroupId);
+ }
+
+ public TopNPeerGroupLookup(
+ long expected, RowIdHashStrategy strategy, long unmappedGroupId, long defaultReturnValue) {
+ this(expected, 0.75f, strategy, unmappedGroupId, defaultReturnValue);
+ }
+
+ /** Returns the size of this hash map in bytes. */
+ public long sizeOf() {
+ return INSTANCE_SIZE + buffer.sizeOf();
+ }
+
+ public long size() {
+ return entryCount;
+ }
+
+ public boolean isEmpty() {
+ return entryCount == 0;
+ }
+
+ public long get(long groupId, long rowId) {
+ checkArgument(groupId != unmappedGroupId, "Group ID cannot be the unmapped group ID");
+
+ long hash = hash(groupId, rowId);
+ long index = hash & mask;
+ if (buffer.isEmptySlot(index)) {
+ return defaultReturnValue;
+ }
+ if (hash == buffer.getPrecomputedHash(index) && equals(groupId, rowId, index)) {
+ return buffer.getValue(index);
+ }
+ // There's always an unused entry.
+ while (true) {
+ index = (index + 1) & mask;
+ if (buffer.isEmptySlot(index)) {
+ return defaultReturnValue;
+ }
+ if (hash == buffer.getPrecomputedHash(index) && equals(groupId, rowId, index)) {
+ return buffer.getValue(index);
+ }
+ }
+ }
+
+ public long get(long groupId, RowReference rowReference) {
+ checkArgument(groupId != unmappedGroupId, "Group ID cannot be the unmapped group ID");
+
+ long hash = hash(groupId, rowReference);
+ long index = hash & mask;
+ if (buffer.isEmptySlot(index)) {
+ return defaultReturnValue;
+ }
+ if (hash == buffer.getPrecomputedHash(index) && equals(groupId, rowReference, index)) {
+ return buffer.getValue(index);
+ }
+ // There's always an unused entry.
+ while (true) {
+ index = (index + 1) & mask;
+ if (buffer.isEmptySlot(index)) {
+ return defaultReturnValue;
+ }
+ if (hash == buffer.getPrecomputedHash(index) && equals(groupId, rowReference, index)) {
+ return buffer.getValue(index);
+ }
+ }
+ }
+
+ public long put(long groupId, long rowId, long value) {
+ checkArgument(groupId != unmappedGroupId, "Group ID cannot be the unmapped group ID");
+
+ long hash = hash(groupId, rowId);
+
+ long index = find(groupId, rowId, hash);
+ if (index < 0) {
+ insert(twosComplement(index), groupId, rowId, hash, value);
+ return defaultReturnValue;
+ }
+ long oldValue = buffer.getValue(index);
+ buffer.setValue(index, value);
+ return oldValue;
+ }
+
+ private long hash(long groupId, long rowId) {
+ return mix(groupId * 31 + strategy.hashCode(rowId));
+ }
+
+ private long hash(long groupId, RowReference rowReference) {
+ return mix(groupId * 31 + rowReference.hash(strategy));
+ }
+
+ private boolean equals(long groupId, long rowId, long index) {
+ return groupId == buffer.getGroupId(index) && strategy.equals(rowId, buffer.getRowId(index));
+ }
+
+ private boolean equals(long groupId, RowReference rowReference, long index) {
+ return groupId == buffer.getGroupId(index)
+ && rowReference.equals(strategy, buffer.getRowId(index));
+ }
+
+ private void insert(long index, long groupId, long rowId, long precomputedHash, long value) {
+ buffer.set(index, groupId, rowId, precomputedHash, value);
+ entryCount++;
+ if (entryCount > maxFill) {
+ rehash(bigArraySize(entryCount + 1, fillFactor));
+ }
+ }
+
+ /**
+ * Locate the index for the specified {@code groupId} and {@code rowId} key pair. If the index is
+ * unpopulated, then return the index as the two's complement value (which will be negative).
+ */
+ private long find(long groupId, long rowId, long precomputedHash) {
+ long index = precomputedHash & mask;
+ if (buffer.isEmptySlot(index)) {
+ return twosComplement(index);
+ }
+ if (precomputedHash == buffer.getPrecomputedHash(index) && equals(groupId, rowId, index)) {
+ return index;
+ }
+ // There's always an unused entry.
+ while (true) {
+ index = (index + 1) & mask;
+ if (buffer.isEmptySlot(index)) {
+ return twosComplement(index);
+ }
+ if (precomputedHash == buffer.getPrecomputedHash(index) && equals(groupId, rowId, index)) {
+ return index;
+ }
+ }
+ }
+
+ public long remove(long groupId, long rowId) {
+ checkArgument(groupId != unmappedGroupId, "Group ID cannot be the unmapped group ID");
+
+ long hash = hash(groupId, rowId);
+ long index = hash & mask;
+ if (buffer.isEmptySlot(index)) {
+ return defaultReturnValue;
+ }
+ if (hash == buffer.getPrecomputedHash(index) && equals(groupId, rowId, index)) {
+ return removeEntry(index);
+ }
+ while (true) {
+ index = (index + 1) & mask;
+ if (buffer.isEmptySlot(index)) {
+ return defaultReturnValue;
+ }
+ if (hash == buffer.getPrecomputedHash(index) && equals(groupId, rowId, index)) {
+ return removeEntry(index);
+ }
+ }
+ }
+
+ private long removeEntry(long index) {
+ long oldValue = buffer.getValue(index);
+ entryCount--;
+ shiftKeys(index);
+ return oldValue;
+ }
+
+ /**
+ * Shifts left entries with the specified hash code, starting at the specified index, and empties
+ * the resulting free entry.
+ *
+ * @param index a starting position.
+ */
+ private void shiftKeys(long index) {
+ // Shift entries with the same hash.
+ while (true) {
+ long currentHash;
+
+ long initialIndex = index;
+ index = ((index) + 1) & mask;
+ while (true) {
+ if (buffer.isEmptySlot(index)) {
+ buffer.clear(initialIndex);
+ return;
+ }
+ currentHash = buffer.getPrecomputedHash(index);
+ long slot = currentHash & mask;
+ // Yes, this is dense logic. See fastutil Long2LongOpenCustomHashMap#shiftKeys
+ // implementation.
+ if (initialIndex <= index
+ ? initialIndex >= slot || slot > index
+ : initialIndex >= slot && slot > index) {
+ break;
+ }
+ index = (index + 1) & mask;
+ }
+ buffer.set(
+ initialIndex,
+ buffer.getGroupId(index),
+ buffer.getRowId(index),
+ currentHash,
+ buffer.getValue(index));
+ }
+ }
+
+ private void rehash(long newTableSize) {
+ long newMask = newTableSize - 1; // Note that this is used by the hashing macro
+ Buffer newBuffer = new Buffer(newTableSize, unmappedGroupId);
+ long index = tableSize;
+ for (long i = entryCount; i > 0; i--) {
+ index--;
+ while (buffer.isEmptySlot(index)) {
+ index--;
+ }
+ long hash = buffer.getPrecomputedHash(index);
+ long newIndex = hash & newMask;
+ if (!newBuffer.isEmptySlot(newIndex)) {
+ newIndex = (newIndex + 1) & newMask;
+ while (!newBuffer.isEmptySlot(newIndex)) {
+ newIndex = (newIndex + 1) & newMask;
+ }
+ }
+ newBuffer.set(
+ newIndex, buffer.getGroupId(index), buffer.getRowId(index), hash, buffer.getValue(index));
+ }
+ tableSize = newTableSize;
+ mask = newMask;
+ maxFill = maxFill(tableSize, fillFactor);
+ buffer = newBuffer;
+ }
+
+ private static long twosComplement(long value) {
+ return -(value + 1);
+ }
+
+ private static class Buffer {
+ private static final long INSTANCE_SIZE = RamUsageEstimator.shallowSizeOfInstance(Buffer.class);
+
+ private static final int POSITIONS_PER_ENTRY = 4;
+ private static final int ROW_ID_OFFSET = 1;
+ private static final int PRECOMPUTED_HASH_OFFSET = 2;
+ private static final int VALUE_OFFSET = 3;
+
+ /*
+ * Memory layout:
+ * [LONG] groupId1, [LONG] rowId1, [LONG] precomputedHash1, [LONG] value1
+ * [LONG] groupId2, [LONG] rowId2, [LONG] precomputedHash2, [LONG] value2
+ * ...
+ */
+ private final LongBigArray buffer;
+ private final long unmappedGroupId;
+
+ public Buffer(long positions, long unmappedGroupId) {
+ buffer = new LongBigArray(unmappedGroupId);
+ buffer.ensureCapacity(positions * POSITIONS_PER_ENTRY);
+ this.unmappedGroupId = unmappedGroupId;
+ }
+
+ public void set(long index, long groupId, long rowId, long precomputedHash, long value) {
+ buffer.set(index * POSITIONS_PER_ENTRY, groupId);
+ buffer.set(index * POSITIONS_PER_ENTRY + ROW_ID_OFFSET, rowId);
+ buffer.set(index * POSITIONS_PER_ENTRY + PRECOMPUTED_HASH_OFFSET, precomputedHash);
+ buffer.set(index * POSITIONS_PER_ENTRY + VALUE_OFFSET, value);
+ }
+
+ public void clear(long index) {
+ // Since all fields of an index are set/unset together as a unit, we only need to choose one
+ // field to serve
+ // as a marker for empty slots. Group IDs have been arbitrarily selected for that purpose.
+ buffer.set(index * POSITIONS_PER_ENTRY, unmappedGroupId);
+ }
+
+ public boolean isEmptySlot(long index) {
+ return getGroupId(index) == unmappedGroupId;
+ }
+
+ public long getGroupId(long index) {
+ return buffer.get(index * POSITIONS_PER_ENTRY);
+ }
+
+ public long getRowId(long index) {
+ return buffer.get(index * POSITIONS_PER_ENTRY + ROW_ID_OFFSET);
+ }
+
+ public long getPrecomputedHash(long index) {
+ return buffer.get(index * POSITIONS_PER_ENTRY + PRECOMPUTED_HASH_OFFSET);
+ }
+
+ public long getValue(long index) {
+ return buffer.get(index * POSITIONS_PER_ENTRY + VALUE_OFFSET);
+ }
+
+ public void setValue(long index, long value) {
+ buffer.set(index * POSITIONS_PER_ENTRY + VALUE_OFFSET, value);
+ }
+
+ public long sizeOf() {
+ return INSTANCE_SIZE + buffer.sizeOf();
+ }
+ }
+
+ public static long maxFill(long n, float f) {
+ return Math.min((long) Math.ceil((double) ((float) n * f)), n - 1L);
+ }
+
+ public static long nextPowerOfTwo(long x) {
+ return 1L << (64 - Long.numberOfLeadingZeros(x - 1L));
+ }
+
+ public static long bigArraySize(long expected, float f) {
+ return nextPowerOfTwo((long) Math.ceil((double) ((float) expected / f)));
+ }
+
+ public static long mix(long x) {
+ long h = x * -7046029254386353131L;
+ h ^= h >>> 32;
+ return h ^ h >>> 16;
+ }
+}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/TsBlockWithPositionEqualsAndHash.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/TsBlockWithPositionEqualsAndHash.java
new file mode 100644
index 0000000000000..1c76a90233ad2
--- /dev/null
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/TsBlockWithPositionEqualsAndHash.java
@@ -0,0 +1,28 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.execution.operator;
+
+import org.apache.tsfile.read.common.block.TsBlock;
+
+public interface TsBlockWithPositionEqualsAndHash {
+ boolean equals(TsBlock left, int leftPosition, TsBlock right, int rightPosition);
+
+ long hashCode(TsBlock block, int position);
+}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/TopKRankingOperator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/TopKRankingOperator.java
index b94c546ac5497..c703fb423b9b4 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/TopKRankingOperator.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/TopKRankingOperator.java
@@ -21,11 +21,14 @@
import org.apache.iotdb.db.queryengine.execution.MemoryEstimationHelper;
import org.apache.iotdb.db.queryengine.execution.operator.GroupedTopNBuilder;
+import org.apache.iotdb.db.queryengine.execution.operator.GroupedTopNRankBuilder;
import org.apache.iotdb.db.queryengine.execution.operator.GroupedTopNRowNumberBuilder;
import org.apache.iotdb.db.queryengine.execution.operator.Operator;
import org.apache.iotdb.db.queryengine.execution.operator.OperatorContext;
import org.apache.iotdb.db.queryengine.execution.operator.SimpleTsBlockWithPositionComparator;
+import org.apache.iotdb.db.queryengine.execution.operator.SimpleTsBlockWithPositionEqualsAndHash;
import org.apache.iotdb.db.queryengine.execution.operator.TsBlockWithPositionComparator;
+import org.apache.iotdb.db.queryengine.execution.operator.TsBlockWithPositionEqualsAndHash;
import org.apache.iotdb.db.queryengine.execution.operator.process.ProcessOperator;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.UpdateMemory;
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped.hash.GroupByHash;
@@ -150,16 +153,21 @@ private Supplier getGroupedTopNBuilderSupplier() {
groupByHashSupplier.get());
}
- // if (rankingType == TopKRankingNode.RankingType.RANK) {
- // Comparator comparator = new SimpleTsBlockWithPositionComparator(
- // sourceTypes, sortChannels, ascendingOrders);
- // return () -> new GroupedTopNRankBuilder(
- // sourceTypes,
- // comparator,
- // maxRankingPerPartition,
- // generateRanking,
- // groupByHashSupplier.get());
- // }
+ if (rankingType == TopKRankingNode.RankingType.RANK) {
+ TsBlockWithPositionComparator comparator =
+ new SimpleTsBlockWithPositionComparator(inputTypes, sortChannels, sortOrders);
+ TsBlockWithPositionEqualsAndHash equalsAndHash =
+ new SimpleTsBlockWithPositionEqualsAndHash(inputTypes, sortChannels);
+ return () ->
+ new GroupedTopNRankBuilder(
+ inputTypes,
+ comparator,
+ equalsAndHash,
+ maxRowCountPerPartition,
+ !partial,
+ partitionChannels.stream().mapToInt(Integer::intValue).toArray(),
+ groupByHashSupplier.get());
+ }
if (rankingType == TopKRankingNode.RankingType.DENSE_RANK) {
throw new UnsupportedOperationException("DENSE_RANK not yet implemented");
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
index 1bc788251a764..cd23122e4b2ff 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/planner/TableOperatorGenerator.java
@@ -4323,6 +4323,35 @@ public Operator visitTopKRanking(TopKRankingNode node, LocalExecutionPlanContext
List partitionChannels = getChannelsForSymbols(partitionBySymbols, childLayout);
List inputDataTypes =
getOutputColumnTypes(node.getChild(), context.getTypeProvider());
+
+ ImmutableList.Builder outputChannels = ImmutableList.builder();
+ for (int i = 0; i < inputDataTypes.size(); i++) {
+ outputChannels.add(i);
+ }
+
+ // compute the layout of the output from the window operator
+ ImmutableMap.Builder outputMappings = ImmutableMap.builder();
+ outputMappings.putAll(childLayout);
+
+ if (!node.isPartial() || !partitionChannels.isEmpty()) {
+ int channel = inputDataTypes.size();
+ outputMappings.put(node.getRankingSymbol(), channel);
+ }
+
+ if (node.isDataPreSortedAndLimited()) {
+ // Data is already limited to K rows per partition and sorted by time.
+ // Use streaming RowNumberOperator (O(n) time, O(partitions) memory)
+ // instead of heap-based TopKRankingOperator (O(n log K) time, O(n) memory).
+ return new RowNumberOperator(
+ operatorContext,
+ child,
+ inputDataTypes,
+ outputChannels.build(),
+ partitionChannels,
+ Optional.of(node.getMaxRankingPerPartition()),
+ 10_000);
+ }
+
List partitionTypes =
partitionChannels.stream().map(inputDataTypes::get).collect(toImmutableList());
@@ -4341,21 +4370,6 @@ public Operator visitTopKRanking(TopKRankingNode node, LocalExecutionPlanContext
.collect(toImmutableList());
}
- ImmutableList.Builder outputChannels = ImmutableList.builder();
- for (int i = 0; i < inputDataTypes.size(); i++) {
- outputChannels.add(i);
- }
-
- // compute the layout of the output from the window operator
- ImmutableMap.Builder outputMappings = ImmutableMap.builder();
- outputMappings.putAll(childLayout);
-
- if (!node.isPartial() || !partitionChannels.isEmpty()) {
- // ranking function goes in the last channel
- int channel = inputDataTypes.size();
- outputMappings.put(node.getRankingSymbol(), channel);
- }
-
return new TopKRankingOperator(
operatorContext,
child,
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableDistributedPlanGenerator.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableDistributedPlanGenerator.java
index 7072b5f519f73..2f62ebc63f8a0 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableDistributedPlanGenerator.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/distribute/TableDistributedPlanGenerator.java
@@ -54,6 +54,7 @@
import org.apache.iotdb.db.queryengine.plan.relational.planner.SortOrder;
import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
import org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolAllocator;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.SymbolsExtractor;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationTableScanNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.AggregationTreeDeviceViewScanNode;
@@ -306,7 +307,11 @@ public List visitOffset(OffsetNode node, PlanContext context) {
@Override
public List visitProject(ProjectNode node, PlanContext context) {
+ Set savedParentRefs = context.getParentReferencedSymbols();
+ context.setParentReferencedSymbols(
+ SymbolsExtractor.extractUnique(node.getAssignments().getExpressions()));
List childrenNodes = node.getChild().accept(this, context);
+ context.setParentReferencedSymbols(savedParentRefs);
OrderingScheme childOrdering = nodeOrderingMap.get(childrenNodes.get(0).getPlanNodeId());
boolean containAllSortItem = false;
if (childOrdering != null) {
@@ -442,6 +447,33 @@ private boolean canTopKEliminated(OrderingScheme orderingScheme, long k, PlanNod
return false;
}
+ private boolean tryPushTopKRankingLimitToScan(
+ TopKRankingNode topKRankingNode, List children, OrderingScheme orderingScheme) {
+ List orderBy = orderingScheme.getOrderBy();
+ if (orderBy.size() != 1) {
+ return false;
+ }
+ Symbol orderSymbol = orderBy.get(0);
+ long limit = topKRankingNode.getMaxRankingPerPartition();
+ boolean pushed = false;
+
+ for (PlanNode child : children) {
+ if (child instanceof DeviceTableScanNode && !(child instanceof AggregationTableScanNode)) {
+ DeviceTableScanNode scanNode = (DeviceTableScanNode) child;
+ if (scanNode.isTimeColumn(orderSymbol)) {
+ scanNode.setPushLimitToEachDevice(true);
+ if (scanNode.getPushDownLimit() <= 0) {
+ scanNode.setPushDownLimit(limit);
+ } else {
+ scanNode.setPushDownLimit(Math.min(limit, scanNode.getPushDownLimit()));
+ }
+ pushed = true;
+ }
+ }
+ }
+ return pushed;
+ }
+
@Override
public List visitGroup(GroupNode node, PlanContext context) {
context.setExpectedOrderingScheme(node.getOrderingScheme());
@@ -568,7 +600,15 @@ public List visitStreamSort(StreamSortNode node, PlanContext context)
@Override
public List visitFilter(FilterNode node, PlanContext context) {
+ Set savedParentRefs = context.getParentReferencedSymbols();
+ if (savedParentRefs != null) {
+ ImmutableSet.Builder merged = ImmutableSet.builder();
+ merged.addAll(savedParentRefs);
+ merged.addAll(SymbolsExtractor.extractUnique(node.getPredicate()));
+ context.setParentReferencedSymbols(merged.build());
+ }
List childrenNodes = node.getChild().accept(this, context);
+ context.setParentReferencedSymbols(savedParentRefs);
OrderingScheme childOrdering = nodeOrderingMap.get(childrenNodes.get(0).getPlanNodeId());
if (childOrdering != null) {
nodeOrderingMap.put(node.getPlanNodeId(), childOrdering);
@@ -1863,6 +1903,38 @@ public List visitRowNumber(RowNumberNode node, PlanContext context) {
node.setChild(((SortNode) node.getChild()).getChild());
}
List childrenNodes = node.getChild().accept(this, context);
+
+ Set parentRefs = context.getParentReferencedSymbols();
+ if (parentRefs != null && !parentRefs.contains(node.getRowNumberSymbol())) {
+ // If maxRowCountPerPartition is set, push it as a per-device limit to each
+ // DeviceTableScanNode so that only the required number of rows are scanned.
+ node.getMaxRowCountPerPartition()
+ .ifPresent(
+ limit -> {
+ for (PlanNode child : childrenNodes) {
+ if (child instanceof DeviceTableScanNode
+ && !(child instanceof AggregationTableScanNode)) {
+ DeviceTableScanNode scanNode = (DeviceTableScanNode) child;
+ scanNode.setPushLimitToEachDevice(true);
+ if (scanNode.getPushDownLimit() <= 0) {
+ scanNode.setPushDownLimit(limit);
+ } else {
+ scanNode.setPushDownLimit(Math.min(limit, scanNode.getPushDownLimit()));
+ }
+ }
+ }
+ });
+ // Eliminate RowNumberNode entirely - return children directly.
+ if (childrenNodes.size() == 1 || canSplitPushDown) {
+ return childrenNodes;
+ } else {
+ CollectNode collectNode =
+ new CollectNode(queryId.genPlanNodeId(), childrenNodes.get(0).getOutputSymbols());
+ childrenNodes.forEach(collectNode::addChild);
+ return Collections.singletonList(collectNode);
+ }
+ }
+
if (childrenNodes.size() == 1) {
node.setChild(childrenNodes.get(0));
return Collections.singletonList(node);
@@ -1885,7 +1957,6 @@ public List visitTopKRanking(TopKRankingNode node, PlanContext context
nodeOrderingMap.put(node.getPlanNodeId(), orderingScheme.get());
}
- // TODO: per partition topk eliminate
checkArgument(
node.getChildren().size() == 1, "Size of TopKRankingNode can only be 1 in logical plan.");
boolean canSplitPushDown = node.getChild() instanceof GroupNode;
@@ -1894,12 +1965,24 @@ public List visitTopKRanking(TopKRankingNode node, PlanContext context
}
List childrenNodes = node.getChildren().get(0).accept(this, context);
if (canSplitPushDown) {
+ // visitGroup may return GroupNode-wrapped children (sort not eliminated) or bare
+ // DeviceTableScanNode (sort eliminated). Unwrap GroupNode/SortNode when present.
childrenNodes =
childrenNodes.stream()
- .map(child -> child.getChildren().get(0))
+ .map(child -> child instanceof SortNode ? child.getChildren().get(0) : child)
.collect(Collectors.toList());
}
+ if (canSplitPushDown && orderingScheme.isPresent()) {
+ if (tryPushTopKRankingLimitToScan(node, childrenNodes, orderingScheme.get())) {
+ node.setDataPreSortedAndLimited(true);
+ Set parentRefs = context.getParentReferencedSymbols();
+ if (parentRefs != null && !parentRefs.contains(node.getRankingSymbol())) {
+ return childrenNodes;
+ }
+ }
+ }
+
if (childrenNodes.size() == 1) {
node.setChild(childrenNodes.get(0));
return Collections.singletonList(node);
@@ -1958,6 +2041,7 @@ public static class PlanContext {
OrderingScheme expectedOrderingScheme;
TRegionReplicaSet mostUsedRegion;
boolean deviceCrossRegion;
+ Set parentReferencedSymbols;
public PlanContext() {
this.nodeDistributionMap = new HashMap<>();
@@ -1984,5 +2068,13 @@ public void setPushDownGrouping(boolean pushDownGrouping) {
public boolean isPushDownGrouping() {
return pushDownGrouping;
}
+
+ public Set getParentReferencedSymbols() {
+ return parentReferencedSymbols;
+ }
+
+ public void setParentReferencedSymbols(Set parentReferencedSymbols) {
+ this.parentReferencedSymbols = parentReferencedSymbols;
+ }
}
}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PruneTopKRankingColumns.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PruneTopKRankingColumns.java
new file mode 100644
index 0000000000000..da49f3a04c6fe
--- /dev/null
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PruneTopKRankingColumns.java
@@ -0,0 +1,34 @@
+package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule;
+
+import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TopKRankingNode;
+
+import com.google.common.collect.Streams;
+
+import java.util.Optional;
+import java.util.Set;
+
+import static com.google.common.collect.ImmutableSet.toImmutableSet;
+import static org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.Util.restrictChildOutputs;
+import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.topNRanking;
+
+public class PruneTopKRankingColumns extends ProjectOffPushDownRule {
+ public PruneTopKRankingColumns() {
+ super(topNRanking());
+ }
+
+ @Override
+ protected Optional pushDownProjectOff(
+ Context context, TopKRankingNode topNRankingNode, Set referencedOutputs) {
+ Set requiredInputs =
+ Streams.concat(
+ referencedOutputs.stream()
+ .filter(symbol -> !symbol.equals(topNRankingNode.getRankingSymbol())),
+ topNRankingNode.getSpecification().getPartitionBy().stream(),
+ topNRankingNode.getSpecification().getOrderingScheme().get().getOrderBy().stream())
+ .collect(toImmutableSet());
+
+ return restrictChildOutputs(context.getIdAllocator(), topNRankingNode, requiredInputs);
+ }
+}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushFilterIntoRowNumber.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushFilterIntoRowNumber.java
new file mode 100644
index 0000000000000..2aa90ba5670f1
--- /dev/null
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushFilterIntoRowNumber.java
@@ -0,0 +1,138 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule;
+
+import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.FilterNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.RowNumberNode;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Literal;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern;
+
+import java.util.Optional;
+import java.util.OptionalInt;
+
+import static java.lang.Math.toIntExact;
+import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.filter;
+import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.rowNumber;
+import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.source;
+import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture.newCapture;
+
+/**
+ * Pushes a row-number upper-bound filter (e.g. {@code rn <= N}) into {@link RowNumberNode} by
+ * setting {@code maxRowCountPerPartition}. The filter is eliminated because the row-number node
+ * guarantees that no partition will emit more than {@code N} rows.
+ *
+ * Before:
+ *
+ *
+ * FilterNode(rn <= N)
+ * └── RowNumberNode(rowNumberSymbol=rn, maxRowCountPerPartition=empty)
+ *
+ *
+ * After:
+ *
+ *
+ * RowNumberNode(rowNumberSymbol=rn, maxRowCountPerPartition=N)
+ *
+ */
+public class PushFilterIntoRowNumber implements Rule {
+ private static final Capture CHILD = newCapture();
+
+ private static final Pattern PATTERN =
+ filter()
+ .with(
+ source()
+ .matching(
+ rowNumber()
+ .matching(
+ rowNumber -> !rowNumber.getMaxRowCountPerPartition().isPresent())
+ .capturedAs(CHILD)));
+
+ @Override
+ public Pattern getPattern() {
+ return PATTERN;
+ }
+
+ @Override
+ public Result apply(FilterNode node, Captures captures, Context context) {
+ RowNumberNode rowNumberNode = captures.get(CHILD);
+ Symbol rowNumberSymbol = rowNumberNode.getRowNumberSymbol();
+
+ OptionalInt upperBound = extractUpperBound(node.getPredicate(), rowNumberSymbol);
+ if (!upperBound.isPresent()) {
+ return Result.empty();
+ }
+
+ if (upperBound.getAsInt() <= 0) {
+ return Result.empty();
+ }
+
+ return Result.ofPlanNode(
+ new RowNumberNode(
+ rowNumberNode.getPlanNodeId(),
+ rowNumberNode.getChild(),
+ rowNumberNode.getPartitionBy(),
+ rowNumberNode.isOrderSensitive(),
+ rowNumberSymbol,
+ Optional.of(upperBound.getAsInt())));
+ }
+
+ private OptionalInt extractUpperBound(Expression predicate, Symbol rowNumberSymbol) {
+ if (!(predicate instanceof ComparisonExpression)) {
+ return OptionalInt.empty();
+ }
+
+ ComparisonExpression comparison = (ComparisonExpression) predicate;
+ Expression left = comparison.getLeft();
+ Expression right = comparison.getRight();
+
+ if (!(left instanceof SymbolReference) || !(right instanceof Literal)) {
+ return OptionalInt.empty();
+ }
+
+ SymbolReference symbolRef = (SymbolReference) left;
+ if (!symbolRef.getName().equals(rowNumberSymbol.getName())) {
+ return OptionalInt.empty();
+ }
+
+ Literal literal = (Literal) right;
+ Object value = literal.getTsValue();
+ if (!(value instanceof Number)) {
+ return OptionalInt.empty();
+ }
+
+ long constantValue = ((Number) value).longValue();
+
+ switch (comparison.getOperator()) {
+ case LESS_THAN:
+ return OptionalInt.of(toIntExact(constantValue - 1));
+ case LESS_THAN_OR_EQUAL:
+ return OptionalInt.of(toIntExact(constantValue));
+ default:
+ return OptionalInt.empty();
+ }
+ }
+}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java
new file mode 100644
index 0000000000000..f0dbae59778ff
--- /dev/null
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushPredicateThroughProjectIntoRowNumber.java
@@ -0,0 +1,195 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule;
+
+import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.FilterNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.RowNumberNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ValuesNode;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Literal;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.Optional;
+import java.util.OptionalInt;
+
+import static java.lang.Math.toIntExact;
+import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.filter;
+import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.project;
+import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.rowNumber;
+import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.source;
+import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture.newCapture;
+
+/**
+ * Pushes a row-number upper-bound filter through an identity projection into {@link RowNumberNode}
+ * by setting {@code maxRowCountPerPartition}.
+ *
+ * Before:
+ *
+ *
+ * FilterNode(rn <= N)
+ * └── ProjectNode (identity)
+ * └── RowNumberNode(rowNumberSymbol=rn, maxRowCountPerPartition=empty)
+ *
+ *
+ * After (for LESS_THAN / LESS_THAN_OR_EQUAL — filter fully absorbed):
+ *
+ *
+ * ProjectNode (identity)
+ * └── RowNumberNode(rowNumberSymbol=rn, maxRowCountPerPartition=N)
+ *
+ *
+ * After (for EQUAL — filter kept):
+ *
+ *
+ * FilterNode(rn = N)
+ * └── ProjectNode (identity)
+ * └── RowNumberNode(rowNumberSymbol=rn, maxRowCountPerPartition=N)
+ *
+ */
+public class PushPredicateThroughProjectIntoRowNumber implements Rule {
+ private static final Capture PROJECT = newCapture();
+ private static final Capture ROW_NUMBER = newCapture();
+
+ private static final Pattern PATTERN =
+ filter()
+ .with(
+ source()
+ .matching(
+ project()
+ .matching(ProjectNode::isIdentity)
+ .capturedAs(PROJECT)
+ .with(
+ source()
+ .matching(
+ rowNumber()
+ .matching(
+ rn -> !rn.getMaxRowCountPerPartition().isPresent())
+ .capturedAs(ROW_NUMBER)))));
+
+ @Override
+ public Pattern getPattern() {
+ return PATTERN;
+ }
+
+ @Override
+ public Result apply(FilterNode filter, Captures captures, Context context) {
+ ProjectNode project = captures.get(PROJECT);
+ RowNumberNode rowNumberNode = captures.get(ROW_NUMBER);
+
+ Symbol rowNumberSymbol = rowNumberNode.getRowNumberSymbol();
+ if (!project.getAssignments().getSymbols().contains(rowNumberSymbol)) {
+ return Result.empty();
+ }
+
+ OptionalInt upperBound = extractUpperBound(filter.getPredicate(), rowNumberSymbol);
+ if (!upperBound.isPresent()) {
+ return Result.empty();
+ }
+ if (upperBound.getAsInt() <= 0) {
+ return Result.ofPlanNode(
+ new ValuesNode(filter.getPlanNodeId(), filter.getOutputSymbols(), ImmutableList.of()));
+ }
+
+ project =
+ (ProjectNode)
+ project.replaceChildren(
+ ImmutableList.of(
+ new RowNumberNode(
+ rowNumberNode.getPlanNodeId(),
+ rowNumberNode.getChild(),
+ rowNumberNode.getPartitionBy(),
+ rowNumberNode.isOrderSensitive(),
+ rowNumberSymbol,
+ Optional.of(upperBound.getAsInt()))));
+
+ if (needToKeepFilter(filter.getPredicate())) {
+ return Result.ofPlanNode(
+ new FilterNode(filter.getPlanNodeId(), project, filter.getPredicate()));
+ }
+ return Result.ofPlanNode(project);
+ }
+
+ private OptionalInt extractUpperBound(Expression predicate, Symbol rowNumberSymbol) {
+ if (!(predicate instanceof ComparisonExpression)) {
+ return OptionalInt.empty();
+ }
+
+ ComparisonExpression comparison = (ComparisonExpression) predicate;
+ Expression left = comparison.getLeft();
+ Expression right = comparison.getRight();
+
+ if (!(left instanceof SymbolReference) || !(right instanceof Literal)) {
+ return OptionalInt.empty();
+ }
+
+ SymbolReference symbolRef = (SymbolReference) left;
+ if (!symbolRef.getName().equals(rowNumberSymbol.getName())) {
+ return OptionalInt.empty();
+ }
+
+ Literal literal = (Literal) right;
+ Object value = literal.getTsValue();
+ if (!(value instanceof Number)) {
+ return OptionalInt.empty();
+ }
+
+ long constantValue = ((Number) value).longValue();
+
+ switch (comparison.getOperator()) {
+ case LESS_THAN:
+ return OptionalInt.of(toIntExact(constantValue - 1));
+ case LESS_THAN_OR_EQUAL:
+ case EQUAL:
+ return OptionalInt.of(toIntExact(constantValue));
+ default:
+ return OptionalInt.empty();
+ }
+ }
+
+ /**
+ * For {@code LESS_THAN} and {@code LESS_THAN_OR_EQUAL}, the RowNumberNode with
+ * maxRowCountPerPartition produces exactly the rows that satisfy the predicate (row numbers
+ * 1..N), so the filter can be removed. For {@code EQUAL} (e.g. {@code rn = 5}), the RowNumberNode
+ * produces rows 1..5 but only rows where {@code rn = 5} are wanted, so the filter must be kept.
+ */
+ private static boolean needToKeepFilter(Expression predicate) {
+ if (!(predicate instanceof ComparisonExpression)) {
+ return true;
+ }
+
+ ComparisonExpression comparison = (ComparisonExpression) predicate;
+ switch (comparison.getOperator()) {
+ case LESS_THAN:
+ case LESS_THAN_OR_EQUAL:
+ return false;
+ default:
+ return true;
+ }
+ }
+}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java
new file mode 100644
index 0000000000000..21c17f7dced98
--- /dev/null
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/iterative/rule/PushPredicateThroughProjectIntoWindow.java
@@ -0,0 +1,207 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule;
+
+import org.apache.iotdb.db.queryengine.plan.relational.planner.PlannerContext;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.Symbol;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.Rule;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.FilterNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TopKRankingNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ValuesNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.WindowNode;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.ComparisonExpression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Expression;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.Literal;
+import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Captures;
+import org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Pattern;
+
+import com.google.common.collect.ImmutableList;
+
+import java.util.OptionalInt;
+
+import static com.google.common.collect.Iterables.getOnlyElement;
+import static java.lang.Math.toIntExact;
+import static java.util.Objects.requireNonNull;
+import static org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.Util.toTopNRankingType;
+import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.filter;
+import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.project;
+import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.source;
+import static org.apache.iotdb.db.queryengine.plan.relational.planner.node.Patterns.window;
+import static org.apache.iotdb.db.queryengine.plan.relational.utils.matching.Capture.newCapture;
+
+/**
+ * Converts a filter on a ranking function (e.g. {@code rn <= N}) into a {@link TopKRankingNode}
+ * when there is an identity projection between the filter and window node.
+ *
+ * Before:
+ *
+ *
+ * FilterNode(rn <= N)
+ * └── ProjectNode (identity)
+ * └── WindowNode(row_number/rank)
+ *
+ *
+ * After (for LESS_THAN / LESS_THAN_OR_EQUAL — filter fully absorbed):
+ *
+ *
+ * ProjectNode (identity)
+ * └── TopKRankingNode(maxRanking=N)
+ *
+ *
+ * After (for EQUAL — filter kept):
+ *
+ *
+ * FilterNode(rn = N)
+ * └── ProjectNode (identity)
+ * └── TopKRankingNode(maxRanking=N)
+ *
+ */
+public class PushPredicateThroughProjectIntoWindow implements Rule {
+ private static final Capture PROJECT = newCapture();
+ private static final Capture WINDOW = newCapture();
+
+ private final PlannerContext plannerContext;
+ private final Pattern pattern;
+
+ public PushPredicateThroughProjectIntoWindow(PlannerContext plannerContext) {
+ this.plannerContext = requireNonNull(plannerContext, "plannerContext is null");
+ this.pattern =
+ filter()
+ .with(
+ source()
+ .matching(
+ project()
+ .matching(ProjectNode::isIdentity)
+ .capturedAs(PROJECT)
+ .with(
+ source()
+ .matching(
+ window()
+ .matching(
+ window -> toTopNRankingType(window).isPresent())
+ .capturedAs(WINDOW)))));
+ }
+
+ @Override
+ public Pattern getPattern() {
+ return pattern;
+ }
+
+ @Override
+ public Result apply(FilterNode filter, Captures captures, Context context) {
+ ProjectNode project = captures.get(PROJECT);
+ WindowNode window = captures.get(WINDOW);
+
+ Symbol rankingSymbol = getOnlyElement(window.getWindowFunctions().keySet());
+ if (!project.getAssignments().getSymbols().contains(rankingSymbol)) {
+ return Result.empty();
+ }
+
+ OptionalInt upperBound = extractUpperBoundFromComparison(filter.getPredicate(), rankingSymbol);
+ if (!upperBound.isPresent()) {
+ return Result.empty();
+ }
+ if (upperBound.getAsInt() <= 0) {
+ return Result.ofPlanNode(
+ new ValuesNode(filter.getPlanNodeId(), filter.getOutputSymbols(), ImmutableList.of()));
+ }
+
+ TopKRankingNode.RankingType rankingType = toTopNRankingType(window).get();
+ project =
+ (ProjectNode)
+ project.replaceChildren(
+ ImmutableList.of(
+ new TopKRankingNode(
+ window.getPlanNodeId(),
+ window.getChild(),
+ window.getSpecification(),
+ rankingType,
+ rankingSymbol,
+ upperBound.getAsInt(),
+ false)));
+
+ if (needToKeepFilter(filter.getPredicate())) {
+ return Result.ofPlanNode(
+ new FilterNode(filter.getPlanNodeId(), project, filter.getPredicate()));
+ }
+ return Result.ofPlanNode(project);
+ }
+
+ private OptionalInt extractUpperBoundFromComparison(Expression predicate, Symbol rankingSymbol) {
+ if (!(predicate instanceof ComparisonExpression)) {
+ return OptionalInt.empty();
+ }
+
+ ComparisonExpression comparison = (ComparisonExpression) predicate;
+ Expression left = comparison.getLeft();
+ Expression right = comparison.getRight();
+
+ if (!(left instanceof SymbolReference) || !(right instanceof Literal)) {
+ return OptionalInt.empty();
+ }
+
+ SymbolReference symbolRef = (SymbolReference) left;
+ if (!symbolRef.getName().equals(rankingSymbol.getName())) {
+ return OptionalInt.empty();
+ }
+
+ Literal literal = (Literal) right;
+ Object value = literal.getTsValue();
+ if (!(value instanceof Number)) {
+ return OptionalInt.empty();
+ }
+
+ long constantValue = ((Number) value).longValue();
+
+ switch (comparison.getOperator()) {
+ case LESS_THAN:
+ return OptionalInt.of(toIntExact(constantValue - 1));
+ case LESS_THAN_OR_EQUAL:
+ case EQUAL:
+ return OptionalInt.of(toIntExact(constantValue));
+ default:
+ return OptionalInt.empty();
+ }
+ }
+
+ /**
+ * For {@code LESS_THAN} and {@code LESS_THAN_OR_EQUAL}, the TopKRankingNode produces exactly the
+ * rows that satisfy the predicate (ranking values 1..N), so the filter can be removed. For {@code
+ * EQUAL} (e.g. {@code rn = 5}), TopKRankingNode produces rows 1..5 but only rows where {@code rn
+ * = 5} are wanted, so the filter must be kept.
+ */
+ private static boolean needToKeepFilter(Expression predicate) {
+ if (!(predicate instanceof ComparisonExpression)) {
+ return true;
+ }
+
+ ComparisonExpression comparison = (ComparisonExpression) predicate;
+ switch (comparison.getOperator()) {
+ case LESS_THAN:
+ case LESS_THAN_OR_EQUAL:
+ return false;
+ default:
+ return true;
+ }
+ }
+}
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/TopKRankingNode.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/TopKRankingNode.java
index fedb7e1f2e172..3b48ec4852340 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/TopKRankingNode.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/node/TopKRankingNode.java
@@ -50,6 +50,11 @@ public enum RankingType {
private final int maxRankingPerPartition;
private final boolean partial;
+ // When true, the child scan already returns pre-sorted, pre-limited data (at most K rows per
+ // partition, ordered correctly). The operator can skip heap-based TopK selection and just assign
+ // sequential row numbers (streaming mode).
+ private boolean dataPreSortedAndLimited = false;
+
public TopKRankingNode(
PlanNodeId id,
DataOrganizationSpecification specification,
@@ -85,13 +90,16 @@ public TopKRankingNode(
@Override
public PlanNode clone() {
- return new TopKRankingNode(
- getPlanNodeId(),
- specification,
- rankingType,
- rankingSymbol,
- maxRankingPerPartition,
- partial);
+ TopKRankingNode topKRankingNode =
+ new TopKRankingNode(
+ getPlanNodeId(),
+ specification,
+ rankingType,
+ rankingSymbol,
+ maxRankingPerPartition,
+ partial);
+ topKRankingNode.setDataPreSortedAndLimited(dataPreSortedAndLimited);
+ return topKRankingNode;
}
@Override
@@ -119,6 +127,14 @@ public RankingType getRankingType() {
return rankingType;
}
+ public boolean isDataPreSortedAndLimited() {
+ return dataPreSortedAndLimited;
+ }
+
+ public void setDataPreSortedAndLimited(boolean dataPreSortedAndLimited) {
+ this.dataPreSortedAndLimited = dataPreSortedAndLimited;
+ }
+
@Override
public List getOutputColumnNames() {
throw new UnsupportedOperationException();
@@ -132,6 +148,7 @@ protected void serializeAttributes(ByteBuffer byteBuffer) {
Symbol.serialize(rankingSymbol, byteBuffer);
ReadWriteIOUtils.write(maxRankingPerPartition, byteBuffer);
ReadWriteIOUtils.write(partial, byteBuffer);
+ ReadWriteIOUtils.write(dataPreSortedAndLimited, byteBuffer);
}
@Override
@@ -142,6 +159,7 @@ protected void serializeAttributes(DataOutputStream stream) throws IOException {
Symbol.serialize(rankingSymbol, stream);
ReadWriteIOUtils.write(maxRankingPerPartition, stream);
ReadWriteIOUtils.write(partial, stream);
+ ReadWriteIOUtils.write(dataPreSortedAndLimited, stream);
}
public static TopKRankingNode deserialize(ByteBuffer byteBuffer) {
@@ -151,10 +169,14 @@ public static TopKRankingNode deserialize(ByteBuffer byteBuffer) {
Symbol rankingSymbol = Symbol.deserialize(byteBuffer);
int maxRankingPerPartition = ReadWriteIOUtils.readInt(byteBuffer);
boolean partial = ReadWriteIOUtils.readBoolean(byteBuffer);
+ boolean dataPreSortedAndLimited = ReadWriteIOUtils.readBoolean(byteBuffer);
PlanNodeId planNodeId = PlanNodeId.deserialize(byteBuffer);
- return new TopKRankingNode(
- planNodeId, specification, rankingType, rankingSymbol, maxRankingPerPartition, partial);
+ TopKRankingNode node =
+ new TopKRankingNode(
+ planNodeId, specification, rankingType, rankingSymbol, maxRankingPerPartition, partial);
+ node.setDataPreSortedAndLimited(dataPreSortedAndLimited);
+ return node;
}
@Override
@@ -167,14 +189,17 @@ public List getOutputSymbols() {
@Override
public PlanNode replaceChildren(List newChildren) {
- return new TopKRankingNode(
- id,
- Iterables.getOnlyElement(newChildren),
- specification,
- rankingType,
- rankingSymbol,
- maxRankingPerPartition,
- partial);
+ TopKRankingNode topKRankingNode =
+ new TopKRankingNode(
+ id,
+ Iterables.getOnlyElement(newChildren),
+ specification,
+ rankingType,
+ rankingSymbol,
+ maxRankingPerPartition,
+ partial);
+ topKRankingNode.setDataPreSortedAndLimited(dataPreSortedAndLimited);
+ return topKRankingNode;
}
@Override
@@ -198,7 +223,8 @@ public int hashCode() {
rankingType,
rankingSymbol,
maxRankingPerPartition,
- partial);
+ partial,
+ dataPreSortedAndLimited);
}
@Override
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
index 864b9a987abca..c5047f8046886 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/LogicalOptimizeFactory.java
@@ -73,14 +73,18 @@
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneTableFunctionProcessorSourceColumns;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneTableScanColumns;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneTopKColumns;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneTopKRankingColumns;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneUnionColumns;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneUnionSourceColumns;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PruneWindowColumns;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushDownFilterIntoWindow;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushDownLimitIntoWindow;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushFilterIntoRowNumber;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushLimitThroughOffset;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushLimitThroughProject;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushLimitThroughUnion;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushPredicateThroughProjectIntoRowNumber;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushPredicateThroughProjectIntoWindow;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushProjectionThroughUnion;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.PushTopKThroughUnion;
import org.apache.iotdb.db.queryengine.plan.relational.planner.iterative.rule.RemoveDuplicateConditions;
@@ -153,6 +157,7 @@ public LogicalOptimizeFactory(PlannerContext plannerContext) {
new PruneTableFunctionProcessorSourceColumns(),
new PruneTableScanColumns(plannerContext.getMetadata()),
new PruneTopKColumns(),
+ new PruneTopKRankingColumns(),
new PruneWindowColumns(),
new PruneJoinColumns(),
new PruneJoinChildrenColumns(),
@@ -375,9 +380,14 @@ public LogicalOptimizeFactory(PlannerContext plannerContext) {
ImmutableSet.>builder()
.add(new PushDownLimitIntoWindow())
.add(new PushDownFilterIntoWindow(plannerContext))
+ .add(new PushPredicateThroughProjectIntoWindow(plannerContext))
.add(new ReplaceWindowWithRowNumber(metadata))
+ .add(new PushFilterIntoRowNumber())
+ .add(new PushPredicateThroughProjectIntoRowNumber())
.addAll(GatherAndMergeWindows.rules())
.build()),
+ inlineProjectionLimitFiltersOptimizer,
+ columnPruningOptimizer,
new TransformAggregationToStreamable(),
new PushAggregationIntoTableScan(),
new TransformSortToStreamSort(),
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/SortElimination.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/SortElimination.java
index bb276f07150b9..de0ef51c14883 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/SortElimination.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/SortElimination.java
@@ -27,8 +27,10 @@
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.FillNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.GapFillNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.PatternRecognitionNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.RowNumberNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.SortNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.StreamSortNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TopKRankingNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ValueFillNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.WindowNode;
@@ -139,15 +141,36 @@ public PlanNode visitPatternRecognition(PatternRecognitionNode node, Context con
context.setCannotEliminateSort(true);
return newNode;
}
+
+ @Override
+ public PlanNode visitTopKRanking(TopKRankingNode node, Context context) {
+ PlanNode newNode = node.clone();
+ for (PlanNode child : node.getChildren()) {
+ newNode.addChild(child.accept(this, context));
+ }
+ context.setCannotEliminateSort(true);
+ return newNode;
+ }
+
+ @Override
+ public PlanNode visitRowNumber(RowNumberNode node, Context context) {
+ PlanNode newNode = node.clone();
+ for (PlanNode child : node.getChildren()) {
+ newNode.addChild(child.accept(this, context));
+ }
+ context.setCannotEliminateSort(true);
+ return newNode;
+ }
}
private static class Context {
private int totalDeviceEntrySize = 0;
- // There are 3 situations where sort cannot be eliminated
+ // There are 4 situations where sort cannot be eliminated
// 1. Query plan has linear fill, previous fill or gapfill
// 2. Query plan has window function and it has ordering scheme
// 3. Query plan has pattern recognition and it has ordering scheme
+ // 4. Query plan has row number node or topk ranking node
private boolean cannotEliminateSort = false;
private String timeColumnName = null;
diff --git a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformSortToStreamSort.java b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformSortToStreamSort.java
index 7eb6dfb81c97c..15da1d3908955 100644
--- a/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformSortToStreamSort.java
+++ b/iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/optimizations/TransformSortToStreamSort.java
@@ -33,8 +33,10 @@
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.DeviceTableScanNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.GroupNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.InformationSchemaTableScanNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.RowNumberNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.SortNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.StreamSortNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TopKRankingNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.node.UnionNode;
import java.util.Map;
@@ -163,6 +165,18 @@ public PlanNode visitAggregationTableScan(AggregationTableScanNode node, Context
return visitTableScan(node, context);
}
+ @Override
+ public PlanNode visitTopKRanking(TopKRankingNode node, Context context) {
+ context.setCanTransform(false);
+ return visitPlan(node, context);
+ }
+
+ @Override
+ public PlanNode visitRowNumber(RowNumberNode node, Context context) {
+ context.setCanTransform(false);
+ return visitPlan(node, context);
+ }
+
@Override
public PlanNode visitUnion(UnionNode node, Context context) {
context.setCanTransform(false);
diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/TopKRankingOperatorTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/TopKRankingOperatorTest.java
index 1bbf05f7679b9..934cc9e276b40 100644
--- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/TopKRankingOperatorTest.java
+++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/execution/operator/process/window/TopKRankingOperatorTest.java
@@ -195,6 +195,165 @@ public void testTopKWithTopOne() {
2);
}
+ // ==================== RANK Tests ====================
+
+ @Test
+ public void testRankWithPartitionAndTies() {
+ // d1: values [5, 3, 3, 1], d2: values [6, 2, 2]
+ // topN=2 ASC → d1 keeps rank≤2: 1(r=1),3(r=2),3(r=2); d2 keeps rank≤2: 2(r=1),2(r=1)
+ long[][] timeArray = {{1, 2, 3, 4, 5, 6, 7}};
+ String[][] deviceArray = {{"d1", "d1", "d1", "d1", "d2", "d2", "d2"}};
+ int[][] valueArray = {{5, 3, 3, 1, 6, 2, 2}};
+
+ Map> expectedByDevice = new HashMap<>();
+ expectedByDevice.put("d1", Arrays.asList(new int[] {1, 1}, new int[] {3, 2}, new int[] {3, 2}));
+ expectedByDevice.put("d2", Arrays.asList(new int[] {2, 1}, new int[] {2, 1}));
+
+ verifyTopKResultsByPartition(
+ timeArray,
+ deviceArray,
+ valueArray,
+ Collections.singletonList(1),
+ Collections.singletonList(TSDataType.TEXT),
+ Collections.singletonList(2),
+ Collections.singletonList(SortOrder.ASC_NULLS_LAST),
+ 2,
+ false,
+ TopKRankingNode.RankingType.RANK,
+ expectedByDevice,
+ 5);
+ }
+
+ @Test
+ public void testRankWithPartitionDescendingAndTies() {
+ // d1: values [5, 3, 3, 1] DESC → 5(r=1),3(r=2),3(r=2),1(r=4) → keep rank≤2
+ // d2: values [6, 2, 4] DESC → 6(r=1),4(r=2),2(r=3) → keep rank≤2
+ long[][] timeArray = {{1, 2, 3, 4, 5, 6}};
+ String[][] deviceArray = {{"d1", "d1", "d1", "d1", "d2", "d2"}};
+ int[][] valueArray = {{5, 3, 3, 1, 6, 4}};
+
+ Map> expectedByDevice = new HashMap<>();
+ expectedByDevice.put("d1", Arrays.asList(new int[] {5, 1}, new int[] {3, 2}, new int[] {3, 2}));
+ expectedByDevice.put("d2", Arrays.asList(new int[] {6, 1}, new int[] {4, 2}));
+
+ verifyTopKResultsByPartition(
+ timeArray,
+ deviceArray,
+ valueArray,
+ Collections.singletonList(1),
+ Collections.singletonList(TSDataType.TEXT),
+ Collections.singletonList(2),
+ Collections.singletonList(SortOrder.DESC_NULLS_LAST),
+ 2,
+ false,
+ TopKRankingNode.RankingType.RANK,
+ expectedByDevice,
+ 5);
+ }
+
+ @Test
+ public void testRankWithoutPartitionAndTies() {
+ // Global: values [5, 3, 1, 3, 2] ASC → 1(r=1),2(r=2),3(r=3),3(r=3),5(r=5) → keep rank≤3
+ long[][] timeArray = {{1, 2, 3, 4, 5}};
+ String[][] deviceArray = {{"d1", "d1", "d2", "d2", "d2"}};
+ int[][] valueArray = {{5, 3, 1, 3, 2}};
+
+ int[][] expectedValueAndRank = {{1, 1}, {2, 2}, {3, 3}, {3, 3}};
+
+ verifyTopKResultsGlobal(
+ timeArray,
+ deviceArray,
+ valueArray,
+ Collections.emptyList(),
+ Collections.emptyList(),
+ Collections.singletonList(2),
+ Collections.singletonList(SortOrder.ASC_NULLS_LAST),
+ 3,
+ false,
+ TopKRankingNode.RankingType.RANK,
+ expectedValueAndRank,
+ 4);
+ }
+
+ @Test
+ public void testRankWithMultipleTsBlocksAndTies() {
+ // Same data as testRankWithPartitionAndTies, split across blocks
+ long[][] timeArray = {{1, 2, 3}, {4, 5}, {6, 7}};
+ String[][] deviceArray = {{"d1", "d1", "d1"}, {"d1", "d2"}, {"d2", "d2"}};
+ int[][] valueArray = {{5, 3, 3}, {1, 6}, {2, 2}};
+
+ Map> expectedByDevice = new HashMap<>();
+ expectedByDevice.put("d1", Arrays.asList(new int[] {1, 1}, new int[] {3, 2}, new int[] {3, 2}));
+ expectedByDevice.put("d2", Arrays.asList(new int[] {2, 1}, new int[] {2, 1}));
+
+ verifyTopKResultsByPartition(
+ timeArray,
+ deviceArray,
+ valueArray,
+ Collections.singletonList(1),
+ Collections.singletonList(TSDataType.TEXT),
+ Collections.singletonList(2),
+ Collections.singletonList(SortOrder.ASC_NULLS_LAST),
+ 2,
+ false,
+ TopKRankingNode.RankingType.RANK,
+ expectedByDevice,
+ 5);
+ }
+
+ @Test
+ public void testRankTopOneWithTies() {
+ // d1: values [5, 3], d2: values [2, 2]
+ // topN=1 ASC → d1: 3(r=1); d2: 2(r=1),2(r=1) (ties at rank 1 are all kept)
+ long[][] timeArray = {{1, 2, 3, 4}};
+ String[][] deviceArray = {{"d1", "d1", "d2", "d2"}};
+ int[][] valueArray = {{5, 3, 2, 2}};
+
+ Map> expectedByDevice = new HashMap<>();
+ expectedByDevice.put("d1", Collections.singletonList(new int[] {3, 1}));
+ expectedByDevice.put("d2", Arrays.asList(new int[] {2, 1}, new int[] {2, 1}));
+
+ verifyTopKResultsByPartition(
+ timeArray,
+ deviceArray,
+ valueArray,
+ Collections.singletonList(1),
+ Collections.singletonList(TSDataType.TEXT),
+ Collections.singletonList(2),
+ Collections.singletonList(SortOrder.ASC_NULLS_LAST),
+ 1,
+ false,
+ TopKRankingNode.RankingType.RANK,
+ expectedByDevice,
+ 3);
+ }
+
+ @Test
+ public void testRankNoTiesBehavesLikeRowNumber() {
+ // When no ties, rank should produce the same results as row_number
+ long[][] timeArray = {{1, 2, 3, 4, 5, 6, 7}};
+ String[][] deviceArray = {{"d1", "d1", "d1", "d1", "d2", "d2", "d2"}};
+ int[][] valueArray = {{5, 3, 1, 4, 6, 2, 7}};
+
+ Map> expectedByDevice = new HashMap<>();
+ expectedByDevice.put("d1", Arrays.asList(new int[] {1, 1}, new int[] {3, 2}));
+ expectedByDevice.put("d2", Arrays.asList(new int[] {2, 1}, new int[] {6, 2}));
+
+ verifyTopKResultsByPartition(
+ timeArray,
+ deviceArray,
+ valueArray,
+ Collections.singletonList(1),
+ Collections.singletonList(TSDataType.TEXT),
+ Collections.singletonList(2),
+ Collections.singletonList(SortOrder.ASC_NULLS_LAST),
+ 2,
+ false,
+ TopKRankingNode.RankingType.RANK,
+ expectedByDevice,
+ 4);
+ }
+
/**
* Verifies top-K results grouped by partition (device). The output order between partitions is
* not guaranteed, so we group results by device and verify each partition independently.
@@ -211,6 +370,34 @@ private void verifyTopKResultsByPartition(
boolean partial,
Map> expectedByDevice,
int expectedTotalCount) {
+ verifyTopKResultsByPartition(
+ timeArray,
+ deviceArray,
+ valueArray,
+ partitionChannels,
+ partitionTypes,
+ sortChannels,
+ sortOrders,
+ maxRowCountPerPartition,
+ partial,
+ TopKRankingNode.RankingType.ROW_NUMBER,
+ expectedByDevice,
+ expectedTotalCount);
+ }
+
+ private void verifyTopKResultsByPartition(
+ long[][] timeArray,
+ String[][] deviceArray,
+ int[][] valueArray,
+ List partitionChannels,
+ List partitionTypes,
+ List sortChannels,
+ List sortOrders,
+ int maxRowCountPerPartition,
+ boolean partial,
+ TopKRankingNode.RankingType rankingType,
+ Map> expectedByDevice,
+ int expectedTotalCount) {
Map> actualByDevice = new HashMap<>();
int count = 0;
@@ -225,7 +412,8 @@ private void verifyTopKResultsByPartition(
sortChannels,
sortOrders,
maxRowCountPerPartition,
- partial)) {
+ partial,
+ rankingType)) {
while (!operator.isFinished()) {
if (operator.hasNext()) {
TsBlock tsBlock = operator.next();
@@ -282,7 +470,34 @@ private void verifyTopKResultsGlobal(
boolean partial,
int[][] expectedValueAndRn,
int expectedTotalCount) {
+ verifyTopKResultsGlobal(
+ timeArray,
+ deviceArray,
+ valueArray,
+ partitionChannels,
+ partitionTypes,
+ sortChannels,
+ sortOrders,
+ maxRowCountPerPartition,
+ partial,
+ TopKRankingNode.RankingType.ROW_NUMBER,
+ expectedValueAndRn,
+ expectedTotalCount);
+ }
+ private void verifyTopKResultsGlobal(
+ long[][] timeArray,
+ String[][] deviceArray,
+ int[][] valueArray,
+ List partitionChannels,
+ List partitionTypes,
+ List sortChannels,
+ List sortOrders,
+ int maxRowCountPerPartition,
+ boolean partial,
+ TopKRankingNode.RankingType rankingType,
+ int[][] expectedValueAndRn,
+ int expectedTotalCount) {
List results = new ArrayList<>();
int count = 0;
@@ -296,7 +511,8 @@ private void verifyTopKResultsGlobal(
sortChannels,
sortOrders,
maxRowCountPerPartition,
- partial)) {
+ partial,
+ rankingType)) {
while (!operator.isFinished()) {
if (operator.hasNext()) {
TsBlock tsBlock = operator.next();
@@ -346,6 +562,30 @@ private TopKRankingOperator genTopKRankingOperator(
List sortOrders,
int maxRowCountPerPartition,
boolean partial) {
+ return genTopKRankingOperator(
+ timeArray,
+ deviceArray,
+ valueArray,
+ partitionChannels,
+ partitionTypes,
+ sortChannels,
+ sortOrders,
+ maxRowCountPerPartition,
+ partial,
+ TopKRankingNode.RankingType.ROW_NUMBER);
+ }
+
+ private TopKRankingOperator genTopKRankingOperator(
+ long[][] timeArray,
+ String[][] deviceArray,
+ int[][] valueArray,
+ List partitionChannels,
+ List partitionTypes,
+ List sortChannels,
+ List sortOrders,
+ int maxRowCountPerPartition,
+ boolean partial,
+ TopKRankingNode.RankingType rankingType) {
DriverContext driverContext = createDriverContext();
List inputDataTypes =
@@ -359,7 +599,7 @@ private TopKRankingOperator genTopKRankingOperator(
return new TopKRankingOperator(
driverContext.getOperatorContexts().get(0),
childOperator,
- TopKRankingNode.RankingType.ROW_NUMBER,
+ rankingType,
inputDataTypes,
outputChannels,
partitionChannels,
diff --git a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/WindowFunctionOptimizationTest.java b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/WindowFunctionOptimizationTest.java
index e31f2f7e58065..8dec0068e9506 100644
--- a/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/WindowFunctionOptimizationTest.java
+++ b/iotdb-core/datanode/src/test/java/org/apache/iotdb/db/queryengine/plan/relational/planner/WindowFunctionOptimizationTest.java
@@ -20,7 +20,12 @@
package org.apache.iotdb.db.queryengine.plan.relational.planner;
import org.apache.iotdb.db.queryengine.plan.planner.plan.LogicalQueryPlan;
+import org.apache.iotdb.db.queryengine.plan.planner.plan.node.PlanNode;
import org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.DeviceTableScanNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.ProjectNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.RowNumberNode;
+import org.apache.iotdb.db.queryengine.plan.relational.planner.node.TopKRankingNode;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
@@ -29,8 +34,10 @@
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanAssert.assertPlan;
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.collect;
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.exchange;
+import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.filter;
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.group;
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.limit;
+import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.mergeSort;
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.output;
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.project;
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.rowNumber;
@@ -38,6 +45,10 @@
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.tableScan;
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.topKRanking;
import static org.apache.iotdb.db.queryengine.plan.relational.planner.assertions.PlanMatchPattern.window;
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
public class WindowFunctionOptimizationTest {
@Test
@@ -122,7 +133,7 @@ public void testTopKRankingPushDown() {
PlanTester planTester = new PlanTester();
String sql =
- "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3 ORDER BY s1) as rn FROM table1) WHERE rn <= 2";
+ "SELECT * FROM (SELECT *, rank() OVER (PARTITION BY tag1, tag2, tag3 ORDER BY s1) as rn FROM table1) WHERE rn <= 2";
LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
PlanMatchPattern tableScan = tableScan("testdb.table1");
@@ -153,7 +164,9 @@ public void testTopKRankingPushDown() {
planTester.getFragmentPlan(0), output((collect(exchange(), exchange(), exchange()))));
assertPlan(planTester.getFragmentPlan(1), topKRanking(tableScan));
assertPlan(planTester.getFragmentPlan(2), topKRanking(tableScan));
- assertPlan(planTester.getFragmentPlan(3), topKRanking(sort(tableScan)));
+ assertPlan(planTester.getFragmentPlan(3), topKRanking(mergeSort(exchange(), exchange())));
+ assertPlan(planTester.getFragmentPlan(4), sort(tableScan));
+ assertPlan(planTester.getFragmentPlan(5), sort(tableScan));
}
@Test
@@ -161,7 +174,7 @@ public void testPushDownFilterIntoWindow() {
PlanTester planTester = new PlanTester();
String sql =
- "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY tag1 ORDER BY s1) as rn FROM table1) WHERE rn <= 2";
+ "SELECT * FROM (SELECT *, rank() OVER (PARTITION BY tag1 ORDER BY s1) as rn FROM table1) WHERE rn <= 2";
LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
PlanMatchPattern tableScan = tableScan("testdb.table1");
@@ -297,4 +310,280 @@ public void testRowNumberPushDown() {
assertPlan(planTester.getFragmentPlan(1), rowNumber(tableScan));
assertPlan(planTester.getFragmentPlan(2), rowNumber(tableScan));
}
+
+ @Test
+ public void testTopKRankingOrderByTimeLimitPushDown() {
+ PlanTester planTester = new PlanTester();
+
+ String sql =
+ "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3 ORDER BY time) as rn FROM table1) WHERE rn <= 2";
+ LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
+ PlanMatchPattern tableScan = tableScan("testdb.table1");
+
+ // Logical plan: OutputNode -> TopKRankingNode -> GroupNode -> TableScanNode
+ assertPlan(logicalQueryPlan, output(topKRanking(group(tableScan))));
+
+ // Distributed plan: TopKRankingNode pushed down to each partition with limit push-down.
+ // Fragment 0: OutputNode -> CollectNode -> ExchangeNodes
+ assertPlan(planTester.getFragmentPlan(0), output(collect(exchange(), exchange(), exchange())));
+
+ // Worker fragments: TopKRankingNode -> DeviceTableScanNode
+ // Verify limit is pushed to DeviceTableScanNode and TopKRankingNode is marked for streaming.
+ for (int i = 1; i <= 2; i++) {
+ PlanNode fragmentRoot = planTester.getFragmentPlan(i);
+ assertTrue(
+ "Fragment " + i + " root should be TopKRankingNode",
+ fragmentRoot instanceof TopKRankingNode);
+ TopKRankingNode topKNode = (TopKRankingNode) fragmentRoot;
+ assertTrue(
+ "TopKRankingNode should be marked as dataPreSortedAndLimited",
+ topKNode.isDataPreSortedAndLimited());
+
+ PlanNode scanChild = topKNode.getChild();
+ assertNotNull("TopKRankingNode should have a child", scanChild);
+ assertTrue("Child should be DeviceTableScanNode", scanChild instanceof DeviceTableScanNode);
+ DeviceTableScanNode dts = (DeviceTableScanNode) scanChild;
+ assertTrue("pushLimitToEachDevice should be true", dts.isPushLimitToEachDevice());
+ assertEquals("pushDownLimit should be 2", 2, dts.getPushDownLimit());
+ }
+ }
+
+ @Test
+ public void testTopKRankingEliminatedWhenRankSymbolNotOutput() {
+ PlanTester planTester = new PlanTester();
+
+ String sql =
+ "SELECT tag1, s1 FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3 ORDER BY time) as rn FROM table1) WHERE rn <= 2";
+ LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
+ PlanMatchPattern tableScan = tableScan("testdb.table1");
+
+ // Logical plan: OutputNode -> ProjectNode -> TopKRankingNode -> GroupNode -> TableScanNode
+ assertPlan(logicalQueryPlan, output(project(topKRanking(group(tableScan)))));
+
+ // Distributed plan: TopKRankingNode eliminated since rn is not in the output.
+ // Limit is pushed to DeviceTableScanNode.
+ // Fragment 0: OutputNode -> CollectNode -> ExchangeNodes
+ assertPlan(planTester.getFragmentPlan(0), output(collect(exchange(), exchange(), exchange())));
+
+ // Worker fragments: ProjectNode -> DeviceTableScanNode (no TopKRankingNode)
+ for (int i = 1; i <= 2; i++) {
+ PlanNode fragmentRoot = planTester.getFragmentPlan(i);
+ assertFalse(
+ "Fragment " + i + " root should NOT be TopKRankingNode",
+ fragmentRoot instanceof TopKRankingNode);
+ assertPlan(planTester.getFragmentPlan(i), project(tableScan));
+
+ assertTrue(
+ "Fragment " + i + " root should be ProjectNode", fragmentRoot instanceof ProjectNode);
+ PlanNode scanChild = fragmentRoot.getChildren().get(0);
+ assertTrue("Child should be DeviceTableScanNode", scanChild instanceof DeviceTableScanNode);
+ DeviceTableScanNode dts = (DeviceTableScanNode) scanChild;
+ assertTrue("pushLimitToEachDevice should be true", dts.isPushLimitToEachDevice());
+ assertEquals("pushDownLimit should be 2", 2, dts.getPushDownLimit());
+ }
+ }
+
+ @Test
+ public void testTopKRankingKeptWhenRankSymbolIsOutput() {
+ PlanTester planTester = new PlanTester();
+
+ // Same query but SELECT * includes rn - TopKRankingNode should NOT be eliminated
+ String sql =
+ "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3 ORDER BY time) as rn FROM table1) WHERE rn <= 2";
+ LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
+ PlanMatchPattern tableScan = tableScan("testdb.table1");
+
+ assertPlan(logicalQueryPlan, output(topKRanking(group(tableScan))));
+
+ // Worker fragments should still have TopKRankingNode
+ for (int i = 1; i <= 2; i++) {
+ PlanNode fragmentRoot = planTester.getFragmentPlan(i);
+ assertTrue(
+ "Fragment " + i + " root should be TopKRankingNode",
+ fragmentRoot instanceof TopKRankingNode);
+ }
+ }
+
+ @Test
+ public void testRowNumberEliminatedWhenRowNumberNotOutput() {
+ PlanTester planTester = new PlanTester();
+
+ // RowNumber with all IDs as partition - row number not in output
+ String sql =
+ "SELECT tag1, s1 FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3) as rn FROM table1)";
+ LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
+ PlanMatchPattern tableScan = tableScan("testdb.table1");
+
+ // Logical plan: row_number is pruned at the window level by PruneWindowColumns
+ // since rn is not referenced anywhere. The plan should not contain RowNumberNode.
+ assertPlan(logicalQueryPlan, output(project(group(tableScan))));
+ }
+
+ @Test
+ public void testRowNumberPushDownWhenRowNumberIsOutput() {
+ PlanTester planTester = new PlanTester();
+
+ // RowNumber with all IDs as partition and rn IS referenced in output
+ String sql =
+ "SELECT tag1, s1, rn FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3) as rn FROM table1)";
+ LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
+ PlanMatchPattern tableScan = tableScan("testdb.table1");
+
+ // Logical plan: OutputNode -> ProjectNode -> RowNumberNode -> GroupNode -> TableScanNode
+ // (project is inlined since it selects a subset including rn)
+ assertPlan(logicalQueryPlan, output(project(rowNumber(group(tableScan)))));
+
+ // RowNumberNode is pushed down to each partition (not eliminated, since rn IS in the output)
+ assertPlan(planTester.getFragmentPlan(0), output(collect(exchange(), exchange(), exchange())));
+ assertPlan(planTester.getFragmentPlan(1), project(rowNumber(tableScan)));
+ assertPlan(planTester.getFragmentPlan(2), project(rowNumber(tableScan)));
+ }
+
+ @Test
+ public void testRowNumberWithMaxCountEliminatedWhenRowNumberNotOutput() {
+ PlanTester planTester = new PlanTester();
+
+ // rn <= 2 pushes the limit into RowNumberNode (maxRowCountPerPartition=2), and since rn is
+ // not in the outer SELECT, the RowNumberNode is eliminated in the distributed plan.
+ String sql =
+ "SELECT tag1, s1 FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3) as rn FROM table1) WHERE rn <= 2";
+ LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
+ PlanMatchPattern tableScan = tableScan("testdb.table1");
+
+ // Logical plan: PushFilterIntoRowNumber absorbs rn<=2, leaving RowNumberNode with maxRowCount=2
+ // No filter remains above RowNumberNode.
+ assertPlan(logicalQueryPlan, output(project(rowNumber(group(tableScan)))));
+
+ // Distributed plan: RowNumberNode eliminated since rn is not in the output.
+ // Limit (maxRowCountPerPartition=2) is pushed to each DeviceTableScanNode.
+ assertPlan(planTester.getFragmentPlan(0), output(collect(exchange(), exchange(), exchange())));
+
+ // Worker fragments: ProjectNode -> DeviceTableScanNode (no RowNumberNode)
+ for (int i = 1; i <= 2; i++) {
+ PlanNode fragmentRoot = planTester.getFragmentPlan(i);
+ assertFalse(
+ "Fragment " + i + " root should NOT be RowNumberNode",
+ fragmentRoot instanceof RowNumberNode);
+ assertPlan(planTester.getFragmentPlan(i), project(tableScan));
+
+ assertTrue(
+ "Fragment " + i + " root should be ProjectNode", fragmentRoot instanceof ProjectNode);
+ PlanNode scanChild = fragmentRoot.getChildren().get(0);
+ assertTrue("Child should be DeviceTableScanNode", scanChild instanceof DeviceTableScanNode);
+ DeviceTableScanNode dts = (DeviceTableScanNode) scanChild;
+ assertTrue("pushLimitToEachDevice should be true", dts.isPushLimitToEachDevice());
+ assertEquals("pushDownLimit should be 2", 2, dts.getPushDownLimit());
+ }
+ }
+
+ @Test
+ public void testRowNumberWithMaxCountKeptWhenRowNumberIsOutput() {
+ PlanTester planTester = new PlanTester();
+
+ // Same query but SELECT * includes rn - RowNumberNode should NOT be eliminated
+ String sql =
+ "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3) as rn FROM table1) WHERE rn <= 2";
+ LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
+ PlanMatchPattern tableScan = tableScan("testdb.table1");
+
+ // Logical plan: RowNumberNode with maxRowCount=2 (filter absorbed), no outer project removing
+ // rn
+ assertPlan(logicalQueryPlan, output(rowNumber(group(tableScan))));
+
+ // Worker fragments should still have RowNumberNode since rn IS in the output
+ for (int i = 1; i <= 2; i++) {
+ PlanNode fragmentRoot = planTester.getFragmentPlan(i);
+ assertTrue(
+ "Fragment " + i + " root should be RowNumberNode", fragmentRoot instanceof RowNumberNode);
+ }
+ }
+
+ @Test
+ public void testTopKRankingWithEqualPredicate() {
+ PlanTester planTester = new PlanTester();
+
+ String sql =
+ "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY tag1 ORDER BY s1) as rn FROM table1) WHERE rn = 2";
+ LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
+ PlanMatchPattern tableScan = tableScan("testdb.table1");
+
+ // TopKRanking created with maxRanking=2, but filter(rn = 2) is kept because
+ // ranking values 1..2 do not all satisfy rn = 2
+ /*
+ * └──OutputNode
+ * └──FilterNode(rn = 2)
+ * └──TopKRankingNode
+ * └──SortNode
+ * └──TableScanNode
+ */
+ assertPlan(logicalQueryPlan, output(filter(topKRanking(sort(tableScan)))));
+ }
+
+ @Test
+ public void testTopKRankingWithEqualPredicateAllPartitions() {
+ PlanTester planTester = new PlanTester();
+
+ String sql =
+ "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3 ORDER BY s1) as rn FROM table1) WHERE rn = 2";
+ LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
+ PlanMatchPattern tableScan = tableScan("testdb.table1");
+
+ // TopKRanking created with maxRanking=2, filter(rn = 2) is kept
+ /*
+ * └──OutputNode
+ * └──FilterNode(rn = 2)
+ * └──TopKRankingNode
+ * └──GroupNode
+ * └──TableScanNode
+ */
+ assertPlan(logicalQueryPlan, output(filter(topKRanking(group(tableScan)))));
+
+ // Distributed plan: TopKRanking and filter pushed down
+ assertPlan(planTester.getFragmentPlan(0), output(collect(exchange(), exchange(), exchange())));
+ assertPlan(planTester.getFragmentPlan(1), filter(topKRanking(tableScan)));
+ assertPlan(planTester.getFragmentPlan(2), filter(topKRanking(tableScan)));
+ }
+
+ @Test
+ public void testTopKRankingWithLessThanPredicate() {
+ PlanTester planTester = new PlanTester();
+
+ // rn < 3 is equivalent to rn <= 2, so the filter should be fully absorbed
+ String sql =
+ "SELECT * FROM (SELECT *, row_number() OVER (PARTITION BY tag1 ORDER BY s1) as rn FROM table1) WHERE rn < 3";
+ LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
+ PlanMatchPattern tableScan = tableScan("testdb.table1");
+
+ // Filter absorbed into TopKRankingNode (maxRanking=2)
+ /*
+ * └──OutputNode
+ * └──TopKRankingNode
+ * └──SortNode
+ * └──TableScanNode
+ */
+ assertPlan(logicalQueryPlan, output(topKRanking(sort(tableScan))));
+ }
+
+ @Test
+ public void testTopKRankingWithEqualPredicateColumnPruned() {
+ PlanTester planTester = new PlanTester();
+
+ // rn = 2 with rn not in output: filter kept, rn pruned by project
+ String sql =
+ "SELECT tag1, s1 FROM (SELECT *, row_number() OVER (PARTITION BY tag1, tag2, tag3 ORDER BY s1) as rn FROM table1) WHERE rn = 2";
+ LogicalQueryPlan logicalQueryPlan = planTester.createPlan(sql);
+ PlanMatchPattern tableScan = tableScan("testdb.table1");
+
+ // Filter(rn = 2) kept, project prunes rn from output
+ /*
+ * └──OutputNode
+ * └──ProjectNode
+ * └──FilterNode(rn = 2)
+ * └──ProjectNode
+ * └──TopKRankingNode
+ * └──GroupNode
+ * └──TableScanNode
+ */
+ assertPlan(logicalQueryPlan, output(project(filter(project(topKRanking(group(tableScan)))))));
+ }
}