Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[improve] improve datasource class load #141

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
import org.apache.commons.lang3.StringUtils;
import org.apache.commons.lang3.tuple.Pair;

import lombok.NonNull;
import lombok.SneakyThrows;
import lombok.extern.slf4j.Slf4j;

import java.io.File;
Expand All @@ -43,24 +45,25 @@
import java.util.Map;
import java.util.Set;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Supplier;

import static com.google.common.base.Preconditions.checkNotNull;

@Slf4j
public abstract class AbstractDataSourceClient implements DataSourceService {
private static final String ST_WEB_BASEDIR_PATH = "ST_WEB_BASEDIR_PATH";
// private ClassLoader datasourceClassLoader; // thradlocal
private ThreadLocal<ClassLoader> datasourceClassLoader = new ThreadLocal<>();
private final ThreadLocal<ClassLoader> datasourceClassLoader = new ThreadLocal<>();

private Map<String, DataSourcePluginInfo> supportedDataSourceInfo = new HashMap<>();
private final Map<String, DataSourcePluginInfo> supportedDataSourceInfo = new HashMap<>();

private Map<String, Integer> supportedDataSourceIndex = new HashMap<>();
private final Map<String, Integer> supportedDataSourceIndex = new HashMap<>();

protected List<DataSourcePluginInfo> supportedDataSources = new ArrayList<>();
private final List<DataSourcePluginInfo> supportedDataSources = new ArrayList<>();

private List<DataSourceChannel> dataSourceChannels = new ArrayList<>();
private final List<DataSourceChannel> dataSourceChannels = new ArrayList<>();

private Map<String, DataSourceChannel> classLoaderChannel = new HashMap<>();
private final Map<String, DataSourceChannel> classLoaderChannel = new HashMap<>();

protected AbstractDataSourceClient() {
AtomicInteger dataSourceIndex = new AtomicInteger();
Expand Down Expand Up @@ -99,7 +102,7 @@ protected AbstractDataSourceClient() {
.get(pluginName.toUpperCase())
.toString());
} catch (Exception e) {
log.warn("datasource " + pluginName + "is error" + ExceptionUtils.getMessage(e));
log.warn("datasource " + pluginName + " is error " + ExceptionUtils.getMessage(e));
}
Thread.currentThread().setContextClassLoader(contextClassLoader);
}
Expand All @@ -108,15 +111,19 @@ protected AbstractDataSourceClient() {
}
}

public Boolean isVirtualTableDatasource(String pluginName) {
log.info("pluginName: {}", pluginName);
return supportedDataSourceInfo.get(pluginName.toUpperCase()).getSupportVirtualTables();
}

@Override
public Boolean checkDataSourceConnectivity(
String pluginName, Map<String, String> dataSourceParams) {
updateClassLoader(pluginName);
boolean isConnect =
getDataSourceChannel(pluginName)
.checkDataSourceConnectivity(pluginName, dataSourceParams);
classLoaderRestore();
return isConnect;
return executeByCustomerClassLoader(
pluginName,
() ->
getDataSourceChannel(pluginName)
.checkDataSourceConnectivity(pluginName, dataSourceParams));
}

@Override
Expand All @@ -126,27 +133,28 @@ public List<DataSourcePluginInfo> listAllDataSources() {

protected DataSourceChannel getDataSourceChannel(String pluginName) {
checkNotNull(pluginName, "pluginName cannot be null");

// Integer index = supportedDataSourceIndex.get(pluginName.toUpperCase());
// if (index == null) {
// throw new DataSourceSDKException(
// "The %s plugin is not supported or plugin not exist.", pluginName);
// }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If these codes is useless, please delete them instead of commenting them.

return DatasourceLoadConfig.classLoaderChannel.get(pluginName.toUpperCase());
}

@Override
public OptionRule queryDataSourceFieldByName(String pluginName) {
updateClassLoader(pluginName);
OptionRule dataSourceOptions =
getDataSourceChannel(pluginName).getDataSourceOptions(pluginName);
classLoaderRestore();
return dataSourceOptions;
return executeByCustomerClassLoader(
pluginName,
() -> getDataSourceChannel(pluginName).getDataSourceOptions(pluginName));
}

@Override
public OptionRule queryMetadataFieldByName(String pluginName) {
updateClassLoader(pluginName);
OptionRule datasourceMetadataFieldsByDataSourceName =
getDataSourceChannel(pluginName)
.getDatasourceMetadataFieldsByDataSourceName(pluginName);
classLoaderRestore();
return datasourceMetadataFieldsByDataSourceName;
return executeByCustomerClassLoader(
pluginName,
() ->
getDataSourceChannel(pluginName)
.getDatasourceMetadataFieldsByDataSourceName(pluginName));
}

@Override
Expand All @@ -155,21 +163,18 @@ public List<String> getTables(
String databaseName,
Map<String, String> requestParams,
Map<String, String> options) {
updateClassLoader(pluginName);
List<String> tables =
getDataSourceChannel(pluginName)
.getTables(pluginName, requestParams, databaseName, options);
classLoaderRestore();
return tables;
return executeByCustomerClassLoader(
pluginName,
() ->
getDataSourceChannel(pluginName)
.getTables(pluginName, requestParams, databaseName, options));
}

@Override
public List<String> getDatabases(String pluginName, Map<String, String> requestParams) {
updateClassLoader(pluginName);
List<String> databases =
getDataSourceChannel(pluginName).getDatabases(pluginName, requestParams);
classLoaderRestore();
return databases;
return executeByCustomerClassLoader(
pluginName,
() -> getDataSourceChannel(pluginName).getDatabases(pluginName, requestParams));
}

@Override
Expand All @@ -178,12 +183,12 @@ public List<TableField> getTableFields(
Map<String, String> requestParams,
String databaseName,
String tableName) {
updateClassLoader(pluginName);
List<TableField> tableFields =
getDataSourceChannel(pluginName)
.getTableFields(pluginName, requestParams, databaseName, tableName);
classLoaderRestore();
return tableFields;
return executeByCustomerClassLoader(
pluginName,
() ->
getDataSourceChannel(pluginName)
.getTableFields(
pluginName, requestParams, databaseName, tableName));
}

@Override
Expand All @@ -192,12 +197,12 @@ public Map<String, List<TableField>> getTableFields(
Map<String, String> requestParams,
String databaseName,
List<String> tableNames) {
updateClassLoader(pluginName);
Map<String, List<TableField>> tableFields =
getDataSourceChannel(pluginName)
.getTableFields(pluginName, requestParams, databaseName, tableNames);
classLoaderRestore();
return tableFields;
return executeByCustomerClassLoader(
pluginName,
() ->
getDataSourceChannel(pluginName)
.getTableFields(
pluginName, requestParams, databaseName, tableNames));
}

@Override
Expand All @@ -207,30 +212,37 @@ public Pair<String, String> getTableSyncMaxValue(
String databaseName,
String tableName,
String updateFieldType) {
updateClassLoader(pluginName);
Pair<String, String> tableSyncMaxValue =
getDataSourceChannel(pluginName)
.getTableSyncMaxValue(
pluginName,
requestParams,
databaseName,
tableName,
updateFieldType);
classLoaderRestore();
return tableSyncMaxValue;
return executeByCustomerClassLoader(
pluginName,
() ->
getDataSourceChannel(pluginName)
.getTableSyncMaxValue(
pluginName,
requestParams,
databaseName,
tableName,
updateFieldType));
}

private ClassLoader getCustomClassloader(String pluginName) {
String getenv = System.getenv(ST_WEB_BASEDIR_PATH);
log.info("ST_WEB_BASEDIR_PATH is : " + getenv);
String libPath = StringUtils.isEmpty(getenv) ? "/datasource" : (getenv + "/datasource");

// String libPath = "/root/apache-seatunnel-web-2.4.7-WS-SNAPSHOT/datasource/";
File jarDirectory = new File(libPath);
File[] jarFiles =
jarDirectory.listFiles(
(dir, name) -> {
String pluginUpperCase = pluginName.toUpperCase();
String nameLowerCase = name.toLowerCase();
String pluginJar =
DatasourceLoadConfig.classLoaderJarName.get(pluginUpperCase);
if (StringUtils.isEmpty(pluginJar)) {
log.warn(
"classLoaderJarName get pluginUpperCase jar name is null : {} ",
pluginUpperCase);
}
if (pluginUpperCase.equals("KAFKA")) {
return !nameLowerCase.contains("kingbase")
&& nameLowerCase.startsWith(
Expand Down Expand Up @@ -294,10 +306,25 @@ private void classLoaderRestore() {

@Override
public Connection getConnection(String pluginName, Map<String, String> requestParams) {
updateClassLoader(pluginName);
Connection connection =
getDataSourceChannel(pluginName).getConnection(pluginName, requestParams);
classLoaderRestore();
return connection;
return executeByCustomerClassLoader(
pluginName,
() -> getDataSourceChannel(pluginName).getConnection(pluginName, requestParams));
}

/**
* Execute the given {@code Callable} within the {@link ClassLoader} of the current thread.
*
* @param supplier
* @param <T>
* @return
*/
@SneakyThrows
private <T> T executeByCustomerClassLoader(String pluginName, @NonNull Supplier<T> supplier) {
try {
updateClassLoader(pluginName);
return supplier.get();
} finally {
classLoaderRestore();
}
}
}