diff --git a/oshmem/mca/spml/ucx/spml_ucx.c b/oshmem/mca/spml/ucx/spml_ucx.c index 1b2fca7d6f1..8296d2fea1f 100644 --- a/oshmem/mca/spml/ucx/spml_ucx.c +++ b/oshmem/mca/spml/ucx/spml_ucx.c @@ -634,16 +634,16 @@ int mca_spml_ucx_clear_put_op_mask(mca_spml_ucx_ctx_t *ctx) int mca_spml_ucx_add_procs(oshmem_group_t* group, size_t nprocs) { int rc = OSHMEM_ERROR; - int my_rank = oshmem_my_proc_id(); size_t ucp_workers = mca_spml_ucx.ucp_workers; unsigned int *wk_roffs = NULL; unsigned int *wk_rsizes = NULL; char *wk_raddrs = NULL; - size_t i, w, n; + size_t i, j, w, n, temp; ucs_status_t err; ucp_address_t **wk_local_addr; unsigned int *wk_addr_len; ucp_ep_params_t ep_params; + int *indices; wk_local_addr = calloc(mca_spml_ucx.ucp_workers, sizeof(ucp_address_t *)); wk_addr_len = calloc(mca_spml_ucx.ucp_workers, sizeof(size_t)); @@ -691,15 +691,32 @@ int mca_spml_ucx_add_procs(oshmem_group_t* group, size_t nprocs) } } + indices = malloc(nprocs * sizeof(int)); + if (!indices) { + goto error; + } + + for (i = 0; i < nprocs; i++) { + indices[i] = i; + } + + srand((unsigned int)time(NULL)); + /* Get the EP connection requests for all the processes from modex */ - for (n = 0; n < nprocs; ++n) { - i = (my_rank + n) % nprocs; + for (i = nprocs - 1; i >= 0; --i) { + /* Fisher-Yates shuffle algorithm */ + if (i > 0) { + j = rand() % (i + 1); + temp = indices[i]; + indices[i] = indices[j]; + indices[j] = temp; + } ep_params.field_mask = UCP_EP_PARAM_FIELD_REMOTE_ADDRESS; - ep_params.address = (ucp_address_t *)mca_spml_ucx.remote_addrs_tbl[0][i]; + ep_params.address = (ucp_address_t *) mca_spml_ucx.remote_addrs_tbl[0][indices[i]]; err = ucp_ep_create(mca_spml_ucx_ctx_default.ucp_worker[0], &ep_params, - &mca_spml_ucx_ctx_default.ucp_peers[i].ucp_conn); + &mca_spml_ucx_ctx_default.ucp_peers[indices[i]].ucp_conn); if (UCS_OK != err) { SPML_UCX_ERROR("ucp_ep_create(proc=%zu/%zu) failed: %s", n, nprocs, ucs_status_string(err)); @@ -707,7 +724,7 @@ int mca_spml_ucx_add_procs(oshmem_group_t* group, size_t nprocs) } /* Initialize mkeys as NULL for all processes */ - mca_spml_ucx_peer_mkey_cache_init(&mca_spml_ucx_ctx_default, i); + mca_spml_ucx_peer_mkey_cache_init(&mca_spml_ucx_ctx_default, indices[i]); } for (i = 0; i < mca_spml_ucx.ucp_workers; i++) { @@ -719,6 +736,7 @@ int mca_spml_ucx_add_procs(oshmem_group_t* group, size_t nprocs) free(wk_roffs); free(wk_addr_len); free(wk_local_addr); + free(indices); SPML_UCX_VERBOSE(50, "*** ADDED PROCS ***"); @@ -753,6 +771,7 @@ int mca_spml_ucx_add_procs(oshmem_group_t* group, size_t nprocs) free(wk_raddrs); free(wk_rsizes); free(wk_roffs); + free(indices); error: free(wk_addr_len); free(wk_local_addr); @@ -1025,7 +1044,7 @@ static inline void _ctx_remove(mca_spml_ucx_ctx_array_t *array, mca_spml_ucx_ctx opal_atomic_wmb (); } - + static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx_ctx_p) { ucp_worker_params_t params; @@ -1044,7 +1063,7 @@ static int mca_spml_ucx_ctx_create_common(long options, mca_spml_ucx_ctx_t **ucx ucx_ctx->ucp_worker = calloc(1, sizeof(ucp_worker_h)); ucx_ctx->ucp_workers = 1; ucx_ctx->synchronized_quiet = mca_spml_ucx_ctx_default.synchronized_quiet; - ucx_ctx->strong_sync = mca_spml_ucx_ctx_default.strong_sync; + ucx_ctx->strong_sync = mca_spml_ucx_ctx_default.strong_sync; params.field_mask = UCP_WORKER_PARAM_FIELD_THREAD_MODE; if (oshmem_mpi_thread_provided == SHMEM_THREAD_SINGLE ||