Merge pull request #292750 from CertainLach/torchaudio-rocm
torchaudio: add rocm support
This commit is contained in:
commit
51d92d050b
|
@ -495,7 +495,7 @@ in buildPythonPackage rec {
|
|||
requiredSystemFeatures = [ "big-parallel" ];
|
||||
|
||||
passthru = {
|
||||
inherit cudaSupport cudaPackages;
|
||||
inherit cudaSupport cudaPackages rocmSupport rocmPackages;
|
||||
# At least for 1.10.2 `torch.fft` is unavailable unless BLAS provider is MKL. This attribute allows for easy detection of its availability.
|
||||
blasProvider = blas.provider;
|
||||
# To help debug when a package is broken due to CUDA support
|
||||
|
|
|
@ -9,10 +9,46 @@
|
|||
, pybind11
|
||||
, sox
|
||||
, torch
|
||||
|
||||
, cudaSupport ? torch.cudaSupport
|
||||
, cudaPackages
|
||||
, rocmSupport ? torch.rocmSupport
|
||||
, rocmPackages
|
||||
|
||||
, gpuTargets ? []
|
||||
}:
|
||||
|
||||
let
|
||||
# TODO: Reuse one defined in torch?
|
||||
# Some of those dependencies are probbly not required,
|
||||
# but it breaks when the store path is different between torch and torchaudio
|
||||
rocmtoolkit_joined = symlinkJoin {
|
||||
name = "rocm-merged";
|
||||
|
||||
paths = with rocmPackages; [
|
||||
rocm-core clr rccl miopen miopengemm rocrand rocblas
|
||||
rocsparse hipsparse rocthrust rocprim hipcub roctracer
|
||||
rocfft rocsolver hipfft hipsolver hipblas
|
||||
rocminfo rocm-thunk rocm-comgr rocm-device-libs
|
||||
rocm-runtime clr.icd hipify
|
||||
];
|
||||
|
||||
# Fix `setuptools` not being found
|
||||
postBuild = ''
|
||||
rm -rf $out/nix-support
|
||||
'';
|
||||
};
|
||||
# Only used for ROCm
|
||||
gpuTargetString = lib.strings.concatStringsSep ";" (
|
||||
if gpuTargets != [ ] then
|
||||
# If gpuTargets is specified, it always takes priority.
|
||||
gpuTargets
|
||||
else if rocmSupport then
|
||||
rocmPackages.clr.gpuTargets
|
||||
else
|
||||
throw "No GPU targets specified"
|
||||
);
|
||||
in
|
||||
buildPythonPackage rec {
|
||||
pname = "torchaudio";
|
||||
version = "2.3.0";
|
||||
|
@ -33,6 +69,11 @@ buildPythonPackage rec {
|
|||
substituteInPlace setup.py \
|
||||
--replace 'print(" --- Initializing submodules")' "return" \
|
||||
--replace "_fetch_archives(_parse_sources())" "pass"
|
||||
''
|
||||
+ lib.optionalString rocmSupport ''
|
||||
# There is no .info/version-dev, only .info/version
|
||||
substituteInPlace cmake/LoadHIP.cmake \
|
||||
--replace "/.info/version-dev" "/.info/version"
|
||||
'';
|
||||
|
||||
env = {
|
||||
|
@ -55,7 +96,11 @@ buildPythonPackage rec {
|
|||
ninja
|
||||
] ++ lib.optionals cudaSupport [
|
||||
cudaPackages.cuda_nvcc
|
||||
];
|
||||
] ++ lib.optionals rocmSupport (with rocmPackages; [
|
||||
clr
|
||||
rocblas
|
||||
hipblas
|
||||
]);
|
||||
|
||||
buildInputs = [
|
||||
ffmpeg-full
|
||||
|
@ -73,6 +118,11 @@ buildPythonPackage rec {
|
|||
BUILD_RNNT=0;
|
||||
BUILD_CTC_DECODER=0;
|
||||
|
||||
preConfigure = lib.optionalString rocmSupport ''
|
||||
export ROCM_PATH=${rocmtoolkit_joined}
|
||||
export PYTORCH_ROCM_ARCH="${gpuTargetString}"
|
||||
'';
|
||||
|
||||
dontUseCmakeConfigure = true;
|
||||
|
||||
doCheck = false; # requires sox backend
|
||||
|
|
Loading…
Reference in New Issue
Block a user