Skip to content

Commit

Permalink
[fix](connector) fix arrow deserialize issue due to data being incons…
Browse files Browse the repository at this point in the history
…istent with column order (apache#256)
  • Loading branch information
gnehil authored Jan 9, 2025
1 parent 38a60ff commit 20b1228
Show file tree
Hide file tree
Showing 5 changed files with 36 additions and 33 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,11 @@ public QueryPlan getQueryPlan(String database, String table, String sql) throws
throw new DorisException();
}
String entity = EntityUtils.toString(response.getEntity());
return MAPPER.readValue(extractEntity(entity, "data").traverse(), QueryPlan.class);
JsonNode dataJsonNode = extractEntity(entity, "data");
if (dataJsonNode.get("exception") != null) {
throw new DorisException("query plan failed, exception: " + dataJsonNode.get("exception").asText());
}
return MAPPER.readValue(dataJsonNode.traverse(), QueryPlan.class);
} catch (Exception e) {
throw new RuntimeException("query plan request failed", e);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -88,8 +88,9 @@ protected AbstractThriftReader(DorisReaderPartition partition) throws Exception
this.contextId = scanOpenResult.getContextId();
Schema schema = getDorisSchema();
this.dorisSchema = processDorisSchema(partition, schema);
logger.debug("origin thrift read Schema: " + schema + ", processed schema: " + dorisSchema);

if (logger.isDebugEnabled()) {
logger.debug("origin thrift read Schema: " + schema + ", processed schema: " + dorisSchema);
}
if (isAsync) {
int blockingQueueSize = config.getValue(DorisOptions.DORIS_DESERIALIZE_QUEUE_SIZE);
this.rowBatchQueue = new ArrayBlockingQueue<>(blockingQueueSize);
Expand Down Expand Up @@ -241,22 +242,21 @@ protected Schema processDorisSchema(DorisReaderPartition partition, final Schema
Schema tableSchema = frontend.getTableSchema(partition.getDatabase(), partition.getTable());
Map<String, Field> fieldTypeMap = tableSchema.getProperties().stream()
.collect(Collectors.toMap(Field::getName, Function.identity()));
Map<String, Field> scanTypeMap = originSchema.getProperties().stream()
.collect(Collectors.toMap(Field::getName, Function.identity()));
String[] readColumns = partition.getReadColumns();
List<Field> newFieldList = new ArrayList<>();
int offset = 0;
for (int i = 0; i < readColumns.length; i++) {
String readColumn = readColumns[i];
if (!fieldTypeMap.containsKey(readColumn) && readColumn.contains(" AS ")) {
for (String readColumn : readColumns) {
if (readColumn.contains(" AS ")) {
int asIdx = readColumn.indexOf(" AS ");
String realColumn = readColumn.substring(asIdx + 4).trim().replaceAll("`", "");
if (fieldTypeMap.containsKey(realColumn)
if (fieldTypeMap.containsKey(realColumn) && scanTypeMap.containsKey(realColumn)
&& ("BITMAP".equalsIgnoreCase(fieldTypeMap.get(realColumn).getType())
|| "HLL".equalsIgnoreCase(fieldTypeMap.get(realColumn).getType()))) {
newFieldList.add(new Field(realColumn, TPrimitiveType.VARCHAR.name(), null, 0, 0, null));
offset++;
}
} else {
newFieldList.add(originSchema.getProperties().get(i + offset));
newFieldList.add(scanTypeMap.get(readColumn.trim().replaceAll("`", "")));
}
}
processedSchema.setProperties(newFieldList);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,7 @@ public void convertArrowToRowBatch() throws DorisException {
FieldVector curFieldVector = fieldVectors.get(col);
MinorType mt = curFieldVector.getMinorType();

final String colName = schema.get(col).getName();
final String currentType = schema.get(col).getType();
switch (currentType) {
case "NULL_TYPE":
Expand All @@ -203,7 +204,7 @@ public void convertArrowToRowBatch() throws DorisException {
break;
case "BOOLEAN":
Preconditions.checkArgument(mt.equals(MinorType.BIT),
typeMismatchMessage(currentType, mt));
typeMismatchMessage(colName, currentType, mt));
BitVector bitVector = (BitVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Object fieldValue = bitVector.isNull(rowIndex) ? null : bitVector.get(rowIndex) != 0;
Expand All @@ -212,7 +213,7 @@ public void convertArrowToRowBatch() throws DorisException {
break;
case "TINYINT":
Preconditions.checkArgument(mt.equals(MinorType.TINYINT),
typeMismatchMessage(currentType, mt));
typeMismatchMessage(colName, currentType, mt));
TinyIntVector tinyIntVector = (TinyIntVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Object fieldValue = tinyIntVector.isNull(rowIndex) ? null : tinyIntVector.get(rowIndex);
Expand All @@ -221,7 +222,7 @@ public void convertArrowToRowBatch() throws DorisException {
break;
case "SMALLINT":
Preconditions.checkArgument(mt.equals(MinorType.SMALLINT),
typeMismatchMessage(currentType, mt));
typeMismatchMessage(colName, currentType, mt));
SmallIntVector smallIntVector = (SmallIntVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Object fieldValue = smallIntVector.isNull(rowIndex) ? null : smallIntVector.get(rowIndex);
Expand All @@ -230,7 +231,7 @@ public void convertArrowToRowBatch() throws DorisException {
break;
case "INT":
Preconditions.checkArgument(mt.equals(MinorType.INT),
typeMismatchMessage(currentType, mt));
typeMismatchMessage(colName, currentType, mt));
IntVector intVector = (IntVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Object fieldValue = intVector.isNull(rowIndex) ? null : intVector.get(rowIndex);
Expand All @@ -239,7 +240,7 @@ public void convertArrowToRowBatch() throws DorisException {
break;
case "BIGINT":
Preconditions.checkArgument(mt.equals(MinorType.BIGINT),
typeMismatchMessage(currentType, mt));
typeMismatchMessage(colName, currentType, mt));
BigIntVector bigIntVector = (BigIntVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Object fieldValue = bigIntVector.isNull(rowIndex) ? null : bigIntVector.get(rowIndex);
Expand All @@ -248,7 +249,7 @@ public void convertArrowToRowBatch() throws DorisException {
break;
case "LARGEINT":
Preconditions.checkArgument(mt.equals(MinorType.FIXEDSIZEBINARY) ||
mt.equals(MinorType.VARCHAR), typeMismatchMessage(currentType, mt));
mt.equals(MinorType.VARCHAR), typeMismatchMessage(colName, currentType, mt));
if (mt.equals(MinorType.FIXEDSIZEBINARY)) {
FixedSizeBinaryVector largeIntVector = (FixedSizeBinaryVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Expand Down Expand Up @@ -276,7 +277,7 @@ public void convertArrowToRowBatch() throws DorisException {
break;
case "IPV4":
Preconditions.checkArgument(mt.equals(MinorType.UINT4) || mt.equals(MinorType.INT),
typeMismatchMessage(currentType, mt));
typeMismatchMessage(colName, currentType, mt));
BaseIntVector ipv4Vector;
if (mt.equals(MinorType.INT)) {
ipv4Vector = (IntVector) curFieldVector;
Expand All @@ -291,7 +292,7 @@ public void convertArrowToRowBatch() throws DorisException {
break;
case "FLOAT":
Preconditions.checkArgument(mt.equals(MinorType.FLOAT4),
typeMismatchMessage(currentType, mt));
typeMismatchMessage(colName, currentType, mt));
Float4Vector float4Vector = (Float4Vector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Object fieldValue = float4Vector.isNull(rowIndex) ? null : float4Vector.get(rowIndex);
Expand All @@ -301,7 +302,7 @@ public void convertArrowToRowBatch() throws DorisException {
case "TIME":
case "DOUBLE":
Preconditions.checkArgument(mt.equals(MinorType.FLOAT8),
typeMismatchMessage(currentType, mt));
typeMismatchMessage(colName, currentType, mt));
Float8Vector float8Vector = (Float8Vector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Object fieldValue = float8Vector.isNull(rowIndex) ? null : float8Vector.get(rowIndex);
Expand All @@ -310,7 +311,7 @@ public void convertArrowToRowBatch() throws DorisException {
break;
case "BINARY":
Preconditions.checkArgument(mt.equals(MinorType.VARBINARY),
typeMismatchMessage(currentType, mt));
typeMismatchMessage(colName, currentType, mt));
VarBinaryVector varBinaryVector = (VarBinaryVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Object fieldValue = varBinaryVector.isNull(rowIndex) ? null : varBinaryVector.get(rowIndex);
Expand All @@ -319,7 +320,7 @@ public void convertArrowToRowBatch() throws DorisException {
break;
case "DECIMAL":
Preconditions.checkArgument(mt.equals(MinorType.VARCHAR),
typeMismatchMessage(currentType, mt));
typeMismatchMessage(colName, currentType, mt));
VarCharVector varCharVectorForDecimal = (VarCharVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (varCharVectorForDecimal.isNull(rowIndex)) {
Expand All @@ -343,7 +344,7 @@ public void convertArrowToRowBatch() throws DorisException {
case "DECIMAL64":
case "DECIMAL128I":
Preconditions.checkArgument(mt.equals(MinorType.DECIMAL),
typeMismatchMessage(currentType, mt));
typeMismatchMessage(colName, currentType, mt));
DecimalVector decimalVector = (DecimalVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (decimalVector.isNull(rowIndex)) {
Expand All @@ -357,7 +358,7 @@ public void convertArrowToRowBatch() throws DorisException {
case "DATE":
case "DATEV2":
Preconditions.checkArgument(mt.equals(MinorType.VARCHAR)
|| mt.equals(MinorType.DATEDAY), typeMismatchMessage(currentType, mt));
|| mt.equals(MinorType.DATEDAY), typeMismatchMessage(colName, currentType, mt));
if (mt.equals(MinorType.VARCHAR)) {
VarCharVector date = (VarCharVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Expand Down Expand Up @@ -417,7 +418,7 @@ public void convertArrowToRowBatch() throws DorisException {
case "JSONB":
case "VARIANT":
Preconditions.checkArgument(mt.equals(MinorType.VARCHAR),
typeMismatchMessage(currentType, mt));
typeMismatchMessage(colName, currentType, mt));
VarCharVector varCharVector = (VarCharVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (varCharVector.isNull(rowIndex)) {
Expand All @@ -430,7 +431,7 @@ public void convertArrowToRowBatch() throws DorisException {
break;
case "IPV6":
Preconditions.checkArgument(mt.equals(MinorType.VARCHAR),
typeMismatchMessage(currentType, mt));
typeMismatchMessage(colName, currentType, mt));
VarCharVector ipv6VarcharVector = (VarCharVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (ipv6VarcharVector.isNull(rowIndex)) {
Expand All @@ -444,7 +445,7 @@ public void convertArrowToRowBatch() throws DorisException {
break;
case "ARRAY":
Preconditions.checkArgument(mt.equals(MinorType.LIST),
typeMismatchMessage(currentType, mt));
typeMismatchMessage(colName, currentType, mt));
ListVector listVector = (ListVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (listVector.isNull(rowIndex)) {
Expand All @@ -457,7 +458,7 @@ public void convertArrowToRowBatch() throws DorisException {
break;
case "MAP":
Preconditions.checkArgument(mt.equals(MinorType.MAP),
typeMismatchMessage(currentType, mt));
typeMismatchMessage(colName, currentType, mt));
MapVector mapVector = (MapVector) curFieldVector;
UnionMapReader reader = mapVector.getReader();
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
Expand All @@ -476,7 +477,7 @@ public void convertArrowToRowBatch() throws DorisException {
break;
case "STRUCT":
Preconditions.checkArgument(mt.equals(MinorType.STRUCT),
typeMismatchMessage(currentType, mt));
typeMismatchMessage(colName, currentType, mt));
StructVector structVector = (StructVector) curFieldVector;
for (int rowIndex = 0; rowIndex < rowCountInOneBatch; rowIndex++) {
if (structVector.isNull(rowIndex)) {
Expand Down Expand Up @@ -508,9 +509,9 @@ public List<Object> next() {
return rowBatch.get(offsetInRowBatch++).getCols();
}

private String typeMismatchMessage(final String sparkType, final MinorType arrowType) {
final String messageTemplate = "Spark type is %1$s, but arrow type is %2$s.";
return String.format(messageTemplate, sparkType, arrowType.name());
private String typeMismatchMessage(final String columnName, final String sparkType, final MinorType arrowType) {
final String messageTemplate = "Spark type for column %1$s is %2$s, but arrow type is %3$s.";
return String.format(messageTemplate, columnName, sparkType, arrowType.name());
}

public int getReadRowCount() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@ private[spark] class ScalaDorisRowRDDIterator(context: TaskContext,
extends AbstractDorisRDDIterator[Row](context, partition) {

override def initReader(config: DorisConfig): Unit = {
config.setProperty(DorisOptions.DORIS_READ_FIELDS, schema.map(f => s"`${f.name}`").mkString(","))
config.getValue(DorisOptions.READ_MODE).toLowerCase match {
case "thrift" => config.setProperty(DorisOptions.DORIS_VALUE_READER_CLASS, classOf[DorisRowThriftReader].getName)
case "arrow" => config.setProperty(DorisOptions.DORIS_VALUE_READER_CLASS, classOf[DorisRowFlightSqlReader].getName)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@ object SchemaConvertors {
def convertToSchema(tscanColumnDescs: Seq[TScanColumnDesc]): Schema = {
val schema = new Schema(tscanColumnDescs.length)
tscanColumnDescs.foreach(desc => {
// println(desc.getName + " " + desc.getType.name())
schema.put(new Field(desc.getName, desc.getType.name, "", 0, 0, ""))
})
schema
Expand Down

0 comments on commit 20b1228

Please sign in to comment.