@@ -29,11 +29,21 @@ def get_sinusoidal_embeddings(
2929 """Returns the positional encoding (same as Tensor2Tensor).
3030
3131 Args:
32- timesteps: a 1-D Tensor of N indices, one per batch element.
33- These may be fractional.
34- embedding_dim: The number of output channels.
35- min_timescale: The smallest time unit (should probably be 0.0).
36- max_timescale: The largest time unit.
32+ timesteps (`jnp.ndarray` of shape `(N,)`):
33+ A 1-D array of N indices, one per batch element. These may be fractional.
34+ embedding_dim (`int`):
35+ The number of output channels.
36+ freq_shift (`float`, *optional*, defaults to `1`):
37+ Shift applied to the frequency scaling of the embeddings.
38+ min_timescale (`float`, *optional*, defaults to `1`):
39+ The smallest time unit used in the sinusoidal calculation (should probably be 0.0).
40+ max_timescale (`float`, *optional*, defaults to `1.0e4`):
41+ The largest time unit used in the sinusoidal calculation.
42+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
43+ Whether to flip the order of sinusoidal components to cosine first.
44+ scale (`float`, *optional*, defaults to `1.0`):
45+ A scaling factor applied to the positional embeddings.
46+
3747 Returns:
3848 a Tensor of timing signals [N, num_channels]
3949 """
@@ -61,9 +71,9 @@ class FlaxTimestepEmbedding(nn.Module):
6171
6272 Args:
6373 time_embed_dim (`int`, *optional*, defaults to `32`):
64- Time step embedding dimension
65- dtype (:obj: `jnp.dtype`, *optional*, defaults to jnp.float32):
66- Parameters `dtype`
74+ Time step embedding dimension.
75+ dtype (`jnp.dtype`, *optional*, defaults to ` jnp.float32` ):
76+ The data type for the embedding parameters.
6777 """
6878
6979 time_embed_dim : int = 32
@@ -83,7 +93,11 @@ class FlaxTimesteps(nn.Module):
8393
8494 Args:
8595 dim (`int`, *optional*, defaults to `32`):
86- Time step embedding dimension
96+ Time step embedding dimension.
97+ flip_sin_to_cos (`bool`, *optional*, defaults to `False`):
98+ Whether to flip the sinusoidal function from sine to cosine.
99+ freq_shift (`float`, *optional*, defaults to `1`):
100+ Frequency shift applied to the sinusoidal embeddings.
87101 """
88102
89103 dim : int = 32
0 commit comments