Merge pull request #292750 from CertainLach/torchaudio-rocm

torchaudio: add rocm support
This commit is contained in:
Aleksana 2024-05-18 13:21:13 +08:00 committed by GitHub
commit 51d92d050b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 52 additions and 2 deletions

View File

@ -495,7 +495,7 @@ in buildPythonPackage rec {
requiredSystemFeatures = [ "big-parallel" ];
passthru = {
inherit cudaSupport cudaPackages;
inherit cudaSupport cudaPackages rocmSupport rocmPackages;
# At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
blasProvider = blas.provider;
# To help debug when a package is broken due to CUDA support

View File

@ -9,10 +9,46 @@
, pybind11
, sox
, torch
, cudaSupport ? torch.cudaSupport
, cudaPackages
, rocmSupport ? torch.rocmSupport
, rocmPackages
, gpuTargets ? []
}:
let
# TODO: Reuse one defined in torch?
# Some of those dependencies are probbly not required,
# but it breaks when the store path is different between torch and torchaudio
rocmtoolkit_joined = symlinkJoin {
name = "rocm-merged";
paths = with rocmPackages; [
rocm-core clr rccl miopen miopengemm rocrand rocblas
rocsparse hipsparse rocthrust rocprim hipcub roctracer
rocfft rocsolver hipfft hipsolver hipblas
rocminfo rocm-thunk rocm-comgr rocm-device-libs
rocm-runtime clr.icd hipify
];
# Fix `setuptools` not being found
postBuild = ''
rm -rf $out/nix-support
'';
};
# Only used for ROCm
gpuTargetString = lib.strings.concatStringsSep ";" (
if gpuTargets != [ ] then
# If gpuTargets is specified, it always takes priority.
gpuTargets
else if rocmSupport then
rocmPackages.clr.gpuTargets
else
throw "No GPU targets specified"
);
in
buildPythonPackage rec {
pname = "torchaudio";
version = "2.3.0";
@ -33,6 +69,11 @@ buildPythonPackage rec {
substituteInPlace setup.py \
--replace 'print(" --- Initializing submodules")' "return" \
--replace "_fetch_archives(_parse_sources())" "pass"
''
+ lib.optionalString rocmSupport ''
# There is no .info/version-dev, only .info/version
substituteInPlace cmake/LoadHIP.cmake \
--replace "/.info/version-dev" "/.info/version"
'';
env = {
@ -55,7 +96,11 @@ buildPythonPackage rec {
ninja
] ++ lib.optionals cudaSupport [
cudaPackages.cuda_nvcc
];
] ++ lib.optionals rocmSupport (with rocmPackages; [
clr
rocblas
hipblas
]);
buildInputs = [
ffmpeg-full
@ -73,6 +118,11 @@ buildPythonPackage rec {
BUILD_RNNT=0;
BUILD_CTC_DECODER=0;
preConfigure = lib.optionalString rocmSupport ''
export ROCM_PATH=${rocmtoolkit_joined}
export PYTORCH_ROCM_ARCH="${gpuTargetString}"
'';
dontUseCmakeConfigure = true;
doCheck = false; # requires sox backend