diff --git a/CMakeLists.txt b/CMakeLists.txt
index 8e88b77..a0d9fe9 100644
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -13,14 +13,12 @@ if(CCACHE_PROGRAM)
endif()
list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake")
+list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/utils")
+list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/dependencies")
set(CMAKE_CXX_STANDARD 20)
set(CMAKE_CXX_STANDARD_REQUIRED ON)
-# The CUDA standard is still C++14 to enable interopability with
-# slightly older and still well-supported versions of CUDA/nvcc
-# (e.g. CUDA < 11). This will be bumped to 17 once CUDA 11 is
-# required.
-set(CMAKE_CUDA_STANDARD 14)
+set(CMAKE_CUDA_STANDARD 20)
set(CMAKE_CUDA_STANDARD_REQUIRED ON)
# no modules in this library
diff --git a/Folder.DotSettings b/Folder.DotSettings
new file mode 100644
index 0000000..ea9fb4b
--- /dev/null
+++ b/Folder.DotSettings
@@ -0,0 +1,6 @@
+
+ <NamingElement Priority="6" Title="Parameters"><Descriptor Static="Indeterminate" Constexpr="Indeterminate" Const="Indeterminate" Volatile="Indeterminate" Accessibility="NOT_APPLICABLE"><type Name="function parameter" /><type Name="lambda parameter" /></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="" Suffix="" Style="aaBb"><ExtraRule Prefix="_" Suffix="" Style="aaBb" /></Policy></NamingElement>
+ <NamingElement Priority="16" Title="Other constants"><Descriptor Static="True" Constexpr="Indeterminate" Const="True" Volatile="Indeterminate" Accessibility="NOT_APPLICABLE"><type Name="class field" /><type Name="local variable" /><type Name="struct field" /></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="" Suffix="" Style="AA_BB"><ExtraRule Prefix="" Suffix="" Style="aa_bb" /></Policy></NamingElement>
+ <NamingElement Priority="15" Title="Enumerators"><Descriptor Static="Indeterminate" Constexpr="Indeterminate" Const="Indeterminate" Volatile="Indeterminate" Accessibility="NOT_APPLICABLE"><type Name="scoped enumerator" /><type Name="unscoped enumerator" /></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="" Suffix="" Style="AA_BB"><ExtraRule Prefix="" Suffix="" Style="aa_bb" /></Policy></NamingElement>
+ <NamingElement Priority="3" Title="Enums"><Descriptor Static="Indeterminate" Constexpr="Indeterminate" Const="Indeterminate" Volatile="Indeterminate" Accessibility="NOT_APPLICABLE"><type Name="enum" /></Descriptor><Policy Inspect="True" WarnAboutPrefixesAndSuffixes="False" Prefix="" Suffix="" Style="AaBb_AaBb"><ExtraRule Prefix="" Suffix="" Style="aa_bb" /></Policy></NamingElement>
+ True
\ No newline at end of file
diff --git a/cmake/dependencies/FindCUDNN.cmake b/cmake/dependencies/FindCUDNN.cmake
index a150310..fd77eea 100644
--- a/cmake/dependencies/FindCUDNN.cmake
+++ b/cmake/dependencies/FindCUDNN.cmake
@@ -76,4 +76,8 @@ if(CUDNN_FOUND)
endif()
endif()
+if (CUDNN_FOUND AND CUDNN_VERSION VERSION_LESS "8.0")
+ message(FATAL_ERROR "Flashlight requires cuDNN >= 8.0, found ${CUDNN_VERSION}")
+endif()
+
mark_as_advanced(CUDNN_ROOT CUDNN_INCLUDE_DIR CUDNN_LIBRARY CUDNN_VERSION)
diff --git a/cmake/utils/flashlightConfig.cmake.in b/cmake/utils/flashlightConfig.cmake.in
index 2bf2550..9d73423 100644
--- a/cmake/utils/flashlightConfig.cmake.in
+++ b/cmake/utils/flashlightConfig.cmake.in
@@ -49,7 +49,7 @@ if (@FL_BUILD_STANDALONE@)
endif()
if (@FL_USE_CUDA@)
if (@FL_USE_CUDNN@)
- find_dependency(CUDNN 7.1)
+ find_dependency(CUDNN 8)
endif()
if (@FL_BUILD_DISTRIBUTED@)
find_dependency(NCCL)
diff --git a/cmake/utils/fm_target_utilities.cmake b/cmake/utils/fm_target_utilities.cmake
index c3ad681..f5310d7 100644
--- a/cmake/utils/fm_target_utilities.cmake
+++ b/cmake/utils/fm_target_utilities.cmake
@@ -53,10 +53,17 @@ function(fm_glob OUT_VAR)
set(GLOB_PATTERNS ${ARG_PATTERNS})
endif()
- if(GLOB_PATTERNS)
+ # Normalize paths to prevent CONFIGURE_DEPENDS cache mismatch issues on Windows
+ set(NORMALIZED_PATTERNS "")
+ foreach(PATTERN IN LISTS GLOB_PATTERNS)
+ cmake_path(ABSOLUTE_PATH PATTERN NORMALIZE OUTPUT_VARIABLE NORMALIZED)
+ list(APPEND NORMALIZED_PATTERNS "${NORMALIZED}")
+ endforeach()
+
+ if(NORMALIZED_PATTERNS)
file(GLOB_RECURSE FOUND_FILES
CONFIGURE_DEPENDS
- ${GLOB_PATTERNS}
+ ${NORMALIZED_PATTERNS}
)
set(${OUT_VAR} ${${OUT_VAR}} ${FOUND_FILES} PARENT_SCOPE)
endif()
diff --git a/flashlight/fl/autograd/Functions.cpp b/flashlight/fl/autograd/Functions.cpp
index 70b7a9e..3eac143 100644
--- a/flashlight/fl/autograd/Functions.cpp
+++ b/flashlight/fl/autograd/Functions.cpp
@@ -1,8 +1,8 @@
/*
- * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * SPDX-License-Identifier: MIT
*
- * This source code is licensed under the MIT license found in the
- * LICENSE file in the root directory of this source tree.
+ * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE)
+ * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE)
*/
#include
@@ -24,7 +24,7 @@
namespace fl {
namespace detail {
- Tensor tileAs(const Tensor& input, const Shape& rdims) {
+ Tensor tileAs(Tensor const& input, Shape const& rdims) {
// Scalar tensor
if(input.ndim() == 0)
return tile(input, rdims);
@@ -36,7 +36,7 @@ namespace detail {
if(rdims[i] % idimsSize != 0) {
std::stringstream ss;
ss << "Invalid dims for tileAs for input dims " << idims
- << " to output dims " << rdims;
+ << " to output dims " << rdims;
throw std::invalid_argument(ss.str());
}
dims[i] = rdims[i] / idimsSize;
@@ -44,19 +44,19 @@ namespace detail {
return tile(input, dims);
}
- Tensor sumAs(const Tensor& input, const Shape& rdims) {
+ Tensor sumAs(Tensor const& input, Shape const& rdims) {
Shape idims = input.shape();
auto result = input;
for(int i = 0; i < input.ndim(); i++)
if(i + 1 > rdims.ndim() || idims[i] != rdims[i])
result = fl::sum(result, {i}, /* keepDims = */ true);
- return fl::reshape(result.astype(input.type()), rdims);
+ return fl::reshape(result.asType(input.type()), rdims);
}
Shape expandedShapeFromReducedDims(
- const Tensor& input,
- const std::vector& axes,
+ Tensor const& input,
+ std::vector const& axes,
bool keepDims /* = false */
) {
// Fast path - tensor already retained its shape
@@ -72,7 +72,7 @@ namespace detail {
unsigned inputIdx = 0;
for(unsigned i = 0; i < preNDims; ++i) {
if(i == axes[axesIdx])
- // This dim was reduced over, leave as 1 in the new shape
+ // This dim was reduced over, leave as 1 in the new shape
axesIdx++;
else {
// Dim wasn't reduced over - add the shape from the new tensor
@@ -83,10 +83,10 @@ namespace detail {
return newShape;
}
-// TODO: remove these/use a simple template
+ // TODO: remove these/use a simple template
Variable expandFromReduction(
- const Variable& input,
- const std::vector& axes,
+ Variable const& input,
+ std::vector const& axes,
bool keepDims
) {
return moddims(
@@ -96,8 +96,8 @@ namespace detail {
}
Tensor expandFromReduction(
- const Tensor& input,
- const std::vector& axes,
+ Tensor const& input,
+ std::vector const& axes,
bool keepDims
) {
auto o = expandedShapeFromReducedDims(input, axes, keepDims);
@@ -107,75 +107,87 @@ namespace detail {
);
}
- bool areVariableTypesEqual(const Variable& a, const Variable& b) { return a.type() == b.type(); }
+ bool areVariableTypesEqual(Variable const& a, Variable const& b) { return a.type() == b.type(); }
} // namespace detail
-Variable operator+(const Variable& lhs, const Variable& rhs) {
+Variable operator+(Variable const& lhs, Variable const& rhs) {
FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs);
auto result = lhs.tensor() + rhs.tensor();
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- inputs[0].addGrad(Variable(gradOutput.tensor(), false));
- inputs[1].addGrad(Variable(gradOutput.tensor(), false));
- };
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ inputs[0].addGrad(Variable(gradOutput.tensor(), false));
+ inputs[1].addGrad(Variable(gradOutput.tensor(), false));
+ };
return Variable(result, {lhs.withoutData(), rhs.withoutData()}, gradFunc);
}
-Variable operator+(const Variable& lhs, const double& rhsVal) {
- auto result = (lhs.tensor() + rhsVal).astype(lhs.type());
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- inputs[0].addGrad(Variable(gradOutput.tensor(), false));
- };
+Variable operator+(Variable const& lhs, double const& rhsVal) {
+ auto result = (lhs.tensor() + rhsVal).asType(lhs.type());
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ inputs[0].addGrad(Variable(gradOutput.tensor(), false));
+ };
return Variable(result, {lhs.withoutData()}, gradFunc);
}
-Variable operator+(const double& lhsVal, const Variable& rhs) { return rhs + lhsVal; }
+Variable operator+(double const& lhsVal, Variable const& rhs) { return rhs + lhsVal; }
-Variable operator-(const Variable& lhs, const Variable& rhs) {
+Variable operator-(Variable const& lhs, Variable const& rhs) {
FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs);
auto result = lhs.tensor() - rhs.tensor();
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- inputs[0].addGrad(Variable(gradOutput.tensor(), false));
- inputs[1].addGrad(Variable(negate(gradOutput).tensor(), false));
- };
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ inputs[0].addGrad(Variable(gradOutput.tensor(), false));
+ inputs[1].addGrad(Variable(negate(gradOutput).tensor(), false));
+ };
return Variable(result, {lhs.withoutData(), rhs.withoutData()}, gradFunc);
}
-Variable operator-(const Variable& lhs, const double& rhsVal) {
- auto result = (lhs.tensor() - rhsVal).astype(lhs.type());
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- inputs[0].addGrad(Variable(gradOutput.tensor(), false));
- };
+Variable operator-(Variable const& lhs, double const& rhsVal) {
+ auto result = (lhs.tensor() - rhsVal).asType(lhs.type());
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ inputs[0].addGrad(Variable(gradOutput.tensor(), false));
+ };
return Variable(result, {lhs.withoutData()}, gradFunc);
}
-Variable operator-(const double& lhsVal, const Variable& rhs) {
- auto result = (lhsVal - rhs.tensor()).astype(rhs.type());
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- inputs[0].addGrad(Variable(negate(gradOutput).tensor(), false));
- };
+Variable operator-(double const& lhsVal, Variable const& rhs) {
+ auto result = (lhsVal - rhs.tensor()).asType(rhs.type());
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ inputs[0].addGrad(Variable(negate(gradOutput).tensor(), false));
+ };
return Variable(result, {rhs.withoutData()}, gradFunc);
}
-Variable operator*(const Variable& lhs, const Variable& rhs) {
+Variable operator*(Variable const& lhs, Variable const& rhs) {
FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs);
auto result = lhs.tensor() * rhs.tensor();
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- if(inputs[0].isCalcGrad())
- inputs[0].addGrad(
- Variable(gradOutput.tensor() * inputs[1].tensor(), false)
- );
- if(inputs[1].isCalcGrad())
- inputs[1].addGrad(
- Variable(gradOutput.tensor() * inputs[0].tensor(), false)
- );
- };
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ if(inputs[0].isCalcGrad())
+ inputs[0].addGrad(
+ Variable(gradOutput.tensor() * inputs[1].tensor(), false)
+ );
+ if(inputs[1].isCalcGrad())
+ inputs[1].addGrad(
+ Variable(gradOutput.tensor() * inputs[0].tensor(), false)
+ );
+ };
return Variable(
result,
{
@@ -186,34 +198,35 @@ Variable operator*(const Variable& lhs, const Variable& rhs) {
);
}
-Variable operator*(const Variable& lhs, const double& rhsVal) {
- auto result = (lhs.tensor() * rhsVal).astype(lhs.type());
- auto gradFunc =
- [rhsVal](std::vector& inputs, const Variable& gradOutput) {
- inputs[0].addGrad(Variable(gradOutput.tensor() * rhsVal, false));
- };
+Variable operator*(Variable const& lhs, double const& rhsVal) {
+ auto result = (lhs.tensor() * rhsVal).asType(lhs.type());
+ auto gradFunc = [rhsVal](std::vector& inputs, Variable const& gradOutput) {
+ inputs[0].addGrad(Variable(gradOutput.tensor() * rhsVal, false));
+ };
return Variable(result, {lhs.withoutData()}, gradFunc);
}
-Variable operator*(const double& lhsVal, const Variable& rhs) { return rhs * lhsVal; }
+Variable operator*(double const& lhsVal, Variable const& rhs) { return rhs * lhsVal; }
-Variable operator/(const Variable& lhs, const Variable& rhs) {
+Variable operator/(Variable const& lhs, Variable const& rhs) {
FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs);
auto result = lhs.tensor() / rhs.tensor();
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- auto inputs1rec = reciprocal(inputs[1]);
- auto gradInput0 = gradOutput * inputs1rec;
- if(inputs[0].isCalcGrad())
- inputs[0].addGrad(Variable(gradInput0.tensor(), false));
- if(inputs[1].isCalcGrad())
- inputs[1].addGrad(
- Variable(
- (gradInput0 * negate(inputs[0]) * inputs1rec).tensor(),
- false
- )
- );
- };
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ auto inputs1rec = reciprocal(inputs[1]);
+ auto gradInput0 = gradOutput * inputs1rec;
+ if(inputs[0].isCalcGrad())
+ inputs[0].addGrad(Variable(gradInput0.tensor(), false));
+ if(inputs[1].isCalcGrad())
+ inputs[1].addGrad(
+ Variable(
+ (gradInput0 * negate(inputs[0]) * inputs1rec).tensor(),
+ false
+ )
+ );
+ };
return Variable(
result,
{rhs.isCalcGrad() ? lhs : lhs.withoutData(), rhs},
@@ -221,368 +234,395 @@ Variable operator/(const Variable& lhs, const Variable& rhs) {
);
}
-Variable operator/(const Variable& lhs, const double& rhsVal) {
- auto result = (lhs.tensor() / rhsVal).astype(lhs.type());
+Variable operator/(Variable const& lhs, double const& rhsVal) {
+ auto result = (lhs.tensor() / rhsVal).asType(lhs.type());
auto gradFunc =
- [rhsVal](std::vector& inputs, const Variable& gradOutput) {
- inputs[0].addGrad(Variable((gradOutput / rhsVal).tensor(), false));
- };
+ [rhsVal](std::vector& inputs, Variable const& gradOutput) {
+ inputs[0].addGrad(Variable((gradOutput / rhsVal).tensor(), false));
+ };
return Variable(result, {lhs.withoutData()}, gradFunc);
}
-Variable operator/(const double& lhsVal, const Variable& rhs) {
- auto result = (lhsVal / rhs.tensor()).astype(rhs.type());
+Variable operator/(double const& lhsVal, Variable const& rhs) {
+ auto result = (lhsVal / rhs.tensor()).asType(rhs.type());
auto gradFunc = [lhsVal](
std::vector& inputs,
- const Variable& gradOutput) {
- inputs[0].addGrad(
- Variable(
- (gradOutput * (-lhsVal) / (inputs[0] * inputs[0])).tensor(),
- false
- )
- );
- };
+ Variable const& gradOutput
+ ) {
+ inputs[0].addGrad(
+ Variable(
+ (gradOutput * (-lhsVal) / (inputs[0] * inputs[0])).tensor(),
+ false
+ )
+ );
+ };
return Variable(result, {rhs}, gradFunc);
}
-Variable operator>(const Variable& lhs, const Variable& rhs) {
+Variable operator>(Variable const& lhs, Variable const& rhs) {
FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs);
auto result = lhs.tensor() > rhs.tensor();
return Variable(result, false);
}
-Variable operator>(const Variable& lhs, const double& rhsVal) {
- auto result = (lhs.tensor() > rhsVal).astype(lhs.type());
+Variable operator>(Variable const& lhs, double const& rhsVal) {
+ auto result = (lhs.tensor() > rhsVal).asType(lhs.type());
return Variable(result, false);
}
-Variable operator>(const double& lhsVal, const Variable& rhs) {
- auto result = (lhsVal > rhs.tensor()).astype(rhs.type());
+Variable operator>(double const& lhsVal, Variable const& rhs) {
+ auto result = (lhsVal > rhs.tensor()).asType(rhs.type());
return Variable(result, false);
}
-Variable operator<(const Variable& lhs, const Variable& rhs) {
+Variable operator<(Variable const& lhs, Variable const& rhs) {
FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs);
auto result = lhs.tensor() < rhs.tensor();
return Variable(result, false);
}
-Variable operator<(const Variable& lhs, const double& rhsVal) {
- auto result = (lhs.tensor() < rhsVal).astype(lhs.type());
+Variable operator<(Variable const& lhs, double const& rhsVal) {
+ auto result = (lhs.tensor() < rhsVal).asType(lhs.type());
return Variable(result, false);
}
-Variable operator<(const double& lhsVal, const Variable& rhs) {
- auto result = (lhsVal < rhs.tensor()).astype(rhs.type());
+Variable operator<(double const& lhsVal, Variable const& rhs) {
+ auto result = (lhsVal < rhs.tensor()).asType(rhs.type());
return Variable(result, false);
}
-Variable operator>=(const Variable& lhs, const Variable& rhs) {
+Variable operator>=(Variable const& lhs, Variable const& rhs) {
FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs);
auto result = lhs.tensor() >= rhs.tensor();
return Variable(result, false);
}
-Variable operator>=(const Variable& lhs, const double& rhsVal) {
- auto result = (lhs.tensor() >= rhsVal).astype(lhs.type());
+Variable operator>=(Variable const& lhs, double const& rhsVal) {
+ auto result = (lhs.tensor() >= rhsVal).asType(lhs.type());
return Variable(result, false);
}
-Variable operator>=(const double& lhsVal, const Variable& rhs) {
- auto result = (lhsVal >= rhs.tensor()).astype(rhs.type());
+Variable operator>=(double const& lhsVal, Variable const& rhs) {
+ auto result = (lhsVal >= rhs.tensor()).asType(rhs.type());
return Variable(result, false);
}
-Variable operator<=(const Variable& lhs, const Variable& rhs) {
+Variable operator<=(Variable const& lhs, Variable const& rhs) {
FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs);
auto result = lhs.tensor() <= rhs.tensor();
return Variable(result, false);
}
-Variable operator<=(const Variable& lhs, const double& rhsVal) {
- auto result = (lhs.tensor() <= rhsVal).astype(lhs.type());
+Variable operator<=(Variable const& lhs, double const& rhsVal) {
+ auto result = (lhs.tensor() <= rhsVal).asType(lhs.type());
return Variable(result, false);
}
-Variable operator<=(const double& lhsVal, const Variable& rhs) {
- auto result = (lhsVal <= rhs.tensor()).astype(rhs.type());
+Variable operator<=(double const& lhsVal, Variable const& rhs) {
+ auto result = (lhsVal <= rhs.tensor()).asType(rhs.type());
return Variable(result, false);
}
-Variable operator&&(const Variable& lhs, const Variable& rhs) {
+Variable operator&&(Variable const& lhs, Variable const& rhs) {
FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs);
auto result = lhs.tensor() && rhs.tensor();
return Variable(result, false);
}
-Variable operator!(const Variable& input) {
- auto result = (!input.tensor()).astype(input.type());
+Variable operator!(Variable const& input) {
+ auto result = (!input.tensor()).asType(input.type());
return Variable(result, false);
}
-Variable max(const Variable& lhs, const Variable& rhs) {
+Variable max(Variable const& lhs, Variable const& rhs) {
FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs);
auto result = fl::maximum(lhs.tensor(), rhs.tensor());
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- auto mask = Variable(
- (inputs[0].tensor() > inputs[1].tensor()).astype(gradOutput.type()),
- false
- );
- inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false));
- inputs[1].addGrad(Variable((!mask * gradOutput).tensor(), false));
- };
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ auto mask = Variable(
+ (inputs[0].tensor() > inputs[1].tensor()).asType(gradOutput.type()),
+ false
+ );
+ inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false));
+ inputs[1].addGrad(Variable((!mask * gradOutput).tensor(), false));
+ };
return Variable(result, {lhs, rhs}, gradFunc);
}
-Variable max(const Variable& lhs, const double& rhsVal) {
- auto result = fl::maximum(lhs.tensor(), rhsVal).astype(lhs.type());
+Variable max(Variable const& lhs, double const& rhsVal) {
+ auto result = fl::maximum(lhs.tensor(), rhsVal).asType(lhs.type());
auto gradFunc =
- [rhsVal](std::vector& inputs, const Variable& gradOutput) {
- auto mask = Variable(
- (inputs[0].tensor() > rhsVal).astype(gradOutput.type()),
- false
- );
- inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false));
- };
+ [rhsVal](std::vector& inputs, Variable const& gradOutput) {
+ auto mask = Variable(
+ (inputs[0].tensor() > rhsVal).asType(gradOutput.type()),
+ false
+ );
+ inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false));
+ };
return Variable(result, {lhs}, gradFunc);
}
-Variable max(const double& lhsVal, const Variable& rhs) { return max(rhs, lhsVal); }
+Variable max(double const& lhsVal, Variable const& rhs) { return max(rhs, lhsVal); }
-Variable min(const Variable& lhs, const Variable& rhs) {
+Variable min(Variable const& lhs, Variable const& rhs) {
FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs);
auto result = fl::minimum(lhs.tensor(), rhs.tensor());
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- auto mask = Variable(
- (inputs[0].tensor() < inputs[1].tensor()).astype(gradOutput.type()),
- false
- );
- inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false));
- inputs[1].addGrad(Variable((!mask * gradOutput).tensor(), false));
- };
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ auto mask = Variable(
+ (inputs[0].tensor() < inputs[1].tensor()).asType(gradOutput.type()),
+ false
+ );
+ inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false));
+ inputs[1].addGrad(Variable((!mask * gradOutput).tensor(), false));
+ };
return Variable(result, {lhs, rhs}, gradFunc);
}
-Variable min(const Variable& lhs, const double& rhsVal) {
- auto result = fl::minimum(lhs.tensor(), rhsVal).astype(lhs.type());
+Variable min(Variable const& lhs, double const& rhsVal) {
+ auto result = fl::minimum(lhs.tensor(), rhsVal).asType(lhs.type());
auto gradFunc =
- [rhsVal](std::vector& inputs, const Variable& gradOutput) {
- auto mask = Variable(
- (inputs[0].tensor() < rhsVal).astype(gradOutput.type()),
- false
- );
- inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false));
- };
+ [rhsVal](std::vector& inputs, Variable const& gradOutput) {
+ auto mask = Variable(
+ (inputs[0].tensor() < rhsVal).asType(gradOutput.type()),
+ false
+ );
+ inputs[0].addGrad(Variable((mask * gradOutput).tensor(), false));
+ };
return Variable(result, {lhs}, gradFunc);
}
-Variable min(const double& lhsVal, const Variable& rhs) { return min(rhs, lhsVal); }
+Variable min(double const& lhsVal, Variable const& rhs) { return min(rhs, lhsVal); }
-Variable negate(const Variable& input) {
- auto result = (0.0 - input.tensor()).astype(input.type());
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- inputs[0].addGrad(Variable(negate(gradOutput).tensor(), false));
- };
+Variable negate(Variable const& input) {
+ auto result = (0.0 - input.tensor()).asType(input.type());
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ inputs[0].addGrad(Variable(negate(gradOutput).tensor(), false));
+ };
return Variable(result, {input.withoutData()}, gradFunc);
}
-Variable reciprocal(const Variable& input) {
+Variable reciprocal(Variable const& input) {
auto result = 1.0 / FL_ADJUST_INPUT_TYPE(input.tensor());
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- auto res = reciprocal(inputs[0]);
- inputs[0].addGrad(
- Variable((negate(gradOutput) * res * res).tensor(), false)
- );
- };
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ auto res = reciprocal(inputs[0]);
+ inputs[0].addGrad(
+ Variable((negate(gradOutput) * res * res).tensor(), false)
+ );
+ };
return Variable(result, {input}, gradFunc);
}
-Variable exp(const Variable& input) {
+Variable exp(Variable const& input) {
auto result = fl::exp(FL_ADJUST_INPUT_TYPE(input.tensor()));
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- inputs[0].addGrad(
- Variable(gradOutput.tensor() * fl::exp(inputs[0].tensor()), false)
- );
- };
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ inputs[0].addGrad(
+ Variable(gradOutput.tensor() * fl::exp(inputs[0].tensor()), false)
+ );
+ };
return Variable(result, {input}, gradFunc);
}
-Variable log(const Variable& input) {
+Variable log(Variable const& input) {
auto result = fl::log(FL_ADJUST_INPUT_TYPE(input.tensor()));
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- inputs[0].addGrad(
- Variable((gradOutput.tensor() / inputs[0].tensor()), false)
- );
- };
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ inputs[0].addGrad(
+ Variable((gradOutput.tensor() / inputs[0].tensor()), false)
+ );
+ };
return Variable(result, {input}, gradFunc);
}
-Variable log1p(const Variable& input) {
+Variable log1p(Variable const& input) {
auto result = fl::log1p(FL_ADJUST_INPUT_TYPE(input.tensor()));
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- inputs[0].addGrad(
- Variable((gradOutput.tensor() / (1.0 + inputs[0].tensor())), false)
- );
- };
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ inputs[0].addGrad(
+ Variable((gradOutput.tensor() / (1.0 + inputs[0].tensor())), false)
+ );
+ };
return Variable(result, {input}, gradFunc);
}
-Variable pow(const Variable& input, double p) {
+Variable pow(Variable const& input, double p) {
auto result = fl::power(FL_ADJUST_INPUT_TYPE(input.tensor()), p);
- auto gradFunc = [p](std::vector& inputs,
- const Variable& gradOutput) {
- Tensor grad =
- p * fl::power(inputs[0].tensor(), p - 1) * gradOutput.tensor();
- inputs[0].addGrad(Variable(grad, false));
- };
+ auto gradFunc = [p](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ Tensor grad =
+ p * fl::power(inputs[0].tensor(), p - 1) * gradOutput.tensor();
+ inputs[0].addGrad(Variable(grad, false));
+ };
return Variable(result, {input}, gradFunc);
}
-Variable sin(const Variable& input) {
+Variable sin(Variable const& input) {
auto result = fl::sin(input.tensor());
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- inputs[0].addGrad(
- Variable((gradOutput.tensor() * cos(inputs[0].tensor())), false)
- );
- };
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ inputs[0].addGrad(
+ Variable((gradOutput.tensor() * cos(inputs[0].tensor())), false)
+ );
+ };
return Variable(result, {input}, gradFunc);
}
-Variable cos(const Variable& input) {
+Variable cos(Variable const& input) {
auto result = fl::cos(input.tensor());
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- inputs[0].addGrad(
- Variable(
- (gradOutput.tensor() * negative(sin(inputs[0].tensor()))),
- false
- )
- );
- };
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ inputs[0].addGrad(
+ Variable(
+ (gradOutput.tensor() * negative(sin(inputs[0].tensor()))),
+ false
+ )
+ );
+ };
return Variable(result, {input}, gradFunc);
}
-Variable tanh(const Variable& input) {
+Variable tanh(Variable const& input) {
auto result = fl::tanh(input.tensor());
auto gradFunc =
- [result](std::vector& inputs, const Variable& gradOutput) {
- auto grad =
- Variable((1.0 - result * result) * gradOutput.tensor(), false);
- inputs[0].addGrad(Variable(grad.tensor(), false));
- };
+ [result](std::vector& inputs, Variable const& gradOutput) {
+ auto grad =
+ Variable((1.0 - result * result) * gradOutput.tensor(), false);
+ inputs[0].addGrad(Variable(grad.tensor(), false));
+ };
return Variable(result, {input.withoutData()}, gradFunc);
}
-Variable clamp(const Variable& input, const double lo, const double hi) {
+Variable clamp(Variable const& input, double const lo, double const hi) {
auto result = fl::clip(input.tensor(), lo, hi);
auto gradFunc = [lo, hi, result](
std::vector& inputs,
- const Variable& gradOutput) {
- Tensor gradMask = gradOutput.tensor();
- gradMask = fl::where((result > lo) && (result < hi), gradMask, 0);
- inputs[0].addGrad(Variable(gradMask, false));
- };
+ Variable const& gradOutput
+ ) {
+ Tensor gradMask = gradOutput.tensor();
+ gradMask = fl::where((result > lo) && (result < hi), gradMask, 0);
+ inputs[0].addGrad(Variable(gradMask, false));
+ };
return Variable(result, {input.withoutData()}, gradFunc);
}
-Variable sqrt(const Variable& input) {
+Variable sqrt(Variable const& input) {
auto result = fl::sqrt(input.tensor());
auto gradFunc = [result](
std::vector& inputs,
- const Variable& gradOutput) {
- auto output = Variable(result, false);
- inputs[0].addGrad(Variable((gradOutput / (2 * output)).tensor(), false));
- };
+ Variable const& gradOutput
+ ) {
+ auto output = Variable(result, false);
+ inputs[0].addGrad(Variable((gradOutput / (2 * output)).tensor(), false));
+ };
return Variable(result, {input.withoutData()}, gradFunc);
}
-Variable sigmoid(const Variable& input) {
+Variable sigmoid(Variable const& input) {
auto result = fl::sigmoid(input.tensor());
auto gradFunc =
- [result](std::vector& inputs, const Variable& gradOutput) {
- auto grad = gradOutput.tensor() * result * (1 - result);
- inputs[0].addGrad(Variable(grad, false));
- };
+ [result](std::vector& inputs, Variable const& gradOutput) {
+ auto grad = gradOutput.tensor() * result * (1 - result);
+ inputs[0].addGrad(Variable(grad, false));
+ };
return Variable(result, {input.withoutData()}, gradFunc);
}
-Variable swish(const Variable& input, double beta) { return input * sigmoid(beta * input); }
+Variable swish(Variable const& input, double beta) { return input * sigmoid(beta * input); }
-Variable erf(const Variable& input) {
+Variable erf(Variable const& input) {
auto result = fl::erf(FL_ADJUST_INPUT_TYPE(input.tensor()));
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- auto x = inputs[0].tensor();
- auto grad = gradOutput.tensor() * 2 / std::sqrt(M_PI) * fl::exp(-(x * x));
- inputs[0].addGrad(Variable(grad, false));
- };
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ auto x = inputs[0].tensor();
+ auto grad = gradOutput.tensor() * 2 / std::sqrt(M_PI) * fl::exp(-(x * x));
+ inputs[0].addGrad(Variable(grad, false));
+ };
return Variable(result, {input}, gradFunc);
}
-Variable transpose(const Variable& input, const Shape& dims /* = {} */) {
+Variable transpose(Variable const& input, Shape const& dims /* = {} */) {
auto result = fl::transpose(input.tensor(), dims);
auto gradFunc = [inputDims = input.shape(), ndim = input.ndim(), dims](
std::vector& inputs,
- const Variable& gradOutput) {
- Shape reverseShape = dims;
-
- if(dims.ndim()) {
- // Reverse vec if transposing all dims (empty arg)
- auto dVec = dims.get();
- std::reverse(dVec.begin(), dVec.end());
- reverseShape = Shape(dVec);
- }
+ Variable const& gradOutput
+ ) {
+ Shape reverseShape = dims;
- for(unsigned i = 0; i < reverseShape.ndim(); ++i)
- reverseShape[dims[i]] = i;
+ if(dims.ndim()) {
+ // Reverse vec if transposing all dims (empty arg)
+ auto dVec = dims.get();
+ std::reverse(dVec.begin(), dVec.end());
+ reverseShape = Shape(dVec);
+ }
- inputs[0].addGrad(
- Variable(fl::transpose(gradOutput.tensor(), reverseShape), false)
- );
- };
+ for(unsigned i = 0; i < reverseShape.ndim(); ++i)
+ reverseShape[dims[i]] = i;
+
+ inputs[0].addGrad(
+ Variable(fl::transpose(gradOutput.tensor(), reverseShape), false)
+ );
+ };
return Variable(result, {input.withoutData()}, gradFunc);
}
-Variable tileAs(const Variable& input, const Shape& rdims) {
+Variable tileAs(Variable const& input, Shape const& rdims) {
auto result = detail::tileAs(input.tensor(), rdims);
Shape inDims = input.shape();
auto gradFunc = [inDims](
std::vector& inputs,
- const Variable& gradOutput) {
- inputs[0].addGrad(
- Variable(
- sumAs(gradOutput, inDims).tensor().astype(inputs[0].type()),
- false
- )
- );
- };
+ Variable const& gradOutput
+ ) {
+ inputs[0].addGrad(
+ Variable(
+ sumAs(gradOutput, inDims).tensor().asType(inputs[0].type()),
+ false
+ )
+ );
+ };
return Variable(result, {input.withoutData()}, gradFunc);
}
-Variable tileAs(const Variable& input, const Variable& reference) { return tileAs(input, reference.shape()); }
+Variable tileAs(Variable const& input, Variable const& reference) { return tileAs(input, reference.shape()); }
-Variable sumAs(const Variable& input, const Shape& rdims) {
+Variable sumAs(Variable const& input, Shape const& rdims) {
auto result = detail::sumAs(FL_ADJUST_INPUT_TYPE(input.tensor()), rdims);
auto idims = input.tensor().shape();
auto gradFunc =
- [idims](std::vector& inputs, const Variable& gradOutput) {
- inputs[0].addGrad(Variable(tileAs(gradOutput, idims).tensor(), false));
- };
+ [idims](std::vector& inputs, Variable const& gradOutput) {
+ inputs[0].addGrad(Variable(tileAs(gradOutput, idims).tensor(), false));
+ };
return Variable(result, {input.withoutData()}, gradFunc);
}
-Variable sumAs(const Variable& input, const Variable& reference) { return sumAs(input, reference.shape()); }
+Variable sumAs(Variable const& input, Variable const& reference) { return sumAs(input, reference.shape()); }
-Variable concatenate(const std::vector& concatInputs, int dim) {
+Variable concatenate(std::vector const& concatInputs, int dim) {
if(concatInputs.empty())
throw std::invalid_argument("cannot concatenate zero variables");
@@ -620,7 +660,7 @@ Variable concatenate(const std::vector& concatInputs, int dim) {
Tensor result(dims, concatInputs[0].type());
std::vector slice(numDims, fl::span);
int start = 0;
- for(const auto& input : concatInputs) {
+ for(auto const& input : concatInputs) {
slice[dim] = fl::range({start, start + input.dim(dim)});
result(slice) = input.tensor();
start += input.dim(dim);
@@ -629,38 +669,39 @@ Variable concatenate(const std::vector& concatInputs, int dim) {
std::vector inputsNoData;
std::vector inDims;
- for(const auto& in : concatInputs) {
+ for(auto const& in : concatInputs) {
inputsNoData.push_back(in.withoutData());
inDims.push_back(in.shape());
}
auto gradFunc = [dim, inDims, numDims](
std::vector& inputs,
- const Variable& gradOutput) {
- std::vector sx(numDims, fl::span);
- int s = 0;
- for(size_t i = 0; i < inputs.size(); ++i) {
- sx[dim] = fl::range(s, s + inDims[i][dim]);
- inputs[i].addGrad(Variable(gradOutput.tensor()(sx), false));
- s += inDims[i][dim];
- }
- };
+ Variable const& gradOutput
+ ) {
+ std::vector sx(numDims, fl::span);
+ int s = 0;
+ for(size_t i = 0; i < inputs.size(); ++i) {
+ sx[dim] = fl::range(s, s + inDims[i][dim]);
+ inputs[i].addGrad(Variable(gradOutput.tensor()(sx), false));
+ s += inDims[i][dim];
+ }
+ };
return Variable(result, inputsNoData, gradFunc);
}
-std::vector split(const Variable& input, long splitSize, int dim) {
+std::vector split(Variable const& input, int64_t splitSize, int dim) {
if(splitSize <= 0)
throw std::invalid_argument("split size must be a positive integer");
auto dimSize = input.dim(dim);
- std::vector splitSizes(dimSize / splitSize, splitSize);
+ std::vector splitSizes(dimSize / splitSize, splitSize);
if(dimSize % splitSize > 0)
splitSizes.push_back(dimSize % splitSize);
return split(input, splitSizes, dim);
}
-std::vector split(const Variable& input, const std::vector& splitSizes, int dim) {
+std::vector split(Variable const& input, std::vector const& splitSizes, int dim) {
if(dim >= input.ndim())
throw std::invalid_argument(
"split: passed dim is larger than the number of dimensions "
@@ -685,24 +726,24 @@ std::vector split(const Variable& input, const std::vector& spli
return outputs;
}
-Variable tile(const Variable& input, const Shape& dims) {
+Variable tile(Variable const& input, Shape const& dims) {
Tensor result = fl::tile(input.tensor(), dims);
Shape idims = input.shape();
auto gradFunc =
- [idims](std::vector& inputs, const Variable& gradOutput) {
- inputs[0].addGrad(
- Variable(
- sumAs(gradOutput, idims).tensor().astype(inputs[0].type()),
- false
- )
- );
- };
+ [idims](std::vector& inputs, Variable const& gradOutput) {
+ inputs[0].addGrad(
+ Variable(
+ sumAs(gradOutput, idims).tensor().asType(inputs[0].type()),
+ false
+ )
+ );
+ };
return Variable(result, {input.withoutData()}, gradFunc);
}
Variable sum(
- const Variable& input,
- const std::vector& axes,
+ Variable const& input,
+ std::vector const& axes,
bool keepDims /* = false*/
) {
auto result = FL_ADJUST_INPUT_TYPE(input.tensor());
@@ -711,23 +752,24 @@ Variable sum(
Shape indims = input.shape();
auto gradFunc = [indims, axes, keepDims](
std::vector& inputs,
- const Variable& gradOutput) {
- inputs[0].addGrad(
- Variable(
- detail::tileAs(
- detail::expandFromReduction(gradOutput.tensor(), axes, keepDims),
- indims
- ),
- false
- )
- );
- };
- return Variable(result.astype(input.type()), {input.withoutData()}, gradFunc);
+ Variable const& gradOutput
+ ) {
+ inputs[0].addGrad(
+ Variable(
+ detail::tileAs(
+ detail::expandFromReduction(gradOutput.tensor(), axes, keepDims),
+ indims
+ ),
+ false
+ )
+ );
+ };
+ return Variable(result.asType(input.type()), {input.withoutData()}, gradFunc);
}
Variable mean(
- const Variable& input,
- const std::vector& axes,
+ Variable const& input,
+ std::vector const& axes,
bool keepDims /* = false*/
) {
auto result = FL_ADJUST_INPUT_TYPE(input.tensor());
@@ -736,38 +778,39 @@ Variable mean(
Shape idims = input.shape();
auto gradFunc = [idims, axes, keepDims](
std::vector& inputs,
- const Variable& gradOutput) {
- Shape odims = gradOutput.shape();
- Dim count = 1;
- for(int i = 0; i < idims.ndim(); i++) {
- Dim odimSize = i + 1 > odims.ndim() ? 1 : odims[i];
- count *= idims[i] / odimSize;
- }
- auto grad =
+ Variable const& gradOutput
+ ) {
+ Shape odims = gradOutput.shape();
+ Dim count = 1;
+ for(int i = 0; i < idims.ndim(); i++) {
+ Dim odimSize = i + 1 > odims.ndim() ? 1 : odims[i];
+ count *= idims[i] / odimSize;
+ }
+ auto grad =
+ detail::tileAs(
+ detail::expandFromReduction(gradOutput.tensor(), axes, keepDims),
+ idims
+ )
+ / count;
+ inputs[0].addGrad(
+ Variable(
detail::tileAs(
detail::expandFromReduction(gradOutput.tensor(), axes, keepDims),
idims
)
- / count;
- inputs[0].addGrad(
- Variable(
- detail::tileAs(
- detail::expandFromReduction(gradOutput.tensor(), axes, keepDims),
- idims
- )
- / count,
- false
- )
- );
- };
+ / count,
+ false
+ )
+ );
+ };
return Variable(result, {input.withoutData()}, gradFunc);
}
Variable var(
- const Variable& in,
- const std::vector& axes,
- const bool isbiased /* = false */,
+ Variable const& in,
+ std::vector const& axes,
+ bool const isbiased /* = false */,
bool keepDims /* = false*/
) {
Tensor input = FL_ADJUST_INPUT_TYPE(in.tensor());
@@ -785,30 +828,30 @@ Variable var(
result = val * (result - n * avg * avg);
auto gradFunc =
- [val, axes](std::vector& inputs, const Variable& gradOutput) {
- Shape expandedDims = inputs[0].shape();
- Shape tileDims = inputs[0].shape();
- for(auto ax : axes) {
- tileDims[ax] = inputs[0].dim(ax);
- expandedDims[ax] = 1;
- }
+ [val, axes](std::vector& inputs, Variable const& gradOutput) {
+ Shape expandedDims = inputs[0].shape();
+ Shape tileDims = inputs[0].shape();
+ for(auto ax : axes) {
+ tileDims[ax] = inputs[0].dim(ax);
+ expandedDims[ax] = 1;
+ }
- inputs[0].addGrad(
- Variable(
- ((2 * val * tileAs(moddims(gradOutput, expandedDims), tileDims))
- * (inputs[0]
- - tileAs(moddims(mean(inputs[0], axes), expandedDims), tileDims)))
- .tensor(),
- false
- )
- );
- };
+ inputs[0].addGrad(
+ Variable(
+ ((2 * val * tileAs(moddims(gradOutput, expandedDims), tileDims))
+ * (inputs[0]
+ - tileAs(moddims(mean(inputs[0], axes), expandedDims), tileDims)))
+ .tensor(),
+ false
+ )
+ );
+ };
return Variable(result, {in}, gradFunc);
}
Variable norm(
- const Variable& input,
- const std::vector& axes,
+ Variable const& input,
+ std::vector const& axes,
double p /* = 2 */,
bool keepDims /* = false */
) {
@@ -823,25 +866,26 @@ Variable norm(
auto gradFunc = [sumap, p, axes, keepDims](
std::vector& inputs,
- const Variable& gradOutput) {
- // correct, but less precise: auto gvar = Variable(fl::power(result, p - 1),
- // false);
- auto gvar = Variable(fl::power(sumap, 1 - 1 / p), false);
- auto normGrad =
- (inputs[0].tensor() * fl::pow(fl::abs(inputs[0]), p - 2).tensor()
- * detail::tileAs(
- detail::expandFromReduction(gradOutput.tensor(), axes, keepDims)
- / gvar.tensor(),
- inputs[0].shape()
- ));
- inputs[0].addGrad(Variable(normGrad, false));
- };
+ Variable const& gradOutput
+ ) {
+ // correct, but less precise: auto gvar = Variable(fl::power(result, p - 1),
+ // false);
+ auto gvar = Variable(fl::power(sumap, 1 - 1 / p), false);
+ auto normGrad =
+ (inputs[0].tensor() * fl::pow(fl::abs(inputs[0]), p - 2).tensor()
+ * detail::tileAs(
+ detail::expandFromReduction(gradOutput.tensor(), axes, keepDims)
+ / gvar.tensor(),
+ inputs[0].shape()
+ ));
+ inputs[0].addGrad(Variable(normGrad, false));
+ };
return Variable(result, {input}, gradFunc);
}
Variable normalize(
- const Variable& in,
- const std::vector& axes,
+ Variable const& in,
+ std::vector const& axes,
double p /* = 2 */,
double eps /* = 1e-12 */
) {
@@ -851,7 +895,7 @@ Variable normalize(
return input / tileAs(invscale, input);
}
-Variable matmul(const Variable& lhs, const Variable& rhs) {
+Variable matmul(Variable const& lhs, Variable const& rhs) {
FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs);
// lhs:Input[0] -- [M, N]
// rhs:Input[1] -- [N, K]
@@ -859,50 +903,55 @@ Variable matmul(const Variable& lhs, const Variable& rhs) {
// -- matmul([M, N], [N, K]) -- [M, K]
// result:gradOutput -- [M, K]
auto result = fl::matmul(lhs.tensor(), rhs.tensor());
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- if(inputs[0].isCalcGrad()) {
- Tensor _lhs = gradOutput.tensor();
- if(_lhs.ndim() == 1)
- _lhs = fl::reshape(_lhs, {1, _lhs.dim(0)});
- Tensor _rhs = inputs[1].tensor();
- if(_rhs.ndim() == 1)
- _rhs = fl::reshape(_rhs, {_rhs.dim(0), 1});
-
- // matmulNT(gradOutput, inputs[1])
- // -- matmulNT([M, K], [N, K])
- // -- matmul([M, K], [K, N]) -- [M, K]
- auto val = fl::matmul(
- _lhs,
- _rhs,
- /* lhsProp = */ MatrixProperty::None,
- /* rhsProp = */ MatrixProperty::Transpose
- );
- inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false));
- }
- if(inputs[1].isCalcGrad()) {
- Tensor _lhs = inputs[0].tensor();
- if(_lhs.ndim() == 1)
- _lhs = fl::reshape(_lhs, {1, _lhs.dim(0)});
- Tensor _rhs = gradOutput.tensor();
- if(_rhs.ndim() == 1)
- _rhs = fl::reshape(_rhs, {_rhs.dim(0), 1});
-
- // matmulTN(inputs[0], gradOutput)
- // -- matmulTN([M, N], [M, K])
- // -- matmul([N, M], [M, K]) -- [N, K]
- auto val = fl::matmul(
- _lhs,
- _rhs,
- /* lhsProp = */ MatrixProperty::Transpose
- );
- inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false));
- }
- };
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ if(inputs[0].isCalcGrad()) {
+ Tensor _lhs = gradOutput.tensor();
+ if(_lhs.ndim() == 1)
+ _lhs = fl::reshape(_lhs, {1, _lhs.dim(0)});
+ Tensor _rhs = inputs[1].tensor();
+ if(_rhs.ndim() == 1)
+ _rhs = fl::reshape(_rhs, {_rhs.dim(0), 1});
+
+ // matmulNT(gradOutput, inputs[1])
+ // -- matmulNT([M, K], [N, K])
+ // -- matmul([M, K], [K, N]) -- [M, K]
+ auto val = fl::matmul(
+ _lhs,
+ _rhs,
+ /* lhsProp = */
+ MatrixProperty::None,
+ /* rhsProp = */
+ MatrixProperty::Transpose
+ );
+ inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false));
+ }
+ if(inputs[1].isCalcGrad()) {
+ Tensor _lhs = inputs[0].tensor();
+ if(_lhs.ndim() == 1)
+ _lhs = fl::reshape(_lhs, {1, _lhs.dim(0)});
+ Tensor _rhs = gradOutput.tensor();
+ if(_rhs.ndim() == 1)
+ _rhs = fl::reshape(_rhs, {_rhs.dim(0), 1});
+
+ // matmulTN(inputs[0], gradOutput)
+ // -- matmulTN([M, N], [M, K])
+ // -- matmul([N, M], [M, K]) -- [N, K]
+ auto val = fl::matmul(
+ _lhs,
+ _rhs,
+ /* lhsProp = */
+ MatrixProperty::Transpose
+ );
+ inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false));
+ }
+ };
return Variable(result, {lhs, rhs}, gradFunc);
}
-Variable matmulTN(const Variable& lhs, const Variable& rhs) {
+Variable matmulTN(Variable const& lhs, Variable const& rhs) {
FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs);
// lhs:Input[0] -- [N, M]
// rhs:Input[1] -- [N, K]
@@ -912,34 +961,39 @@ Variable matmulTN(const Variable& lhs, const Variable& rhs) {
// result:gradOutput -- [M, K]
auto result = fl::matmul(
lhs.tensor(),
- rhs.tensor(), /* lhsProp = */
+ rhs.tensor(),
+ /* lhsProp = */
MatrixProperty::Transpose
);
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- if(inputs[0].isCalcGrad()) {
- // matmulNT(inputs[1], gradOutput)
- // -- matmulNT([N, K], [M, K])
- // -- matmul([N, K], [K, M]) -- [N, M]
- auto val = fl::matmul(
- inputs[1].tensor(),
- gradOutput.tensor(),
- /* lhsProp = */ MatrixProperty::None,
- /* rhsProp = */ MatrixProperty::Transpose
- );
- inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false));
- }
- if(inputs[1].isCalcGrad()) {
- // matmul(inputs[0], gradOutput)
- // -- matmulNT([N, M], [M, K]) -- [N, K]
- auto val = fl::matmul(inputs[0].tensor(), gradOutput.tensor());
- inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false));
- }
- };
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ if(inputs[0].isCalcGrad()) {
+ // matmulNT(inputs[1], gradOutput)
+ // -- matmulNT([N, K], [M, K])
+ // -- matmul([N, K], [K, M]) -- [N, M]
+ auto val = fl::matmul(
+ inputs[1].tensor(),
+ gradOutput.tensor(),
+ /* lhsProp = */
+ MatrixProperty::None,
+ /* rhsProp = */
+ MatrixProperty::Transpose
+ );
+ inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false));
+ }
+ if(inputs[1].isCalcGrad()) {
+ // matmul(inputs[0], gradOutput)
+ // -- matmulNT([N, M], [M, K]) -- [N, K]
+ auto val = fl::matmul(inputs[0].tensor(), gradOutput.tensor());
+ inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false));
+ }
+ };
return Variable(result, {lhs, rhs}, gradFunc);
}
-Variable matmulNT(const Variable& lhs, const Variable& rhs) {
+Variable matmulNT(Variable const& lhs, Variable const& rhs) {
FL_VARIABLE_DTYPES_MATCH_CHECK(lhs, rhs);
// lhs:Input[0] -- [M, N]
// rhs:Input[1] -- [K, N]
@@ -950,54 +1004,61 @@ Variable matmulNT(const Variable& lhs, const Variable& rhs) {
auto result = fl::matmul(
lhs.tensor(),
rhs.tensor(),
- /* lhsProp = */ MatrixProperty::None,
- /* rhsProp = */ MatrixProperty::Transpose
+ /* lhsProp = */
+ MatrixProperty::None,
+ /* rhsProp = */
+ MatrixProperty::Transpose
);
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- if(inputs[0].isCalcGrad()) {
- // matmul(gradOutput, inputs[1])
- // -- matmul([M, K], [K, N]) -- [M, N]
- auto val = fl::matmul(gradOutput.tensor(), inputs[1].tensor());
- inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false));
- }
- if(inputs[1].isCalcGrad()) {
- // matmulTN(gradOutput, inputs[0])
- // -- matmulTN([M, K], [M, N])
- // -- matmul([K, M], [M, N]) -- [K, N]
- auto val = fl::matmul(
- gradOutput.tensor(),
- inputs[0].tensor(),
- /* lhsProp = */ MatrixProperty::Transpose
- );
- inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false));
- }
- };
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ if(inputs[0].isCalcGrad()) {
+ // matmul(gradOutput, inputs[1])
+ // -- matmul([M, K], [K, N]) -- [M, N]
+ auto val = fl::matmul(gradOutput.tensor(), inputs[1].tensor());
+ inputs[0].addGrad(Variable(detail::sumAs(val, inputs[0].shape()), false));
+ }
+ if(inputs[1].isCalcGrad()) {
+ // matmulTN(gradOutput, inputs[0])
+ // -- matmulTN([M, K], [M, N])
+ // -- matmul([K, M], [M, N]) -- [K, N]
+ auto val = fl::matmul(
+ gradOutput.tensor(),
+ inputs[0].tensor(),
+ /* lhsProp = */
+ MatrixProperty::Transpose
+ );
+ inputs[1].addGrad(Variable(detail::sumAs(val, inputs[1].shape()), false));
+ }
+ };
return Variable(result, {lhs, rhs}, gradFunc);
}
-Variable abs(const Variable& input) {
+Variable abs(Variable const& input) {
auto result = fl::abs(input.tensor());
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- // Convert it into -1, 0, 1
- auto sign = fl::sign(inputs[0].tensor());
- inputs[0].addGrad(Variable((sign * gradOutput.tensor()), false));
- };
+ auto gradFunc = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ // Convert it into -1, 0, 1
+ auto sign = fl::sign(inputs[0].tensor());
+ inputs[0].addGrad(Variable((sign * gradOutput.tensor()), false));
+ };
return Variable(result, {input}, gradFunc);
}
-Variable flat(const Variable& input) {
+Variable flat(Variable const& input) {
auto result = input.tensor().flatten();
Shape idims = input.shape();
auto gradFunc =
- [idims](std::vector& inputs, const Variable& gradOutput) {
- inputs[0].addGrad(Variable(reshape(gradOutput.tensor(), idims), false));
- };
+ [idims](std::vector& inputs, Variable const& gradOutput) {
+ inputs[0].addGrad(Variable(reshape(gradOutput.tensor(), idims), false));
+ };
return Variable(result, {input.withoutData()}, gradFunc);
}
-Variable moddims(const Variable& input, const Shape& dims) {
+Variable moddims(Variable const& input, Shape const& dims) {
if(input.ndim() == 0)
return input;
Shape inferDims = dims;
@@ -1036,13 +1097,14 @@ Variable moddims(const Variable& input, const Shape& dims) {
Shape inDims = input.shape();
auto gradFunc = [inDims](
std::vector& inputs,
- const Variable& gradOutput) {
- inputs[0].addGrad(Variable(moddims(gradOutput, inDims).tensor(), false));
- };
+ Variable const& gradOutput
+ ) {
+ inputs[0].addGrad(Variable(moddims(gradOutput, inDims).tensor(), false));
+ };
return Variable(result, {input.withoutData()}, gradFunc);
}
-Variable softmax(const Variable& input, const int dim) {
+Variable softmax(Variable const& input, int const dim) {
Tensor inputArr = FL_ADJUST_INPUT_TYPE(input.tensor());
auto maxvals = amax(inputArr, {dim}, /* keepDims = */ true);
Shape tiledims(std::vector(input.ndim(), 1));
@@ -1055,17 +1117,18 @@ Variable softmax(const Variable& input, const int dim) {
fl::eval(result);
auto gradFunc = [dim, tiledims, result](
std::vector& inputs,
- const Variable& gradOutput) {
- auto rbyg = gradOutput.tensor() * result;
- auto gradSm = rbyg
- - result
- * fl::tile(fl::sum(rbyg, {dim}, /* keepDims = */ true), tiledims);
- inputs[0].addGrad(Variable(gradSm.astype(inputs[0].type()), false));
- };
+ Variable const& gradOutput
+ ) {
+ auto rbyg = gradOutput.tensor() * result;
+ auto gradSm = rbyg
+ - result
+ * fl::tile(fl::sum(rbyg, {dim}, /* keepDims = */ true), tiledims);
+ inputs[0].addGrad(Variable(gradSm.asType(inputs[0].type()), false));
+ };
return Variable(result, {input.withoutData()}, gradFunc);
}
-Variable logSoftmax(const Variable& input, const int dim) {
+Variable logSoftmax(Variable const& input, int const dim) {
Tensor inputArr = FL_ADJUST_INPUT_TYPE(input.tensor());
auto maxvals = amax(inputArr, {dim}, /* keepDims = */ true);
// TODO{fl::Tensor}{rewrite}
@@ -1077,7 +1140,8 @@ Variable logSoftmax(const Variable& input, const int dim) {
fl::sum(
fl::exp(inputArr - fl::tile(maxvals, tiledims)),
{dim},
- /* keepDims = */ true
+ /* keepDims = */
+ true
)
)
+ maxvals,
@@ -1087,28 +1151,29 @@ Variable logSoftmax(const Variable& input, const int dim) {
fl::eval(result);
auto gradFunc = [dim, tiledims, result](
std::vector& inputs,
- const Variable& gradOutput) {
- auto gradLsm = gradOutput.tensor()
- - fl::exp(result)
- * fl::tile(
- fl::sum(gradOutput.tensor(), {dim}, /* keepDims = */ true),
- tiledims
- );
- inputs[0].addGrad(Variable(gradLsm.astype(inputs[0].type()), false));
- };
+ Variable const& gradOutput
+ ) {
+ auto gradLsm = gradOutput.tensor()
+ - fl::exp(result)
+ * fl::tile(
+ fl::sum(gradOutput.tensor(), {dim}, /* keepDims = */ true),
+ tiledims
+ );
+ inputs[0].addGrad(Variable(gradLsm.asType(inputs[0].type()), false));
+ };
return Variable(result, {input.withoutData()}, gradFunc);
}
-Variable binaryCrossEntropy(const Variable& inputs, const Variable& targets) {
- auto targetsTyped = targets.astype(inputs.type());
+Variable binaryCrossEntropy(Variable const& inputs, Variable const& targets) {
+ auto targetsTyped = targets.asType(inputs.type());
return negate(
targetsTyped * log(inputs) + (1 - targetsTyped) * log(1 - inputs)
);
}
Variable categoricalCrossEntropy(
- const Variable& in,
- const Variable& targets,
+ Variable const& in,
+ Variable const& targets,
ReduceMode reduction /* =ReduceMode::MEAN */,
int ignoreIndex /* = -1 */
) {
@@ -1129,7 +1194,7 @@ Variable categoricalCrossEntropy(
int C = input.dim(0);
int X = targets.elements();
if(
- fl::any(
+ fl::any_of(
((targets.tensor() < 0) || (targets.tensor() >= C))
&& (targets.tensor() != ignoreIndex)
)
@@ -1143,7 +1208,7 @@ Variable categoricalCrossEntropy(
auto x = fl::reshape(input.tensor(), Shape({C, X}));
auto y = fl::reshape(targets.tensor(), Shape({1, X}));
- auto A = fl::arange(Shape({C, X}));
+ auto A = fl::arrange(Shape({C, X}));
auto B = fl::tile(y, Shape({C}));
auto mask = -(A == B); // [C X]
@@ -1155,12 +1220,15 @@ Variable categoricalCrossEntropy(
Tensor denominator;
if(reduction == ReduceMode::NONE) {
result = fl::reshape(result, targets.shape()); // [X1 X2 X3]
- } else if(reduction == ReduceMode::MEAN) {
- denominator = fl::sum((!ignoreMask).astype(fl::dtype::s32), {0});
+ }
+ else if(reduction == ReduceMode::MEAN) {
+ denominator = fl::sum((!ignoreMask).asType(fl::dtype::s32), {0});
result = fl::sum(result, {0}) / denominator; // [1]
- } else if(reduction == ReduceMode::SUM) {
+ }
+ else if(reduction == ReduceMode::SUM) {
result = fl::sum(result, {0}); // [1]
- } else
+ }
+ else
throw std::invalid_argument(
"unknown reduction method for categorical cross entropy"
);
@@ -1168,28 +1236,29 @@ Variable categoricalCrossEntropy(
auto inputDims = input.shape();
auto gradFunc = [C, X, mask, ignoreMask, denominator, reduction, inputDims](
std::vector& inputs,
- const Variable& gradOutput) {
- Tensor grad = gradOutput.tensor();
- if(reduction == ReduceMode::NONE)
- grad = fl::reshape(grad, {X});
- else if(reduction == ReduceMode::MEAN)
- grad = fl::tile(grad / denominator, {X});
- else if(reduction == ReduceMode::SUM)
- grad = fl::tile(grad, {X});
- // [1 X]
- grad(ignoreMask) = 0.;
- grad = fl::reshape(grad, {1, X});
- grad = fl::tile(grad, {C}) * mask;
- inputs[0].addGrad(Variable(fl::reshape(grad, inputDims), false));
- };
+ Variable const& gradOutput
+ ) {
+ Tensor grad = gradOutput.tensor();
+ if(reduction == ReduceMode::NONE)
+ grad = fl::reshape(grad, {X});
+ else if(reduction == ReduceMode::MEAN)
+ grad = fl::tile(grad / denominator, {X});
+ else if(reduction == ReduceMode::SUM)
+ grad = fl::tile(grad, {X});
+ // [1 X]
+ grad(ignoreMask) = 0.;
+ grad = fl::reshape(grad, {1, X});
+ grad = fl::tile(grad, {C}) * mask;
+ inputs[0].addGrad(Variable(fl::reshape(grad, inputDims), false));
+ };
return Variable(result, {input.withoutData(), targets}, gradFunc);
}
Variable weightedCategoricalCrossEntropy(
- const Variable& input,
- const Variable& targets,
- const Variable& weight,
+ Variable const& input,
+ Variable const& targets,
+ Variable const& weight,
int ignoreIndex /* = -1 */
) {
// input -- [C, X1, X2, X3]
@@ -1213,7 +1282,7 @@ Variable weightedCategoricalCrossEntropy(
int C = input.dim(0);
int X = targets.elements();
if(
- fl::any((targets.tensor() < 0) || (targets.tensor() >= C))
+ fl::any_of((targets.tensor() < 0) || (targets.tensor() >= C))
.scalar()
)
throw std::invalid_argument(
@@ -1224,7 +1293,7 @@ Variable weightedCategoricalCrossEntropy(
auto x = fl::reshape(input.tensor(), {C, X});
auto y = fl::reshape(targets.tensor(), {1, X});
- auto A = fl::arange({C, X});
+ auto A = fl::arrange({C, X});
auto B = fl::tile(y, {C});
auto mask = -(A == B); // [C X]
@@ -1234,29 +1303,30 @@ Variable weightedCategoricalCrossEntropy(
auto result = mask * x;
result = result * weight.tensor();
- auto ignoreMask = (y != ignoreIndex).astype(fl::dtype::s32); // [1, X]
+ auto ignoreMask = (y != ignoreIndex).asType(fl::dtype::s32); // [1, X]
result = ignoreMask * fl::sum(result, {0}, /* keepDims = */ true); // [1, X]
result = fl::sum(result, {1}, /* keepDims = */ true) / denominator.tensor();
auto inputDims = input.shape();
auto gradFunc = [C, X, mask, ignoreMask, denominator, inputDims](
std::vector& inputs,
- const Variable& gradOutput) {
- auto grad = gradOutput.tensor();
- grad = fl::tile(grad / denominator.tensor(), {1, X});
-
- auto weightTensor = inputs[2].tensor();
- grad *= ignoreMask;
- grad = fl::tile(grad, {C}) * mask;
- grad = fl::reshape(grad, inputDims);
- grad = grad * weightTensor;
- inputs[0].addGrad(Variable(fl::reshape(grad, inputDims), false));
- };
+ Variable const& gradOutput
+ ) {
+ auto grad = gradOutput.tensor();
+ grad = fl::tile(grad / denominator.tensor(), {1, X});
+
+ auto weightTensor = inputs[2].tensor();
+ grad *= ignoreMask;
+ grad = fl::tile(grad, {C}) * mask;
+ grad = fl::reshape(grad, inputDims);
+ grad = grad * weightTensor;
+ inputs[0].addGrad(Variable(fl::reshape(grad, inputDims), false));
+ };
return Variable(result, {input.withoutData(), targets, weight}, gradFunc);
}
-Variable reorder(const Variable& input, const Shape& shape) {
+Variable reorder(Variable const& input, Shape const& shape) {
auto result = fl::transpose(input.tensor(), shape);
if(!result.isContiguous())
result = result.asContiguousTensor();
@@ -1268,24 +1338,24 @@ Variable reorder(const Variable& input, const Shape& shape) {
std::sort(dimGrad.begin(), dimGrad.end());
auto gradFunc =
- [dimGrad](std::vector& inputs, const Variable& gradOutput) {
- Shape reordered(std::vector(dimGrad.size()));
- for(unsigned i = 0; i < dimGrad.size(); ++i)
- reordered[i] = dimGrad[i].second;
+ [dimGrad](std::vector& inputs, Variable const& gradOutput) {
+ Shape reordered(std::vector(dimGrad.size()));
+ for(unsigned i = 0; i < dimGrad.size(); ++i)
+ reordered[i] = dimGrad[i].second;
- inputs[0].addGrad(
- Variable(fl::transpose(gradOutput.tensor(), reordered), false)
- );
- };
+ inputs[0].addGrad(
+ Variable(fl::transpose(gradOutput.tensor(), reordered), false)
+ );
+ };
return Variable(result, {input.withoutData()}, gradFunc);
}
-Variable linear(const Variable& input, const Variable& weight) {
- auto dummyBias = Variable(Tensor().astype(input.type()), false);
+Variable linear(Variable const& input, Variable const& weight) {
+ auto dummyBias = Variable(Tensor().asType(input.type()), false);
return linear(input, weight, dummyBias);
}
-Variable linear(const Variable& in, const Variable& wt, const Variable& bs) {
+Variable linear(Variable const& in, Variable const& wt, Variable const& bs) {
FL_VARIABLE_DTYPES_MATCH_CHECK(in, wt, bs);
auto input = FL_ADJUST_INPUT_TYPE(in);
auto weight = FL_ADJUST_INPUT_TYPE(wt);
@@ -1307,42 +1377,43 @@ Variable linear(const Variable& in, const Variable& wt, const Variable& bs) {
auto gradFunc = [hasBias](
std::vector& inputs,
- const Variable& gradOutput) {
- auto& in = inputs[0];
- auto& wt = inputs[1];
- Tensor wtTensor = wt.tensor();
- Tensor gradOutputTensor = gradOutput.tensor();
-
- auto nframes = in.elements() / in.dim(0);
-
- if(hasBias && inputs[2].isCalcGrad()) {
- auto& bs = inputs[2];
- auto biasGrad = sumAs(gradOutput, bs).tensor();
- bs.addGrad(Variable(biasGrad, false));
- }
- if(in.isCalcGrad()) {
- Shape to2dout({wtTensor.dim(0), nframes});
- auto inGrad =
- moddims(matmulTN(wt, moddims(gradOutput, to2dout)), in.shape())
- .tensor();
- in.addGrad(Variable(inGrad, false));
- }
- if(wt.isCalcGrad()) {
- Shape to2din({wtTensor.dim(1), nframes});
- Shape to2dout({wtTensor.dim(0), nframes});
- auto wtGrad =
- matmulNT(moddims(gradOutput, to2dout), moddims(in, to2din)).tensor();
- wt.addGrad(Variable(wtGrad, false));
- }
- };
+ Variable const& gradOutput
+ ) {
+ auto& in = inputs[0];
+ auto& wt = inputs[1];
+ Tensor wtTensor = wt.tensor();
+ Tensor gradOutputTensor = gradOutput.tensor();
+
+ auto nframes = in.elements() / in.dim(0);
+
+ if(hasBias && inputs[2].isCalcGrad()) {
+ auto& bs = inputs[2];
+ auto biasGrad = sumAs(gradOutput, bs).tensor();
+ bs.addGrad(Variable(biasGrad, false));
+ }
+ if(in.isCalcGrad()) {
+ Shape to2dout({wtTensor.dim(0), nframes});
+ auto inGrad =
+ moddims(matmulTN(wt, moddims(gradOutput, to2dout)), in.shape())
+ .tensor();
+ in.addGrad(Variable(inGrad, false));
+ }
+ if(wt.isCalcGrad()) {
+ Shape to2din({wtTensor.dim(1), nframes});
+ Shape to2dout({wtTensor.dim(0), nframes});
+ auto wtGrad =
+ matmulNT(moddims(gradOutput, to2dout), moddims(in, to2din)).tensor();
+ wt.addGrad(Variable(wtGrad, false));
+ }
+ };
if(hasBias)
return Variable(output, {input, weight, bias}, gradFunc);
return Variable(output, {input, weight}, gradFunc);
}
Variable conv2d(
- const Variable& input,
- const Variable& weights,
+ Variable const& input,
+ Variable const& weights,
int sx,
int sy,
int px,
@@ -1369,9 +1440,9 @@ Variable conv2d(
}
Variable conv2d(
- const Variable& in,
- const Variable& wt,
- const Variable& bs,
+ Variable const& in,
+ Variable const& wt,
+ Variable const& bs,
int sx,
int sy,
int px,
@@ -1407,103 +1478,105 @@ Variable conv2d(
auto gradFunc =
[sx, sy, px, py, dx, dy, hasBias, groups, benchmarks, payload](
- std::vector& inputs, const Variable& gradOutput) {
- // Create benchmarks if needed
- auto& autogradExtension =
- inputs[0].tensor().backend().getExtension();
-
- std::shared_ptr dataBench;
- std::shared_ptr filterBench;
- std::shared_ptr biasBench;
- if(benchmarks && DynamicBenchmark::getBenchmarkMode()) {
- if(!benchmarks->bwdFilterBenchmark) {
- benchmarks->bwdFilterBenchmark =
- autogradExtension.createBenchmarkOptions();
- filterBench = benchmarks->bwdFilterBenchmark;
- }
- if(!benchmarks->bwdDataBenchmark) {
- benchmarks->bwdDataBenchmark =
- autogradExtension.createBenchmarkOptions();
- dataBench = benchmarks->bwdDataBenchmark;
- }
- if(!benchmarks->bwdBiasBenchmark) {
- benchmarks->bwdBiasBenchmark =
- autogradExtension.createBenchmarkOptions();
- biasBench = benchmarks->bwdBiasBenchmark;
- }
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ // Create benchmarks if needed
+ auto& autogradExtension =
+ inputs[0].tensor().backend().getExtension();
+
+ std::shared_ptr dataBench;
+ std::shared_ptr filterBench;
+ std::shared_ptr biasBench;
+ if(benchmarks && DynamicBenchmark::getBenchmarkMode()) {
+ if(!benchmarks->bwdFilterBenchmark) {
+ benchmarks->bwdFilterBenchmark =
+ autogradExtension.createBenchmarkOptions();
+ filterBench = benchmarks->bwdFilterBenchmark;
}
-
- // Bias gradients
- Tensor bs;
- const bool computeBiasGrad =
- inputs.size() > 2 && inputs[2].isCalcGrad();
- if(hasBias && computeBiasGrad) {
- bs = inputs[2].tensor();
- // auto biasGrad =
- // bs.backend().getExtension().conv2dBackwardBias(
- // gradOutput.tensor(), bs, biasBench, payload);
-
- // inputs[2].addGrad(Variable(biasGrad, false)); // bias
+ if(!benchmarks->bwdDataBenchmark) {
+ benchmarks->bwdDataBenchmark =
+ autogradExtension.createBenchmarkOptions();
+ dataBench = benchmarks->bwdDataBenchmark;
}
-
- auto& in = inputs[0].tensor();
- auto& wt = inputs[1].tensor();
-
- // Data (input) gradients
- if(inputs[0].isCalcGrad()) {
- auto dataGrad =
- in.backend().getExtension().conv2dBackwardData(
- gradOutput.tensor(),
- in,
- wt,
- sx,
- sy,
- px,
- py,
- dx,
- dy,
- groups,
- dataBench,
- payload
- );
-
- inputs[0].addGrad(Variable(dataGrad, false)); // input/data
+ if(!benchmarks->bwdBiasBenchmark) {
+ benchmarks->bwdBiasBenchmark =
+ autogradExtension.createBenchmarkOptions();
+ biasBench = benchmarks->bwdBiasBenchmark;
}
+ }
+
+ // Bias gradients
+ Tensor bs;
+ bool const computeBiasGrad =
+ inputs.size() > 2 && inputs[2].isCalcGrad();
+ if(hasBias && computeBiasGrad) {
+ bs = inputs[2].tensor();
+ // auto biasGrad =
+ // bs.backend().getExtension().conv2dBackwardBias(
+ // gradOutput.tensor(), bs, biasBench, payload);
+
+ // inputs[2].addGrad(Variable(biasGrad, false)); // bias
+ }
+
+ auto& in = inputs[0].tensor();
+ auto& wt = inputs[1].tensor();
+
+ // Data (input) gradients
+ if(inputs[0].isCalcGrad()) {
+ auto dataGrad =
+ in.backend().getExtension().conv2dBackwardData(
+ gradOutput.tensor(),
+ in,
+ wt,
+ sx,
+ sy,
+ px,
+ py,
+ dx,
+ dy,
+ groups,
+ dataBench,
+ payload
+ );
- // Filter (weight) and bias gradients
- if(inputs[1].isCalcGrad() || computeBiasGrad) {
- auto [filterGrad, biasGrad] = wt.backend()
- .getExtension()
- .conv2dBackwardFilterBias(
- gradOutput.tensor(),
- in,
- wt,
- bs,
- sx,
- sy,
- px,
- py,
- dx,
- dy,
- groups,
- filterBench,
- biasBench,
- payload
- );
- if(inputs[1].isCalcGrad()) {
- inputs[1].addGrad(Variable(filterGrad, false)); // filter/weight
- }
- if(computeBiasGrad)
- inputs[2].addGrad(Variable(biasGrad, false));
+ inputs[0].addGrad(Variable(dataGrad, false)); // input/data
+ }
+
+ // Filter (weight) and bias gradients
+ if(inputs[1].isCalcGrad() || computeBiasGrad) {
+ auto [filterGrad, biasGrad] = wt.backend()
+ .getExtension()
+ .conv2dBackwardFilterBias(
+ gradOutput.tensor(),
+ in,
+ wt,
+ bs,
+ sx,
+ sy,
+ px,
+ py,
+ dx,
+ dy,
+ groups,
+ filterBench,
+ biasBench,
+ payload
+ );
+ if(inputs[1].isCalcGrad()) {
+ inputs[1].addGrad(Variable(filterGrad, false)); // filter/weight
}
- };
+ if(computeBiasGrad)
+ inputs[2].addGrad(Variable(biasGrad, false));
+ }
+ };
if(hasBias)
return Variable(output, {input, weights, bias}, gradFunc);
return Variable(output, {input, weights}, gradFunc);
}
Variable pool2d(
- const Variable& input,
+ Variable const& input,
int wx,
int wy,
int sx,
@@ -1518,40 +1591,41 @@ Variable pool2d(
auto gradFunc = [wx, wy, sx, sy, px, py, mode, output, payload](
std::vector& inputs,
- const Variable& gradOutput) {
- auto& in = inputs[0];
- if(!in.isCalcGrad())
- return;
+ Variable const& gradOutput
+ ) {
+ auto& in = inputs[0];
+ if(!in.isCalcGrad())
+ return;
- in.addGrad(
- Variable(
- in.tensor().backend().getExtension().pool2dBackward(
- gradOutput.tensor(),
- in.tensor(),
- output,
- wx,
- wy,
- sx,
- sy,
- px,
- py,
- mode,
- payload
- ),
- false
- )
- );
- };
+ in.addGrad(
+ Variable(
+ in.tensor().backend().getExtension().pool2dBackward(
+ gradOutput.tensor(),
+ in.tensor(),
+ output,
+ wx,
+ wy,
+ sx,
+ sy,
+ px,
+ py,
+ mode,
+ payload
+ ),
+ false
+ )
+ );
+ };
return Variable(output, {input}, gradFunc);
}
Variable batchnorm(
- const Variable& _input,
- const Variable& weight,
- const Variable& bias,
+ Variable const& _input,
+ Variable const& weight,
+ Variable const& bias,
Variable& runningMean,
Variable& runningVar,
- const std::vector& axes,
+ std::vector const& axes,
bool train,
double momentum,
double epsilon
@@ -1581,41 +1655,41 @@ Variable batchnorm(
train,
axes,
epsilon,
- payload](std::vector& inputs, const Variable& _gradOutput) {
- auto& in = inputs[0];
- auto& wt = inputs[1];
- auto& bs = inputs[2];
-
- auto gradOutput = detail::adjustInputType(_gradOutput, "batchnorm");
-
- if(!in.isCalcGrad() && !wt.isCalcGrad() && !bs.isCalcGrad())
- return;
-
- auto [gradIn, gradWt, gradBs] =
- in.tensor()
- .backend()
- .getExtension()
- .batchnormBackward(
- gradOutput.tensor(),
- saveMean,
- saveVar,
- detail::adjustInputType(in.tensor(), "batchnorm"),
- wt.tensor(),
- axes,
- train,
- epsilon,
- payload
- );
-
- in.addGrad(Variable(gradIn.astype(in.type()), false));
- wt.addGrad(Variable(gradWt.astype(wt.type()), false));
- if(!bs.isEmpty())
- bs.addGrad(Variable(gradBs.astype(bs.type()), false));
- };
+ payload](std::vector& inputs, Variable const& _gradOutput) {
+ auto& in = inputs[0];
+ auto& wt = inputs[1];
+ auto& bs = inputs[2];
+
+ auto gradOutput = detail::adjustInputType(_gradOutput, "batchnorm");
+
+ if(!in.isCalcGrad() && !wt.isCalcGrad() && !bs.isCalcGrad())
+ return;
+
+ auto [gradIn, gradWt, gradBs] =
+ in.tensor()
+ .backend()
+ .getExtension()
+ .batchnormBackward(
+ gradOutput.tensor(),
+ saveMean,
+ saveVar,
+ detail::adjustInputType(in.tensor(), "batchnorm"),
+ wt.tensor(),
+ axes,
+ train,
+ epsilon,
+ payload
+ );
+
+ in.addGrad(Variable(gradIn.asType(in.type()), false));
+ wt.addGrad(Variable(gradWt.asType(wt.type()), false));
+ if(!bs.isEmpty())
+ bs.addGrad(Variable(gradBs.asType(bs.type()), false));
+ };
return Variable(output, {input, weight, bias}, gradFunc);
}
-Variable gatedlinearunit(const Variable& input, const int dim) {
+Variable gatedlinearunit(Variable const& input, int const dim) {
if(dim >= input.ndim())
throw std::invalid_argument(
"gatedlinearunit - passed dim is great than the "
@@ -1643,21 +1717,22 @@ Variable gatedlinearunit(const Variable& input, const int dim) {
auto gradFunc = [fhalf, shalf, fhalfout, shalfout, inDims, inType](
std::vector& inputs,
- const Variable& gradOutput) {
- auto gradGlu = Tensor(inDims, inType);
- gradGlu(fhalf) = shalfout * gradOutput.tensor();
- gradGlu(shalf) =
- shalfout * (1.0 - shalfout) * fhalfout * gradOutput.tensor();
- inputs[0].addGrad(Variable(gradGlu, false));
- };
+ Variable const& gradOutput
+ ) {
+ auto gradGlu = Tensor(inDims, inType);
+ gradGlu(fhalf) = shalfout * gradOutput.tensor();
+ gradGlu(shalf) =
+ shalfout * (1.0 - shalfout) * fhalfout * gradOutput.tensor();
+ inputs[0].addGrad(Variable(gradGlu, false));
+ };
return Variable(fhalfout * shalfout, {input.withoutData()}, gradFunc);
}
std::tuple rnn(
- const Variable& input,
- const Variable& hiddenState,
- const Variable& cellState,
- const Variable& weights,
+ Variable const& input,
+ Variable const& hiddenState,
+ Variable const& cellState,
+ Variable const& weights,
int hiddenSize,
int numLayers,
RnnMode mode,
@@ -1691,63 +1766,65 @@ std::tuple rnn(
dropProb,
gradData,
payload](
- std::vector& inputs,
- const Variable& /* gradOutput */) {
- auto& input = inputs[0];
- auto& hiddenState = inputs[1];
- auto& cellState = inputs[2];
- auto& weights = inputs[3];
-
- if(
- !(input.isCalcGrad() || hiddenState.isCalcGrad()
+ std::vector& inputs,
+ Variable const& /* gradOutput */
+
+ ) {
+ auto& input = inputs[0];
+ auto& hiddenState = inputs[1];
+ auto& cellState = inputs[2];
+ auto& weights = inputs[3];
+
+ if(
+ !(input.isCalcGrad() || hiddenState.isCalcGrad()
|| cellState.isCalcGrad() || weights.isCalcGrad())
- )
- return;
-
- auto [dy, dhy, dcy, dweights] =
- input.tensor().backend().getExtension().rnnBackward(
- input.tensor(),
- hiddenState.tensor(),
- cellState.tensor(),
- weights.tensor(),
- gradData,
- output,
- numLayers,
- hiddenSize,
- mode,
- bidirectional,
- dropProb,
- payload
- );
+ )
+ return;
+
+ auto [dy, dhy, dcy, dweights] =
+ input.tensor().backend().getExtension().rnnBackward(
+ input.tensor(),
+ hiddenState.tensor(),
+ cellState.tensor(),
+ weights.tensor(),
+ gradData,
+ output,
+ numLayers,
+ hiddenSize,
+ mode,
+ bidirectional,
+ dropProb,
+ payload
+ );
- input.addGrad(Variable(dy.astype(input.type()), false));
- hiddenState.addGrad(Variable(dhy.astype(hiddenState.type()), false));
- cellState.addGrad(Variable(dcy.astype(cellState.type()), false));
- weights.addGrad(Variable(dweights.astype(weights.type()), false));
- };
+ input.addGrad(Variable(dy.asType(input.type()), false));
+ hiddenState.addGrad(Variable(dhy.asType(hiddenState.type()), false));
+ cellState.addGrad(Variable(dcy.asType(cellState.type()), false));
+ weights.addGrad(Variable(dweights.asType(weights.type()), false));
+ };
Variable dummy(Tensor(), {input, hiddenState, cellState, weights}, gradFunc);
auto dyGradFunc =
- [gradData](std::vector& inputs, const Variable& gradOutput) {
- if(!inputs[0].isGradAvailable())
- inputs[0].addGrad(Variable(Tensor(), false));
- gradData->dy = gradOutput.tensor().asContiguousTensor();
- };
+ [gradData](std::vector& inputs, Variable const& gradOutput) {
+ if(!inputs[0].isGradAvailable())
+ inputs[0].addGrad(Variable(Tensor(), false));
+ gradData->dy = gradOutput.tensor().asContiguousTensor();
+ };
auto dhyGradFunc =
- [gradData](std::vector& inputs, const Variable& gradOutput) {
- if(!inputs[0].isGradAvailable())
- inputs[0].addGrad(Variable(Tensor(), false));
- gradData->dhy = gradOutput.tensor().asContiguousTensor();
- };
+ [gradData](std::vector& inputs, Variable const& gradOutput) {
+ if(!inputs[0].isGradAvailable())
+ inputs[0].addGrad(Variable(Tensor(), false));
+ gradData->dhy = gradOutput.tensor().asContiguousTensor();
+ };
auto dcyGradFunc =
- [gradData](std::vector& inputs, const Variable& gradOutput) {
- if(!inputs[0].isGradAvailable())
- inputs[0].addGrad(Variable(Tensor(), false));
- gradData->dcy = gradOutput.tensor().asContiguousTensor();
- };
+ [gradData](std::vector& inputs, Variable const& gradOutput) {
+ if(!inputs[0].isGradAvailable())
+ inputs[0].addGrad(Variable(Tensor(), false));
+ gradData->dcy = gradOutput.tensor().asContiguousTensor();
+ };
Variable yv(output, {dummy}, dyGradFunc); // output
Variable hyv(hiddenOut, {dummy}, dhyGradFunc); // hidden state output
@@ -1755,63 +1832,67 @@ std::tuple rnn(
return std::make_tuple(yv, hyv, cyv);
}
-Variable embedding(const Variable& input, const Variable& embeddings) {
+Variable embedding(Variable const& input, Variable const& embeddings) {
// TODO{fl::Tensor}{4-dims} - relax this
if(input.ndim() >= 4)
- throw std::invalid_argument("embedding input must have 3 or fewer dims");
+ throw std::invalid_argument{"embedding input must have 3 or fewer dims"};
- auto idxs = input.tensor().flatten();
+ auto const idxs = input.tensor().flatten();
auto inDims = input.shape();
std::vector rDims(input.ndim() + 1);
rDims[0] = embeddings.dim(0);
- for(unsigned i = 1; i < input.ndim() + 1; i++)
+ for(Dim i = 1; i < input.ndim() + 1; i++)
rDims[i] = inDims[i - 1];
- Shape resultDims(rDims);
- Tensor result = fl::reshape(embeddings.tensor()(fl::span, idxs), resultDims);
-
- auto gradFunc = [](std::vector& inputs,
- const Variable& gradOutput) {
- auto& w = inputs[1];
- if(!w.isCalcGrad())
- return;
-
- auto ip = inputs[0].tensor().flatten();
- unsigned size = ip.elements();
- auto deltas = fl::reshape(gradOutput.tensor(), {w.dim(0), size});
-
- // Sparse Tensor
- auto sp = Tensor(
- ip.elements(),
- w.dim(1),
- fl::full({size}, 1, deltas.type()),
- fl::arange({size + 1}, 0, fl::dtype::s32),
- ip.astype(fl::dtype::s32),
- fl::StorageType::CSR
- );
-
- auto grad = transpose(
- fl::matmul(
- sp,
- transpose(deltas), /* lhsProp = */
- MatrixProperty::Transpose
- )
- );
- w.addGrad(Variable(grad, false));
- };
-
- return Variable(result, {input, embeddings}, gradFunc);
+
+ Shape const resultDims{rDims};
+ auto const result = fl::reshape(embeddings.tensor()(fl::span, idxs), resultDims);
+
+ auto grad_func = [](
+ std::vector& inputs,
+ Variable const& gradOutput
+ ) {
+ auto& w = inputs[1];
+ if(!w.isCalcGrad())
+ return;
+
+ auto const ip = inputs[0].tensor().flatten();
+ auto size = static_cast(ip.elements());
+ auto const deltas = fl::reshape(gradOutput.tensor(), {w.dim(0), size});
+
+ // Sparse Tensor
+ auto const sp = Tensor{
+ size,
+ w.dim(1),
+ fl::full({size}, 1.0, deltas.type()),
+ fl::arrange({size + 1}, 0, fl::dtype::s32),
+ ip.asType(fl::dtype::s32),
+ fl::StorageType::CSR
+ };
+
+ auto const grad = transpose(
+ fl::matmul(
+ sp,
+ transpose(deltas),
+ /* lhsProp = */
+ MatrixProperty::Transpose
+ )
+ );
+ w.addGrad(Variable{grad, false});
+ };
+
+ return Variable{result, {input, embeddings}, grad_func};
}
Variable padding(
- const Variable& input,
+ Variable const& input,
std::vector> pad,
double val
) {
if(pad.size() > input.ndim())
- throw std::invalid_argument(
+ throw std::invalid_argument{
"padding: number of padding dimensions exceeds number "
"of input dimensions"
- );
+ };
Shape opDims = input.shape();
std::vector inSeq(input.ndim(), fl::span);
@@ -1823,33 +1904,34 @@ Variable padding(
result(inSeq) = input.tensor();
auto gradFunc =
- [inSeq](std::vector& inputs, const Variable& gradOutput) {
- inputs[0].addGrad(Variable(gradOutput.tensor()(inSeq), false));
- };
+ [inSeq](std::vector& inputs, Variable const& gradOutput) {
+ inputs[0].addGrad(Variable(gradOutput.tensor()(inSeq), false));
+ };
return Variable(result, {input.withoutData()}, gradFunc);
}
-Variable dropout(const Variable& input, double p) {
+Variable dropout(Variable const& input, double p) {
if(p > 0.0) {
auto mask = Variable(
- (fl::rand(input.shape(), input.type()) > p).astype(input.type()),
+ (fl::rand(input.shape(), input.type()) > p).asType(input.type()),
false
);
return 1.0 / (1.0 - p) * mask * input;
- } else
+ }
+ else
return input;
}
-Variable relu(const Variable& input) { return max(input, 0.0); }
+Variable relu(Variable const& input) { return max(input, 0.0); }
-Variable gelu(const Variable& in) {
+Variable gelu(Variable const& in) {
auto input = FL_ADJUST_INPUT_TYPE(in);
return 0.5 * input
- * (1.0
- + fl::tanh(0.7978845608 * (input + 0.044715 * input * input * input)));
+ * (1.0
+ + fl::tanh(0.7978845608 * (input + 0.044715 * input * input * input)));
}
-fl::Variable relativePositionEmbeddingRotate(const fl::Variable& input) {
+fl::Variable relativePositionEmbeddingRotate(fl::Variable const& input) {
if(input.ndim() != 3)
throw std::invalid_argument(
"relativePositionEmbeddingRotate - "
@@ -1870,31 +1952,32 @@ fl::Variable relativePositionEmbeddingRotate(const fl::Variable& input) {
data = fl::reshape(data, {d0 + d1 - 1, d1, d2});
auto gradFunc = [d0, d1, d2](
std::vector& inputs,
- const fl::Variable& gradOutput) {
- auto gradData = gradOutput.tensor();
- gradData = fl::reshape(gradData, {(d0 + d1 - 1) * d1, 1, d2});
- gradData = fl::concatenate(
- 0,
- gradData,
- fl::full({d1, 1, d2}, 0.0, gradData.type())
- );
- gradData = reshape(gradData, {d0 + d1, d1, d2});
- gradData = Variable(gradData, false)(fl::range(0, d0)).tensor();
- inputs[0].addGrad(fl::Variable(gradData, false));
- };
+ fl::Variable const& gradOutput
+ ) {
+ auto gradData = gradOutput.tensor();
+ gradData = fl::reshape(gradData, {(d0 + d1 - 1) * d1, 1, d2});
+ gradData = fl::concatenate(
+ 0,
+ gradData,
+ fl::full({d1, 1, d2}, 0.0, gradData.type())
+ );
+ gradData = reshape(gradData, {d0 + d1, d1, d2});
+ gradData = Variable(gradData, false)(fl::range(0, d0)).tensor();
+ inputs[0].addGrad(fl::Variable(gradData, false));
+ };
return fl::Variable(data, {input}, gradFunc);
}
fl::Variable multiheadAttention(
- const fl::Variable& query,
- const fl::Variable& key,
- const fl::Variable& value,
- const fl::Variable& posEmb,
- const fl::Variable& mask,
- const fl::Variable& padMask,
- const int32_t nHeads,
- const double pDropout,
- const int32_t offset /* = 0 */
+ fl::Variable const& query,
+ fl::Variable const& key,
+ fl::Variable const& value,
+ fl::Variable const& posEmb,
+ fl::Variable const& mask,
+ fl::Variable const& padMask,
+ int32_t const nHeads,
+ double const pDropout,
+ int32_t const offset /* = 0 */
) {
if(query.ndim() != 3)
throw std::invalid_argument(
@@ -1925,12 +2008,12 @@ fl::Variable multiheadAttention(
if(!posEmb.isEmpty()) {
int n = posEmb.dim(0) / 2 - offset;
auto pscores =
- relativePositionEmbeddingRotate(matmulNT(posEmb.astype(q.type()), q));
+ relativePositionEmbeddingRotate(matmulNT(posEmb.asType(q.type()), q));
scores =
scores + transpose(pscores(fl::range(n, n + k.dim(0))), {1, 0, 2});
}
if(!mask.isEmpty())
- scores = scores + tileAs(mask.astype(scores.type()), scores);
+ scores = scores + tileAs(mask.asType(scores.type()), scores);
if(!padMask.isEmpty()) {
if(padMask.dim(0) != query.dim(0))
throw std::invalid_argument(
@@ -1941,13 +2024,13 @@ fl::Variable multiheadAttention(
tileAs(padMaskTile, {padMask.dim(0), padMask.dim(0), nHeads, bsz});
scores = scores
+ moddims(
- padMaskTile.astype(scores.type()),
+ padMaskTile.asType(scores.type()),
{padMask.dim(0), padMask.dim(0), nHeads * bsz}
);
}
auto attn = dropout(softmax(scores, 1), pDropout);
- auto result = matmul(attn.astype(v.type()), v);
+ auto result = matmul(attn.asType(v.type()), v);
result = moddims(result, {-1, headDim * nHeads, bsz});
return result;
}
diff --git a/flashlight/fl/autograd/Functions.h b/flashlight/fl/autograd/Functions.h
index b2d23a0..e6423ab 100644
--- a/flashlight/fl/autograd/Functions.h
+++ b/flashlight/fl/autograd/Functions.h
@@ -71,11 +71,11 @@ namespace detail {
&& optimLevel != OptimLevel::DEFAULT
)
// Not in the excluded list - cast to f16
- res = in.astype(fl::dtype::f16);
+ res = in.asType(fl::dtype::f16);
else {
// Upcast to f32 only if we have an f16 input - otherwise, leave as is
if(in.type() == fl::dtype::f16)
- res = in.astype(fl::dtype::f32);
+ res = in.asType(fl::dtype::f32);
else
res = in;
}
@@ -449,7 +449,7 @@ FL_API Variable concatenate(const std::vector& concatInputs, int dim);
* divisible, last chunk of smaller splitSize will be included.
* @param dim dimension along which to split the Variable
*/
-FL_API std::vector split(const Variable& input, long splitSize, int dim);
+FL_API std::vector split(const Variable& input, int64_t splitSize, int dim);
/**
* Splits a Variable into smaller chunks.
@@ -458,7 +458,7 @@ FL_API std::vector split(const Variable& input, long splitSize, int di
* @param splitSizes vector of integers specifying the sizes for each split
* @param dim dimension along which to split the Variable
*/
-FL_API std::vector split(const Variable& input, const std::vector& splitSizes, int dim);
+FL_API std::vector split(const Variable& input, std::vector const& splitSizes, int dim);
/**
* Repeats the tensor `input` along specific dimensions. The number of
diff --git a/flashlight/fl/autograd/Variable.cpp b/flashlight/fl/autograd/Variable.cpp
index bd1fb6a..08e1914 100644
--- a/flashlight/fl/autograd/Variable.cpp
+++ b/flashlight/fl/autograd/Variable.cpp
@@ -95,13 +95,13 @@ Variable Variable::copy() const {
return Variable(sharedData_->data, sharedGrad_->calcGrad);
}
-Variable Variable::astype(fl::dtype newType) const {
- auto output = tensor().astype(newType);
+Variable Variable::asType(fl::dtype newType) const {
+ auto output = tensor().asType(newType);
auto gradFunc = [](std::vector& inputs,
const Variable& gradOutput) {
auto& input = inputs[0];
// Cast the grad output to match the type of the input's grad
- input.addGrad(Variable(gradOutput.tensor().astype(input.type()), false));
+ input.addGrad(Variable(gradOutput.tensor().asType(input.type()), false));
};
return Variable(output, {this->withoutData()}, gradFunc);
}
diff --git a/flashlight/fl/autograd/Variable.h b/flashlight/fl/autograd/Variable.h
index 60fb040..2ce498d 100644
--- a/flashlight/fl/autograd/Variable.h
+++ b/flashlight/fl/autograd/Variable.h
@@ -128,7 +128,12 @@ class FL_API Variable {
*
* @return returns the casted variable.
*/
- Variable astype(fl::dtype type) const;
+ Variable asType(fl::dtype type) const;
+
+ /**
+ * @deprecated use @ref Variable::asType(fl::dtype) const instead
+ */
+ Variable astype(fl::dtype type) const { return asType(type); }
/**
* @return a reference to the underlying gradient Variable.
@@ -207,25 +212,19 @@ class FL_API Variable {
* Must eventually be freed manually via `free` or a related call.
*/
template
- T* host() const {
- return tensor().host();
- }
+ T* host() const { return tensor().host(); }
/**
* Copies the array to the existing host pointer `ptr`
*/
template
- void host(T* ptr) const {
- tensor().host(ptr);
- }
+ void host(T* ptr) const { tensor().host(ptr); }
/**
* Get the first element of the array as a scalar
*/
template
- T scalar() const {
- return tensor().scalar();
- }
+ T scalar() const { return tensor().scalar(); }
/**
* Remove the gradient stored by the Variable
diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp
index 9f6b315..bb4e7ad 100644
--- a/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp
+++ b/flashlight/fl/autograd/tensor/backend/cudnn/BatchNorm.cpp
@@ -1,8 +1,8 @@
/*
- * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * SPDX-License-Identifier: MIT
*
- * This source code is licensed under the MIT license found in the
- * LICENSE file in the root directory of this source tree.
+ * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE)
+ * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE)
*/
#include "flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h"
@@ -48,15 +48,15 @@ namespace {
if(minAxis == 0) {
modeOut = CUDNN_BATCHNORM_PER_ACTIVATION;
- inDescDimsOut = Shape(
+ inDescDimsOut = Shape{
{
1,
1,
nfeatures,
- static_cast(input.elements() / nfeatures)
+ static_cast(input.elements() / nfeatures)
}
- );
- wtDescDimsOut = Shape({1, 1, nfeatures});
+ };
+ wtDescDimsOut = Shape{1, 1, nfeatures};
} else {
modeOut = CUDNN_BATCHNORM_SPATIAL;
#if CUDNN_VERSION >= 7003
@@ -67,15 +67,15 @@ namespace {
int batchsz = 1;
for(int i = maxAxis + 1; i < input.ndim(); ++i)
batchsz *= input.dim(i);
- inDescDimsOut = Shape(
+ inDescDimsOut = Shape{
{
1,
- static_cast(input.elements() / (nfeatures * batchsz)),
+ static_cast(input.elements() / (nfeatures * batchsz)),
nfeatures,
batchsz,
}
- );
- wtDescDimsOut = Shape({1, 1, nfeatures});
+ };
+ wtDescDimsOut = Shape{1, 1, nfeatures};
}
}
@@ -101,7 +101,7 @@ Tensor CudnnAutogradExtension::batchnorm(
);
FL_TENSOR_DTYPES_MATCH_CHECK(weight, bias, runningMean, runningVar);
- auto output = Tensor(input.shape(), input.type());
+ auto output = Tensor{input.shape(), input.type()};
cudnnBatchNormMode_t mode;
Shape inDescDims, wtDescDims;
@@ -115,15 +115,15 @@ Tensor CudnnAutogradExtension::batchnorm(
// Weight, bias, and running mean/var arrays can't be fp16 (must be fp32)
Tensor weightArray = weight.isEmpty()
? fl::full(wtDescDims, 1.0, fl::dtype::f32)
- : weight.astype(fl::dtype::f32);
+ : weight.asType(fl::dtype::f32);
Tensor biasArray = bias.isEmpty() ? fl::full(wtDescDims, 0.0, fl::dtype::f32)
- : bias.astype(fl::dtype::f32);
+ : bias.asType(fl::dtype::f32);
fl::dtype scalarsType =
input.type() == fl::dtype::f16 ? fl::dtype::f32 : input.type();
- auto inDesc = TensorDescriptor(input.type(), inDescDims);
- auto wtDesc = TensorDescriptor(weightArray.type(), wtDescDims);
+ auto inDesc = TensorDescriptor{input.type(), inDescDims};
+ auto wtDesc = TensorDescriptor{weightArray.type(), wtDescDims};
{
DevicePtr inRaw(input);
@@ -140,8 +140,8 @@ Tensor CudnnAutogradExtension::batchnorm(
);
if(train) {
- saveMean = Tensor({wtDescDims[2]}, scalarsType);
- saveVar = Tensor({wtDescDims[2]}, scalarsType);
+ saveMean = Tensor{{wtDescDims[2]}, scalarsType};
+ saveVar = Tensor{{wtDescDims[2]}, scalarsType};
DevicePtr saveMeanRaw(saveMean);
DevicePtr saveVarRaw(saveVar);
@@ -153,11 +153,11 @@ Tensor CudnnAutogradExtension::batchnorm(
mode,
kOne(scalarsType),
kZero(scalarsType),
- inDesc.descriptor,
+ inDesc.get(),
inRaw.get(),
- inDesc.descriptor,
+ inDesc.get(),
outRaw.get(),
- wtDesc.descriptor,
+ wtDesc.get(),
wtRaw.get(),
bsRaw.get(),
momentum,
@@ -175,11 +175,11 @@ Tensor CudnnAutogradExtension::batchnorm(
mode,
kOne(scalarsType),
kZero(scalarsType),
- inDesc.descriptor,
+ inDesc.get(),
inRaw.get(),
- inDesc.descriptor,
+ inDesc.get(),
outRaw.get(),
- wtDesc.descriptor,
+ wtDesc.get(),
wtRaw.get(),
bsRaw.get(),
runMeanRaw.get(),
@@ -223,13 +223,13 @@ std::tuple CudnnAutogradExtension::batchnormBackward(
const void* one1 = kOne(scalarsType);
const void* zero0 = kZero(scalarsType);
- auto iDesc = TensorDescriptor(input.type(), inDescDims);
- auto wDesc = TensorDescriptor(wt.type(), wtDescDims);
+ auto iDesc = TensorDescriptor{input.type(), inDescDims};
+ auto wDesc = TensorDescriptor{wt.type(), wtDescDims};
// CuDNN doesn't support calculating only the gradients
// required for batchnorm
- auto gradIn = Tensor(input.shape(), input.type());
- auto gradWt = Tensor(wt.shape(), wt.type());
- auto gradBs = Tensor(wt.shape(), wt.type());
+ auto gradIn = Tensor{input.shape(), input.type()};
+ auto gradWt = Tensor{wt.shape(), wt.type()};
+ auto gradBs = Tensor{wt.shape(), wt.type()};
{
DevicePtr iRaw(input);
DevicePtr wRaw(wt);
@@ -257,13 +257,13 @@ std::tuple CudnnAutogradExtension::batchnormBackward(
zero0,
one1,
zero0,
- iDesc.descriptor,
+ iDesc.get(),
iRaw.get(),
- iDesc.descriptor,
+ iDesc.get(),
gradOpRaw.get(),
- iDesc.descriptor,
+ iDesc.get(),
gradInRaw.get(),
- wDesc.descriptor,
+ wDesc.get(),
wRaw.get(),
gradWtRaw.get(),
gradBsRaw.get(),
diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CMakeLists.txt b/flashlight/fl/autograd/tensor/backend/cudnn/CMakeLists.txt
index 49660c9..0bb7eba 100644
--- a/flashlight/fl/autograd/tensor/backend/cudnn/CMakeLists.txt
+++ b/flashlight/fl/autograd/tensor/backend/cudnn/CMakeLists.txt
@@ -8,6 +8,8 @@ target_sources(
${CMAKE_CURRENT_LIST_DIR}/Conv2D.cpp
${CMAKE_CURRENT_LIST_DIR}/CudnnUtils.h
${CMAKE_CURRENT_LIST_DIR}/CudnnUtils.cpp
+ ${CMAKE_CURRENT_LIST_DIR}/CudnnRnnUtils.h
+ ${CMAKE_CURRENT_LIST_DIR}/CudnnRnnUtils.cpp
${CMAKE_CURRENT_LIST_DIR}/Pool2D.cpp
${CMAKE_CURRENT_LIST_DIR}/RNN.cpp
)
diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp
index bb89e61..e6560e5 100644
--- a/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp
+++ b/flashlight/fl/autograd/tensor/backend/cudnn/Conv2D.cpp
@@ -1,8 +1,8 @@
/*
- * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * SPDX-License-Identifier: MIT
*
- * This source code is licensed under the MIT license found in the
- * LICENSE file in the root directory of this source tree.
+ * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE)
+ * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE)
*/
#include "flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h"
@@ -270,7 +270,7 @@ namespace {
) {
CUDNN_CHECK_ERR(
cudnnSetConvolutionMathType(
- cDesc.descriptor,
+ cDesc.get(),
kKernelModesToCudnnMathType.at(kernelOptions->currentOption())
)
);
@@ -280,13 +280,13 @@ namespace {
if(input.type() == fl::dtype::f16)
CUDNN_CHECK_ERR(
cudnnSetConvolutionMathType(
- cDesc.descriptor,
+ cDesc.get(),
CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION
)
);
else
CUDNN_CHECK_ERR(
- cudnnSetConvolutionMathType(cDesc.descriptor, CUDNN_DEFAULT_MATH)
+ cudnnSetConvolutionMathType(cDesc.get(), CUDNN_DEFAULT_MATH)
);
}
@@ -314,42 +314,42 @@ Tensor CudnnAutogradExtension::conv2d(
auto hasBias = bias.elements() > 0;
- auto inDesc = TensorDescriptor(input);
- auto wtDesc = FilterDescriptor(weights);
- auto convDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups);
+ auto inDesc = TensorDescriptor{input};
+ auto wtDesc = FilterDescriptor{weights};
+ auto convDesc = ConvDescriptor{input.type(), px, py, sx, sy, dx, dy, groups};
if(input.type() == fl::dtype::f16)
CUDNN_CHECK_ERR(
cudnnSetConvolutionMathType(
- convDesc.descriptor,
+ convDesc.get(),
CUDNN_TENSOR_OP_MATH_ALLOW_CONVERSION
)
);
else
CUDNN_CHECK_ERR(
- cudnnSetConvolutionMathType(convDesc.descriptor, CUDNN_DEFAULT_MATH)
+ cudnnSetConvolutionMathType(convDesc.get(), CUDNN_DEFAULT_MATH)
);
std::array odims;
CUDNN_CHECK_ERR(
cudnnGetConvolutionNdForwardOutputDim(
- convDesc.descriptor,
- inDesc.descriptor,
- wtDesc.descriptor,
+ convDesc.get(),
+ inDesc.get(),
+ wtDesc.get(),
4,
odims.data()
)
);
- auto output = Tensor({odims[3], odims[2], odims[1], odims[0]}, input.type());
- auto outDesc = TensorDescriptor(output);
+ auto output = Tensor{{odims[3], odims[2], odims[1], odims[0]}, input.type()};
+ auto outDesc = TensorDescriptor{output};
auto handle = getCudnnHandle();
const auto& cudnnStream = getCudnnStream();
auto fwdAlgoBestPerf = getFwdAlgo(
- inDesc.descriptor,
- wtDesc.descriptor,
- convDesc.descriptor,
- outDesc.descriptor,
+ inDesc.get(),
+ wtDesc.get(),
+ convDesc.get(),
+ outDesc.get(),
input.type()
);
@@ -357,22 +357,22 @@ Tensor CudnnAutogradExtension::conv2d(
try {
wspace =
- Tensor({static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8);
+ Tensor{{static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8};
} catch(const std::exception&) {
fwdAlgoBestPerf.algo = kFwdDefaultAlgo;
CUDNN_CHECK_ERR(
cudnnGetConvolutionForwardWorkspaceSize(
handle,
- inDesc.descriptor,
- wtDesc.descriptor,
- convDesc.descriptor,
- outDesc.descriptor,
+ inDesc.get(),
+ wtDesc.get(),
+ convDesc.get(),
+ outDesc.get(),
fwdAlgoBestPerf.algo,
&fwdAlgoBestPerf.memory
)
);
wspace =
- Tensor({static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8);
+ Tensor{{static_cast(fwdAlgoBestPerf.memory)}, fl::dtype::b8};
}
{
DevicePtr inPtr(input);
@@ -390,22 +390,22 @@ Tensor CudnnAutogradExtension::conv2d(
cudnnConvolutionForward(
handle,
one,
- inDesc.descriptor,
+ inDesc.get(),
inPtr.get(),
- wtDesc.descriptor,
+ wtDesc.get(),
wtPtr.get(),
- convDesc.descriptor,
+ convDesc.get(),
fwdAlgoBestPerf.algo,
wspacePtr.get(),
fwdAlgoBestPerf.memory,
zero,
- outDesc.descriptor,
+ outDesc.get(),
outPtr.get()
)
);
if(hasBias) {
- auto bsDesc = TensorDescriptor(bias);
+ auto bsDesc = TensorDescriptor{bias};
DevicePtr bsPtr(bias);
// ensure cudnn compute stream waits on stream of bias tensor
relativeSync(cudnnStream, {bias});
@@ -413,10 +413,10 @@ Tensor CudnnAutogradExtension::conv2d(
cudnnAddTensor(
handle,
one,
- bsDesc.descriptor,
+ bsDesc.get(),
bsPtr.get(),
one,
- outDesc.descriptor,
+ outDesc.get(),
outPtr.get()
)
);
@@ -453,10 +453,10 @@ Tensor CudnnAutogradExtension::conv2dBackwardData(
// benchmarking suggests input or weight casting should occur, these
// descriptors may not be used/new ones with the correct types will be
// used instead.
- auto iDesc = TensorDescriptor(input);
- auto wDesc = FilterDescriptor(weight);
- auto cDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups);
- auto oDesc = TensorDescriptor(gradOutput);
+ auto iDesc = TensorDescriptor{input};
+ auto wDesc = FilterDescriptor{weight};
+ auto cDesc = ConvDescriptor{input.type(), px, py, sx, sy, dx, dy, groups};
+ auto oDesc = TensorDescriptor{gradOutput};
setDefaultMathType(cDesc, input);
@@ -481,40 +481,40 @@ Tensor CudnnAutogradExtension::conv2dBackwardData(
relativeSync(cudnnStream, {wtTensor});
bool isStrided = (dx * dy) > 1;
auto bwdDataAlgoBestPerf = getBwdDataAlgo(
- iDesc.descriptor,
- wDesc.descriptor,
- cDesc.descriptor,
- oDesc.descriptor,
+ iDesc.get(),
+ wDesc.get(),
+ cDesc.get(),
+ oDesc.get(),
isStrided,
inTensor.type()
);
Tensor ws;
try {
- ws = Tensor(
- {static_cast(bwdDataAlgoBestPerf.memory)},
+ ws = Tensor{
+ {static_cast(bwdDataAlgoBestPerf.memory)},
fl::dtype::b8
- );
+ };
} catch(const std::exception&) {
bwdDataAlgoBestPerf.algo = kBwdDataDefaultAlgo;
CUDNN_CHECK_ERR(
cudnnGetConvolutionBackwardDataWorkspaceSize(
hndl,
- wDesc.descriptor,
- oDesc.descriptor,
- cDesc.descriptor,
- iDesc.descriptor,
+ wDesc.get(),
+ oDesc.get(),
+ cDesc.get(),
+ iDesc.get(),
bwdDataAlgoBestPerf.algo,
&bwdDataAlgoBestPerf.memory
)
);
- ws = Tensor(
- {static_cast(bwdDataAlgoBestPerf.memory)},
+ ws = Tensor{
+ {static_cast(bwdDataAlgoBestPerf.memory)},
fl::dtype::b8
- );
+ };
}
- auto gradInput = Tensor(inTensor.shape(), inTensor.type());
+ auto gradInput = Tensor{inTensor.shape(), inTensor.type()};
{
DevicePtr gradInputPtr(gradInput);
DevicePtr gradResultPtr(gradOutputTensor);
@@ -525,16 +525,16 @@ Tensor CudnnAutogradExtension::conv2dBackwardData(
cudnnConvolutionBackwardData(
hndl,
oneg,
- wDesc.descriptor,
+ wDesc.get(),
wPtr.get(),
- oDesc.descriptor,
+ oDesc.get(),
gradResultPtr.get(),
- cDesc.descriptor,
+ cDesc.get(),
bwdDataAlgoBestPerf.algo,
wsPtr.get(),
bwdDataAlgoBestPerf.memory,
zerog,
- iDesc.descriptor,
+ iDesc.get(),
gradInputPtr.get()
)
);
@@ -570,18 +570,18 @@ Tensor CudnnAutogradExtension::conv2dBackwardData(
&wtTensorF32,
&gradOutput,
&gradOutputTensorF32]() {
- inTensorF32 = input.astype(fl::dtype::f32);
- wtTensorF32 = weight.astype(fl::dtype::f32);
- gradOutputTensorF32 = gradOutput.astype(fl::dtype::f32);
+ inTensorF32 = input.asType(fl::dtype::f32);
+ wtTensorF32 = weight.asType(fl::dtype::f32);
+ gradOutputTensorF32 = gradOutput.asType(fl::dtype::f32);
},
/* incrementCount = */ false
);
- auto iDescF32 = TensorDescriptor(inTensorF32);
- auto wDescF32 = FilterDescriptor(wtTensorF32);
+ auto iDescF32 = TensorDescriptor{inTensorF32};
+ auto wDescF32 = FilterDescriptor{wtTensorF32};
auto cDescF32 =
- ConvDescriptor(fl::dtype::f32, px, py, sx, sy, dx, dy, groups);
- auto oDescF32 = TensorDescriptor(gradOutputTensorF32);
+ ConvDescriptor{fl::dtype::f32, px, py, sx, sy, dx, dy, groups};
+ auto oDescF32 = TensorDescriptor{gradOutputTensorF32};
// core bwd data computation
dataGradBenchmark->audit(
[&dataGradOut,
@@ -671,10 +671,10 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias(
// benchmarking suggests input or weight casting should occur, these
// descriptors may not be used/new ones with the correct types will be
// used instead.
- auto iDesc = TensorDescriptor(input);
- auto wDesc = FilterDescriptor(weight);
- auto cDesc = ConvDescriptor(input.type(), px, py, sx, sy, dx, dy, groups);
- auto oDesc = TensorDescriptor(gradOutput);
+ auto iDesc = TensorDescriptor{input};
+ auto wDesc = FilterDescriptor{weight};
+ auto cDesc = ConvDescriptor{input.type(), px, py, sx, sy, dx, dy, groups};
+ auto oDesc = TensorDescriptor{gradOutput};
setDefaultMathType(cDesc, input);
@@ -699,39 +699,39 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias(
// ensure cudnn compute stream waits on stream of input tensor
relativeSync(cudnnStream, {inTensor});
auto bwdFilterAlgoBestPerf = getBwdFilterAlgo(
- iDesc.descriptor,
- wDesc.descriptor,
- cDesc.descriptor,
- oDesc.descriptor,
+ iDesc.get(),
+ wDesc.get(),
+ cDesc.get(),
+ oDesc.get(),
inTensor.type()
);
Tensor ws;
try {
- ws = Tensor(
- {static_cast(bwdFilterAlgoBestPerf.memory)},
+ ws = Tensor{
+ {static_cast(bwdFilterAlgoBestPerf.memory)},
fl::dtype::b8
- );
+ };
} catch(const std::exception&) {
bwdFilterAlgoBestPerf.algo = kBwdFilterDefaultAlgo;
CUDNN_CHECK_ERR(
cudnnGetConvolutionBackwardFilterWorkspaceSize(
hndl,
- iDesc.descriptor,
- oDesc.descriptor,
- cDesc.descriptor,
- wDesc.descriptor,
+ iDesc.get(),
+ oDesc.get(),
+ cDesc.get(),
+ wDesc.get(),
bwdFilterAlgoBestPerf.algo,
&bwdFilterAlgoBestPerf.memory
)
);
- ws = Tensor(
- {static_cast(bwdFilterAlgoBestPerf.memory)},
+ ws = Tensor{
+ {static_cast(bwdFilterAlgoBestPerf.memory)},
fl::dtype::b8
- );
+ };
}
- auto gradWeight = Tensor(wtTensor.shape(), wtTensor.type());
+ auto gradWeight = Tensor{wtTensor.shape(), wtTensor.type()};
{
DevicePtr gradWeightPtr(gradWeight);
DevicePtr gradResultPtr(gradOutputTensor);
@@ -742,16 +742,16 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias(
cudnnConvolutionBackwardFilter(
hndl,
oneg,
- iDesc.descriptor,
+ iDesc.get(),
iPtr.get(),
- oDesc.descriptor,
+ oDesc.get(),
gradResultPtr.get(),
- cDesc.descriptor,
+ cDesc.get(),
bwdFilterAlgoBestPerf.algo,
wsPtr.get(),
bwdFilterAlgoBestPerf.memory,
zerog,
- wDesc.descriptor,
+ wDesc.get(),
gradWeightPtr.get()
)
);
@@ -787,18 +787,18 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias(
&wtTensorF32,
&gradOutput,
&gradOutputTensorF32]() {
- inTensorF32 = input.astype(fl::dtype::f32);
- wtTensorF32 = weight.astype(fl::dtype::f32);
- gradOutputTensorF32 = gradOutput.astype(fl::dtype::f32);
+ inTensorF32 = input.asType(fl::dtype::f32);
+ wtTensorF32 = weight.asType(fl::dtype::f32);
+ gradOutputTensorF32 = gradOutput.asType(fl::dtype::f32);
},
/* incrementCount = */ false
);
- auto iDescF32 = TensorDescriptor(inTensorF32);
- auto wDescF32 = FilterDescriptor(wtTensorF32);
+ auto iDescF32 = TensorDescriptor{inTensorF32};
+ auto wDescF32 = FilterDescriptor{wtTensorF32};
auto cDescF32 =
- ConvDescriptor(fl::dtype::f32, px, py, sx, sy, dx, dy, groups);
- auto oDescF32 = TensorDescriptor(gradOutputTensorF32);
+ ConvDescriptor{fl::dtype::f32, px, py, sx, sy, dx, dy, groups};
+ auto oDescF32 = TensorDescriptor{gradOutputTensorF32};
// core bwd data computation
filterGradBenchmark->audit(
[&filterGradOut,
@@ -860,21 +860,21 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias(
const Tensor& bsTensor,
const Tensor& gradOutput,
const TensorDescriptor& oDesc) -> Tensor {
- auto gradBias = Tensor(bsTensor.shape(), bsTensor.type());
+ auto gradBias = Tensor{bsTensor.shape(), bsTensor.type()};
{
DevicePtr gradBiasPtr(gradBias);
DevicePtr gradResultPtr(gradOutput);
// ensure cudnn compute stream waits on gradient tensor streams
relativeSync(cudnnStream, {gradOutput, gradBias});
- auto bDesc = TensorDescriptor(bsTensor);
+ auto bDesc = TensorDescriptor{bsTensor};
CUDNN_CHECK_ERR(
cudnnConvolutionBackwardBias(
hndl,
oneg,
- oDesc.descriptor,
+ oDesc.get(),
gradResultPtr.get(),
zerog,
- bDesc.descriptor,
+ bDesc.get(),
gradBiasPtr.get()
)
);
@@ -906,12 +906,12 @@ std::pair CudnnAutogradExtension::conv2dBackwardFilterBias(
// Time cast bias and grad output if benchmarking
biasGradBenchmark->audit(
[&bias, &gradOutput, &biasF32, &gradOutputF32]() {
- biasF32 = bias.astype(fl::dtype::f32);
- gradOutputF32 = gradOutput.astype(fl::dtype::f32);
+ biasF32 = bias.asType(fl::dtype::f32);
+ gradOutputF32 = gradOutput.asType(fl::dtype::f32);
},
/* incrementCount = */ false
);
- auto oDescF32 = TensorDescriptor(gradOutputF32);
+ auto oDescF32 = TensorDescriptor{gradOutputF32};
// Perform bias gradient computation
biasGradBenchmark->audit(
[&biasGradOut,
diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp
index 305a6cc..eaac140 100644
--- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp
+++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.cpp
@@ -16,13 +16,11 @@ namespace fl {
std::shared_ptr CudnnAutogradExtension::createBenchmarkOptions() {
return std::make_shared(
std::make_shared>(
- std::vector(
- {
- KernelMode::F32,
- KernelMode::F32_ALLOW_CONVERSION,
- KernelMode::F16
- }
- ),
+ std::vector{
+ KernelMode::F32,
+ KernelMode::F32_ALLOW_CONVERSION,
+ KernelMode::F16
+ },
fl::kDynamicBenchmarkDefaultCount
)
);
diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h
index b960c30..2edfd89 100644
--- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h
+++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnAutogradExtension.h
@@ -1,8 +1,8 @@
/*
- * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * SPDX-License-Identifier: MIT
*
- * This source code is licensed under the MIT license found in the
- * LICENSE file in the root directory of this source tree.
+ * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE)
+ * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE)
*/
#pragma once
@@ -19,96 +19,96 @@ class CudnnAutogradExtension : public AutogradExtension {
// TODO(jacobkahn): implement getCudnnHandle
public:
- bool isDataTypeSupported(const fl::dtype& dtype) const override;
+ bool isDataTypeSupported(fl::dtype const& dtype) const override;
- enum class KernelMode {F32 = 0, F32_ALLOW_CONVERSION = 1, F16 = 2};
+ enum class KernelMode { F32 = 0, F32_ALLOW_CONVERSION = 1, F16 = 2 };
std::shared_ptr createBenchmarkOptions() override;
/**************************** Forward ****************************/
Tensor conv2d(
- const Tensor& input,
- const Tensor& weights,
- const Tensor& bias,
- const int sx,
- const int sy,
- const int px,
- const int py,
- const int dx,
- const int dy,
- const int groups,
+ Tensor const& input,
+ Tensor const& weights,
+ Tensor const& bias,
+ int sx,
+ int sy,
+ int px,
+ int py,
+ int dx,
+ int dy,
+ int groups,
std::shared_ptr payload
) override;
Tensor pool2d(
- const Tensor& input,
- const int wx,
- const int wy,
- const int sx,
- const int sy,
- const int px,
- const int py,
- const PoolingMode mode,
+ Tensor const& input,
+ int wx,
+ int wy,
+ int sx,
+ int sy,
+ int px,
+ int py,
+ PoolingMode mode,
std::shared_ptr payload
) override;
Tensor batchnorm(
Tensor& saveMean,
Tensor& saveVar,
- const Tensor& input,
- const Tensor& weight,
- const Tensor& bias,
+ Tensor const& input,
+ Tensor const& weight,
+ Tensor const& bias,
Tensor& runningMean,
Tensor& runningVar,
- const std::vector& axes,
- const bool train,
- const double momentum,
- const double epsilon,
+ std::vector const& axes,
+ bool train,
+ double momentum,
+ double epsilon,
std::shared_ptr payload
) override;
std::tuple rnn(
- const Tensor& input,
- const Tensor& hiddenState,
- const Tensor& cellState,
- const Tensor& weights,
- const int hiddenSize,
- const int numLayers,
- const RnnMode mode,
- const bool bidirectional,
- const float dropout,
- std::shared_ptr payload
+ Tensor const& input,
+ Tensor const& hiddenState,
+ Tensor const& cellState,
+ Tensor const& weights,
+ int hiddenSize,
+ int numLayers,
+ RnnMode mode,
+ bool bidirectional,
+ float dropProb,
+ std::shared_ptr autogradPayload
) override;
/**************************** Backward ****************************/
// ]----- Convolution
Tensor conv2dBackwardData(
- const Tensor& gradOutput,
- const Tensor& input,
- const Tensor& weight,
- const int sx,
- const int sy,
- const int px,
- const int py,
- const int dx,
- const int dy,
- const int groups,
+ Tensor const& gradOutput,
+ Tensor const& input,
+ Tensor const& weight,
+ int sx,
+ int sy,
+ int px,
+ int py,
+ int dx,
+ int dy,
+ int groups,
std::shared_ptr dataGradBenchmark,
std::shared_ptr payload
) override;
std::pair conv2dBackwardFilterBias(
- const Tensor& gradOutput,
- const Tensor& input,
- const Tensor& weights,
- const Tensor& bias,
- const int sx,
- const int sy,
- const int px,
- const int py,
- const int dx,
- const int dy,
- const int groups,
+ Tensor const& gradOutput,
+ Tensor const& input,
+ Tensor const& weights,
+ Tensor const& bias,
+ int sx,
+ int sy,
+ int px,
+ int py,
+ int dx,
+ int dy,
+ int groups,
std::shared_ptr filterBench,
std::shared_ptr biasBench,
std::shared_ptr autogradPayload
@@ -116,47 +116,59 @@ class CudnnAutogradExtension : public AutogradExtension {
// ]----- pool2D
Tensor pool2dBackward(
- const Tensor& gradOutput,
- const Tensor& input,
- const Tensor& poolOutput,
- const int wx,
- const int wy,
- const int sx,
- const int sy,
- const int px,
- const int py,
- const PoolingMode mode,
+ Tensor const& gradOutput,
+ Tensor const& input,
+ Tensor const& poolOutput,
+ int wx,
+ int wy,
+ int sx,
+ int sy,
+ int px,
+ int py,
+ PoolingMode mode,
std::shared_ptr payload
) override;
// ]----- batchnorm
std::tuple batchnormBackward(
- const Tensor& gradOutput,
- const Tensor& saveMean,
- const Tensor& saveVar,
- const Tensor& input,
- const Tensor& weight,
- const std::vector& axes,
- const bool train,
- const float epsilon,
+ Tensor const& gradOutput,
+ Tensor const& saveMean,
+ Tensor const& saveVar,
+ Tensor const& input,
+ Tensor const& weight,
+ std::vector const& axes,
+ bool train,
+ float epsilon,
std::shared_ptr payload
) override;
// ]----- rnn
std::tuple rnnBackward(
- const Tensor& input,
- const Tensor& hiddenState,
- const Tensor& cellState,
- const Tensor& weights,
- const std::shared_ptr gradData,
- const Tensor& output,
- const int numLayers,
- const int hiddenSize,
- const RnnMode mode,
- const bool bidirectional,
- const float dropProb,
- std::shared_ptr payload
+ Tensor const& input,
+ Tensor const& hiddenState,
+ Tensor const& cellState,
+ Tensor const& weights,
+ std::shared_ptr gradData,
+ Tensor const& output,
+ int numLayers,
+ int hiddenSize,
+ RnnMode mode,
+ bool bidirectional,
+ float dropProb,
+ std::shared_ptr autogradPayload
) override;
+
+private:
+
+ static void checkHiddenStateDims(int hiddenSize, Tensor const& hiddenState, int batchSize, int totalLayers);
+ static void checkCellStateDims(
+ int hiddenSize,
+ RnnMode mode,
+ Tensor const& cellState,
+ int batchSize,
+ int totalLayers
+ );
+
};
} // namespace fl
diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnRnnUtils.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnRnnUtils.cpp
new file mode 100644
index 0000000..c64c234
--- /dev/null
+++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnRnnUtils.cpp
@@ -0,0 +1,305 @@
+#include "CudnnRnnUtils.h"
+
+#include "flashlight/fl/common/DevicePtr.h"
+#include "flashlight/fl/tensor/Compute.h"
+
+
+namespace fl {
+namespace {
+ struct temp_space_sizes {
+ size_t size;
+ size_t reserveSize;
+ };
+
+ temp_space_sizes rnn_temp_space_sizes(
+ cudnnHandle_t handle,
+ RNNDescriptor const& rnnDescriptor,
+ RNNDataDescriptor const& xDescriptor,
+ cudnnForwardMode_t mode
+ ) {
+ temp_space_sizes sizes{};
+
+ CUDNN_CHECK_ERR(
+ cudnnGetRNNTempSpaceSizes(
+ handle,
+ rnnDescriptor.get(),
+ mode,
+ xDescriptor.get(),
+ &sizes.size,
+ &sizes.reserveSize
+ )
+ );
+
+ return sizes;
+ }
+
+ size_t rnn_weight_space_size(
+ cudnnHandle_t handle,
+ RNNDescriptor const& rnnDescriptor
+ ) {
+ size_t size = 0;
+
+ CUDNN_CHECK_ERR(
+ cudnnGetRNNWeightSpaceSize(handle,rnnDescriptor.get(),&size)
+ );
+ return size;
+ }
+
+ std::optional create_dev_seq_lengths(int batchSize, int seqLength) {
+ //see cudnn docs for cudnnRNNForward as explanation
+#if CUDNN_VERSION >= 8901
+ return std::nullopt;
+#else
+ return fl::full({batchSize}, seqLength, fl::dtype::s32);
+#endif
+ }
+
+}
+}
+
+namespace fl {
+void cudnn_rnn_forward(
+ int batchSize,
+ int seqLength,
+ bool train,
+ RNNDescriptor const& rnnDesc,
+ Tensor const& x,
+ Tensor const& y,
+ Tensor const& weights,
+ TensorDescriptor const& cxDesc,
+ TensorDescriptor const& hxDesc,
+ Tensor const& hy,
+ Tensor const& cy,
+ Tensor const& hiddenState,
+ Tensor const& cellState,
+ Tensor& reserveSpace
+) {
+ RNNDataDescriptor xDesc{x.type(), x.shape()};
+ RNNDataDescriptor yDesc{y.type(), y.shape()};
+
+ auto handle = getCudnnHandle();
+
+ size_t weightSpaceSize = rnn_weight_space_size(handle, rnnDesc);
+
+ if(weightSpaceSize != weights.bytes())
+ throw std::invalid_argument("invalid # of parameters or wrong input shape for RNN");
+
+ auto const forwardMode = train ? CUDNN_FWD_MODE_TRAINING : CUDNN_FWD_MODE_INFERENCE;
+
+ auto [workspaceSize, reserveSize] = rnn_temp_space_sizes(handle, rnnDesc, xDesc, forwardMode);
+
+ Tensor workspace({static_cast(workspaceSize)}, fl::dtype::b8);
+ // Space must be reused between forward and backward for cuDNN
+
+ reserveSpace = Tensor{{static_cast(reserveSize)}, fl::dtype::b8};
+
+ auto devSeqLengths = create_dev_seq_lengths(batchSize, seqLength);
+
+ auto const& cudnnStream = getCudnnStream();
+
+ {
+ auto contiguousX = x.asContiguousTensor();
+ auto contiguousWeights = weights.asContiguousTensor();
+ DevicePtr xRaw(contiguousX);
+ DevicePtr hxRaw(hiddenState);
+ DevicePtr cxRaw(cellState);
+ DevicePtr weightSpaceRaw(contiguousWeights);
+ DevicePtr yRaw(y);
+ DevicePtr hyRaw(hy);
+ DevicePtr cyRaw(cy);
+ DevicePtr workspaceRaw(workspace);
+ DevicePtr reserveSpaceRaw(reserveSpace);
+
+ std::optional devSeqLengthsRaw{};
+
+ if(devSeqLengths)
+ devSeqLengthsRaw.emplace(*devSeqLengths);
+
+ // ensure cudnn compute stream waits greaterThanEqual(&on input/output tensor streams
+
+ std::vector waits{
+ contiguousX,
+ hiddenState,
+ cellState,
+ contiguousWeights,
+ y,
+ hy,
+ cy,
+ workspace,
+ reserveSpace,
+ };
+ if(devSeqLengths)
+ waits.push_back(*devSeqLengths);
+
+ relativeSync(cudnnStream, waits);
+
+
+ CUDNN_CHECK_ERR(
+ cudnnRNNForward(
+ handle,
+ rnnDesc.get(),
+ forwardMode,
+ devSeqLengthsRaw ? devSeqLengthsRaw->getAs() : nullptr,
+
+ xDesc.get(),
+ xRaw.get(),
+ yDesc.get(),
+ yRaw.get(),
+
+ hxDesc.get(),
+ hxRaw.get(),
+ hyRaw.get(),
+ cxDesc.get(),
+ cxRaw.get(),
+ cyRaw.get(),
+
+ weightSpaceSize,
+ weightSpaceRaw.get(),
+
+ workspaceSize,
+ workspaceRaw.get(),
+
+ reserveSize,
+ reserveSpaceRaw.get()
+ )
+ );
+ }
+
+ // ensure output tensor streams wait on cudnn compute stream
+ relativeSync({y, hy, cy}, cudnnStream);
+}
+
+void cudnn_rnn_backward(
+ int batchSize,
+ int seqLength,
+ RNNDescriptor const& rnnDesc,
+
+ Tensor const& x,
+ Tensor const& y,
+ Tensor const& dy,
+ Tensor const& weights,
+ TensorDescriptor const& cxDesc,
+ TensorDescriptor const& hxDesc,
+ Tensor const& dhy,
+ Tensor const& dcy,
+ Tensor const& hiddenState,
+ Tensor const& cellState,
+ Tensor const& dx,
+ Tensor const& dhx,
+ Tensor const& dcx,
+ Tensor const& dw,
+
+ Tensor const& reserveSpace
+) {
+ auto handle = getCudnnHandle();
+ auto const& cudnnStream = getCudnnStream();
+
+ RNNDataDescriptor xDesc{x.type(), x.shape()};
+ RNNDataDescriptor yDesc{y.type(), y.shape()};
+
+ size_t weightSpaceSize = rnn_weight_space_size(handle, rnnDesc);
+ auto [workspaceSize, reserveSize] = rnn_temp_space_sizes(handle, rnnDesc, xDesc, CUDNN_FWD_MODE_TRAINING);
+
+ Tensor workspace({static_cast(workspaceSize)}, fl::dtype::b8);
+
+ auto devSeqLengths = create_dev_seq_lengths(batchSize, seqLength);
+
+ std::vector waits = {y, workspace, reserveSpace};
+ if(devSeqLengths)
+ waits.push_back(*devSeqLengths);
+
+ // ensure cudnn compute stream waits on input/output tensor streams
+ relativeSync(cudnnStream, waits);
+
+ DevicePtr yRaw(y);
+ DevicePtr workspaceRaw(workspace);
+ DevicePtr reserveSpaceRaw(reserveSpace);
+
+ std::optional devSeqLengthsRaw{};
+ if(devSeqLengths)
+ devSeqLengthsRaw.emplace(*devSeqLengths);
+
+ {
+ DevicePtr dyRaw(dy); // Has to be set to 0 if empty
+ DevicePtr dhyRaw(dhy);
+ DevicePtr dcyRaw(dcy);
+
+ DevicePtr wRaw(weights);
+
+ DevicePtr hxRaw(hiddenState);
+ DevicePtr cxRaw(cellState);
+
+ DevicePtr dxRaw(dx);
+ DevicePtr dhxRaw(dhx);
+ DevicePtr dcxRaw(dcx);
+
+ // ensure cudnn compute stream waits on input/output tensor streams
+ relativeSync(
+ cudnnStream,
+ {dy, dhy, dcy, weights, hiddenState, cellState, dx, dhx, dcx}
+ );
+
+ /* We need to update reserveSpace even if we just want the
+ * weight gradients. */
+ CUDNN_CHECK_ERR(
+ cudnnRNNBackwardData_v8(
+ handle,
+ rnnDesc.get(),
+ devSeqLengthsRaw ? devSeqLengthsRaw->getAs() : nullptr,
+ yDesc.get(),
+ yRaw.get(),
+ dyRaw.get(),
+ xDesc.get(),
+ dxRaw.get(),
+ hxDesc.get(),
+ hxRaw.get(),
+ dhyRaw.get(),
+ dhxRaw.get(),
+ cxDesc.get(),
+ cxRaw.get(),
+ dcyRaw.get(),
+ dcxRaw.get(),
+ weightSpaceSize,
+ wRaw.get(),
+ workspaceSize,
+ workspaceRaw.get(),
+ reserveSpace.bytes(),
+ reserveSpaceRaw.get()
+ )
+ );
+ }
+
+ {
+ DevicePtr xRaw(x);
+ DevicePtr dwRaw(dw);
+ DevicePtr hxRaw(hiddenState);
+
+ // ensure cudnn compute stream waits on input/output tensor streams
+ relativeSync(cudnnStream, {x, dw, hiddenState});
+
+ CUDNN_CHECK_ERR(
+ cudnnRNNBackwardWeights_v8(
+ handle,
+ rnnDesc.get(),
+ CUDNN_WGRAD_MODE_ADD,
+ devSeqLengthsRaw ? devSeqLengthsRaw->getAs() : nullptr,
+ xDesc.get(),
+ xRaw.get(),
+ hxDesc.get(),
+ hxRaw.get(),
+ yDesc.get(),
+ yRaw.get(),
+ weightSpaceSize,
+ dwRaw.get(),
+ workspaceSize,
+ workspaceRaw.get(),
+ reserveSpace.bytes(),
+ reserveSpaceRaw.get()
+ )
+ );
+ }
+
+ // ensure output tensor streams wait on cudnn compute stream
+ relativeSync({dx, dhx, dcx, dw}, cudnnStream);
+}
+}
diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnRnnUtils.h b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnRnnUtils.h
new file mode 100644
index 0000000..3ed5b07
--- /dev/null
+++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnRnnUtils.h
@@ -0,0 +1,45 @@
+#pragma once
+#include "CudnnUtils.h"
+
+namespace fl {
+void cudnn_rnn_forward(
+ int batchSize,
+ int seqLength,
+ bool train,
+ RNNDescriptor const& rnnDesc,
+
+ Tensor const& x,
+ Tensor const& y,
+ Tensor const& weights,
+ TensorDescriptor const& cxDesc,
+ TensorDescriptor const& hxDesc,
+ Tensor const& hy,
+ Tensor const& cy,
+ Tensor const& hiddenState,
+ Tensor const& cellState,
+
+ Tensor& reserveSpace // out
+);
+void cudnn_rnn_backward(
+ int batchSize,
+ int seqLength,
+ RNNDescriptor const& rnnDesc,
+
+ Tensor const& x,
+ Tensor const& y,
+ Tensor const& dy,
+ Tensor const& weights,
+ TensorDescriptor const& cxDesc,
+ TensorDescriptor const& hxDesc,
+ Tensor const& dhy,
+ Tensor const& dcy,
+ Tensor const& hiddenState,
+ Tensor const& cellState,
+ Tensor const& dx,
+ Tensor const& dhx,
+ Tensor const& dcx,
+ Tensor const& dw,
+
+ Tensor const& reserveSpace
+);
+}
diff --git a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp
index 82cadcb..4a20900 100644
--- a/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp
+++ b/flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.cpp
@@ -1,8 +1,8 @@
/*
- * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * SPDX-License-Identifier: MIT
*
- * This source code is licensed under the MIT license found in the
- * LICENSE file in the root directory of this source tree.
+ * Original code: Copyright (c) Meta Platforms, Inc. (see FLASHLIGHT_LICENSE)
+ * Modifications: Copyright (c) 2026 Lukas Thomann (see LICENSE)
*/
#include "flashlight/fl/autograd/tensor/backend/cudnn/CudnnUtils.h"
@@ -25,16 +25,16 @@ struct DeviceHandle {
std::shared_ptr stream;
explicit DeviceHandle(std::shared_ptr _stream) : cudnnHandle(nullptr),
- stream(_stream) {
+ stream(_stream) {
CUDNN_CHECK_ERR(cudnnCreate(&cudnnHandle));
CUDNN_CHECK_ERR(cudnnSetStream(cudnnHandle, stream->handle()));
}
~DeviceHandle() {
if(cudnnHandle) {
-// See https://git.io/fNQnM - sometimes, at exit, the CUDA context
-// (or something) is already destroyed by the time a handle gets destroyed
-// because of an issue with the destruction order.
+ // See https://git.io/fNQnM - sometimes, at exit, the CUDA context
+ // (or something) is already destroyed by the time a handle gets destroyed
+ // because of an issue with the destruction order.
#ifdef NO_CUDNN_DESTROY_HANDLE
#else
CUDNN_CHECK_ERR(cudnnDestroy(cudnnHandle));
@@ -43,16 +43,16 @@ struct DeviceHandle {
}
};
-const float kFloatZero = 0.0;
-const float kFloatOne = 1.0;
+constexpr float kFloatZero = 0.0;
+constexpr float kFloatOne = 1.0;
-const double kDoubleZero = 0.0;
-const double kDoubleOne = 1.0;
+constexpr double kDoubleZero = 0.0;
+constexpr double kDoubleOne = 1.0;
// TODO: move this to CudnnAutogradExtension if we make it a singleton
std::unordered_map handles;
-const DeviceHandle& getActiveDeviceHandle() {
+DeviceHandle const& getActiveDeviceHandle() {
auto& manager = fl::DeviceManager::getInstance();
auto& cudaDevice =
manager.getActiveDevice(fl::DeviceType::CUDA).impl();
@@ -88,58 +88,43 @@ namespace fl {
void cudnnCheckErr(cudnnStatus_t status) {
if(status == CUDNN_STATUS_SUCCESS)
return;
- const char* err = cudnnGetErrorString(status);
+ char const* err = cudnnGetErrorString(status);
switch(status) {
- case CUDNN_STATUS_BAD_PARAM:
- throw std::invalid_argument(err);
- default:
- throw std::runtime_error(err);
+ case CUDNN_STATUS_BAD_PARAM: throw std::invalid_argument(err);
+ default: throw std::runtime_error(err);
}
}
-cudnnDataType_t cudnnMapToType(const fl::dtype& t) {
+cudnnDataType_t cudnnMapToType(fl::dtype const& t) {
switch(t) {
- case fl::dtype::f16:
- return CUDNN_DATA_HALF;
- case fl::dtype::f32:
- return CUDNN_DATA_FLOAT;
- case fl::dtype::f64:
- return CUDNN_DATA_DOUBLE;
- default:
- throw std::invalid_argument("unsupported data type for cuDNN");
+ case fl::dtype::f16: return CUDNN_DATA_HALF;
+ case fl::dtype::f32: return CUDNN_DATA_FLOAT;
+ case fl::dtype::f64: return CUDNN_DATA_DOUBLE;
+ default: throw std::invalid_argument("unsupported data type for cuDNN");
}
}
-cudnnPoolingMode_t cudnnMapToPoolingMode(const PoolingMode mode) {
+cudnnPoolingMode_t cudnnMapToPoolingMode(PoolingMode const mode) {
switch(mode) {
- case PoolingMode::MAX:
- return CUDNN_POOLING_MAX;
- case PoolingMode::AVG_INCLUDE_PADDING:
- return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
- case PoolingMode::AVG_EXCLUDE_PADDING:
- return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
- default:
- throw std::invalid_argument("unsupported pooling mode for cuDNN");
+ case PoolingMode::MAX: return CUDNN_POOLING_MAX;
+ case PoolingMode::AVG_INCLUDE_PADDING: return CUDNN_POOLING_AVERAGE_COUNT_INCLUDE_PADDING;
+ case PoolingMode::AVG_EXCLUDE_PADDING: return CUDNN_POOLING_AVERAGE_COUNT_EXCLUDE_PADDING;
+ default: throw std::invalid_argument("unsupported pooling mode for cuDNN");
}
}
-cudnnRNNMode_t cudnnMapToRNNMode(const RnnMode mode) {
+cudnnRNNMode_t cudnnMapToRNNMode(RnnMode const mode) {
switch(mode) {
- case RnnMode::RELU:
- return CUDNN_RNN_RELU;
- case RnnMode::TANH:
- return CUDNN_RNN_TANH;
- case RnnMode::LSTM:
- return CUDNN_LSTM;
- case RnnMode::GRU:
- return CUDNN_GRU;
- default:
- throw std::invalid_argument("unsupported RNN mode for cuDNN");
+ case RnnMode::RELU: return CUDNN_RNN_RELU;
+ case RnnMode::TANH: return CUDNN_RNN_TANH;
+ case RnnMode::LSTM: return CUDNN_LSTM;
+ case RnnMode::GRU: return CUDNN_GRU;
+ default: throw std::invalid_argument("unsupported RNN mode for cuDNN");
}
}
-TensorDescriptor::TensorDescriptor(const fl::dtype type, const Shape& flDims) {
- CUDNN_CHECK_ERR(cudnnCreateTensorDescriptor(&descriptor));
+TensorDescriptor::TensorDescriptor(fl::dtype const type, Shape const& flDims) {
+ CUDNN_CHECK_ERR(cudnnCreateTensorDescriptor(&_handle));
cudnnDataType_t cudnntype = cudnnMapToType(type);
std::array dims = {1, 1, 1, 1};
@@ -156,7 +141,7 @@ TensorDescriptor::TensorDescriptor(const fl::dtype type, const Shape& flDims) {
CUDNN_CHECK_ERR(
cudnnSetTensorNdDescriptor(
- descriptor,
+ _handle,
cudnntype,
dims.size(),
dims.data(),
@@ -165,8 +150,8 @@ TensorDescriptor::TensorDescriptor(const fl::dtype type, const Shape& flDims) {
);
}
-TensorDescriptor::TensorDescriptor(const Tensor& input) {
- CUDNN_CHECK_ERR(cudnnCreateTensorDescriptor(&descriptor));
+TensorDescriptor::TensorDescriptor(Tensor const& input) {
+ CUDNN_CHECK_ERR(cudnnCreateTensorDescriptor(&_handle));
cudnnDataType_t cudnntype = cudnnMapToType(input.type());
auto flStrides = input.strides();
@@ -185,7 +170,7 @@ TensorDescriptor::TensorDescriptor(const Tensor& input) {
CUDNN_CHECK_ERR(
cudnnSetTensorNdDescriptor(
- descriptor /* descriptor handle */,
+ _handle /* descriptor handle */,
cudnntype /* = dataType */,
4,
dims.data(),
@@ -194,21 +179,19 @@ TensorDescriptor::TensorDescriptor(const Tensor& input) {
);
}
-TensorDescriptor::~TensorDescriptor() {
- CUDNN_CHECK_ERR(cudnnDestroyTensorDescriptor(descriptor));
-}
+TensorDescriptor::~TensorDescriptor() { CUDNN_CHECK_ERR(cudnnDestroyTensorDescriptor(_handle)); }
TensorDescriptorArray::TensorDescriptorArray(
int size,
- const fl::dtype type,
- const Shape& dims
+ fl::dtype const type,
+ Shape const& dims
) {
- desc_vec.reserve(size);
+ _descVec.reserve(size);
for(int i = 0; i < size; i++) {
- desc_vec.emplace_back(type, dims);
- desc_raw_vec.push_back(desc_vec.back().descriptor);
+ _descVec.emplace_back(type, dims);
+ _descRawVec.push_back(_descVec.back().get());
}
- descriptors = desc_raw_vec.data();
+ descriptors = _descRawVec.data();
}
TensorDescriptorArray::~TensorDescriptorArray() = default;
@@ -222,7 +205,7 @@ PoolingDescriptor::PoolingDescriptor(
int py,
PoolingMode mode
) {
- CUDNN_CHECK_ERR(cudnnCreatePoolingDescriptor(&descriptor));
+ CUDNN_CHECK_ERR(cudnnCreatePoolingDescriptor(&_handle));
std::array