diff --git a/src/amuse/couple/bridge.py b/src/amuse/couple/bridge.py index cc56c458bb..06c87a98f7 100644 --- a/src/amuse/couple/bridge.py +++ b/src/amuse/couple/bridge.py @@ -295,7 +295,7 @@ def get_gravity_at_point(self,radius,x,y,z): class GravityCodeInField(object): - def __init__(self, code, field_codes, do_sync=True, verbose=False, radius_is_eps=False, h_smooth_is_eps=False, zero_smoothing=False): + def __init__(self, code, field_codes, do_sync=True, verbose=False, radius_is_eps=False, h_smooth_is_eps=False, zero_smoothing=False, use_velocity=False): """ verbose indicates whether to output some run info """ @@ -328,6 +328,8 @@ def __init__(self, code, field_codes, do_sync=True, verbose=False, radius_is_eps except CoreException: # hasattr will fail with an exception self.zero_smoothing = True + self.use_velocity=use_velocity + def evolve_model(self,tend,timestep=None): """ @@ -487,12 +489,23 @@ def get_potential_energy_in_field_code(self, particles, field_code): return (pot*particles.mass).sum() / 2 def kick_with_field_code(self, particles, field_code, dt): - ax,ay,az=field_code.get_gravity_at_point( - self._softening_lengths(particles), - particles.x, - particles.y, - particles.z - ) + if self.use_velocity: + ax,ay,az=field_code.get_gravity_at_point( + self._softening_lengths(particles), + particles.x, + particles.y, + particles.z, + particles.vx, + particles.vy, + particles.vz + ) + else: + ax,ay,az=field_code.get_gravity_at_point( + self._softening_lengths(particles), + particles.x, + particles.y, + particles.z + ) self.update_velocities(particles, dt, ax, ay, az) def update_velocities(self,particles, dt, ax, ay, az): @@ -711,3 +724,52 @@ def kick_codes(self,dt): de += x.kick(dt) self.kick_energy += de + + +class VelocityDependentBridge(Bridge): + def __init__(self, timestep=None, verbose=False, use_threading=True, method=None): + """ + verbose indicates whether to output some run info + """ + self.codes=[] + self.use_velocity=[] + self.time=quantities.zero + self.verbose=verbose + self.timestep=timestep + self.kick_energy = quantities.zero + self.use_threading = use_threading + self.time_offsets = dict() + self.method=method + self.channels = datamodel.Channels() + + def add_system(self, interface, partners=set(), do_sync=True, + radius_is_eps=False, h_smooth_is_eps=False, zero_smoothing=False,use_velocity=False): + """ + add a system to bridge integrator + """ + + if hasattr(interface, "particles"): + code = GravityCodeInField(interface, partners, do_sync, self.verbose, + radius_is_eps, h_smooth_is_eps, zero_smoothing, use_velocity) + self.add_code(code) + else: + if len(partners): + raise Exception("You added a code without particles, but with partners, this is not supported!") + self.add_code(interface) + + self.use_velocity.append(use_velocity) + + def get_gravity_at_point(self,radius,x,y,z,**kwargs): + ax=quantities.zero + ay=quantities.zero + az=quantities.zero + for i,code in enumerate(self.codes): + if self.use_velocity[i]: + vx,vy,vz=kwargs.get('vx'),kwargs.get('vy'),kwargs.get('vz') + _ax,_ay,_az=code.get_gravity_at_point(radius,x,y,z,vx=vx,vy=vy,vz=vz) + else: + _ax,_ay,_az=code.get_gravity_at_point(radius,x,y,z) + ax=ax+_ax + ay=ay+_ay + az=az+_az + return ax,ay,az