diff --git a/python/paddle/nn/layer/norm.py b/python/paddle/nn/layer/norm.py index 27d0c672d2490a..efc2c601e51697 100644 --- a/python/paddle/nn/layer/norm.py +++ b/python/paddle/nn/layer/norm.py @@ -22,6 +22,7 @@ from paddle import _C_ops, in_dynamic_mode, pir_utils from paddle.device import get_all_custom_device_type +from paddle.utils.decorator_utils import param_one_alias from ...base import dygraph_utils from ...base.data_feeder import check_variable_and_dtype @@ -51,6 +52,7 @@ DataLayoutND, DTypeLike, ParamAttrLike, + PlaceLike, ShapeLike, ) @@ -589,13 +591,34 @@ class LayerNorm(Layer): which is expected to be of that specific size. epsilon(float, optional): The small value added to the variance to prevent division by zero. Default: 1e-05. + alias: ``eps``. + elementwise_affine(bool, optional): Whether to apply element-wise affine transformation + (i.e., learnable scale and bias). If set to ``False``, both the scale (:math:`g`) and + bias (:math:`b`) parameters will be disabled, regardless of the settings of `weight_attr` + and `bias_attr`. This parameter acts as a master switch. Defaults to True. + **Note: This argument must be passed as a keyword argument.** + bias(bool, optional): Whether to include a learnable bias term in the layer. This setting + only takes effect when `elementwise_affine` is ``True``. If set to ``False``, no bias + parameter will be created, even if `bias_attr` is specified. Defaults to True. + **Note: This argument must be passed as a keyword argument.** weight_attr(ParamAttr|bool|None, optional): The parameter attribute for the learnable - gain :math:`g`. If False, weight is None. If is None, a default :code:`ParamAttr` would be added as scale. The - :attr:`param_attr` is initialized as 1 if it is added. Default: None. For more information, please refer to :ref:`api_paddle_ParamAttr` . + gain :math:`g` (scale). This setting only takes effect when `elementwise_affine` is ``True``. + - If set to ``False``, no gain parameter will be created. + - If set to ``None`` or ``True``, a default :code:`ParamAttr` will be used, and the + parameter will be initialized to 1. + - If set to a custom :code:`ParamAttr` object, it will be used to configure the parameter. + Default: None. + **Note: This argument must be passed as a keyword argument.** bias_attr(ParamAttr|bool|None, optional): The parameter attribute for the learnable - bias :math:`b`. If is False, bias is None. If is None, a default :code:`ParamAttr` would be added as bias. The - :attr:`bias_attr` is initialized as 0 if it is added. Default: None. For more information, please refer to :ref:`api_paddle_ParamAttr` . + bias :math:`b`. This setting only takes effect when both `elementwise_affine` and `bias` are ``True``. + - If set to ``False``, no bias parameter will be created. + - If set to ``None`` or ``True``, a default :code:`ParamAttr` will be used, and the + parameter will be initialized to 0. + - If set to a custom :code:`ParamAttr` object, it will be used to configure the parameter. + Default: None. + **Note: This argument must be passed as a keyword argument.** name(str|None, optional): Name for the LayerNorm, default is None. For more information, please refer to :ref:`api_guide_Name` . + **Note: This argument must be passed as a keyword argument.** Shape: - x: 2-D, 3-D, 4-D or 5-D tensor. @@ -629,10 +652,16 @@ class LayerNorm(Layer): weight: Tensor | None bias: Tensor | None + @param_one_alias(["epsilon", "eps"]) def __init__( self, normalized_shape: int | Sequence[int], epsilon: float = 1e-5, + *, + elementwise_affine: bool = True, + bias: bool = True, + device: PlaceLike | None = None, + dtype: DTypeLike | None = None, weight_attr: bool | ParamAttr | None = None, bias_attr: bool | ParamAttr | None = None, name: str | None = None, @@ -643,6 +672,17 @@ def __init__( self._normalized_shape = list(normalized_shape) self._epsilon = epsilon + self._device = device + self._dtype = ( + self._helper.get_default_dtype() if dtype is None else dtype + ) + + if not elementwise_affine: + weight_attr = False + bias_attr = False + elif not bias: + bias_attr = False + self._weight_attr = weight_attr self._bias_attr = bias_attr param_shape = [np.prod(self._normalized_shape)] @@ -652,15 +692,22 @@ def __init__( else: self.weight = self.create_parameter( attr=self._weight_attr, + dtype=self._dtype, shape=param_shape, default_initializer=Constant(1.0), + device=self._device, ) if bias_attr is False: self.bias = None else: self.bias = self.create_parameter( - attr=self._bias_attr, shape=param_shape, is_bias=True + attr=self._bias_attr, + dtype=self._dtype, + shape=param_shape, + default_initializer=Constant(0.0), + device=self._device, + is_bias=True, ) def forward(self, input: Tensor) -> Tensor: diff --git a/test/legacy_test/test_layer_norm_op_v2.py b/test/legacy_test/test_layer_norm_op_v2.py index a8bfec46252114..6cffb67e0bfe16 100644 --- a/test/legacy_test/test_layer_norm_op_v2.py +++ b/test/legacy_test/test_layer_norm_op_v2.py @@ -16,6 +16,7 @@ import numpy as np from op_test import get_places +from utils import static_guard import paddle from paddle import base @@ -159,6 +160,304 @@ def compute_v4(x): ) +class TestLayerNormParamDygraph(unittest.TestCase): + def setUp(self): + paddle.disable_static() + self.normalized_shape = [6] + self.x_shape = [2, 4, 4, 6] + self.places = get_places() + + def _run_test_on_places(self, test_func): + """Helper to run the test function on all places.""" + for p in self.places: + with base.dygraph.guard(p): + test_func(p) + + def test_elementwise_affine_false(self): + """test that when elementwise_affine=False, weight and bias parameters are not created.""" + + def run_test(p): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, elementwise_affine=False + ) + assert layer.weight is None + assert layer.bias is None + + x_tensor = paddle.randn(self.x_shape) + out = layer(x_tensor) + assert out.shape == self.x_shape + + self._run_test_on_places(run_test) + + def test_elementwise_affine_true(self): + """test that when elementwise_affine=True and attr=None, parameters are created with default initialization.""" + + def run_test(p): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + ) + assert layer.weight is not None + assert layer.bias is not None + + expected_weight = paddle.ones(self.normalized_shape) + expected_bias = paddle.zeros(self.normalized_shape) + + np.testing.assert_allclose( + layer.weight.numpy(), expected_weight.numpy() + ) + np.testing.assert_allclose( + layer.bias.numpy(), expected_bias.numpy() + ) + + self._run_test_on_places(run_test) + + def test_bias_false(self): + """test that when bias=False, the bias parameter is disabled even if elementwise_affine=True.""" + + def run_test(p): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + bias=False, + ) + assert layer.weight is not None + assert layer.bias is None + + self._run_test_on_places(run_test) + + def test_weight_and_bias_false(self): + """test that when weight_attr=False and bias_attr=False, both parameters are disabled.""" + + def run_test(p): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + weight_attr=False, + bias_attr=False, + ) + assert layer.weight is None + assert layer.bias is None + + self._run_test_on_places(run_test) + + def test_alias(self): + """test parameter alias epsilon/eps""" + + def run_test(p): + layer_epsilon = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + epsilon=1e-5, + ) + layer_eps = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + eps=1e-5, + ) + + x_tensor = paddle.randn(self.x_shape) + out_epsilon = layer_epsilon(x_tensor) + out_eps = layer_eps(x_tensor) + + np.testing.assert_array_equal(out_epsilon.numpy(), out_eps.numpy()) + + self._run_test_on_places(run_test) + + def test_errors(self): + """test for errors.""" + + def run_test(p): + with self.assertRaises(ValueError): + layer_norm = paddle.nn.LayerNorm(self.normalized_shape) + x1 = np.random.random([3, *self.normalized_shape]).astype( + 'float32' + ) + layer_norm(x1) + + with self.assertRaises(TypeError): + paddle.nn.LayerNorm( + self.normalized_shape, 1e-5, None, None, "name" + ) + + with self.assertRaises(TypeError): + paddle.nn.LayerNorm( + self.normalized_shape, 1e-5, False, "cpu", paddle.float32 + ) + + self._run_test_on_places(run_test) + + +class TestLayerNormParamStatic(unittest.TestCase): + def setUp(self): + paddle.enable_static() + self.normalized_shape = [6] + self.x_shape = [2, 4, 4, 6] + self.places = get_places() + + def test_static_elementwise_affine_false(self): + """test elementwise_affine=False in static graph mode.""" + for p in self.places: + with static_guard(): + main = base.Program() + start = base.Program() + with ( + base.unique_name.guard(), + base.program_guard(main, start), + ): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=False, + ) + x = paddle.static.data( + name='x', shape=self.x_shape, dtype='float32' + ) + out = layer(x) + + exe = base.Executor(p) + exe.run(start) + input_np = np.random.randn(*self.x_shape).astype('float32') + result = exe.run(main, feed={'x': input_np}, fetch_list=[out])[ + 0 + ] + + assert result.shape == tuple(self.x_shape) + + def test_static_elementwise_affine_true(self): + """test elementwise_affine=True with default init in static graph mode.""" + for p in self.places: + with static_guard(): + main = base.Program() + start = base.Program() + with ( + base.unique_name.guard(), + base.program_guard(main, start), + ): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + ) + + exe = base.Executor(p) + exe.run(start) + weight_np, bias_np = exe.run( + main, fetch_list=[layer.weight, layer.bias] + ) + + assert weight_np is not None + assert bias_np is not None + + expected_weight = np.ones(self.normalized_shape) + expected_bias = np.zeros(self.normalized_shape) + + np.testing.assert_allclose(weight_np, expected_weight) + np.testing.assert_allclose(bias_np, expected_bias) + + def test_static_bias_false(self): + """test bias=False in static graph mode.""" + for p in self.places: + with static_guard(): + main = base.Program() + start = base.Program() + with ( + base.unique_name.guard(), + base.program_guard(main, start), + ): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + bias=False, + ) + assert layer.bias is None + + exe = base.Executor(p) + exe.run(start) + weight_np = exe.run(main, fetch_list=[layer.weight])[0] + assert weight_np is not None + assert weight_np.shape == tuple(self.normalized_shape) + + def test_static_weight_and_bias_false(self): + """test weight_attr=False and bias_attr=False in static graph mode.""" + for p in self.places: + with static_guard(): + main = base.Program() + start = base.Program() + with ( + base.unique_name.guard(), + base.program_guard(main, start), + ): + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + weight_attr=False, + bias_attr=False, + ) + assert layer.weight is None + assert layer.bias is None + + def test_static_alias(self): + """test parameter alias epsilon/eps in static graph mode.""" + for p in self.places: + with static_guard(): + main = base.Program() + start = base.Program() + with ( + base.unique_name.guard(), + base.program_guard(main, start), + ): + layer_epsilon = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + epsilon=1e-5, + ) + layer_eps = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + eps=1e-5, + ) + + x = paddle.static.data( + name='x', shape=self.x_shape, dtype='float32' + ) + out_epsilon = layer_epsilon(x) + out_eps = layer_eps(x) + + exe = base.Executor(p) + exe.run(start) + input_np = np.random.randn(*self.x_shape).astype('float32') + out_eps_val, out_epsilon_val = exe.run( + main, + feed={'x': input_np}, + fetch_list=[out_eps, out_epsilon], + ) + + np.testing.assert_array_equal(out_epsilon_val, out_eps_val) + + def test_static_errors(self): + """test errors in static graph mode.""" + for p in self.places: + with static_guard(): + main = base.Program() + start = base.Program() + with ( + base.unique_name.guard(), + base.program_guard(main, start), + ): + with self.assertRaises(TypeError): + paddle.nn.LayerNorm( + self.normalized_shape, 1e-5, None, None, "name" + ) + + with self.assertRaises(TypeError): + paddle.nn.LayerNorm( + self.normalized_shape, + 1e-5, + False, + "cpu", + paddle.float32, + ) + + if __name__ == '__main__': paddle.enable_static() unittest.main() diff --git a/test/legacy_test/test_layer_norm_op_v2_init.py b/test/legacy_test/test_layer_norm_op_v2_init.py new file mode 100644 index 00000000000000..e0e9ee560523a2 --- /dev/null +++ b/test/legacy_test/test_layer_norm_op_v2_init.py @@ -0,0 +1,93 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np +from utils import dygraph_guard, static_guard + +import paddle +from paddle import base + + +class TestLayerNormParamInit(unittest.TestCase): + def setUp(self): + self.normalized_shape = [6] + self.x_shape = [2, 4, 4, 6] + + def test_dygraph(self): + """test custom initialization using weight_attr and bias_attr.""" + paddle.disable_static() + with dygraph_guard(): + weight_val = 2.5 + bias_val = -1.0 + weight_initializer = paddle.nn.initializer.Constant( + value=weight_val + ) + bias_initializer = paddle.nn.initializer.Constant(value=bias_val) + + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + weight_attr=weight_initializer, + bias_attr=bias_initializer, + ) + + expected_weight = np.full(self.normalized_shape, weight_val) + expected_bias = np.full(self.normalized_shape, bias_val) + + np.testing.assert_allclose(layer.weight.numpy(), expected_weight) + np.testing.assert_allclose(layer.bias.numpy(), expected_bias) + + def test_static(self): + """test custom initialization in static graph mode.""" + paddle.enable_static() + with static_guard(): + main = base.Program() + start = base.Program() + with ( + base.unique_name.guard(), + base.program_guard(main, start), + ): + weight_val = 2.5 + bias_val = -1.0 + weight_initializer = paddle.nn.initializer.Constant( + value=weight_val + ) + bias_initializer = paddle.nn.initializer.Constant( + value=bias_val + ) + + layer = paddle.nn.LayerNorm( + normalized_shape=self.normalized_shape, + elementwise_affine=True, + weight_attr=weight_initializer, + bias_attr=bias_initializer, + ) + + exe = base.Executor() + exe.run(start) + weight_np, bias_np = exe.run( + main, fetch_list=[layer.weight, layer.bias] + ) + + expected_weight = np.full(self.normalized_shape, weight_val) + expected_bias = np.full(self.normalized_shape, bias_val) + + np.testing.assert_allclose(weight_np, expected_weight) + np.testing.assert_allclose(bias_np, expected_bias) + + +if __name__ == '__main__': + unittest.main()