diff --git a/Makefile.arm64 b/Makefile.arm64 index fccc0d0d0f..46e4baefc4 100644 --- a/Makefile.arm64 +++ b/Makefile.arm64 @@ -30,6 +30,11 @@ FCOMMON_OPT += -march=armv8-a+sve endif endif +ifeq ($(CORE), ARMV9SME) +CCOMMON_OPT += -march=armv9-a+sve2+sme +FCOMMON_OPT += -march=armv9-a+sve2 +endif + ifeq ($(CORE), CORTEXA53) CCOMMON_OPT += -march=armv8-a -mtune=cortex-a53 ifneq ($(F_COMPILER), NAG) diff --git a/Makefile.system b/Makefile.system index 29ea819f13..14830eb4e2 100644 --- a/Makefile.system +++ b/Makefile.system @@ -420,6 +420,7 @@ ifeq ($(ARCH), arm64) export MACOSX_DEPLOYMENT_TARGET=11.0 ifeq ($(C_COMPILER), GCC) export NO_SVE = 1 +export NO_SME = 1 endif else export MACOSX_DEPLOYMENT_TARGET=10.8 @@ -709,6 +710,9 @@ DYNAMIC_CORE += NEOVERSEN2 DYNAMIC_CORE += ARMV8SVE DYNAMIC_CORE += A64FX endif +ifneq ($(NO_SME), 1) +DYNAMIC_CORE += ARMV9SME +endif DYNAMIC_CORE += THUNDERX DYNAMIC_CORE += THUNDERX2T99 DYNAMIC_CORE += TSV110 @@ -1474,6 +1478,10 @@ ifeq ($(NO_SVE), 1) CCOMMON_OPT += -DNO_SVE endif +ifeq ($(NO_SME), 1) +CCOMMON_OPT += -DNO_SME +endif + ifdef SMP CCOMMON_OPT += -DSMP_SERVER diff --git a/TargetList.txt b/TargetList.txt index 25eeddfb00..232e12ffa6 100644 --- a/TargetList.txt +++ b/TargetList.txt @@ -111,6 +111,7 @@ THUNDERX3T110 VORTEX A64FX ARMV8SVE +ARMV9SME FT2000 9.System Z: diff --git a/cmake/arch.cmake b/cmake/arch.cmake index 27ba6f8727..ec91a2d598 100644 --- a/cmake/arch.cmake +++ b/cmake/arch.cmake @@ -44,9 +44,21 @@ endif () if (DYNAMIC_ARCH) if (ARM64) - set(DYNAMIC_CORE ARMV8 CORTEXA53 CORTEXA57 THUNDERX THUNDERX2T99 TSV110 EMAG8180 NEOVERSEN1 THUNDERX3T110) - if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER 9.99) - set(DYNAMIC_CORE ${DYNAMIC_CORE} NEOVERSEV1 NEOVERSEN2 ARMV8SVE A64FX) + set(DYNAMIC_CORE ARMV8 CORTEXA53 CORTEXA57 THUNDERX THUNDERX2T99 TSV110 EMAG8180 NEOVERSEN1 THUNDERX3T110) + if (${CMAKE_C_COMPILER_ID} STREQUAL "GNU") + if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER_EQUAL 10) # SVE ACLE supported in GCC >= 10 + set(DYNAMIC_CORE ${DYNAMIC_CORE} NEOVERSEV1 NEOVERSEN2 ARMV8SVE A64FX) + endif () + if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER_EQUAL 14) # SME ACLE supported in GCC >= 14 + set(DYNAMIC_CORE ${DYNAMIC_CORE} ARMV9SME) + endif() + elseif (${CMAKE_C_COMPILER_ID} MATCHES "Clang") + if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER_EQUAL 11) # SVE ACLE supported in LLVM >= 11 + set(DYNAMIC_CORE ${DYNAMIC_CORE} NEOVERSEV1 NEOVERSEN2 ARMV8SVE A64FX) + endif () + if (${CMAKE_C_COMPILER_VERSION} VERSION_GREATER_EQUAL 19) # SME ACLE supported in LLVM >= 19 + set(DYNAMIC_CORE ${DYNAMIC_CORE} ARMV9SME) + endif() endif () if (DYNAMIC_LIST) set(DYNAMIC_CORE ARMV8 ${DYNAMIC_LIST}) diff --git a/cmake/cc.cmake b/cmake/cc.cmake index 775239e1cd..5e9c5a8c42 100644 --- a/cmake/cc.cmake +++ b/cmake/cc.cmake @@ -238,6 +238,12 @@ if (${CORE} STREQUAL ARMV8SVE) endif () endif () +if (${CORE} STREQUAL ARMV9SME) + if (NOT DYNAMIC_ARCH) + set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv9-a+sme") + endif () +endif () + if (${CORE} STREQUAL CORTEXA510) if (NOT DYNAMIC_ARCH) set (CCOMMON_OPT "${CCOMMON_OPT} -march=armv8-a+sve") diff --git a/cmake/prebuild.cmake b/cmake/prebuild.cmake index 53a78d782f..f6ca73b7b6 100644 --- a/cmake/prebuild.cmake +++ b/cmake/prebuild.cmake @@ -1014,7 +1014,7 @@ endif () set(ZGEMM_UNROLL_M 4) set(ZGEMM_UNROLL_N 4) set(SYMV_P 16) - elseif ("${TCORE}" STREQUAL "NEOVERSEN2") + elseif ("${TCORE}" STREQUAL "NEOVERSEN2" or "${TCORE}" STREQUAL "ARMV9SME") file(APPEND ${TARGET_CONF_TEMP} "#define L1_CODE_SIZE\t65536\n" "#define L1_CODE_LINESIZE\t64\n" diff --git a/cmake/system.cmake b/cmake/system.cmake index 6b891ca0ef..6e8055e924 100644 --- a/cmake/system.cmake +++ b/cmake/system.cmake @@ -310,6 +310,9 @@ if (${TARGET} STREQUAL NEOVERSEV1) set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -march=armv8.2-a+sve") endif() endif() + if (${TARGET} STREQUAL ARMV9SME) + set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -march=armv9-a+sme -O3") + endif() if (${TARGET} STREQUAL A64FX) if (${CMAKE_C_COMPILER_ID} STREQUAL "PGI" AND NOT NO_SVE) set (KERNEL_DEFINITIONS "${KERNEL_DEFINITIONS} -Msve-intrinsics -march=armv8.2-a+sve -mtune=a64fx") @@ -382,6 +385,8 @@ if (NEED_PIC) if (NOT NOFORTRAN) if (${F_COMPILER} STREQUAL "SUN") set(FCOMMON_OPT "${FCOMMON_OPT} -pic") + elseif (${F_COMPILER} STREQUAL "NAGFOR") + set(FCOMMON_OPT "${FCOMMON_OPT} -PIC") else () set(FCOMMON_OPT "${FCOMMON_OPT} -fPIC") endif () @@ -640,17 +645,17 @@ if (${CMAKE_SYSTEM_NAME} STREQUAL "Windows") endif () if (CMAKE_Fortran_COMPILER) -if ("${F_COMPILER}" STREQUAL "NAG" OR "${F_COMPILER}" STREQUAL "CRAY" OR CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*") - set(FILTER_FLAGS "-msse3;-mssse3;-msse4.1;-mavx;-mavx2,-mskylake-avx512") - if (CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*") -message(STATUS "removing fortran flags") - set(FILTER_FLAGS "${FILTER_FLAGS};-m32;-m64") + if ("${F_COMPILER}" STREQUAL "NAGFOR" OR "${F_COMPILER}" STREQUAL "CRAY" OR CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*") + set(FILTER_FLAGS "-msse3;-mssse3;-msse4.1;-mavx;-mavx2,-mskylake-avx512") + if (CMAKE_Fortran_COMPILER_ID MATCHES "LLVMFlang.*") + message(STATUS "removing fortran flags") + set(FILTER_FLAGS "${FILTER_FLAGS};-m32;-m64") + endif () + foreach (FILTER_FLAG ${FILTER_FLAGS}) + string(REPLACE ${FILTER_FLAG} "" LAPACK_FFLAGS ${LAPACK_FFLAGS}) + string(REPLACE ${FILTER_FLAG} "" LAPACK_FPFLAGS ${LAPACK_FPFLAGS}) + endforeach () endif () - foreach (FILTER_FLAG ${FILTER_FLAGS}) - string(REPLACE ${FILTER_FLAG} "" LAPACK_FFLAGS ${LAPACK_FFLAGS}) - string(REPLACE ${FILTER_FLAG} "" LAPACK_FPFLAGS ${LAPACK_FPFLAGS}) - endforeach () -endif () endif () if ("${F_COMPILER}" STREQUAL "GFORTRAN") diff --git a/common.h b/common.h index b8bac1ad27..766b89cf74 100644 --- a/common.h +++ b/common.h @@ -696,6 +696,7 @@ void gotoblas_profile_init(void); void gotoblas_profile_quit(void); int support_avx512(void); +int support_sme1(void); #ifdef USE_OPENMP diff --git a/common_arm64.h b/common_arm64.h index 595a01995a..5856898a2b 100644 --- a/common_arm64.h +++ b/common_arm64.h @@ -175,7 +175,7 @@ static inline int blas_quickdivide(blasint x, blasint y){ #define HUGE_PAGESIZE ( 4 << 20) #ifndef BUFFERSIZE -#if defined(NEOVERSEN1) || defined(NEOVERSEN2) || defined(NEOVERSEV1) || defined(A64FX) || defined(ARMV8SVE) +#if defined(NEOVERSEN1) || defined(NEOVERSEN2) || defined(NEOVERSEV1) || defined(A64FX) || defined(ARMV8SVE) || defined(ARMV9SME) #define BUFFER_SIZE (32 << 22) #else #define BUFFER_SIZE (32 << 20) diff --git a/common_param.h b/common_param.h index c082d248e8..2914a8963e 100644 --- a/common_param.h +++ b/common_param.h @@ -221,6 +221,11 @@ BLASLONG (*ismin_k) (BLASLONG, float *, BLASLONG); void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG); int (*sgemm_direct_performant) (BLASLONG M, BLASLONG N, BLASLONG K); #endif +#ifdef ARCH_ARM64 + void (*sgemm_direct) (BLASLONG, BLASLONG, BLASLONG, float *, BLASLONG , float *, BLASLONG , float * , BLASLONG); + int (*sgemm_direct_performant) (BLASLONG M, BLASLONG N, BLASLONG K); +#endif + int (*sgemm_kernel )(BLASLONG, BLASLONG, BLASLONG, float, float *, float *, float *, BLASLONG); int (*sgemm_beta )(BLASLONG, BLASLONG, BLASLONG, float, float *, BLASLONG, float *, BLASLONG, float *, BLASLONG); diff --git a/common_s.h b/common_s.h index fdd80b62f6..af9d940ae1 100644 --- a/common_s.h +++ b/common_s.h @@ -213,9 +213,9 @@ #ifdef ARCH_X86_64 #define SGEMM_DIRECT_PERFORMANT gotoblas -> sgemm_direct_performant #define SGEMM_DIRECT gotoblas -> sgemm_direct -#else +#elif ARCH_ARM64 #define SGEMM_DIRECT_PERFORMANT sgemm_direct_performant -#define SGEMM_DIRECT sgemm_direct +#define SGEMM_DIRECT gotoblas -> sgemm_direct #endif #define SGEMM_ONCOPY gotoblas -> sgemm_oncopy diff --git a/driver/others/dynamic_arm64.c b/driver/others/dynamic_arm64.c index dc88d816fb..8e2963b5a3 100644 --- a/driver/others/dynamic_arm64.c +++ b/driver/others/dynamic_arm64.c @@ -115,6 +115,11 @@ extern gotoblas_t gotoblas_ARMV8SVE; #else #define gotoblas_ARMV8SVE gotoblas_ARMV8 #endif +#ifdef DYN_ARMV9SME +extern gotoblas_t gotoblas_ARMV9SME; +#else +#define gotoblas_ARMV9SME gotoblas_ARMV8 +#endif #ifdef DYN_CORTEX_A55 extern gotoblas_t gotoblas_CORTEXA55; #else @@ -148,6 +153,13 @@ extern gotoblas_t gotoblas_A64FX; #define gotoblas_ARMV8SVE gotoblas_ARMV8 #define gotoblas_A64FX gotoblas_ARMV8 #endif + +#ifndef NO_SME +extern gotoblas_t gotoblas_ARMV9SME; +#else +#define gotoblas_ARMV9SME gotoblas_ARMV8SVE +#endif + extern gotoblas_t gotoblas_THUNDERX3T110; #endif #define gotoblas_NEOVERSEV2 gotoblas_NEOVERSEV1 @@ -393,6 +405,13 @@ static gotoblas_t *get_coretype(void) { snprintf(coremsg, 128, "Unknown CPU model - implementer %x part %x\n",implementer,part); openblas_warning(1, coremsg); } + +#if !defined(NO_SME) && defined(HWCAP2_SME) + if ((getauxval(AT_HWCAP2) & HWCAP2_SME)) { + return &gotoblas_ARMV9SME; + } +#endif + #ifndef NO_SVE if ((getauxval(AT_HWCAP) & HWCAP_SVE)) { return &gotoblas_ARMV8SVE; @@ -443,3 +462,15 @@ void gotoblas_dynamic_init(void) { void gotoblas_dynamic_quit(void) { gotoblas = NULL; } + +int support_sme1(void) { + int ret = 0; + +#if (defined OS_LINUX || defined OS_ANDROID) + ret = getauxval(AT_HWCAP2) & HWCAP2_SME; + if(getauxval(AT_HWCAP2) & HWCAP2_SME){ + ret = 1; + } +#endif + return ret; +} diff --git a/getarch.c b/getarch.c index 826dd1ce0a..b51c3ed643 100644 --- a/getarch.c +++ b/getarch.c @@ -1289,6 +1289,19 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define CORENAME "ARMV8SVE" #endif +#ifdef FORCE_ARMV9SME +#define FORCE +#define ARCHITECTURE "ARM64" +#define SUBARCHITECTURE "ARMV9SME" +#define SUBDIRNAME "arm64" +#define ARCHCONFIG "-DARMV9SME " \ + "-DL1_DATA_SIZE=32768 -DL1_DATA_LINESIZE=64 " \ + "-DL2_SIZE=262144 -DL2_LINESIZE=64 " \ + "-DDTB_DEFAULT_ENTRIES=64 -DDTB_SIZE=4096 -DL2_ASSOCIATIVE=32 " \ + "-DHAVE_VFPV4 -DHAVE_VFPV3 -DHAVE_VFP -DHAVE_NEON -DHAVE_SVE -DHAVE_SME -DARMV8 -DARMV9" +#define LIBNAME "armv9sme" +#define CORENAME "ARMV9SME" +#endif #ifdef FORCE_ARMV8 #define FORCE diff --git a/interface/gemm.c b/interface/gemm.c index 576e94593c..a19588f493 100644 --- a/interface/gemm.c +++ b/interface/gemm.c @@ -45,6 +45,7 @@ #include "functable.h" #endif + #ifndef COMPLEX #define SMP_THRESHOLD_MIN 65536.0 #ifdef XDOUBLE @@ -85,6 +86,7 @@ #define GEMM_MULTITHREAD_THRESHOLD 4 #endif + static int (*gemm[])(blas_arg_t *, BLASLONG *, BLASLONG *, IFLOAT *, IFLOAT *, BLASLONG) = { #ifndef GEMM3M GEMM_NN, GEMM_TN, GEMM_RN, GEMM_CN, @@ -347,17 +349,26 @@ void CNAME(enum CBLAS_ORDER order, enum CBLAS_TRANSPOSE TransA, enum CBLAS_TRANS int nodes; #endif + PRINT_DEBUG_CNAME; #if !defined(COMPLEX) && !defined(DOUBLE) && !defined(BFLOAT16) && defined(USE_SGEMM_KERNEL_DIRECT) -#ifdef DYNAMIC_ARCH +#if defined(DYNAMIC_ARCH) && defined(ARCH_x86) if (support_avx512() ) -#endif + if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans && SGEMM_DIRECT_PERFORMANT(m,n,k)) { SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc); return; } - +#endif +#if defined(DYNAMIC_ARCH) && defined(ARCH_ARM64) + if (support_sme1()){ + if (beta == 0 && alpha == 1.0 && order == CblasRowMajor && TransA == CblasNoTrans && TransB == CblasNoTrans) { + SGEMM_DIRECT(m, n, k, a, lda, b, ldb, c, ldc); + return; + } + } +#endif #endif #ifndef COMPLEX diff --git a/kernel/Makefile b/kernel/Makefile index 3f9afd3fa1..454968797f 100644 --- a/kernel/Makefile +++ b/kernel/Makefile @@ -24,7 +24,11 @@ ifdef NO_AVX2 AVX2OPT= endif + ifdef TARGET_CORE +ifeq ($(TARGET_CORE), ARMV9SME) + override CFLAGS += -march=armv9-a+sve2+sme +endif ifeq ($(TARGET_CORE), SAPPHIRERAPIDS) override CFLAGS += -DBUILD_KERNEL -DTABLE_NAME=gotoblas_$(TARGET_CORE) ifeq (1, $(filter 1,$(GCCVERSIONGTEQ11) $(CLANGVERSIONGTEQ12))) diff --git a/kernel/Makefile.L3 b/kernel/Makefile.L3 index ed1c74ecff..d00028d12b 100644 --- a/kernel/Makefile.L3 +++ b/kernel/Makefile.L3 @@ -24,6 +24,7 @@ endif ifeq ($(ARCH), arm64) USE_TRMM = 1 +USE_DIRECT_SGEMM = 1 endif ifeq ($(ARCH), riscv64) @@ -95,9 +96,16 @@ endif ifdef USE_DIRECT_SGEMM ifndef SGEMMDIRECTKERNEL +ifeq ($(ARCH), x86_64) SGEMMDIRECTKERNEL = sgemm_direct_skylakex.c SGEMMDIRECTPERFORMANT = sgemm_direct_performant.c endif +ifeq ($(ARCH), arm64) +ifdef HAVE_SME +SGEMMDIRECTKERNEL = sgemm_direct_arm64_sme1.c +endif +endif +endif endif ifeq ($(BUILD_BFLOAT16), 1) @@ -128,9 +136,19 @@ SKERNELOBJS += \ $(SGEMMONCOPYOBJ) $(SGEMMOTCOPYOBJ) ifdef USE_DIRECT_SGEMM +ifeq ($(ARCH), x86_64) SKERNELOBJS += \ sgemm_direct$(TSUFFIX).$(SUFFIX) \ - sgemm_direct_performant$(TSUFFIX).$(SUFFIX) + sgemm_direct_performant$(TSUFFIX).$(SUFFIX) +endif +ifeq ($(ARCH), arm64) +ifdef HAVE_SME +SKERNELOBJS += \ + sgemm_direct.$(SUFFIX) \ + sgemm_direct_sme1.$(SUFFIX) \ + sgemm_direct_sme1_preprocess.$(SUFFIX) +endif +endif endif endif @@ -809,11 +827,23 @@ else endif ifdef USE_DIRECT_SGEMM +ifeq ($(ARCH), x86_64) $(KDIR)sgemm_direct_performant$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTPERFORMANT) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ $(KDIR)sgemm_direct$(TSUFFIX).$(SUFFIX) : $(KERNELDIR)/$(SGEMMDIRECTKERNEL) $(CC) $(CFLAGS) -c -UDOUBLE -UCOMPLEX $< -o $@ endif +ifeq ($(ARCH), arm64) +ifdef HAVE_SME +$(KDIR)sgemm_direct_sme1.$(SUFFIX) : + $(CC) $(CFLAGS) -c $(KERNELDIR)/sgemm_direct_sme1.S -UDOUBLE -UCOMPLEX -o $@ +$(KDIR)sgemm_direct_sme1_preprocess.$(SUFFIX) : + $(CC) $(CFLAGS) -c $(KERNELDIR)/sgemm_direct_sme1_preprocess.S -UDOUBLE -UCOMPLEX -o $@ +$(KDIR)sgemm_direct.$(SUFFIX) : + $(CC) $(CFLAGS) -c $(KERNELDIR)/sgemm_direct_arm64_sme1.c -UDOUBLE -UCOMPLEX -o $@ +endif +endif +endif ifeq ($(BUILD_BFLOAT16), 1) diff --git a/kernel/arm64/KERNEL.ARMV9SME b/kernel/arm64/KERNEL.ARMV9SME new file mode 100644 index 0000000000..dc333d8298 --- /dev/null +++ b/kernel/arm64/KERNEL.ARMV9SME @@ -0,0 +1,3 @@ +include $(KERNELDIR)/KERNEL.ARMV8SVE + + diff --git a/kernel/arm64/sgemm_direct_arm64_sme1.c b/kernel/arm64/sgemm_direct_arm64_sme1.c new file mode 100644 index 0000000000..ed4ad93a24 --- /dev/null +++ b/kernel/arm64/sgemm_direct_arm64_sme1.c @@ -0,0 +1,62 @@ +/* + Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. + SPDX-License-Identifier: BSD-3-Clause-Clear +*/ + +#include "common.h" +#include +#include +#include + +#if defined(HAVE_SME) + +/* Function prototypes */ +void sgemm_direct_sme1_preprocess(uint64_t nbr, uint64_t nbc,\ + const float * restrict a, float * a_mod); +void sgemm_direct_sme1_2VLx2VL(uint64_t m, uint64_t k, uint64_t n,\ + const float * matLeft,\ + const float * restrict matRight,\ + const float * restrict matResult); + +/* Function Definitions */ +uint64_t sve_cntw() { + uint64_t cnt; + asm volatile( + "rdsvl %[res], #1\n" + "lsr %[res], %[res], #2\n" + : [res] "=r" (cnt) :: + ); + return cnt; +} + +/*void sgemm_kernel_direct (BLASLONG M, BLASLONG N, BLASLONG K,\ + float * __restrict A, BLASLONG strideA, float * __restrict B,\ + BLASLONG strideB , float * __restrict R, BLASLONG strideR) +*/ +void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A,\ + BLASLONG strideA, float * __restrict B, BLASLONG strideB ,\ + float * __restrict R, BLASLONG strideR){ + + uint64_t m_mod, vl_elms; + + vl_elms = sve_cntw(); + + m_mod = ceil((double)M/(double)vl_elms) * vl_elms; + + float *A_mod = (float *) malloc(m_mod*K*sizeof(float)); + + /* Pre-process the left matrix to make it suitable for + matrix sum of outer-product calculation + */ + sgemm_direct_sme1_preprocess(M, K, A, A_mod); + + /* Calculate C = A*B */ + sgemm_direct_sme1_2VLx2VL(M, K, N, A_mod, B, R); + + free(A_mod); +} +#else + +void CNAME (BLASLONG M, BLASLONG N, BLASLONG K, float * __restrict A, BLASLONG strideA, float * __restrict B, BLASLONG strideB , float * __restrict R, BLASLONG strideR) +{} +#endif diff --git a/kernel/arm64/sgemm_direct_sme1.S b/kernel/arm64/sgemm_direct_sme1.S new file mode 100644 index 0000000000..8c0a173f3d --- /dev/null +++ b/kernel/arm64/sgemm_direct_sme1.S @@ -0,0 +1,228 @@ +/* + Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. + SPDX-License-Identifier: BSD-3-Clause-Clear +*/ + +/*-------------------------------------------------------------------------- + * SME1 based Matrix multiplication code for FP32 input matrices to FP32 + * output matrix + * C = A*B + * A: Left input matrix of dimension M x K + * B: Right input matrix of dimension K x N + * C: Result matrix of dimension M x N + * + * Usage of function: + * sgemm_direct_sme1_2VLx2VL( uint64_t M , uint64_t K, uint64_t N,\ + const float * restrict A_base,\ + const float * restrict B_base,\ + const float * restrict C_base); +----------------------------------------------------------------------------*/ + +#define M x0 //M dimension +#define K x1 //K dimension +#define N x2 //N dimension +#define A_base x3 //Pointer to left matrix(A) +#define B_base x4 //Pointer to right matrix(B) +#define C_base x5 //Pointer to result matrix(C) +#define Aptr x6 //Pointer to traverse A +#define Aptr_end x7 //Pointer to end of row of A +#define Cptr x8 //Pointer to traverse C +#define Cptr0 x9 //2nd Pointer to traverse C +#define Cptr1 x10 //3rd Pointer to traverse C +#define Bptr x11 //Pointer to traverse B +#define Bptr0 x12 //2nd Pointer to traverse B +#define N_exit x14 //Exit condition for N loop +#define K_exit x15 //Exit condition for K loop +#define M_cntr x16 //M loop counter +#define C1 x17 //Constant1: N*(SVLs+1);SVLs-No. of 32-bit elements +#define C2 x18 //Constant2: N + SVLs +#define C3 x19 //Constant3: K*SVLs + SVLs +#define C4 x20 //Constant4: SVLs-2 +#define C5 x21 //Constant5: K*SVLs +#define C6 x22 //Constant6: N*SVLs + + .text + .global sgemm_direct_sme1_2VLx2VL + + sgemm_direct_sme1_2VLx2VL: + + stp x19, x20, [sp, #-48]! + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + + smstart + + cntw C4 //SVLs + mul C5, C4, K //K*SVLs + mul C6, C4, N //N*SVLs + add C1, C6, N //N*SVLs + N + add N_exit, B_base, N, lsl #2 //N_Loop exit conditon + mov M_cntr, #0 + add C2, N, C4 //N + SVLs + add C3, C5, C4 //K*SVLs + SVLs + whilelt p2.s, M_cntr, M //Tile 0,1 predicate (M dimension) + sub w20, w20, #2 //SVLs-2 + +.M_Loop: + incw M_cntr + whilelt p3.s, M_cntr, M //Tile 2,3 predicate (M dimension) + mov Bptr, B_base //B_base + mov Cptr, C_base //C_base + whilelt p0.b, Bptr, N_exit //Tile 0/2 predicate (N dimension) + +.N_Loop: + mov Aptr, A_base //Aptr = A_base + mov Bptr0, Bptr //Bptr = B_base + mov Cptr0, Cptr //Cptr0 = C_base + addvl Cptr1, Cptr, #1 //Cptr1 = C_base + SVLb + addvl Bptr, Bptr, #1 + whilelt p1.b, Bptr, N_exit //Tile 1,3 predicate (N dimension) + add Aptr_end, A_base, C5, lsl #2 //A_base + K*SVLs + addvl K_exit, Aptr_end, #-1 //Exit condition for K loop + //Load 1st vector from Aptr + ld1w {z1.s}, p2/z, [Aptr] + zero {za} + // Load 1st vector from Bptr + ld1w {z2.s}, p0/z, [Bptr0] + // ZA0 += 1st Aptr vector OP 1st Bptr vector + fmopa za0.s, p2/m, p0/m, z1.s, z2.s + // Load 2nd vector from Aptr + ld1w {z5.s}, p3/z, [Aptr, C5, lsl #2] + // Aptr += SVLb + addvl Aptr, Aptr, #1 + +.K_Loop: + // ZA2 += 2nd Aptr vector OP 1st Bptr vector + fmopa za2.s, p3/m, p0/m, z5.s, z2.s + // Load 2nd vector from Bptr + ld1w {z3.s}, p1/z, [Bptr0, #1, MUL VL] + // ZA1 += 1st Aptr vector OP 2nd Bptr vector + fmopa za1.s, p2/m, p1/m, z1.s, z3.s + // Load next 1st vector from Aptr + ld1w {z0.s}, p2/z, [Aptr] + // ZA3 += 2nd Aptr vector OP 2nd Bptr vector + fmopa za3.s, p3/m, p1/m, z5.s, z3.s + cmp K, #2 + b.le process_K_less_than_equal_2 + // Load next 1st vector from Bptr + ld1w {z6.s}, p0/z, [Bptr0, N, lsl #2] + // ZA0 += 1st Aptr vector OP 1st Bptr vector + fmopa za0.s, p2/m, p0/m, z0.s, z6.s + // Load next 2nd vector from Aptr + ld1w {z4.s}, p3/z, [Aptr, C5, lsl #2] + // ZA2 += 2nd Aptr vector OP 1st Bptr vector + fmopa za2.s, p3/m, p0/m, z4.s, z6.s + // Load next 2nd vector from Bptr + ld1w {z7.s}, p1/z, [Bptr0, C2, lsl #2] + // Bptr += 2*ldb FP32 elms [Bytes] + add Bptr0, Bptr0, N, lsl #3 + // ZA1 += 1st Aptr vector OP 2nd Bptr vector + fmopa za1.s, p2/m, p1/m, z0.s, z7.s + // Load next 2nd vector from Aptr + ld1w {z1.s}, p2/z, [Aptr, #1, MUL VL] + // ZA3 += 2nd Aptr vector OP 2nd Bptr vector + fmopa za3.s, p3/m, p1/m, z4.s, z7.s + // Load next 1st vector from Bptr + ld1w {z2.s}, p0/z, [Bptr0] + // ZA0 += 1st Aptr vector OP 1st Bptr vector + fmopa za0.s, p2/m, p0/m, z1.s, z2.s + // Load next 2nd vector from Aptr + ld1w {z5.s}, p3/z, [Aptr, C3, lsl #2] + // Aptr += 2*SVLb [Bytes] + addvl Aptr, Aptr, #2 + cmp Aptr, K_exit + b.mi .K_Loop + // ZA2 += 2nd Aptr vector OP 1st Bptr vector + fmopa za2.s, p3/m, p0/m, z5.s, z2.s + // Load next 2nd vector from Bptr + ld1w {z3.s}, p1/z, [Bptr0, #1, MUL VL] + // ZA1 += 1st Aptr vector OP 2nd Bptr vector + fmopa za1.s, p2/m, p1/m, z1.s, z3.s + // ZA3 += 2nd Aptr vector OP 2nd Bptr vector + fmopa za3.s, p3/m, p1/m, z5.s, z3.s + +process_K_less_than_equal_2: + // Bptr += 2*ldb FP32 elements + add Bptr0, Bptr0, N, lsl #2 + cmp Aptr, Aptr_end + b.pl .Ktail_end + +.Ktail_start: + ld1w {z1.s}, p2/z, [Aptr] + ld1w {z2.s}, p0/z, [Bptr0] + ld1w {z3.s}, p1/z, [Bptr0, #1, MUL VL] + fmopa za0.s, p2/m, p0/m, z1.s, z2.s + ld1w {z5.s}, p3/z, [Aptr, C5, lsl #2] + fmopa za2.s, p3/m, p0/m, z5.s, z2.s + fmopa za1.s, p2/m, p1/m, z1.s, z3.s + fmopa za3.s, p3/m, p1/m, z5.s, z3.s + +.Ktail_end: + mov w13, #0 + psel p4, p0, p2.s[w13, 0] + psel p5, p1, p2.s[w13, 0] + psel p6, p0, p3.s[w13, 0] + psel p7, p1, p3.s[w13, 0] + // Store to Cptr0 + st1w {za0h.s[w13, #0]}, p4, [Cptr0] + // Store to Cptr1 + st1w {za1h.s[w13, #0]}, p5, [Cptr1] + // Store to Cptr0 + N*SVLs + st1w {za2h.s[w13, #0]}, p6, [Cptr0, C6, lsl #2] + // Store to Cptr1 + N*SVLs + st1w {za3h.s[w13, #0]}, p7, [Cptr1, C6, lsl #2] + +.Loop_store_ZA: + psel p4, p0, p2.s[w13, 1] + psel p5, p1, p2.s[w13, 1] + psel p6, p0, p3.s[w13, 1] + psel p7, p1, p3.s[w13, 1] + // Store to Cptr0 + N + st1w {za0h.s[w13, #1]}, p4, [Cptr0, N, lsl #2] + // Store to Cptr1 + N + st1w {za1h.s[w13, #1]}, p5, [Cptr1, N, lsl #2] + // Store to Cptr0 + N*(SVLs+1) + st1w {za2h.s[w13, #1]}, p6, [Cptr0, C1, lsl #2] + // Store to Cptr1 + N*(SVLs+1) + st1w {za3h.s[w13, #1]}, p7, [Cptr1, C1, lsl #2] + + add Cptr0, Cptr0, N, lsl #3 //Cptr0 += 2*N FP32 elements + add Cptr1, Cptr1, N, lsl #3 //Cptr1 += 2*N FP32 elements + add w13, w13, #2 + + psel p4, p0, p2.s[w13, 0] + psel p5, p1, p2.s[w13, 0] + psel p6, p0, p3.s[w13, 0] + psel p7, p1, p3.s[w13, 0] + st1w {za0h.s[w13, #0]}, p4, [Cptr0] + st1w {za1h.s[w13, #0]}, p5, [Cptr1] + st1w {za2h.s[w13, #0]}, p6, [Cptr0, C6, lsl #2] + st1w {za3h.s[w13, #0]}, p7, [Cptr1, C6, lsl #2] + cmp w13, w20 + b.mi .Loop_store_ZA + psel p4, p0, p2.s[w13, 1] + psel p5, p1, p2.s[w13, 1] + psel p6, p0, p3.s[w13, 1] + psel p7, p1, p3.s[w13, 1] + st1w {za0h.s[w13, #1]}, p4, [Cptr0, N, lsl #2] + st1w {za1h.s[w13, #1]}, p5, [Cptr1, N, lsl #2] + st1w {za2h.s[w13, #1]}, p6, [Cptr0, C1, lsl #2] + st1w {za3h.s[w13, #1]}, p7, [Cptr1, C1, lsl #2] + addvl Cptr, Cptr, #2 + addvl Bptr, Bptr, #1 + whilelt p0.b, Bptr, N_exit //1st Tile predicate (N dimension) + b.first .N_Loop + add A_base, A_base, C5, lsl #3 //A_base += 2*K*SVLs FP32 elements + add C_base, C_base, C6, lsl #3 //C_base += 2*N*SVLs FP32 elements + incw M_cntr + whilelt p2.s, M_cntr, M //1st Tile predicate (M dimension) + b.first .M_Loop + + smstop + + ldp x23, x24, [sp, #32] + ldp x21, x22, [sp, #16] + ldp x19, x20, [sp], #48 + + ret + diff --git a/kernel/arm64/sgemm_direct_sme1_preprocess.S b/kernel/arm64/sgemm_direct_sme1_preprocess.S new file mode 100644 index 0000000000..fa13620751 --- /dev/null +++ b/kernel/arm64/sgemm_direct_sme1_preprocess.S @@ -0,0 +1,133 @@ +/* + Copyright (c) 2025 Qualcomm Innovation Center, Inc. All rights reserved. + SPDX-License-Identifier: BSD-3-Clause-Clear +*/ + +/*---------------------------------------------------------------------------- + * This function is used to re-arrange the elements of input matrix to + * make it suitable for matrix outer product computation using SME for matrix + * multiplication. It should be used to pre-process the leftmatrix(A) in the + * matrix muliplication (C= A*B) using sgemm_direct_sme1_2VLx2VL() + * + * The pre-processing transposes a block of SVLs rows of the input matrix and + * stores it contiguously. The same is applied to remaining blocks of SVLs + * rows. The last block of SVLs rows is zero-padded to SVLs rows if needed. + * + * Usage of function: + * sgemm_direct_sme1_preprocess(uint64_t nrow, uint64_t ncol, \ + * const float * restrict mat, float * mat_mod); + * + ----------------------------------------------------------------------------*/ + + +#define nrow x0 //Number of rows of input matrix +#define ncol x1 //Number of coulumns of input matrix +#define mat x2 //Input matrix base address +#define mat_mod x3 //Output matrix (re-arranged matrix) base address +#define mat_mod_ptr x4 //Pointer to output matrix +#define mat_ptr0 x5 //Pointer to input matrix +#define mat_ptr1 x6 //2nd pointer to input matrix +#define outer_loop_cntr x7 //Outer loop counter +#define inner_loop_exit x8 //Inner loop exit condition +#define C1 x9 //Constant1: SVLs - No. of 32-bit elements +#define C2 x10 //Constant2: 3*SVLs +#define C3 x11 //Constant3: ncol*SVLs +#define C4 x13 //Constant4: 2*SVLs +#define C5 x14 //Constant5: 2*ncol +#define C6 x15 //Constant6: 3*ncol + + .text + .global sgemm_direct_sme1_preprocess + + sgemm_direct_sme1_preprocess: + + stp x19, x20, [sp, #-48]! + stp x21, x22, [sp, #16] + stp x23, x24, [sp, #32] + + smstart + + cntw C1 //SVLs + mul C3, C1, ncol //SVLs*ncol + lsl C5, ncol, #1 //2*ncol + add C6, C5, ncol //3*ncol + cnth C4 //2*SVLs + add C2, C1, C1, lsl #1 //3*SVLs + + mov outer_loop_cntr, #0 + //Tile predicate (M dimension) + whilelt p0.s, outer_loop_cntr, nrow + //Predicate for stores + ptrue p9.s + +.M_Loop: + mov mat_ptr0, mat //Load base address of mat + mov mat_mod_ptr, mat_mod //a_mod store base address + add inner_loop_exit, mat, ncol, lsl #2 //Exit condition for inner loop + whilelt p8.b, mat_ptr0, inner_loop_exit //Tile predicate (K dimension) + +.Loop_process: + mov mat_ptr1, mat_ptr0 + //Load_to_tile loop counter + mov w12, #0 + +.Load_to_tile: + psel p2, p8, p0.s[w12, 0] + psel p3, p8, p0.s[w12, 1] + psel p4, p8, p0.s[w12, 2] + psel p5, p8, p0.s[w12, 3] + //Load 1st row from mat_ptr1 + ld1w {za0h.s[w12, #0]}, p2/z, [mat_ptr1] + //Load 2nd row from mat_ptr1 + ncol + ld1w {za0h.s[w12, #1]}, p3/z, [mat_ptr1, ncol, lsl #2] + //Load 3rd row from mat_ptr1 + 2*ncol + ld1w {za0h.s[w12, #2]}, p4/z, [mat_ptr1, C5, lsl #2] + //Load 4th row from mat_ptr1 + 3*ncol + ld1w {za0h.s[w12, #3]}, p5/z, [mat_ptr1, C6, lsl #2] + //mat_ptr1+=4*ncol FP32 elements + add mat_ptr1, mat_ptr1, ncol, lsl #4 + //Increment counter + add w12, w12, #4 + cmp w12, w9 + b.mi .Load_to_tile + // Store_from_tile loop counter + mov w12, #0 + +.Store_from_tile: + psel p2, p9, p8.s[w12, 0] + psel p3, p9, p8.s[w12, 1] + psel p4, p9, p8.s[w12, 2] + psel p5, p9, p8.s[w12, 3] + //Store 1st col to mat_mod + st1w {za0v.s[w12, #0]}, p2, [mat_mod_ptr] + //Store 2nd col to mat_mod + SVLs + st1w {za0v.s[w12, #1]}, p3, [mat_mod_ptr, C1, lsl #2] + //Store 3rd col to mat_mod + 2*SVLs + st1w {za0v.s[w12, #2]}, p4, [mat_mod_ptr, C4, lsl #2] + //Store 4th col to mat_mod + 3*SVLs + st1w {za0v.s[w12, #3]}, p5, [mat_mod_ptr, C2, lsl #2] + + addvl mat_mod_ptr, mat_mod_ptr, #4 //mat_mod_ptr += 4*SVLb + add w12, w12, #4 //Increment counter + cmp w12, w9 + b.mi .Store_from_tile + + addvl mat_ptr0, mat_ptr0, #1 //mat_ptr0 += SVLb + whilelt p8.b, mat_ptr0, inner_loop_exit + b.first .Loop_process + + add mat_mod, mat_mod, C3, lsl #2 //mat_mod+=SVLs*nbc FP32 elements + add mat, mat, C3, lsl #2 //mat+=SVLs*nbc FP32 elements + incw outer_loop_cntr + + whilelt p0.s, outer_loop_cntr, nrow + b.first .M_Loop + + smstop + + ldp x23, x24, [sp, #32] + ldp x21, x22, [sp, #16] + ldp x19, x20, [sp], #48 + + ret + diff --git a/kernel/setparam-ref.c b/kernel/setparam-ref.c index fa61a209e1..a7dae781d5 100644 --- a/kernel/setparam-ref.c +++ b/kernel/setparam-ref.c @@ -179,6 +179,10 @@ gotoblas_t TABLE_NAME = { sgemm_directTS, sgemm_direct_performantTS, #endif +#ifdef ARCH_ARM64 + sgemm_direct, + NULL, +#endif sgemm_kernelTS, sgemm_betaTS, #if SGEMM_DEFAULT_UNROLL_M != SGEMM_DEFAULT_UNROLL_N diff --git a/param.h b/param.h index fee9195d02..51ebcbabbe 100644 --- a/param.h +++ b/param.h @@ -3303,6 +3303,8 @@ USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. #define GEMM_DEFAULT_OFFSET_A 0 #define GEMM_DEFAULT_OFFSET_B 0 + + #ifdef _WIN64 /* Use explicit casting for win64 as LLP64 datamodel is used */ #define GEMM_DEFAULT_ALIGN (BLASULONG)0x03fffUL @@ -3667,7 +3669,7 @@ Until then, just keep it different than DGEMM_DEFAULT_UNROLL_N to keep copy rout #define CGEMM_DEFAULT_R 4096 #define ZGEMM_DEFAULT_R 4096 -#elif defined(ARMV8SVE) || defined(ARMV9) || defined(CORTEXA510)|| defined(CORTEXA710) || defined(CORTEXX2) // 128-bit SVE +#elif defined(ARMV8SVE) || defined(ARMV9SME) || defined(ARMV9) || defined(CORTEXA510)|| defined(CORTEXA710) || defined(CORTEXX2) // 128-bit SVE #if defined(XDOUBLE) || defined(DOUBLE) #define SWITCH_RATIO 8 @@ -3738,6 +3740,10 @@ Until then, just keep it different than DGEMM_DEFAULT_UNROLL_N to keep copy rout #endif /* ARMv8 */ +#if defined(ARMV9SME) /* ARMv9 SME */ +#define USE_SGEMM_KERNEL_DIRECT 1 +#endif /* ARMv9 SME */ + #if defined(ARMV5) #define SNUMOPT 2 #define DNUMOPT 2