Avoid resource leaks when a dblink connection fails.
authorTom Lane <tgl@sss.pgh.pa.us>
Thu, 29 May 2025 14:39:55 +0000 (10:39 -0400)
committerTom Lane <tgl@sss.pgh.pa.us>
Thu, 29 May 2025 14:39:55 +0000 (10:39 -0400)
If we hit out-of-memory between creating the PGconn and inserting
it into dblink's hashtable, we'd lose track of the PGconn, which
is quite bad since it represents a live connection to a remote DB.
Fix by rearranging things so that we create the hashtable entry
first.

Also reduce the number of states we have to deal with by getting rid
of the separately-allocated remoteConn object, instead allocating it
in-line in the hashtable entries.  (That incidentally removes a
session-lifespan memory leak observed in the regression tests.)

There is an apparently-irreducible remaining OOM hazard, which
is that if the connection fails at the libpq level (ie it's
CONNECTION_BAD) then we have to pstrdup the PGconn's error message
before we can release it, and theoretically that could fail.  However,
in such cases we're only leaking memory not a live remote connection,
so I'm not convinced that it's worth sweating over.

This is a pretty low-probability failure mode of course, but losing
a live connection seems bad enough to justify back-patching.

Author: Tom Lane <tgl@sss.pgh.pa.us>
Reviewed-by: Matheus Alcantara <matheusssilv97@gmail.com>
Discussion: https://postgr.es/m/1346940.1748381911@sss.pgh.pa.us
Backpatch-through: 13

contrib/dblink/dblink.c

index 98d4e3d7dac4cc969acd0b365163a3960fc69928..8a0b112a7ff294db0bfc9e1c0d3d056e381b1f11 100644 (file)
@@ -105,7 +105,7 @@ static PGresult *storeQueryResult(volatile storeInfo *sinfo, PGconn *conn, const
 static void storeRow(volatile storeInfo *sinfo, PGresult *res, bool first);
 static remoteConn *getConnectionByName(const char *name);
 static HTAB *createConnHash(void);
-static void createNewConnection(const char *name, remoteConn *rconn);
+static remoteConn *createNewConnection(const char *name);
 static void deleteConnection(const char *name);
 static char **get_pkey_attnames(Relation rel, int16 *indnkeyatts);
 static char **get_text_array_contents(ArrayType *array, int *numitems);
@@ -119,7 +119,8 @@ static Relation get_rel_from_relname(text *relname_text, LOCKMODE lockmode, AclM
 static char *generate_relation_name(Relation rel);
 static void dblink_connstr_check(const char *connstr);
 static bool dblink_connstr_has_pw(const char *connstr);
-static void dblink_security_check(PGconn *conn, remoteConn *rconn, const char *connstr);
+static void dblink_security_check(PGconn *conn, const char *connname,
+                                 const char *connstr);
 static void dblink_res_error(PGconn *conn, const char *conname, PGresult *res,
                             bool fail, const char *fmt,...) pg_attribute_printf(5, 6);
 static char *get_connect_string(const char *servername);
@@ -147,16 +148,22 @@ static uint32 dblink_we_get_conn = 0;
 static uint32 dblink_we_get_result = 0;
 
 /*
- * Following is list that holds multiple remote connections.
+ * Following is hash that holds multiple remote connections.
  * Calling convention of each dblink function changes to accept
- * connection name as the first parameter. The connection list is
+ * connection name as the first parameter. The connection hash is
  * much like ecpg e.g. a mapping between a name and a PGconn object.
+ *
+ * To avoid potentially leaking a PGconn object in case of out-of-memory
+ * errors, we first create the hash entry, then open the PGconn.
+ * Hence, a hash entry whose rconn.conn pointer is NULL must be
+ * understood as a leftover from a failed create; it should be ignored
+ * by lookup operations, and silently replaced by create operations.
  */
 
 typedef struct remoteConnHashEnt
 {
    char        name[NAMEDATALEN];
-   remoteConn *rconn;
+   remoteConn  rconn;
 } remoteConnHashEnt;
 
 /* initial number of connection hashes */
@@ -233,7 +240,7 @@ dblink_get_conn(char *conname_or_str,
                     errmsg("could not establish connection"),
                     errdetail_internal("%s", msg)));
        }
-       dblink_security_check(conn, rconn, connstr);
+       dblink_security_check(conn, NULL, connstr);
        if (PQclientEncoding(conn) != GetDatabaseEncoding())
            PQsetClientEncoding(conn, GetDatabaseEncodingName());
        freeconn = true;
@@ -296,15 +303,6 @@ dblink_connect(PG_FUNCTION_ARGS)
    else if (PG_NARGS() == 1)
        conname_or_str = text_to_cstring(PG_GETARG_TEXT_PP(0));
 
-   if (connname)
-   {
-       rconn = (remoteConn *) MemoryContextAlloc(TopMemoryContext,
-                                                 sizeof(remoteConn));
-       rconn->conn = NULL;
-       rconn->openCursorCount = 0;
-       rconn->newXactForCursor = false;
-   }
-
    /* first check for valid foreign data server */
    connstr = get_connect_string(conname_or_str);
    if (connstr == NULL)
@@ -317,6 +315,13 @@ dblink_connect(PG_FUNCTION_ARGS)
    if (dblink_we_connect == 0)
        dblink_we_connect = WaitEventExtensionNew("DblinkConnect");
 
+   /* if we need a hashtable entry, make that first, since it might fail */
+   if (connname)
+   {
+       rconn = createNewConnection(connname);
+       Assert(rconn->conn == NULL);
+   }
+
    /* OK to make connection */
    conn = libpqsrv_connect(connstr, dblink_we_connect);
 
@@ -324,8 +329,8 @@ dblink_connect(PG_FUNCTION_ARGS)
    {
        msg = pchomp(PQerrorMessage(conn));
        libpqsrv_disconnect(conn);
-       if (rconn)
-           pfree(rconn);
+       if (connname)
+           deleteConnection(connname);
 
        ereport(ERROR,
                (errcode(ERRCODE_SQLCLIENT_UNABLE_TO_ESTABLISH_SQLCONNECTION),
@@ -334,16 +339,16 @@ dblink_connect(PG_FUNCTION_ARGS)
    }
 
    /* check password actually used if not superuser */
-   dblink_security_check(conn, rconn, connstr);
+   dblink_security_check(conn, connname, connstr);
 
    /* attempt to set client encoding to match server encoding, if needed */
    if (PQclientEncoding(conn) != GetDatabaseEncoding())
        PQsetClientEncoding(conn, GetDatabaseEncodingName());
 
+   /* all OK, save away the conn */
    if (connname)
    {
        rconn->conn = conn;
-       createNewConnection(connname, rconn);
    }
    else
    {
@@ -383,10 +388,7 @@ dblink_disconnect(PG_FUNCTION_ARGS)
 
    libpqsrv_disconnect(conn);
    if (rconn)
-   {
        deleteConnection(conname);
-       pfree(rconn);
-   }
    else
        pconn->conn = NULL;
 
@@ -1304,6 +1306,9 @@ dblink_get_connections(PG_FUNCTION_ARGS)
        hash_seq_init(&status, remoteConnHash);
        while ((hentry = (remoteConnHashEnt *) hash_seq_search(&status)) != NULL)
        {
+           /* ignore it if it's not an open connection */
+           if (hentry->rconn.conn == NULL)
+               continue;
            /* stash away current value */
            astate = accumArrayResult(astate,
                                      CStringGetTextDatum(hentry->name),
@@ -2539,8 +2544,8 @@ getConnectionByName(const char *name)
    hentry = (remoteConnHashEnt *) hash_search(remoteConnHash,
                                               key, HASH_FIND, NULL);
 
-   if (hentry)
-       return hentry->rconn;
+   if (hentry && hentry->rconn.conn != NULL)
+       return &hentry->rconn;
 
    return NULL;
 }
@@ -2557,8 +2562,8 @@ createConnHash(void)
                       HASH_ELEM | HASH_STRINGS);
 }
 
-static void
-createNewConnection(const char *name, remoteConn *rconn)
+static remoteConn *
+createNewConnection(const char *name)
 {
    remoteConnHashEnt *hentry;
    bool        found;
@@ -2572,17 +2577,15 @@ createNewConnection(const char *name, remoteConn *rconn)
    hentry = (remoteConnHashEnt *) hash_search(remoteConnHash, key,
                                               HASH_ENTER, &found);
 
-   if (found)
-   {
-       libpqsrv_disconnect(rconn->conn);
-       pfree(rconn);
-
+   if (found && hentry->rconn.conn != NULL)
        ereport(ERROR,
                (errcode(ERRCODE_DUPLICATE_OBJECT),
                 errmsg("duplicate connection name")));
-   }
 
-   hentry->rconn = rconn;
+   /* New, or reusable, so initialize the rconn struct to zeroes */
+   memset(&hentry->rconn, 0, sizeof(remoteConn));
+
+   return &hentry->rconn;
 }
 
 static void
@@ -2671,9 +2674,12 @@ dblink_connstr_has_required_scram_options(const char *connstr)
  * We need to make sure that the connection made used credentials
  * which were provided by the user, so check what credentials were
  * used to connect and then make sure that they came from the user.
+ *
+ * On failure, we close "conn" and also delete the hashtable entry
+ * identified by "connname" (if that's not NULL).
  */
 static void
-dblink_security_check(PGconn *conn, remoteConn *rconn, const char *connstr)
+dblink_security_check(PGconn *conn, const char *connname, const char *connstr)
 {
    /* Superuser bypasses security check */
    if (superuser())
@@ -2703,8 +2709,8 @@ dblink_security_check(PGconn *conn, remoteConn *rconn, const char *connstr)
 
    /* Otherwise, fail out */
    libpqsrv_disconnect(conn);
-   if (rconn)
-       pfree(rconn);
+   if (connname)
+       deleteConnection(connname);
 
    ereport(ERROR,
            (errcode(ERRCODE_S_R_E_PROHIBITED_SQL_STATEMENT_ATTEMPTED),