diff --git a/Directory.Build.props b/Directory.Build.props
index 313c39566..8248f291d 100644
--- a/Directory.Build.props
+++ b/Directory.Build.props
@@ -5,6 +5,7 @@
+
Debug
Debug;Release
<_DefaultArchitecture>$([System.Runtime.InteropServices.RuntimeInformation]::OSArchitecture.ToString().ToLower())
@@ -169,6 +170,9 @@
$(DefineContants);DEBUG
false
+
+ $(DefineContants);CUDA_TOOLKIT_FOUND
+
true
diff --git a/Directory.Build.targets b/Directory.Build.targets
index cc6bb3d4e..cd48c04ee 100644
--- a/Directory.Build.targets
+++ b/Directory.Build.targets
@@ -94,6 +94,7 @@
+
diff --git a/pkg/pack.proj b/pkg/pack.proj
index 55474fdc6..b29502f45 100644
--- a/pkg/pack.proj
+++ b/pkg/pack.proj
@@ -1,6 +1,7 @@
-
+
+
@@ -31,5 +32,4 @@
-
diff --git a/src/Native/LibTorchSharp/CMakeLists.txt b/src/Native/LibTorchSharp/CMakeLists.txt
index 8e5e1e38a..fb332de5b 100644
--- a/src/Native/LibTorchSharp/CMakeLists.txt
+++ b/src/Native/LibTorchSharp/CMakeLists.txt
@@ -1,5 +1,12 @@
project(LibTorchSharp)
+find_package(CUDA QUIET)
+if(CUDA_FOUND)
+ include_directories(${CUDA_INCLUDE_DIRS})
+ link_directories(${CUDA_LIBRARY_DIRS})
+ add_compile_definitions(TORCHSHARP_CUDA_TOOLKIT_FOUND)
+endif()
+
if(APPLE AND NOT LIBTORCH_ARCH STREQUAL "arm64")
include_directories("/usr/local/include" "/usr/local/opt/llvm/include")
link_directories("/usr/local/lib" "/usr/local/opt/llvm/lib")
@@ -72,6 +79,10 @@ include_directories(${TORCH_INCLUDE_DIRS})
add_library(LibTorchSharp SHARED ${SOURCES} ${RESOURCES})
+if(CUDA_FOUND)
+target_link_libraries(LibTorchSharp ${CUDA_LIBRARIES})
+endif()
+
target_link_libraries(LibTorchSharp ${TORCH_LIBRARIES})
set_property(TARGET LibTorchSharp PROPERTY CXX_STANDARD 14)
diff --git a/src/Native/build.proj b/src/Native/build.proj
index 6dbbc70a9..1f67671a7 100644
--- a/src/Native/build.proj
+++ b/src/Native/build.proj
@@ -44,17 +44,20 @@
-
+
$(NativeConfiguration) $(TargetArchitecture) --libtorchpath $(LibTorchCmakePath)
-
+
+
+ $(NativeConfiguration) $(TargetArchitecture) --libtorchpath $(CustomLibTorchFullPath)
+
-
+
throw new NotImplementedException();
+ bool density = false)=> throw new NotImplementedException();
// https://pytorch.org/docs/stable/generated/torch.histogram
[Obsolete("not implemented", true)]
- static Tensor histogram(
- Tensor input,
+ static Tensor histogram(Tensor input,
long[] bins,
(float min, float max)? range = null,
Tensor? weight = null,
bool density = false)
- => throw new NotImplementedException();
+ {
+ throw new NotImplementedException();
+ }
+
// https://pytorch.org/docs/stable/generated/torch.histogram
[Obsolete("not implemented", true)]
- static Tensor histogram(
- Tensor input,
+ static Tensor histogram(Tensor input,
Tensor[] bins,
(float min, float max)? range = null,
Tensor? weight = null,
bool density = false)
- => throw new NotImplementedException();
+ {
+ throw new NotImplementedException();
+ }
// https://pytorch.org/docs/stable/generated/torch.histogramdd
[Obsolete("not implemented", true)]
diff --git a/src/TorchSharp/TorchSharp.csproj b/src/TorchSharp/TorchSharp.csproj
index 2d227e2c8..09be939a3 100644
--- a/src/TorchSharp/TorchSharp.csproj
+++ b/src/TorchSharp/TorchSharp.csproj
@@ -1,4 +1,4 @@
-
+
@@ -11,6 +11,7 @@
false
false
$(DefineConstants);LIBTORCH_$(LibTorchPackageVersion.Replace('.', '_'));CUDA_$(CudaVersionDot.Replace('.', '_'))
+
@@ -61,7 +62,8 @@
-
+
+
@@ -70,8 +72,10 @@
-
-
+
+
+
+
diff --git a/test/Directory.Build.props b/test/Directory.Build.props
index de003c15a..72877ac85 100644
--- a/test/Directory.Build.props
+++ b/test/Directory.Build.props
@@ -6,7 +6,7 @@
$(TargetFrameworks);net48
false
true
-
+