Indicate that lttng-ust 2.13+ is required
[deliverable/lttng-ust-mpi.git] / test-mpi.c
1 /*
2 * SPDX-License-Identifier: MIT
3 *
4 * SPDX-FileCopyrightText: 2023 Olivier Dion <odion@efficios.com>
5 */
6
7 #include <assert.h>
8 #include <stdint.h>
9 #include <stdio.h>
10 #include <stdlib.h>
11
12 #include <mpi.h>
13
14 static uint64_t sum_of(uint64_t *values, size_t values_count)
15 {
16 size_t acc = 0;
17 for (size_t k=0; k<values_count; ++k) {
18 acc += values[k];
19 }
20 return acc;
21 }
22
23 static void usage()
24 {
25 fprintf(stderr, "Usage: test-mpi N\n");
26 exit(EXIT_FAILURE);
27 }
28
29 static uint64_t *allocate_values(size_t upto)
30 {
31 uint64_t *values = (uint64_t*)malloc(sizeof(uint64_t) * upto);
32 for (size_t k=0; k<upto; ++k) {
33 values[k] = k + 1;
34 }
35 return values;
36 }
37
38 static void send_values(int target, uint64_t *values,
39 size_t values_count,
40 MPI_Request *request)
41 {
42 MPI_Isend(values, values_count, MPI_UINT64_T,
43 target, 0, MPI_COMM_WORLD, request);
44 }
45
46 static void recv_answer(int target, uint64_t *value,
47 MPI_Request *request)
48 {
49 MPI_Irecv(value, 1, MPI_UINT64_T,
50 target, 0, MPI_COMM_WORLD, request);
51 }
52
53 static void send_answer(uint64_t value)
54 {
55 MPI_Send(&value, 1, MPI_UINT64_T,
56 0, 0, MPI_COMM_WORLD);
57 }
58
59 static uint64_t *recv_values(size_t chunk_size)
60 {
61 uint64_t *values = (uint64_t*)malloc(sizeof(uint64_t) * chunk_size);
62 MPI_Recv(values, chunk_size, MPI_UINT64_T,
63 0, 0, MPI_COMM_WORLD, MPI_STATUS_IGNORE);
64 return values;
65 }
66
67 int main(int argc, char *argv[])
68 {
69 int rank;
70 int size;
71 long long upto;
72 uint64_t *values;
73
74 if (argc < 2) {
75 usage();
76 }
77
78 upto = atoll(argv[1]);
79
80 if (upto <= 0) {
81 fprintf(stderr, "N must be greater than 0\n");
82 exit(EXIT_FAILURE);
83 }
84
85 MPI_Init(&argc, &argv);
86
87 MPI_Comm_set_errhandler(MPI_COMM_WORLD,
88 MPI_ERRORS_RETURN);
89
90 MPI_Comm_rank(MPI_COMM_WORLD, &rank);
91 MPI_Comm_size(MPI_COMM_WORLD, &size);
92
93 size_t chunk_size;
94 size_t rest;
95 uint64_t total;
96
97 if (size > 1) {
98 chunk_size = upto / (size - 1);
99 rest = upto % (size - 1);
100 } else {
101 chunk_size = 0;
102 rest = upto;
103 }
104
105 if (rank == 0) {
106 uint64_t sums[size];
107 MPI_Request requests[size - 1];
108
109 values = allocate_values(upto);
110
111 for (int k=1; k<size; ++k) {
112 send_values(k,
113 values + (chunk_size * (k - 1)),
114 chunk_size,
115 &requests[k-1]);
116 }
117
118 sums[0] = sum_of(values + chunk_size * (size - 1),
119 rest);
120
121 MPI_Waitall(size - 1, requests, MPI_STATUS_IGNORE);
122
123 for (int k=1; k<size; ++k) {
124 recv_answer(k, &sums[k], &requests[k-1]);
125 }
126
127 MPI_Waitall(size - 1, requests, MPI_STATUS_IGNORE);
128
129 total = sum_of(sums, size);
130 } else {
131 send_answer(sum_of(recv_values(chunk_size),
132 chunk_size));
133 }
134
135 MPI_Finalize();
136
137 if (rank == 0){
138 assert(total ==
139 (((uint64_t)upto * ((uint64_t)upto + 1U)) >> 1U));
140 }
141
142 return 0;
143 }
This page took 0.031909 seconds and 4 git commands to generate.