diff --git a/examples/textbook/relax_gas_and_stars.py b/examples/textbook/relax_gas_and_stars.py index 083548d1a4..e73ec5b891 100644 --- a/examples/textbook/relax_gas_and_stars.py +++ b/examples/textbook/relax_gas_and_stars.py @@ -1,11 +1,20 @@ import numpy import pickle -from amuse.lab import * + +from amuse.lab import ( + units, nbody_system, + new_plummer_model, write_set_to_file, + new_salpeter_mass_distribution, +) from amuse.community.fastkick.interface import FastKick -from amuse.ext.relax_sph import relax +from amuse.community.fi.interface import Fi +from amuse.ext.relax_sph import relax, monitor_energy from amuse.ext.spherical_model import new_gas_plummer_distribution from amuse.community.fractalcluster.interface import new_fractal_cluster_model + +from prepare_figure import * + ###BOOKLISTSTART1### def check_energy_conservation(system, i_step, time, n_steps): unit = units.J @@ -163,7 +172,6 @@ def make_map(sph,N=100,L=1): def plot_hydro_and_stars(hydro, stars): x_label = "x [pc]" y_label = "y [pc]" - from prepare_figure import * fig = single_frame(x_label, y_label, logx=False, logy=False, xsize=12, ysize=12) diff --git a/src/amuse/datamodel/particle_attributes.py b/src/amuse/datamodel/particle_attributes.py index 9a169c9300..0d00943d74 100644 --- a/src/amuse/datamodel/particle_attributes.py +++ b/src/amuse/datamodel/particle_attributes.py @@ -923,7 +923,7 @@ def distances_squared(particles, other_particles): return (dxdydz**2).sum(-1) -def nearest_neighbour(particles, neighbours=None, max_array_length=10000000): +def nearest_neighbour(particles, neighbours=None, self_search=False, max_array_length=10000000): """ Returns the nearest neighbour of each particle in this set. If the 'neighbours' particle set is supplied, the search is performed on the neighbours set, for @@ -931,6 +931,7 @@ def nearest_neighbour(particles, neighbours=None, max_array_length=10000000): set is searched. :argument neighbours: the particle set in which to search for the nearest neighbour (optional) + :argument self_search: if True, the nearest neighbour can be the particle itself (default False) >>> from amuse.datamodel import Particles >>> particles = Particles(3) @@ -964,7 +965,7 @@ def nearest_neighbour(particles, neighbours=None, max_array_length=10000000): ) for indices in indices_in_each_batch: distances_squared = particles[indices].distances_squared(other_particles) - if neighbours is None: + if not self_search and neighbours is None: diagonal_indices = (numpy.arange(len(indices)), indices) distances_squared.number[ diagonal_indices @@ -973,11 +974,11 @@ def nearest_neighbour(particles, neighbours=None, max_array_length=10000000): return other_particles[numpy.concatenate(neighbour_indices)] distances_squared = particles.distances_squared(other_particles) - if neighbours is None: + if not self_search and neighbours is None: # can't be your own neighbour diagonal_indices = numpy.diag_indices(len(particles)) distances_squared.number[ diagonal_indices - ] = numpy.inf # can't be your own neighbour + ] = numpy.inf return other_particles[distances_squared.argmin(axis=1)] diff --git a/src/tests/core_tests/test_particle_attributes.py b/src/tests/core_tests/test_particle_attributes.py index a34f09a0f7..112c95e45f 100644 --- a/src/tests/core_tests/test_particle_attributes.py +++ b/src/tests/core_tests/test_particle_attributes.py @@ -173,6 +173,7 @@ def test11(self): particles.z = 0.0 | units.m self.assertEqual(particles.nearest_neighbour()[0], particles[1]) self.assertEqual(particles.nearest_neighbour()[1:].key, particles[:-1].key) + self.assertEqual(particles.nearest_neighbour(self_search=True).key, particles.key) neighbours = Particles(3) neighbours.x = [1.0, 10.0, 100.0] | units.m