Skip to content

Commit

Permalink
Fix wsrelay connection leak
Browse files Browse the repository at this point in the history
- when re-establishing connection to db close old connection
- re-initialize WebSocketRelayManager when restarting asyncio.run
- log and ignore error in cleanup_offline_host (this might come back to bite us)
- cleanup connection when WebSocketRelayManager crash
  • Loading branch information
TheRealHaoLiu committed Apr 16, 2024
1 parent 199507c commit 9515ae5
Show file tree
Hide file tree
Showing 2 changed files with 72 additions and 49 deletions.
3 changes: 1 addition & 2 deletions awx/main/management/commands/run_wsrelay.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,11 +165,10 @@ def handle(self, *arg, **options):
return

WebsocketsMetricsServer().start()
websocket_relay_manager = WebSocketRelayManager()

while True:
try:
asyncio.run(websocket_relay_manager.run())
asyncio.run(WebSocketRelayManager().run())
except KeyboardInterrupt:
logger.info('Shutting down Websocket Relayer')
break
Expand Down
118 changes: 71 additions & 47 deletions awx/main/wsrelay.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,6 +285,8 @@ async def cleanup_offline_host(self, hostname):
except asyncio.CancelledError:
# Handle the case where the task was already cancelled by the time we got here.
pass
except Exception as e:
logger.warning(f"Failed to cancel relay connection for {hostname}: {e}")

del self.relay_connections[hostname]

Expand All @@ -295,6 +297,8 @@ async def cleanup_offline_host(self, hostname):
self.stats_mgr.delete_remote_host_stats(hostname)
except KeyError:
pass
except Exception as e:
logger.warning(f"Failed to delete stats for {hostname}: {e}")

async def run(self):
event_loop = asyncio.get_running_loop()
Expand All @@ -316,57 +320,77 @@ async def run(self):

task = None

# Establishes a websocket connection to /websocket/relay on all API servers
while True:
if not task or task.done():
try:
async_conn = await psycopg.AsyncConnection.connect(
dbname=database_conf['NAME'],
host=database_conf['HOST'],
user=database_conf['USER'],
port=database_conf['PORT'],
**database_conf.get("OPTIONS", {}),
)
await async_conn.set_autocommit(True)

task = event_loop.create_task(self.on_ws_heartbeat(async_conn), name="on_ws_heartbeat")
logger.info("Creating `on_ws_heartbeat` task in event loop.")
# Managing the async_conn here so that we can close it if we need to restart the connection
async_conn = None

except Exception as e:
logger.warning(f"Failed to connect to database for pg_notify: {e}")

future_remote_hosts = self.known_hosts.keys()
current_remote_hosts = self.relay_connections.keys()
deleted_remote_hosts = set(current_remote_hosts) - set(future_remote_hosts)
new_remote_hosts = set(future_remote_hosts) - set(current_remote_hosts)

# This loop handles if we get an advertisement from a host we already know about but
# the advertisement has a different IP than we are currently connected to.
for hostname, address in self.known_hosts.items():
if hostname not in self.relay_connections:
# We've picked up a new hostname that we don't know about yet.
continue
# Establishes a websocket connection to /websocket/relay on all API servers
try:
while True:
if not task or task.done():
try:
# Try to close the connection if it's open
if async_conn:
try:
await async_conn.close()
except Exception as e:
logger.warning(f"Failed to close connection to database for pg_notify: {e}")

# and re-establish the connection
async_conn = await psycopg.AsyncConnection.connect(
dbname=database_conf['NAME'],
host=database_conf['HOST'],
user=database_conf['USER'],
port=database_conf['PORT'],
**database_conf.get("OPTIONS", {}),
)
await async_conn.set_autocommit(True)

# before creating the task that uses the connection
task = event_loop.create_task(self.on_ws_heartbeat(async_conn), name="on_ws_heartbeat")
logger.info("Creating `on_ws_heartbeat` task in event loop.")

except Exception as e:
logger.warning(f"Failed to connect to database for pg_notify: {e}")

future_remote_hosts = self.known_hosts.keys()
current_remote_hosts = self.relay_connections.keys()
deleted_remote_hosts = set(current_remote_hosts) - set(future_remote_hosts)
new_remote_hosts = set(future_remote_hosts) - set(current_remote_hosts)

# This loop handles if we get an advertisement from a host we already know about but
# the advertisement has a different IP than we are currently connected to.
for hostname, address in self.known_hosts.items():
if hostname not in self.relay_connections:
# We've picked up a new hostname that we don't know about yet.
continue

if address != self.relay_connections[hostname].remote_host:
deleted_remote_hosts.add(hostname)
new_remote_hosts.add(hostname)
if address != self.relay_connections[hostname].remote_host:
deleted_remote_hosts.add(hostname)
new_remote_hosts.add(hostname)

# Delete any hosts with closed connections
for hostname, relay_conn in self.relay_connections.items():
if not relay_conn.connected:
deleted_remote_hosts.add(hostname)
# Delete any hosts with closed connections
for hostname, relay_conn in self.relay_connections.items():
if not relay_conn.connected:
deleted_remote_hosts.add(hostname)

if deleted_remote_hosts:
logger.info(f"Removing {deleted_remote_hosts} from websocket broadcast list")
await asyncio.gather(*[self.cleanup_offline_host(h) for h in deleted_remote_hosts])
if deleted_remote_hosts:
logger.info(f"Removing {deleted_remote_hosts} from websocket broadcast list")
await asyncio.gather(*[self.cleanup_offline_host(h) for h in deleted_remote_hosts]) # <- problem

if new_remote_hosts:
logger.info(f"Adding {new_remote_hosts} to websocket broadcast list")
if new_remote_hosts:
logger.info(f"Adding {new_remote_hosts} to websocket broadcast list")

for h in new_remote_hosts:
stats = self.stats_mgr.new_remote_host_stats(h)
relay_connection = WebsocketRelayConnection(name=self.local_hostname, stats=stats, remote_host=self.known_hosts[h])
relay_connection.start()
self.relay_connections[h] = relay_connection
for h in new_remote_hosts:
stats = self.stats_mgr.new_remote_host_stats(h)
relay_connection = WebsocketRelayConnection(name=self.local_hostname, stats=stats, remote_host=self.known_hosts[h])
relay_connection.start()
self.relay_connections[h] = relay_connection

await asyncio.sleep(settings.BROADCAST_WEBSOCKET_NEW_INSTANCE_POLL_RATE_SECONDS)
await asyncio.sleep(settings.BROADCAST_WEBSOCKET_NEW_INSTANCE_POLL_RATE_SECONDS)
finally:
if async_conn:
logger.info("Shutting down db connection for wsrelay.")
try:
await async_conn.close()
except Exception as e:
logger.info(f"Failed to close connection to database for pg_notify: {e}")

0 comments on commit 9515ae5

Please sign in to comment.