diff --git a/Directory.Build.props b/Directory.Build.props index 313c39566..8248f291d 100644 --- a/Directory.Build.props +++ b/Directory.Build.props @@ -5,6 +5,7 @@ + Debug Debug;Release <_DefaultArchitecture>$([System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture.ToString().ToLower()) @@ -169,6 +170,9 @@ $(DefineContants);DEBUG false + + $(DefineContants);CUDA_TOOLKIT_FOUND + true diff --git a/Directory.Build.targets b/Directory.Build.targets index cc6bb3d4e..cd48c04ee 100644 --- a/Directory.Build.targets +++ b/Directory.Build.targets @@ -94,6 +94,7 @@ + diff --git a/pkg/pack.proj b/pkg/pack.proj index 55474fdc6..b29502f45 100644 --- a/pkg/pack.proj +++ b/pkg/pack.proj @@ -1,6 +1,7 @@ - + + @@ -31,5 +32,4 @@ - diff --git a/src/Native/LibTorchSharp/CMakeLists.txt b/src/Native/LibTorchSharp/CMakeLists.txt index 8e5e1e38a..fb332de5b 100644 --- a/src/Native/LibTorchSharp/CMakeLists.txt +++ b/src/Native/LibTorchSharp/CMakeLists.txt @@ -1,5 +1,12 @@ project(LibTorchSharp) +find_package(CUDA QUIET) +if(CUDA_FOUND) + include_directories(${CUDA_INCLUDE_DIRS}) + link_directories(${CUDA_LIBRARY_DIRS}) + add_compile_definitions(TORCHSHARP_CUDA_TOOLKIT_FOUND) +endif() + if(APPLE AND NOT LIBTORCH_ARCH STREQUAL "arm64") include_directories("/usr/local/include" "/usr/local/opt/llvm/include") link_directories("/usr/local/lib" "/usr/local/opt/llvm/lib") @@ -72,6 +79,10 @@ include_directories(${TORCH_INCLUDE_DIRS}) add_library(LibTorchSharp SHARED ${SOURCES} ${RESOURCES}) +if(CUDA_FOUND) +target_link_libraries(LibTorchSharp ${CUDA_LIBRARIES}) +endif() + target_link_libraries(LibTorchSharp ${TORCH_LIBRARIES}) set_property(TARGET LibTorchSharp PROPERTY CXX_STANDARD 14) diff --git a/src/Native/build.proj b/src/Native/build.proj index 6dbbc70a9..1f67671a7 100644 --- a/src/Native/build.proj +++ b/src/Native/build.proj @@ -44,17 +44,20 @@ - + $(NativeConfiguration) $(TargetArchitecture) --libtorchpath $(LibTorchCmakePath) - + + + $(NativeConfiguration) $(TargetArchitecture) --libtorchpath $(CustomLibTorchFullPath) + - + throw new NotImplementedException(); + bool density = false)=> throw new NotImplementedException(); // https://pytorch.org/docs/stable/generated/torch.histogram [Obsolete("not implemented", true)] - static Tensor histogram( - Tensor input, + static Tensor histogram(Tensor input, long[] bins, (float min, float max)? range = null, Tensor? weight = null, bool density = false) - => throw new NotImplementedException(); + { + throw new NotImplementedException(); + } + // https://pytorch.org/docs/stable/generated/torch.histogram [Obsolete("not implemented", true)] - static Tensor histogram( - Tensor input, + static Tensor histogram(Tensor input, Tensor[] bins, (float min, float max)? range = null, Tensor? weight = null, bool density = false) - => throw new NotImplementedException(); + { + throw new NotImplementedException(); + } // https://pytorch.org/docs/stable/generated/torch.histogramdd [Obsolete("not implemented", true)] diff --git a/src/TorchSharp/TorchSharp.csproj b/src/TorchSharp/TorchSharp.csproj index 2d227e2c8..09be939a3 100644 --- a/src/TorchSharp/TorchSharp.csproj +++ b/src/TorchSharp/TorchSharp.csproj @@ -1,4 +1,4 @@ - + @@ -11,6 +11,7 @@ false false $(DefineConstants);LIBTORCH_$(LibTorchPackageVersion.Replace('.', '_'));CUDA_$(CudaVersionDot.Replace('.', '_')) + @@ -61,7 +62,8 @@ - + + @@ -70,8 +72,10 @@ - - + + + + diff --git a/test/Directory.Build.props b/test/Directory.Build.props index de003c15a..72877ac85 100644 --- a/test/Directory.Build.props +++ b/test/Directory.Build.props @@ -6,7 +6,7 @@ $(TargetFrameworks);net48 false true - +