diff --git a/csa/elementary.py b/csa/elementary.py index 4d848b8..a08c785 100644 --- a/csa/elementary.py +++ b/csa/elementary.py @@ -28,10 +28,26 @@ # Connection-Set constructor # -def cset (mask, *valueSets): +def cset(mask, *valueSets): + """ + Construct a ConnectionSet from a mask and optional value sets. + + Parameters + ---------- + mask : Mask + A mask defining the connectivity pattern. + *valueSets : ValueSet + Optional value sets associated with the connection set. + + Returns + ------- + ConnectionSet or Mask + If value sets are provided, returns a ConnectionSet object. + Otherwise, returns the mask unchanged. + """ if valueSets: - c = _cs.ExplicitCSet (mask, *valueSets) - return _cs.ConnectionSet (c) + c = _cs.ExplicitCSet(mask, *valueSets) + return _cs.ConnectionSet(c) else: return mask @@ -63,15 +79,45 @@ def vset (obj): # Intervals # -def ival (beg, end): - return _iset.IntervalSet ((beg, end)) +def ival(beg, end): + """ + Create an IntervalSet representing a closed interval. + + Parameters + ---------- + beg : int + Beginning of the interval. + end : int + End of the interval. + + Returns + ------- + IntervalSet + An interval set containing the specified range. + """ + return _iset.IntervalSet((beg, end)) N = _iset.N # Cartesian product # -def cross (set0, set1): - return _cs.intervalSetMask (set0, set1) +def cross(set0, set1): + """ + Compute the Cartesian product mask between two interval sets. + + Parameters + ---------- + set0 : IntervalSet + The first interval set. + set1 : IntervalSet + The second interval set. + + Returns + ------- + Mask + A mask representing the Cartesian product of the two sets. + """ + return _cs.intervalSetMask(set0, set1) # Elementary masks # diff --git a/csa/plot.py b/csa/plot.py index dd2fecc..1f55ae9 100644 --- a/csa/plot.py +++ b/csa/plot.py @@ -43,7 +43,7 @@ def show (cset, N0 = 30, N1 = None): a = _numpy.zeros ((N0, N1)) for (i, j) in elementary.cross (range (N0), range (N1)) * cset: a[i,j] += 1.0 - _plt.imshow (a, interpolation='nearest', vmin = 0.0, vmax = 1.0) + _plt.imshow(a, interpolation='nearest', vmin=0.0, vmax=1.0, cmap=_plt.cm.gray) _plt.show () def gplotsel2d (g, cset, source = elementary.N, target = elementary.N, N0 = 900, N1 = None, value = None, range=[], lines = True):