1- import re
2-
31import numpy as np
42import pytest
53import scipy .stats as stats
2220from pytensor .link .jax .dispatch .random import numpyro_available # noqa: E402
2321
2422
23+ def random_function (* args , ** kwargs ):
24+ with pytest .warns (
25+ UserWarning , match = r"The RandomType SharedVariables \[.+\] will not be used"
26+ ):
27+ return function (* args , ** kwargs )
28+
29+
2530def test_random_RandomStream ():
2631 """Two successive calls of a compiled graph using `RandomStream` should
2732 return different values.
@@ -30,11 +35,7 @@ def test_random_RandomStream():
3035 srng = RandomStream (seed = 123 )
3136 out = srng .normal () - srng .normal ()
3237
33- with pytest .warns (
34- UserWarning ,
35- match = r"The RandomType SharedVariables \[.+\] will not be used" ,
36- ):
37- fn = function ([], out , mode = jax_mode )
38+ fn = random_function ([], out , mode = jax_mode )
3839 jax_res_1 = fn ()
3940 jax_res_2 = fn ()
4041
@@ -47,13 +48,7 @@ def test_random_updates(rng_ctor):
4748 rng = shared (original_value , name = "original_rng" , borrow = False )
4849 next_rng , x = at .random .normal (name = "x" , rng = rng ).owner .outputs
4950
50- with pytest .warns (
51- UserWarning ,
52- match = re .escape (
53- "The RandomType SharedVariables [original_rng] will not be used"
54- ),
55- ):
56- f = pytensor .function ([], [x ], updates = {rng : next_rng }, mode = jax_mode )
51+ f = random_function ([], [x ], updates = {rng : next_rng }, mode = jax_mode )
5752 assert f () != f ()
5853
5954 # Check that original rng variable content was not overwritten when calling jax_typify
@@ -83,17 +78,14 @@ def test_random_updates_input_storage_order():
8378
8479 # This function replaces inp by input_shared in the update expression
8580 # This is what caused the RNG to appear later than inp_shared in the input_storage
86- with pytest .warns (
87- UserWarning ,
88- match = r"The RandomType SharedVariables \[.+\] will not be used" ,
89- ):
90- fn = pytensor .function (
91- inputs = [],
92- outputs = [],
93- updates = {inp_shared : inp_update },
94- givens = {inp : inp_shared },
95- mode = "JAX" ,
96- )
81+
82+ fn = random_function (
83+ inputs = [],
84+ outputs = [],
85+ updates = {inp_shared : inp_update },
86+ givens = {inp : inp_shared },
87+ mode = "JAX" ,
88+ )
9789 fn ()
9890 np .testing .assert_allclose (inp_shared .get_value (), 5 , rtol = 1e-3 )
9991 fn ()
@@ -457,7 +449,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
457449 else :
458450 rng = shared (np .random .RandomState (29402 ))
459451 g = rv_op (* dist_params , size = (10_000 ,) + base_size , rng = rng )
460- g_fn = function (dist_params , g , mode = jax_mode )
452+ g_fn = random_function (dist_params , g , mode = jax_mode )
461453 samples = g_fn (
462454 * [
463455 i .tag .test_value
@@ -481,7 +473,7 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
481473def test_random_bernoulli (size ):
482474 rng = shared (np .random .RandomState (123 ))
483475 g = at .random .bernoulli (0.5 , size = (1000 ,) + size , rng = rng )
484- g_fn = function ([], g , mode = jax_mode )
476+ g_fn = random_function ([], g , mode = jax_mode )
485477 samples = g_fn ()
486478 np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
487479
@@ -492,7 +484,7 @@ def test_random_mvnormal():
492484 mu = np .ones (4 )
493485 cov = np .eye (4 )
494486 g = at .random .multivariate_normal (mu , cov , size = (10000 ,), rng = rng )
495- g_fn = function ([], g , mode = jax_mode )
487+ g_fn = random_function ([], g , mode = jax_mode )
496488 samples = g_fn ()
497489 np .testing .assert_allclose (samples .mean (axis = 0 ), mu , atol = 0.1 )
498490
@@ -507,7 +499,7 @@ def test_random_mvnormal():
507499def test_random_dirichlet (parameter , size ):
508500 rng = shared (np .random .RandomState (123 ))
509501 g = at .random .dirichlet (parameter , size = (1000 ,) + size , rng = rng )
510- g_fn = function ([], g , mode = jax_mode )
502+ g_fn = random_function ([], g , mode = jax_mode )
511503 samples = g_fn ()
512504 np .testing .assert_allclose (samples .mean (axis = 0 ), 0.5 , 1 )
513505
@@ -517,29 +509,29 @@ def test_random_choice():
517509 num_samples = 10000
518510 rng = shared (np .random .RandomState (123 ))
519511 g = at .random .choice (np .arange (4 ), size = num_samples , rng = rng )
520- g_fn = function ([], g , mode = jax_mode )
512+ g_fn = random_function ([], g , mode = jax_mode )
521513 samples = g_fn ()
522514 np .testing .assert_allclose (np .sum (samples == 3 ) / num_samples , 0.25 , 2 )
523515
524516 # `replace=False` produces unique results
525517 rng = shared (np .random .RandomState (123 ))
526518 g = at .random .choice (np .arange (100 ), replace = False , size = 99 , rng = rng )
527- g_fn = function ([], g , mode = jax_mode )
519+ g_fn = random_function ([], g , mode = jax_mode )
528520 samples = g_fn ()
529521 assert len (np .unique (samples )) == 99
530522
531523 # We can pass an array with probabilities
532524 rng = shared (np .random .RandomState (123 ))
533525 g = at .random .choice (np .arange (3 ), p = np .array ([1.0 , 0.0 , 0.0 ]), size = 10 , rng = rng )
534- g_fn = function ([], g , mode = jax_mode )
526+ g_fn = random_function ([], g , mode = jax_mode )
535527 samples = g_fn ()
536528 np .testing .assert_allclose (samples , np .zeros (10 ))
537529
538530
539531def test_random_categorical ():
540532 rng = shared (np .random .RandomState (123 ))
541533 g = at .random .categorical (0.25 * np .ones (4 ), size = (10000 , 4 ), rng = rng )
542- g_fn = function ([], g , mode = jax_mode )
534+ g_fn = random_function ([], g , mode = jax_mode )
543535 samples = g_fn ()
544536 np .testing .assert_allclose (samples .mean (axis = 0 ), 6 / 4 , 1 )
545537
@@ -548,7 +540,7 @@ def test_random_permutation():
548540 array = np .arange (4 )
549541 rng = shared (np .random .RandomState (123 ))
550542 g = at .random .permutation (array , rng = rng )
551- g_fn = function ([], g , mode = jax_mode )
543+ g_fn = random_function ([], g , mode = jax_mode )
552544 permuted = g_fn ()
553545 with pytest .raises (AssertionError ):
554546 np .testing .assert_allclose (array , permuted )
@@ -558,7 +550,7 @@ def test_random_geometric():
558550 rng = shared (np .random .RandomState (123 ))
559551 p = np .array ([0.3 , 0.7 ])
560552 g = at .random .geometric (p , size = (10_000 , 2 ), rng = rng )
561- g_fn = function ([], g , mode = jax_mode )
553+ g_fn = random_function ([], g , mode = jax_mode )
562554 samples = g_fn ()
563555 np .testing .assert_allclose (samples .mean (axis = 0 ), 1 / p , rtol = 0.1 )
564556 np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt ((1 - p ) / p ** 2 ), rtol = 0.1 )
@@ -569,7 +561,7 @@ def test_negative_binomial():
569561 n = np .array ([10 , 40 ])
570562 p = np .array ([0.3 , 0.7 ])
571563 g = at .random .negative_binomial (n , p , size = (10_000 , 2 ), rng = rng )
572- g_fn = function ([], g , mode = jax_mode )
564+ g_fn = random_function ([], g , mode = jax_mode )
573565 samples = g_fn ()
574566 np .testing .assert_allclose (samples .mean (axis = 0 ), n * (1 - p ) / p , rtol = 0.1 )
575567 np .testing .assert_allclose (
@@ -583,7 +575,7 @@ def test_binomial():
583575 n = np .array ([10 , 40 ])
584576 p = np .array ([0.3 , 0.7 ])
585577 g = at .random .binomial (n , p , size = (10_000 , 2 ), rng = rng )
586- g_fn = function ([], g , mode = jax_mode )
578+ g_fn = random_function ([], g , mode = jax_mode )
587579 samples = g_fn ()
588580 np .testing .assert_allclose (samples .mean (axis = 0 ), n * p , rtol = 0.1 )
589581 np .testing .assert_allclose (samples .std (axis = 0 ), np .sqrt (n * p * (1 - p )), rtol = 0.1 )
@@ -598,7 +590,7 @@ def test_beta_binomial():
598590 a = np .array ([1.5 , 13 ])
599591 b = np .array ([0.5 , 9 ])
600592 g = at .random .betabinom (n , a , b , size = (10_000 , 2 ), rng = rng )
601- g_fn = function ([], g , mode = jax_mode )
593+ g_fn = random_function ([], g , mode = jax_mode )
602594 samples = g_fn ()
603595 np .testing .assert_allclose (samples .mean (axis = 0 ), n * a / (a + b ), rtol = 0.1 )
604596 np .testing .assert_allclose (
@@ -616,7 +608,7 @@ def test_multinomial():
616608 n = np .array ([10 , 40 ])
617609 p = np .array ([[0.3 , 0.7 , 0.0 ], [0.1 , 0.4 , 0.5 ]])
618610 g = at .random .multinomial (n , p , size = (10_000 , 2 ), rng = rng )
619- g_fn = function ([], g , mode = jax_mode )
611+ g_fn = random_function ([], g , mode = jax_mode )
620612 samples = g_fn ()
621613 np .testing .assert_allclose (samples .mean (axis = 0 ), n [..., None ] * p , rtol = 0.1 )
622614 np .testing .assert_allclose (
@@ -632,7 +624,7 @@ def test_vonmises_mu_outside_circle():
632624 mu = np .array ([- 30 , 40 ])
633625 kappa = np .array ([100 , 10 ])
634626 g = at .random .vonmises (mu , kappa , size = (10_000 , 2 ), rng = rng )
635- g_fn = function ([], g , mode = jax_mode )
627+ g_fn = random_function ([], g , mode = jax_mode )
636628 samples = g_fn ()
637629 np .testing .assert_allclose (
638630 samples .mean (axis = 0 ), (mu + np .pi ) % (2.0 * np .pi ) - np .pi , rtol = 0.1
@@ -678,7 +670,10 @@ def rng_fn(cls, rng, size):
678670 fgraph = FunctionGraph ([out .owner .inputs [0 ]], [out ], clone = False )
679671
680672 with pytest .raises (NotImplementedError ):
681- compare_jax_and_py (fgraph , [])
673+ with pytest .warns (
674+ UserWarning , match = r"The RandomType SharedVariables \[.+\] will not be used"
675+ ):
676+ compare_jax_and_py (fgraph , [])
682677
683678
684679def test_random_custom_implementation ():
@@ -709,7 +704,10 @@ def sample_fn(rng, size, dtype, *parameters):
709704 rng = shared (np .random .RandomState (123 ))
710705 out = nonexistentrv (rng = rng )
711706 fgraph = FunctionGraph ([out .owner .inputs [0 ]], [out ], clone = False )
712- compare_jax_and_py (fgraph , [])
707+ with pytest .warns (
708+ UserWarning , match = r"The RandomType SharedVariables \[.+\] will not be used"
709+ ):
710+ compare_jax_and_py (fgraph , [])
713711
714712
715713def test_random_concrete_shape ():
@@ -726,19 +724,15 @@ def test_random_concrete_shape():
726724 rng = shared (np .random .RandomState (123 ))
727725 x_at = at .dmatrix ()
728726 out = at .random .normal (0 , 1 , size = x_at .shape , rng = rng )
729- jax_fn = function ([x_at ], out , mode = jax_mode )
727+ jax_fn = random_function ([x_at ], out , mode = jax_mode )
730728 assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
731729
732730
733731def test_random_concrete_shape_from_param ():
734732 rng = shared (np .random .RandomState (123 ))
735733 x_at = at .dmatrix ()
736734 out = at .random .normal (x_at , 1 , rng = rng )
737- with pytest .warns (
738- UserWarning ,
739- match = "The RandomType SharedVariables \[.+\] will not be used"
740- ):
741- jax_fn = function ([x_at ], out , mode = jax_mode )
735+ jax_fn = random_function ([x_at ], out , mode = jax_mode )
742736 assert jax_fn (np .ones ((2 , 3 ))).shape == (2 , 3 )
743737
744738
@@ -757,7 +751,7 @@ def test_random_concrete_shape_subtensor():
757751 rng = shared (np .random .RandomState (123 ))
758752 x_at = at .dmatrix ()
759753 out = at .random .normal (0 , 1 , size = x_at .shape [1 ], rng = rng )
760- jax_fn = function ([x_at ], out , mode = jax_mode )
754+ jax_fn = random_function ([x_at ], out , mode = jax_mode )
761755 assert jax_fn (np .ones ((2 , 3 ))).shape == (3 ,)
762756
763757
@@ -773,7 +767,7 @@ def test_random_concrete_shape_subtensor_tuple():
773767 rng = shared (np .random .RandomState (123 ))
774768 x_at = at .dmatrix ()
775769 out = at .random .normal (0 , 1 , size = (x_at .shape [0 ],), rng = rng )
776- jax_fn = function ([x_at ], out , mode = jax_mode )
770+ jax_fn = random_function ([x_at ], out , mode = jax_mode )
777771 assert jax_fn (np .ones ((2 , 3 ))).shape == (2 ,)
778772
779773
@@ -784,5 +778,5 @@ def test_random_concrete_shape_graph_input():
784778 rng = shared (np .random .RandomState (123 ))
785779 size_at = at .scalar ()
786780 out = at .random .normal (0 , 1 , size = size_at , rng = rng )
787- jax_fn = function ([size_at ], out , mode = jax_mode )
781+ jax_fn = random_function ([size_at ], out , mode = jax_mode )
788782 assert jax_fn (10 ).shape == (10 ,)
0 commit comments