1if (NOT EXISTS $ENV{MUSA_PATH})
  2    if (NOT EXISTS /opt/musa)
  3        set(MUSA_PATH /usr/local/musa)
  4    else()
  5        set(MUSA_PATH /opt/musa)
  6    endif()
  7else()
  8    set(MUSA_PATH $ENV{MUSA_PATH})
  9endif()
 10
 11set(CMAKE_C_COMPILER "${MUSA_PATH}/bin/clang")
 12set(CMAKE_C_EXTENSIONS OFF)
 13set(CMAKE_CXX_COMPILER "${MUSA_PATH}/bin/clang++")
 14set(CMAKE_CXX_EXTENSIONS OFF)
 15
 16list(APPEND CMAKE_MODULE_PATH "${MUSA_PATH}/cmake")
 17
 18find_package(MUSAToolkit)
 19
 20if (MUSAToolkit_FOUND)
 21    message(STATUS "MUSA Toolkit found")
 22
 23    if (NOT DEFINED MUSA_ARCHITECTURES)
 24        set(MUSA_ARCHITECTURES "21;22;31")
 25    endif()
 26    message(STATUS "Using MUSA architectures: ${MUSA_ARCHITECTURES}")
 27
 28    file(GLOB   GGML_HEADERS_MUSA "../ggml-cuda/*.cuh")
 29    list(APPEND GGML_HEADERS_MUSA "../../include/ggml-cuda.h")
 30    list(APPEND GGML_HEADERS_MUSA "../ggml-musa/mudnn.cuh")
 31
 32    file(GLOB   GGML_SOURCES_MUSA "../ggml-cuda/*.cu")
 33    file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-tile*.cu")
 34    list(APPEND GGML_SOURCES_MUSA ${SRCS})
 35    file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-mma*.cu")
 36    list(APPEND GGML_SOURCES_MUSA ${SRCS})
 37    file(GLOB   SRCS "../ggml-cuda/template-instances/mmq*.cu")
 38    list(APPEND GGML_SOURCES_MUSA ${SRCS})
 39
 40    if (GGML_MUSA_MUDNN_COPY)
 41        file(GLOB   SRCS "../ggml-musa/*.cu")
 42        list(APPEND GGML_SOURCES_MUSA ${SRCS})
 43        add_compile_definitions(GGML_MUSA_MUDNN_COPY)
 44    endif()
 45
 46    if (GGML_CUDA_FA_ALL_QUANTS)
 47        file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*.cu")
 48        list(APPEND GGML_SOURCES_MUSA ${SRCS})
 49        add_compile_definitions(GGML_CUDA_FA_ALL_QUANTS)
 50    else()
 51        file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*q4_0-q4_0.cu")
 52        list(APPEND GGML_SOURCES_MUSA ${SRCS})
 53        file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*q8_0-q8_0.cu")
 54        list(APPEND GGML_SOURCES_MUSA ${SRCS})
 55        file(GLOB   SRCS "../ggml-cuda/template-instances/fattn-vec*f16-f16.cu")
 56        list(APPEND GGML_SOURCES_MUSA ${SRCS})
 57    endif()
 58
 59    set_source_files_properties(${GGML_SOURCES_MUSA} PROPERTIES LANGUAGE CXX)
 60    foreach(SOURCE ${GGML_SOURCES_MUSA})
 61        set(COMPILE_FLAGS "-Od3 -fno-strict-aliasing -ffast-math -fsigned-char -x musa -mtgpu -fmusa-flush-denormals-to-zero")
 62        foreach(ARCH ${MUSA_ARCHITECTURES})
 63            set(COMPILE_FLAGS "${COMPILE_FLAGS} --cuda-gpu-arch=mp_${ARCH}")
 64        endforeach()
 65        set_property(SOURCE ${SOURCE} PROPERTY COMPILE_FLAGS ${COMPILE_FLAGS})
 66    endforeach()
 67
 68    ggml_add_backend_library(ggml-musa
 69                             ${GGML_HEADERS_MUSA}
 70                             ${GGML_SOURCES_MUSA}
 71                            )
 72
 73    # TODO: do not use CUDA definitions for MUSA
 74    if (NOT GGML_BACKEND_DL)
 75        target_compile_definitions(ggml PUBLIC GGML_USE_CUDA)
 76    endif()
 77
 78    add_compile_definitions(GGML_USE_MUSA)
 79    add_compile_definitions(GGML_CUDA_PEER_MAX_BATCH_SIZE=${GGML_CUDA_PEER_MAX_BATCH_SIZE})
 80
 81    if (GGML_MUSA_GRAPHS)
 82        add_compile_definitions(GGML_MUSA_GRAPHS)
 83    endif()
 84
 85    if (GGML_CUDA_FORCE_MMQ)
 86        add_compile_definitions(GGML_CUDA_FORCE_MMQ)
 87    endif()
 88
 89    if (GGML_CUDA_FORCE_CUBLAS)
 90        add_compile_definitions(GGML_CUDA_FORCE_CUBLAS)
 91    endif()
 92
 93    if (GGML_CUDA_NO_VMM)
 94        add_compile_definitions(GGML_CUDA_NO_VMM)
 95    endif()
 96
 97    if (NOT GGML_CUDA_FA)
 98        add_compile_definitions(GGML_CUDA_NO_FA)
 99    endif()
100
101    if (GGML_CUDA_NO_PEER_COPY)
102        add_compile_definitions(GGML_CUDA_NO_PEER_COPY)
103    endif()
104
105    if (GGML_STATIC)
106        target_link_libraries(ggml-musa PRIVATE MUSA::musart_static MUSA::mublas_static)
107        # TODO: mudnn has not provided static libraries yet
108        # if (GGML_MUSA_MUDNN_COPY)
109        #     target_link_libraries(ggml-musa PRIVATE mudnn_static)
110        # endif()
111    else()
112        target_link_libraries(ggml-musa PRIVATE MUSA::musart MUSA::mublas)
113        if (GGML_MUSA_MUDNN_COPY)
114            target_link_libraries(ggml-musa PRIVATE mudnn)
115        endif()
116    endif()
117
118    if (GGML_CUDA_NO_VMM)
119        # No VMM requested, no need to link directly with the musa driver lib (libmusa.so)
120    else()
121        target_link_libraries(ggml-musa PRIVATE MUSA::musa_driver)
122    endif()
123else()
124    message(FATAL_ERROR "MUSA Toolkit not found")
125endif()