@@ -545,27 +545,49 @@ def test_random_dirichlet(parameter, size):
545545
546546
547547def test_random_choice ():
548- # Elements are picked at equal frequency
549- num_samples = 10000
548+ # `replace=True` and `p is None`
550549 rng = shared (np .random .RandomState (123 ))
551- g = pt .random .choice (np .arange (4 ), size = num_samples , rng = rng )
550+ g = pt .random .choice (np .arange (4 ), size = 10_000 , rng = rng )
551+ g_fn = compile_random_function ([], g , mode = jax_mode )
552+ samples = g_fn ()
553+ assert samples .shape == (10_000 ,)
554+ # Elements are picked at equal frequency
555+ np .testing .assert_allclose (np .mean (samples == 3 ), 0.25 , 2 )
556+
557+ # `replace=True` and `p is not None`
558+ rng = shared (np .random .default_rng (123 ))
559+ g = pt .random .choice (4 , p = np .array ([0.0 , 0.5 , 0.0 , 0.5 ]), size = (5 , 2 ), rng = rng )
552560 g_fn = compile_random_function ([], g , mode = jax_mode )
553561 samples = g_fn ()
554- np .testing .assert_allclose (np .sum (samples == 3 ) / num_samples , 0.25 , 2 )
562+ assert samples .shape == (5 , 2 )
563+ # Only odd numbers are picked
564+ assert np .all (samples % 2 == 1 )
555565
556- # `replace=False` produces unique results
566+ # `replace=False` and `p is None`
557567 rng = shared (np .random .RandomState (123 ))
558- g = pt .random .choice (np .arange (100 ), replace = False , size = 99 , rng = rng )
568+ g = pt .random .choice (np .arange (100 ), replace = False , size = ( 2 , 49 ) , rng = rng )
559569 g_fn = compile_random_function ([], g , mode = jax_mode )
560570 samples = g_fn ()
561- assert len (np .unique (samples )) == 99
571+ assert samples .shape == (2 , 49 )
572+ # Elements are unique
573+ assert len (np .unique (samples )) == 98
562574
563- # We can pass an array with probabilities
575+ # `replace=False` and `p is not None`
564576 rng = shared (np .random .RandomState (123 ))
565- g = pt .random .choice (np .arange (3 ), p = np .array ([1.0 , 0.0 , 0.0 ]), size = 10 , rng = rng )
577+ g = pt .random .choice (
578+ 8 ,
579+ p = np .array ([0.25 , 0 , 0.25 , 0 , 0.25 , 0 , 0.25 , 0 ]),
580+ size = 3 ,
581+ rng = rng ,
582+ replace = False ,
583+ )
566584 g_fn = compile_random_function ([], g , mode = jax_mode )
567585 samples = g_fn ()
568- np .testing .assert_allclose (samples , np .zeros (10 ))
586+ assert samples .shape == (3 ,)
587+ # Elements are unique
588+ assert len (np .unique (samples )) == 3
589+ # Only even numbers are picked
590+ assert np .all (samples % 2 == 0 )
569591
570592
571593def test_random_categorical ():
0 commit comments