Skip to content

Commit

Permalink
support partial limit push down
Browse files Browse the repository at this point in the history
  • Loading branch information
gnehil committed Jan 10, 2025
1 parent 20b1228 commit 3dcc42b
Show file tree
Hide file tree
Showing 8 changed files with 63 additions and 17 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class DorisReaderPartition implements Serializable {
private final String opaquedQueryPlan;
private final String[] readColumns;
private final String[] filters;
private final Integer limit;
private final DorisConfig config;

public DorisReaderPartition(String database, String table, Backend backend, Long[] tablets, String opaquedQueryPlan, String[] readColumns, String[] filters, DorisConfig config) {
Expand All @@ -42,6 +43,19 @@ public DorisReaderPartition(String database, String table, Backend backend, Long
this.opaquedQueryPlan = opaquedQueryPlan;
this.readColumns = readColumns;
this.filters = filters;
this.limit = -1;
this.config = config;
}

public DorisReaderPartition(String database, String table, Backend backend, Long[] tablets, String opaquedQueryPlan, String[] readColumns, String[] filters, Integer limit, DorisConfig config) {
this.database = database;
this.table = table;
this.backend = backend;
this.tablets = tablets;
this.opaquedQueryPlan = opaquedQueryPlan;
this.readColumns = readColumns;
this.filters = filters;
this.limit = limit;
this.config = config;
}

Expand Down Expand Up @@ -78,6 +92,10 @@ public String[] getFilters() {
return filters;
}

public Integer getLimit() {
return limit;
}

@Override
public boolean equals(Object o) {
if (o == null || getClass() != o.getClass()) return false;
Expand All @@ -89,11 +107,12 @@ public boolean equals(Object o) {
&& Objects.equals(opaquedQueryPlan, that.opaquedQueryPlan)
&& Objects.deepEquals(readColumns, that.readColumns)
&& Objects.deepEquals(filters, that.filters)
&& Objects.equals(limit, that.limit)
&& Objects.equals(config, that.config);
}

@Override
public int hashCode() {
return Objects.hash(database, table, backend, Arrays.hashCode(tablets), opaquedQueryPlan, Arrays.hashCode(readColumns), Arrays.hashCode(filters), config);
return Objects.hash(database, table, backend, Arrays.hashCode(tablets), opaquedQueryPlan, Arrays.hashCode(readColumns), Arrays.hashCode(filters), limit, config);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ public abstract class AbstractThriftReader extends DorisReader {

private final Thread asyncThread;

private int readCount = 0;

protected AbstractThriftReader(DorisReaderPartition partition) throws Exception {
super(partition);
this.frontend = new DorisFrontendClient(config);
Expand Down Expand Up @@ -132,6 +134,9 @@ private void runAsync() throws DorisException, InterruptedException {

@Override
public boolean hasNext() throws DorisException {
if (partition.getLimit() > 0 && readCount >= partition.getLimit()) {
return false;
}
boolean hasNext = false;
if (isAsync && asyncThread != null && asyncThread.isAlive()) {
if (rowBatch == null || !rowBatch.hasNext()) {
Expand Down Expand Up @@ -186,6 +191,9 @@ public Object next() throws DorisException {
if (!hasNext()) {
throw new RuntimeException("No more elements");
}
if (partition.getLimit() > 0) {
readCount++;
}
return rowBatch.next().toArray();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,7 +149,8 @@ protected String generateQuerySql(DorisReaderPartition partition) throws OptionR
String fullTableName = config.getValue(DorisOptions.DORIS_TABLE_IDENTIFIER);
String tablets = String.format("TABLET(%s)", StringUtils.join(partition.getTablets(), ","));
String predicates = partition.getFilters().length == 0 ? "" : " WHERE " + String.join(" AND ", partition.getFilters());
return String.format("SELECT %s FROM %s %s%s", columns, fullTableName, tablets, predicates);
String limit = partition.getLimit() > 0 ? " LIMIT " + partition.getLimit() : "";
return String.format("SELECT %s FROM %s %s%s%s", columns, fullTableName, tablets, predicates, limit);
}

protected Schema processDorisSchema(DorisReaderPartition partition, final Schema originSchema) throws Exception {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
import org.apache.doris.spark.rest.models.QueryPlan;
import org.apache.doris.spark.rest.models.Schema;
import org.apache.doris.spark.util.DorisDialects;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.util.ArrayList;
import java.util.Arrays;
Expand All @@ -38,6 +40,8 @@

public class ReaderPartitionGenerator {

private static final Logger LOG = LoggerFactory.getLogger(ReaderPartitionGenerator.class);

/*
* for spark 2
*/
Expand All @@ -51,14 +55,14 @@ public static DorisReaderPartition[] generatePartitions(DorisConfig config) thro
}
String[] filters = config.contains(DorisOptions.DORIS_FILTER_QUERY) ?
config.getValue(DorisOptions.DORIS_FILTER_QUERY).split("\\.") : new String[0];
return generatePartitions(config, originReadCols, filters);
return generatePartitions(config, originReadCols, filters, -1);
}

/*
* for spark 3
*/
public static DorisReaderPartition[] generatePartitions(DorisConfig config,
String[] fields, String[] filters) throws Exception {
String[] fields, String[] filters, Integer limit) throws Exception {
DorisFrontendClient frontend = new DorisFrontendClient(config);
String fullTableName = config.getValue(DorisOptions.DORIS_TABLE_IDENTIFIER);
String[] tableParts = fullTableName.split("\\.");
Expand All @@ -69,13 +73,15 @@ public static DorisReaderPartition[] generatePartitions(DorisConfig config,
originReadCols = frontend.getTableAllColumns(db, table);
}
String[] finalReadColumns = getFinalReadColumns(config, frontend, db, table, originReadCols);
String sql = "SELECT " + String.join(",", finalReadColumns) + " FROM `" + db + "`.`" + table + "`" +
(filters.length == 0 ? "" : " WHERE " + String.join(" AND ", filters));
String finalReadColumnString = String.join(",", finalReadColumns);
String finalWhereClauseString = filters.length == 0 ? "" : " WHERE " + String.join(" AND ", filters);
String sql = "SELECT " + finalReadColumnString + " FROM `" + db + "`.`" + table + "`" + finalWhereClauseString;
LOG.info("get query plan for table " + db + "." + table + ", sql: " + sql);
QueryPlan queryPlan = frontend.getQueryPlan(db, table, sql);
Map<String, List<Long>> beToTablets = mappingBeToTablets(queryPlan);
int maxTabletSize = config.getValue(DorisOptions.DORIS_TABLET_SIZE);
return distributeTabletsToPartitions(db, table, beToTablets, queryPlan.getOpaqued_query_plan(), maxTabletSize,
finalReadColumns, filters, config);
finalReadColumns, filters, config, limit);
}

@VisibleForTesting
Expand Down Expand Up @@ -106,7 +112,7 @@ private static DorisReaderPartition[] distributeTabletsToPartitions(String datab
Map<String, List<Long>> beToTablets,
String opaquedQueryPlan, int maxTabletSize,
String[] readColumns, String[] predicates,
DorisConfig config) {
DorisConfig config, Integer limit) {
List<DorisReaderPartition> partitions = new ArrayList<>();
beToTablets.forEach((backendStr, tabletIds) -> {
List<Long> distinctTablets = new ArrayList<>(new HashSet<>(tabletIds));
Expand All @@ -115,7 +121,7 @@ private static DorisReaderPartition[] distributeTabletsToPartitions(String datab
Long[] tablets = distinctTablets.subList(offset, Math.min(offset + maxTabletSize, distinctTablets.size())).toArray(new Long[0]);
offset += maxTabletSize;
partitions.add(new DorisReaderPartition(database, table, new Backend(backendStr), tablets,
opaquedQueryPlan, readColumns, predicates, config));
opaquedQueryPlan, readColumns, predicates, limit, config));
}
});
return partitions.toArray(new DorisReaderPartition[0]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ abstract class AbstractDorisScan(config: DorisConfig, schema: StructType) extend
override def toBatch: Batch = this

override def planInputPartitions(): Array[InputPartition] = {
ReaderPartitionGenerator.generatePartitions(config, schema.names, compiledFilters()).map(toInputPartition)
ReaderPartitionGenerator.generatePartitions(config, schema.names, compiledFilters(), getLimit).map(toInputPartition)
}


Expand All @@ -44,10 +44,12 @@ abstract class AbstractDorisScan(config: DorisConfig, schema: StructType) extend
}

private def toInputPartition(rp: DorisReaderPartition): DorisInputPartition =
DorisInputPartition(rp.getDatabase, rp.getTable, rp.getBackend, rp.getTablets.map(_.toLong), rp.getOpaquedQueryPlan, rp.getReadColumns, rp.getFilters)
DorisInputPartition(rp.getDatabase, rp.getTable, rp.getBackend, rp.getTablets.map(_.toLong), rp.getOpaquedQueryPlan, rp.getReadColumns, rp.getFilters, rp.getLimit)

protected def compiledFilters(): Array[String]

protected def getLimit: Int = -1

}

case class DorisInputPartition(database: String, table: String, backend: Backend, tablets: Array[Long], opaquedQueryPlan: String, readCols: Array[String], predicates: Array[String]) extends InputPartition
case class DorisInputPartition(database: String, table: String, backend: Backend, tablets: Array[Long], opaquedQueryPlan: String, readCols: Array[String], predicates: Array[String], limit: Int = -1) extends InputPartition
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ class DorisPartitionReader(inputPartition: InputPartition, schema: StructType, m
private implicit def toReaderPartition(inputPart: DorisInputPartition): DorisReaderPartition = {
val tablets = inputPart.tablets.map(java.lang.Long.valueOf)
new DorisReaderPartition(inputPart.database, inputPart.table, inputPart.backend, tablets,
inputPart.opaquedQueryPlan, inputPart.readCols, inputPart.predicates, config)
inputPart.opaquedQueryPlan, inputPart.readCols, inputPart.predicates, inputPart.limit, config)
}

private lazy val reader: DorisReader = {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,20 @@ package org.apache.doris.spark.read
import org.apache.doris.spark.config.{DorisConfig, DorisOptions}
import org.apache.doris.spark.read.expression.V2ExpressionBuilder
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownV2Filters}
import org.apache.spark.sql.connector.read.{Scan, SupportsPushDownLimit, SupportsPushDownV2Filters}
import org.apache.spark.sql.types.StructType

class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisScanBuilderBase(config, schema)
with SupportsPushDownV2Filters {
with SupportsPushDownV2Filters
with SupportsPushDownLimit {

private var pushDownPredicates: Array[Predicate] = Array[Predicate]()

private val expressionBuilder = new V2ExpressionBuilder(config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT))

override def build(): Scan = new DorisScanV2(config, schema, pushDownPredicates)
private var limitSize: Int = -1

override def build(): Scan = new DorisScanV2(config, schema, pushDownPredicates, limitSize)

override def pushPredicates(predicates: Array[Predicate]): Array[Predicate] = {
val (pushed, unsupported) = predicates.partition(predicate => {
Expand All @@ -42,4 +45,9 @@ class DorisScanBuilder(config: DorisConfig, schema: StructType) extends DorisSca

override def pushedPredicates(): Array[Predicate] = pushDownPredicates

override def pushLimit(i: Int): Boolean = {
limitSize = i
true
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ import org.apache.spark.internal.Logging
import org.apache.spark.sql.connector.expressions.filter.Predicate
import org.apache.spark.sql.types.StructType

class DorisScanV2(config: DorisConfig, schema: StructType, filters: Array[Predicate]) extends AbstractDorisScan(config, schema) with Logging {
class DorisScanV2(config: DorisConfig, schema: StructType, filters: Array[Predicate], limit: Int) extends AbstractDorisScan(config, schema) with Logging {
override protected def compiledFilters(): Array[String] = {
val inValueLengthLimit = config.getValue(DorisOptions.DORIS_FILTER_QUERY_IN_MAX_COUNT)
val v2ExpressionBuilder = new V2ExpressionBuilder(inValueLengthLimit)
filters.map(e => Option[String](v2ExpressionBuilder.build(e))).filter(_.isDefined).map(_.get)
}

override protected def getLimit: Int = limit
}

0 comments on commit 3dcc42b

Please sign in to comment.