Correctly handle MPI_STATUS{ES}_IGNORE as valid values for
MPI_Status* arguments (as opposed to segfaulting :-)



git-svn-id: svn://svn.valgrind.org/valgrind/trunk@10926 a5019735-40e9-0310-863c-91ae7b9d1cf9
diff --git a/mpi/libmpiwrap.c b/mpi/libmpiwrap.c
index c0f41b8..1519fd7 100644
--- a/mpi/libmpiwrap.c
+++ b/mpi/libmpiwrap.c
@@ -57,6 +57,32 @@
    without prior written permission.
 */
 
+/* Handling of MPI_STATUS{ES}_IGNORE for MPI_Status* arguments.
+
+   The MPI-2 spec allows many functions which have MPI_Status* purely
+   as an out parameter, to accept the constants MPI_STATUS_IGNORE or
+   MPI_STATUSES_IGNORE there instead, if the caller does not care
+   about the status.  See the MPI-2 spec sec 4.5.1 ("Passing
+   MPI_STATUS_IGNORE for Status").  (mpi2-report.pdf, 1615898 bytes,
+   md5=694a5efe2fd291eecf7e8c9875b5f43f).
+
+   This library handles such cases by allocating a fake MPI_Status
+   object (on the stack) or an array thereof (on the heap), and
+   passing that onwards instead.  From the outside the caller sees no
+   difference.  Unfortunately the simpler approach of merely detecting
+   and handling these special cases at a lower level does not work,
+   because we need to use information returned in MPI_Status*
+   arguments to paint result buffers, even if the caller doesn't
+   supply a real MPI_Status object.
+
+   Eg, MPI_Recv.  We can't paint the result buffer without knowing how
+   many items arrived; but we can't find that out without passing a
+   real MPI_Status object to the (real) MPI_Recv call.  Hence, if the
+   caller did not supply one, we have no option but to use a temporary
+   stack allocated one for the inner call.  Ditto, more indirectly
+   (via maybe_complete) for nonblocking receives and the various
+   associated wait/test calls. */
+
 
 /*------------------------------------------------------------*/
 /*--- includes                                             ---*/
@@ -103,6 +129,17 @@
 #endif
 
 
+/* Define HAVE_MPI_STATUS_IGNORE iff we have to deal with
+   MPI_STATUS{ES}_IGNORE. */
+#if MPI_VERSION >= 2 \
+    || (defined(MPI_STATUS_IGNORE) && defined(MPI_STATUSES_IGNORE))
+#  undef HAVE_MPI_STATUS_IGNORE
+#  define HAVE_MPI_STATUS_IGNORE 1
+#else
+#  undef HAVE_MPI_STATUS_IGNORE
+#endif
+
+
 /*------------------------------------------------------------*/
 /*--- Decls                                                ---*/
 /*------------------------------------------------------------*/
@@ -401,6 +438,20 @@
    return r1 == r2;
 }
 
+/* Return True if status is MPI_STATUS_IGNORE or MPI_STATUSES_IGNORE.
+   On MPI-1.x platforms which don't have these symbols (and they would
+   only have them if they've been backported from 2.x) always return
+   False. */
+static __inline__
+Bool isMSI ( MPI_Status* status )
+{
+#  if defined(HAVE_MPI_STATUS_IGNORE)
+   return status == MPI_STATUSES_IGNORE || status == MPI_STATUS_IGNORE;
+#  else
+   return False;
+#  endif
+}
+
 /* Get the 'extent' of a type.  Note, as per the MPI spec this
    includes whatever padding would be required when using 'ty' in an
    array. */
@@ -1045,10 +1096,13 @@
                            int source, int tag, 
                            MPI_Comm comm, MPI_Status *status)
 {
-   OrigFn fn;
-   int    err, recv_count = 0;
+   OrigFn     fn;
+   int        err, recv_count = 0;
+   MPI_Status fake_status;
    VALGRIND_GET_ORIG_FN(fn);
    before("Recv");
+   if (isMSI(status))
+      status = &fake_status;
    check_mem_is_addressable(buf, count, datatype);
    check_mem_is_addressable_untyped(status, sizeof(*status));
    CALL_FN_W_7W(err, fn, buf,count,datatype,source,tag,comm,status);
@@ -1386,10 +1440,13 @@
                             MPI_Status* status )
 {
    MPI_Request  request_before;
+   MPI_Status   fake_status;
    OrigFn       fn;
    int          err;
    VALGRIND_GET_ORIG_FN(fn);
    before("Wait");
+   if (isMSI(status))
+      status = &fake_status;
    check_mem_is_addressable_untyped(status, sizeof(MPI_Status));
    check_mem_is_defined_untyped(request, sizeof(MPI_Request));
    request_before = *request;
@@ -1410,10 +1467,13 @@
                                MPI_Status* status )
 {
    MPI_Request* requests_before = NULL;
+   MPI_Status   fake_status;
    OrigFn       fn;
    int          err, i;
    VALGRIND_GET_ORIG_FN(fn);
    before("Waitany");
+   if (isMSI(status))
+      status = &fake_status;
    if (0) fprintf(stderr, "Waitany: %d\n", count);
    check_mem_is_addressable_untyped(index, sizeof(int));
    check_mem_is_addressable_untyped(status, sizeof(MPI_Status));
@@ -1441,9 +1501,14 @@
    MPI_Request* requests_before = NULL;
    OrigFn       fn;
    int          err, i;
+   Bool         free_sta = False;
    VALGRIND_GET_ORIG_FN(fn);
    before("Waitall");
    if (0) fprintf(stderr, "Waitall: %d\n", count);
+   if (isMSI(statuses)) {
+      free_sta = True;
+      statuses = malloc( (count < 0 ? 0 : count) * sizeof(MPI_Status) );
+   }
    for (i = 0; i < count; i++) {
       check_mem_is_addressable_untyped(&statuses[i], sizeof(MPI_Status));
       check_mem_is_defined_untyped(&requests[i], sizeof(MPI_Request));
@@ -1462,6 +1527,8 @@
    }
    if (requests_before)
       free(requests_before);
+   if (free_sta)
+      free(statuses);
    after("Waitall", err);
    return err;
 }
@@ -1472,10 +1539,13 @@
                             MPI_Status* status )
 {
    MPI_Request  request_before;
+   MPI_Status   fake_status;
    OrigFn       fn;
    int          err;
    VALGRIND_GET_ORIG_FN(fn);
    before("Test");
+   if (isMSI(status))
+      status = &fake_status;
    check_mem_is_addressable_untyped(status, sizeof(MPI_Status));
    check_mem_is_addressable_untyped(flag, sizeof(int));
    check_mem_is_defined_untyped(request, sizeof(MPI_Request));
@@ -1498,9 +1568,14 @@
    MPI_Request* requests_before = NULL;
    OrigFn       fn;
    int          err, i;
+   Bool         free_sta = False;
    VALGRIND_GET_ORIG_FN(fn);
    before("Testall");
    if (0) fprintf(stderr, "Testall: %d\n", count);
+   if (isMSI(statuses)) {
+      free_sta = True;
+      statuses = malloc( (count < 0 ? 0 : count) * sizeof(MPI_Status) );
+   }
    check_mem_is_addressable_untyped(flag, sizeof(int));
    for (i = 0; i < count; i++) {
       check_mem_is_addressable_untyped(&statuses[i], sizeof(MPI_Status));
@@ -1516,11 +1591,14 @@
       for (i = 0; i < count; i++) {
          maybe_complete(e_i_s, requests_before[i], requests[i], 
                                &statuses[i]);
-         make_mem_defined_if_addressable_untyped(&statuses[i], sizeof(MPI_Status));
+         make_mem_defined_if_addressable_untyped(&statuses[i],
+                                                 sizeof(MPI_Status));
       }
    }
    if (requests_before)
       free(requests_before);
+   if (free_sta)
+      free(statuses);
    after("Testall", err);
    return err;
 }
@@ -1533,10 +1611,13 @@
                              MPI_Comm comm, 
                              int* flag, MPI_Status* status)
 {
-   OrigFn fn;
-   int    err;
+   MPI_Status fake_status;
+   OrigFn     fn;
+   int        err;
    VALGRIND_GET_ORIG_FN(fn);
    before("Iprobe");
+   if (isMSI(status))
+      status = &fake_status;
    check_mem_is_addressable_untyped(flag, sizeof(*flag));
    check_mem_is_addressable_untyped(status, sizeof(*status));
    CALL_FN_W_5W(err, fn, source,tag,comm,flag,status);
@@ -1555,10 +1636,13 @@
 int WRAPPER_FOR(PMPI_Probe)(int source, int tag,
                             MPI_Comm comm, MPI_Status* status)
 {
-   OrigFn fn;
-   int    err;
+   MPI_Status fake_status;
+   OrigFn     fn;
+   int        err;
    VALGRIND_GET_ORIG_FN(fn);
    before("Probe");
+   if (isMSI(status))
+      status = &fake_status;
    check_mem_is_addressable_untyped(status, sizeof(*status));
    CALL_FN_W_WWWW(err, fn, source,tag,comm,status);
    make_mem_defined_if_addressable_if_success_untyped(err, status, sizeof(*status));
@@ -1606,12 +1690,16 @@
        int source, int recvtag,
        MPI_Comm comm,  MPI_Status *status)
 {
-   OrigFn fn;
-   int    err, recvcount_actual = 0;
+   MPI_Status fake_status;
+   OrigFn     fn;
+   int        err, recvcount_actual = 0;
    VALGRIND_GET_ORIG_FN(fn);
    before("Sendrecv");
+   if (isMSI(status))
+      status = &fake_status;
    check_mem_is_defined(sendbuf, sendcount, sendtype);
    check_mem_is_addressable(recvbuf, recvcount, recvtype);
+   check_mem_is_addressable_untyped(status, sizeof(*status));
    CALL_FN_W_12W(err, fn, sendbuf,sendcount,sendtype,dest,sendtag,
                           recvbuf,recvcount,recvtype,source,recvtag,
                           comm,status);