Actual source code: mpiaijkok.kokkos.cxx

  1: #include <petscvec_kokkos.hpp>
  2: #include <petscpkg_version.h>
  3: #include <petscsf.h>
  4: #include <petsc/private/sfimpl.h>
  5: #include <../src/mat/impls/aij/mpi/mpiaij.h>
  6: #include <../src/mat/impls/aij/mpi/kokkos/mpiaijkok.hpp>
  7: #include <KokkosSparse_spadd.hpp>

  9: PetscErrorCode MatAssemblyEnd_MPIAIJKokkos(Mat A, MatAssemblyType mode)
 10: {
 11:   Mat_SeqAIJKokkos *aijkok;
 12:   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)A->data;

 14:   PetscFunctionBegin;
 15:   PetscCall(MatAssemblyEnd_MPIAIJ(A, mode));
 16:   /* E.g., MatCreateSubMatrix() calls MatCreateMPIAIJWithSeqAIJ(comm,A,B,..), which creates Bnew of SEQAIJ and destroys B of SEQAIJKOKKOS.
 17:      Thus we finalize A/B/lvec's type in MatAssemblyEnd() to handle various cases.
 18:    */
 19:   if (mode == MAT_FINAL_ASSEMBLY) {
 20:     PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
 21:     PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
 22:     PetscCall(VecSetType(mpiaij->lvec, VECSEQKOKKOS));
 23:   }
 24:   aijkok = static_cast<Mat_SeqAIJKokkos *>(((Mat_MPIAIJ *)A->data)->A->spptr); /* Access spptr after MatAssemblyEnd_MPIAIJ(), which might have deleted old spptr */
 25:   if (aijkok && aijkok->device_mat_d.data()) {
 26:     A->offloadmask = PETSC_OFFLOAD_GPU; // in GPU mode, no going back. MatSetValues checks this
 27:   }

 29:   PetscFunctionReturn(PETSC_SUCCESS);
 30: }

 32: PetscErrorCode MatMPIAIJSetPreallocation_MPIAIJKokkos(Mat mat, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[])
 33: {
 34:   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;

 36:   PetscFunctionBegin;
 37:   PetscCall(PetscLayoutSetUp(mat->rmap));
 38:   PetscCall(PetscLayoutSetUp(mat->cmap));
 39: #if defined(PETSC_USE_DEBUG)
 40:   if (d_nnz) {
 41:     PetscInt i;
 42:     for (i = 0; i < mat->rmap->n; i++) PetscCheck(d_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "d_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, d_nnz[i]);
 43:   }
 44:   if (o_nnz) {
 45:     PetscInt i;
 46:     for (i = 0; i < mat->rmap->n; i++) PetscCheck(o_nnz[i] >= 0, PETSC_COMM_SELF, PETSC_ERR_ARG_OUTOFRANGE, "o_nnz cannot be less than 0: local row %" PetscInt_FMT " value %" PetscInt_FMT, i, o_nnz[i]);
 47:   }
 48: #endif
 49: #if defined(PETSC_USE_CTABLE)
 50:   PetscCall(PetscHMapIDestroy(&mpiaij->colmap));
 51: #else
 52:   PetscCall(PetscFree(mpiaij->colmap));
 53: #endif
 54:   PetscCall(PetscFree(mpiaij->garray));
 55:   PetscCall(VecDestroy(&mpiaij->lvec));
 56:   PetscCall(VecScatterDestroy(&mpiaij->Mvctx));
 57:   /* Because the B will have been resized we simply destroy it and create a new one each time */
 58:   PetscCall(MatDestroy(&mpiaij->B));

 60:   if (!mpiaij->A) {
 61:     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->A));
 62:     PetscCall(MatSetSizes(mpiaij->A, mat->rmap->n, mat->cmap->n, mat->rmap->n, mat->cmap->n));
 63:   }
 64:   if (!mpiaij->B) {
 65:     PetscMPIInt size;
 66:     PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)mat), &size));
 67:     PetscCall(MatCreate(PETSC_COMM_SELF, &mpiaij->B));
 68:     PetscCall(MatSetSizes(mpiaij->B, mat->rmap->n, size > 1 ? mat->cmap->N : 0, mat->rmap->n, size > 1 ? mat->cmap->N : 0));
 69:   }
 70:   PetscCall(MatSetType(mpiaij->A, MATSEQAIJKOKKOS));
 71:   PetscCall(MatSetType(mpiaij->B, MATSEQAIJKOKKOS));
 72:   PetscCall(MatSeqAIJSetPreallocation(mpiaij->A, d_nz, d_nnz));
 73:   PetscCall(MatSeqAIJSetPreallocation(mpiaij->B, o_nz, o_nnz));
 74:   mat->preallocated = PETSC_TRUE;
 75:   PetscFunctionReturn(PETSC_SUCCESS);
 76: }

 78: PetscErrorCode MatMult_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
 79: {
 80:   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
 81:   PetscInt    nt;

 83:   PetscFunctionBegin;
 84:   PetscCall(VecGetLocalSize(xx, &nt));
 85:   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
 86:   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
 87:   PetscCall((*mpiaij->A->ops->mult)(mpiaij->A, xx, yy));
 88:   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
 89:   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, yy, yy));
 90:   PetscFunctionReturn(PETSC_SUCCESS);
 91: }

 93: PetscErrorCode MatMultAdd_MPIAIJKokkos(Mat mat, Vec xx, Vec yy, Vec zz)
 94: {
 95:   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
 96:   PetscInt    nt;

 98:   PetscFunctionBegin;
 99:   PetscCall(VecGetLocalSize(xx, &nt));
100:   PetscCheck(nt == mat->cmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->cmap->n, nt);
101:   PetscCall(VecScatterBegin(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
102:   PetscCall((*mpiaij->A->ops->multadd)(mpiaij->A, xx, yy, zz));
103:   PetscCall(VecScatterEnd(mpiaij->Mvctx, xx, mpiaij->lvec, INSERT_VALUES, SCATTER_FORWARD));
104:   PetscCall((*mpiaij->B->ops->multadd)(mpiaij->B, mpiaij->lvec, zz, zz));
105:   PetscFunctionReturn(PETSC_SUCCESS);
106: }

108: PetscErrorCode MatMultTranspose_MPIAIJKokkos(Mat mat, Vec xx, Vec yy)
109: {
110:   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)mat->data;
111:   PetscInt    nt;

113:   PetscFunctionBegin;
114:   PetscCall(VecGetLocalSize(xx, &nt));
115:   PetscCheck(nt == mat->rmap->n, PETSC_COMM_SELF, PETSC_ERR_ARG_SIZ, "Incompatible partition of mat (%" PetscInt_FMT ") and xx (%" PetscInt_FMT ")", mat->rmap->n, nt);
116:   PetscCall((*mpiaij->B->ops->multtranspose)(mpiaij->B, xx, mpiaij->lvec));
117:   PetscCall((*mpiaij->A->ops->multtranspose)(mpiaij->A, xx, yy));
118:   PetscCall(VecScatterBegin(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
119:   PetscCall(VecScatterEnd(mpiaij->Mvctx, mpiaij->lvec, yy, ADD_VALUES, SCATTER_REVERSE));
120:   PetscFunctionReturn(PETSC_SUCCESS);
121: }

123: /* Merge the "A, B" matrices of mat into a matrix C.  mat's type is MPIAIJKOKKOS. C's type is MATSEQAIJKOKKOS.
124:    A is put before B. C's size would be A->rmap->n by (A->cmap->n + B->cmap->n).
125:    C still uses local column ids. Their corresponding global column ids are returned in glob.
126: */
127: PetscErrorCode MatMPIAIJGetLocalMatMerge_MPIAIJKokkos(Mat mat, MatReuse reuse, IS *glob, Mat *C)
128: {
129:   Mat             Ad, Ao;
130:   const PetscInt *cmap;

132:   PetscFunctionBegin;
133:   PetscCall(MatMPIAIJGetSeqAIJ(mat, &Ad, &Ao, &cmap));
134:   PetscCall(MatSeqAIJKokkosMergeMats(Ad, Ao, reuse, C));
135:   if (glob) {
136:     PetscInt cst, i, dn, on, *gidx;
137:     PetscCall(MatGetLocalSize(Ad, NULL, &dn));
138:     PetscCall(MatGetLocalSize(Ao, NULL, &on));
139:     PetscCall(MatGetOwnershipRangeColumn(mat, &cst, NULL));
140:     PetscCall(PetscMalloc1(dn + on, &gidx));
141:     for (i = 0; i < dn; i++) gidx[i] = cst + i;
142:     for (i = 0; i < on; i++) gidx[i + dn] = cmap[i];
143:     PetscCall(ISCreateGeneral(PetscObjectComm((PetscObject)Ad), dn + on, gidx, PETSC_OWN_POINTER, glob));
144:   }
145:   PetscFunctionReturn(PETSC_SUCCESS);
146: }

148: /* Structs used in matrix product C=AB, C=A^tB and C=B^tAB */
149: struct MatMatStruct {
150:   MatRowMapKokkosView Cdstart; /* Used to split sequential matrix into petsc's A, B format */
151:   PetscSF             sf;      /* SF to send/recv matrix entries */
152:   MatScalarKokkosView abuf;    /* buf of mat values in send/recv */
153:   Mat                 C1, C2, B_local;
154:   KokkosCsrMatrix     C1_global, C2_global, C_global;
155:   KernelHandle        kh;
156:   MatMatStruct() noexcept : sf(nullptr), C1(nullptr), C2(nullptr), B_local(nullptr) { }

158:   ~MatMatStruct()
159:   {
160:     PetscFunctionBegin;
161:     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&C1));
162:     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&C2));
163:     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&B_local));
164:     PetscCallAbort(PETSC_COMM_SELF, PetscSFDestroy(&sf));
165:     kh.destroy_spadd_handle();
166:     PetscFunctionReturnVoid();
167:   }
168: };

170: struct MatMatStruct_AB : public MatMatStruct {
171:   MatColIdxKokkosView rows{};
172:   MatRowMapKokkosView rowoffset{};
173:   Mat                 B_other{}, C_petsc{}; /* SEQAIJKOKKOS matrices. TODO: have a better var name than C_petsc */
174:   MatColIdxKokkosView B_NzDiagLeft;         // Number of nonzeros on the left of B's diagonal block; Used to recover the unsplit B (i.e., local mat)

176:   ~MatMatStruct_AB() noexcept
177:   {
178:     PetscFunctionBegin;
179:     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&B_other));
180:     PetscCallAbort(PETSC_COMM_SELF, MatDestroy(&C_petsc));
181:     PetscFunctionReturnVoid();
182:   }
183: };

185: struct MatMatStruct_AtB : public MatMatStruct {
186:   MatRowMapKokkosView srcrowoffset, dstrowoffset;
187: };

189: struct MatProductData_MPIAIJKokkos {
190:   MatMatStruct_AB  *mmAB     = nullptr;
191:   MatMatStruct_AtB *mmAtB    = nullptr;
192:   PetscBool         reusesym = PETSC_FALSE;

194:   ~MatProductData_MPIAIJKokkos()
195:   {
196:     delete mmAB;
197:     delete mmAtB;
198:   }
199: };

201: static PetscErrorCode MatProductDataDestroy_MPIAIJKokkos(void *data)
202: {
203:   PetscFunctionBegin;
204:   PetscCallCXX(delete static_cast<MatProductData_MPIAIJKokkos *>(data));
205:   PetscFunctionReturn(PETSC_SUCCESS);
206: }

208: /* MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds - Get a KokkosCsrMatrix from a MATSEQAIJKOKKOS matrix

210:    Input Parameters:
211: +  A       - the MATSEQAIJKOKKOS matrix
212: .  N       - new column size for the returned Kokkos matrix
213: -  l2g     - a map that maps old col ids to new col ids

215:    Output Parameters:
216: .  csrmat  - the Kokkos matrix, which has the same row size as A, shares a, i but not j with A.
217:  */
218: static PetscErrorCode MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(Mat A, PetscInt N, const ConstMatColIdxKokkosView &l2g, KokkosCsrMatrix &csrmat)
219: {
220:   KokkosCsrMatrix    &orig = static_cast<Mat_SeqAIJKokkos *>(A->spptr)->csrmat;
221:   MatColIdxKokkosView jg("jg", orig.nnz()); /* New j array for csrmat */

223:   PetscFunctionBegin;
224:   PetscCallCXX(Kokkos::parallel_for(
225:     orig.nnz(), KOKKOS_LAMBDA(const PetscInt i) { jg(i) = l2g(orig.graph.entries(i)); }));
226:   PetscCallCXX(csrmat = KokkosCsrMatrix("csrmat", orig.numRows(), N, orig.nnz(), orig.values, orig.graph.row_map, jg));
227:   PetscFunctionReturn(PETSC_SUCCESS);
228: }

230: /* MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices - Set the diag and offdiag matrices of a MATMPIAIJKOKKOS matrix.
231:    It is similar to MatCreateMPIAIJWithSplitArrays.

233:   Input Parameters:
234: +  mat   - the MATMPIAIJKOKKOS matrix, which should have its type and layout set, but should not have its diag, offdiag matrices set
235: .  A     - the diag matrix using local col ids
236: -  B     - the offdiag matrix using global col ids

238:   Output Parameters:
239: .  mat   - the updated MATMPIAIJKOKKOS matrix
240: */
241: static PetscErrorCode MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(Mat mat, Mat A, Mat B)
242: {
243:   Mat_MPIAIJ       *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
244:   PetscInt          m, n, M, N, Am, An, Bm, Bn;
245:   Mat_SeqAIJKokkos *bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);

247:   PetscFunctionBegin;
248:   PetscCall(MatGetSize(mat, &M, &N));
249:   PetscCall(MatGetLocalSize(mat, &m, &n));
250:   PetscCall(MatGetLocalSize(A, &Am, &An));
251:   PetscCall(MatGetLocalSize(B, &Bm, &Bn));

253:   PetscCheck(m == Am && m == Bm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of rows do not match");
254:   PetscCheck(n == An, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local number of columns do not match");
255:   PetscCheck(N == Bn, PETSC_COMM_SELF, PETSC_ERR_PLIB, "global number of columns do not match");
256:   PetscCheck(!mpiaij->A && !mpiaij->B, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A, B of the MPIAIJ matrix are not empty");
257:   mpiaij->A = A;
258:   mpiaij->B = B;

260:   mat->preallocated     = PETSC_TRUE;
261:   mat->nooffprocentries = PETSC_TRUE; /* See MatAssemblyBegin_MPIAIJ. In effect, making MatAssemblyBegin a nop */

263:   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_TRUE));
264:   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
265:   /* MatAssemblyEnd is critical here. It sets mat->offloadmask according to A and B's, and
266:     also gets mpiaij->B compacted, with its col ids and size reduced
267:   */
268:   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
269:   PetscCall(MatSetOption(mat, MAT_NO_OFF_PROC_ENTRIES, PETSC_FALSE));
270:   PetscCall(MatSetOption(mat, MAT_NEW_NONZERO_LOCATION_ERR, PETSC_TRUE));

272:   /* Update bkok with new local col ids (stored on host) and size */
273:   bkok->j_dual.modify_host();
274:   bkok->j_dual.sync_device();
275:   bkok->SetColSize(mpiaij->B->cmap->n);
276:   PetscFunctionReturn(PETSC_SUCCESS);
277: }

279: /* MatSeqAIJKokkosBcast - Bcast rows of a SEQAIJKOKKOS matrice (B) to form a SEQAIJKOKKOS matrix (C).

281:    It is essentially the MPIAIJKOKKOS counterpart of MatGetBrowsOfAoCols_MPIAIJ, but supports device and uses PetscSF.
282:    In the given ownerSF, leaves correspond to rows in C, and roots correspond to rows in B. Roots may connect to multiple leaves.
283:    Suppose C's j-th row is connected to a root identified by PetscSFNode (k,i), it means we will bcast the i-th row of B on rank k
284:    to j-th row of C. ownerSF's leaves must be contiguous (in other words, as if ilocal=NULL was used to set its graph).

286:    Collective

288:    Input Parameters:
289: +   B       - the SEQAIJKOKKOS matrix, using local col ids
290: .   reuse   - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
291: .   N       - global col ids are in range of [0,N). N Must be the same across ranks (nonsignificant in MAT_REUSE_MATRIX)
292: .   l2g     - a map mapping B's local col ids to global ones (nonsignificant in MAT_REUSE_MATRIX)
293: .   ownerSF - the ownership SF (nonsignificant in MAT_REUSE_MATRIX)

295:    Input/Output Parameters (out when resue = MAT_INITIAL_MATRIX, inout when reuse = MAT_REUSE_MATRIX)
296: +   bcastSF   - the SF used to bcast rows of B. This plain SF does buffer (abuf) to buffer (Ca) send/recv. In this SF, vertices are nonzeros.
297: .   abuf      - buffer for sending matrix values
298: .   rows      - array containing indices of (local) rows that this rank needs to bcast to others. Each receiver rank has a chunk in rows[].
299:                 Values in rows[] might have repeats, which simply indicates a row will be bcast'ed to multiple neighbors.
300: .   rowoffset - For each row in rows[], it will be copied to rowoffset[] at abuf[]
301: -   C         -  the SEQAIJKOKKOS matrix made of the bcast'ed rows, using local col ids.
302: */
303: static PetscErrorCode MatSeqAIJKokkosBcast(Mat B, MatReuse reuse, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &bcastSF, MatScalarKokkosView &abuf, MatColIdxKokkosView &rows, MatRowMapKokkosView &rowoffset, Mat &C)
304: {
305:   Mat_SeqAIJKokkos *bkok, *ckok;

307:   PetscFunctionBegin;
308:   PetscCall(MatSeqAIJKokkosSyncDevice(B)); /* Make sure B->spptr is accessible */
309:   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);

311:   if (reuse == MAT_REUSE_MATRIX) {
312:     ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);

314:     const auto &Ba = bkok->a_dual.view_device();
315:     const auto &Bi = bkok->i_dual.view_device();
316:     const auto &Ca = ckok->a_dual.view_device();

318:     /* Copy Ba to abuf */
319:     Kokkos::parallel_for(
320:       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
321:         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
322:         PetscInt r    = rows(i);
323:         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
324:         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) { abuf(base + k) = Ba(Bi(r) + k); });
325:       });

327:     /* Send abuf to Ca through bcastSF and then mark C is updated on device */
328:     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE)); /* TODO: get memtype for abuf */
329:     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.data(), MPI_REPLACE));
330:     ckok->a_dual.modify_device();
331:   } else if (reuse == MAT_INITIAL_MATRIX) {
332:     MPI_Comm    comm;
333:     PetscMPIInt tag;
334:     PetscInt    k, Cm, Cn, Cnnz, *Ci_h, nroots, nleaves;

336:     PetscCallMPI(PetscObjectGetComm((PetscObject)ownerSF, &comm));
337:     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
338:     Cm = nleaves; /* row size of C */
339:     Cn = N;       /* col size of C, which initially uses global ids, so we can safely set its col size as N */

341:     /* Get row lens (nz) of B's rows for later fast query */
342:     PetscInt       *Browlens;
343:     const PetscInt *tmp = bkok->i_host_data();
344:     PetscCall(PetscMalloc1(nroots, &Browlens));
345:     for (k = 0; k < nroots; k++) Browlens[k] = tmp[k + 1] - tmp[k];

347:     /* By ownerSF, each proc gets lens of rows of C */
348:     MatRowMapKokkosDualView Ci("i", Cm + 1); /* C's rowmap */
349:     Ci_h    = Ci.view_host().data();
350:     Ci_h[0] = 0;
351:     PetscCall(PetscSFBcastWithMemTypeBegin(ownerSF, MPIU_INT, PETSC_MEMTYPE_HOST, Browlens, PETSC_MEMTYPE_HOST, &Ci_h[1], MPI_REPLACE));
352:     PetscCall(PetscSFBcastEnd(ownerSF, MPIU_INT, Browlens, &Ci_h[1], MPI_REPLACE));
353:     for (k = 1; k < Cm + 1; k++) Ci_h[k] += Ci_h[k - 1]; /* Convert lens to CSR */
354:     Cnnz = Ci_h[Cm];
355:     Ci.modify_host();
356:     Ci.sync_device();

358:     /* With the newly known Cnnz, we are able to allocate (j, a) for C on host & device */
359:     MatColIdxKokkosDualView Cj("j", Cnnz);
360:     MatScalarKokkosDualView Ca("a", Cnnz);

362:     /* Now build the bcastSF to fill Ca, Cj. This plain SF only does (contiguous) buffer to buffer send/recv */
363:     const PetscMPIInt *iranks, *ranks;
364:     const PetscInt    *ioffset, *irootloc, *roffset;
365:     PetscInt           i, j, niranks, nranks, *sdisp, *rdisp, *rowptr;
366:     MPI_Request       *reqs;

368:     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &irootloc));                      /* irootloc[] contains indices of rows I need to send to each receiver */
369:     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* recv info */

371:     /* figure out offsets at the send buffer, to build the SF
372:       sdisp[]  - stores offsets of nonzeros (in abuf or jbuf, see later) I need to send, per receiver.
373:       rowptr[] - stores offsets for data of each row in abuf

375:       rdisp[]  - to receive sdisp[]
376:     */
377:     PetscCall(PetscMalloc3(niranks + 1, &sdisp, nranks, &rdisp, niranks + nranks, &reqs));
378:     MatRowMapKokkosViewHost rowptr_h("rowptr_h", ioffset[niranks] + 1); /* Let Kokkos do the allocation, so that we can do an easy mirror later */
379:     rowptr = rowptr_h.data();

381:     sdisp[0]  = 0;
382:     rowptr[0] = 0;
383:     for (i = 0; i < niranks; i++) { /* for each receiver */
384:       PetscInt len, nz = 0;
385:       for (j = ioffset[i]; j < ioffset[i + 1]; j++) { /* for each row to this receiver */
386:         len           = Browlens[irootloc[j]];
387:         rowptr[j + 1] = rowptr[j] + len;
388:         nz += len;
389:       }
390:       sdisp[i + 1] = sdisp[i] + nz;
391:     }
392:     PetscCallMPI(PetscCommGetNewTag(comm, &tag));
393:     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
394:     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
395:     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));

397:     PetscInt     nleaves2 = Cnnz;           /* leaves are the nonzeros I will receive */
398:     PetscInt     nroots2  = sdisp[niranks]; /* roots are the nonzeros (in abuf) I will send */
399:     PetscSFNode *iremote;
400:     PetscCall(PetscMalloc1(nleaves2, &iremote));
401:     for (i = 0; i < nranks; i++) { /* for each sender */
402:       k = 0;
403:       for (j = Ci_h[roffset[i]]; j < Ci_h[roffset[i + 1]]; j++) {
404:         iremote[j].rank  = ranks[i];
405:         iremote[j].index = rdisp[i] + k;
406:         k++;
407:       }
408:     }
409:     /* TODO: we should extend PetscSF APIs for this buffer-to-buffer send/recv */
410:     PetscCall(PetscSFCreate(comm, &bcastSF));
411:     PetscCall(PetscSFSetGraph(bcastSF, nroots2, nleaves2, NULL /*ilocal*/, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));

413:     /* Extract selected rows of B, and copy their (a, j) into abuf[] and jbuf[], with j converted
414:       from local to global. Then use bcastSF to fill Ca, Cj.
415:     */
416:     ConstMatColIdxKokkosViewHost rows_h(irootloc, ioffset[niranks]); /* irootloc[] stores indices of rows I need to send */
417:     MatColIdxKokkosView          rows("rows", ioffset[niranks]);
418:     Kokkos::deep_copy(rows, rows_h); /* Use deep copy since irootoc is managed by PetscSF and we want 'rows' to be standalone */

420:     rowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rowptr_h); /* If no device, rowoffset will be an alias to rowptr_h */

422:     MatColIdxKokkosView jbuf("jbuf", sdisp[niranks]);   /* send buf for (global) col ids */
423:     abuf = MatScalarKokkosView("abuf", sdisp[niranks]); /* send buf for mat values */

425:     const auto &Ba = bkok->a_dual.view_device();
426:     const auto &Bi = bkok->i_dual.view_device();
427:     const auto &Bj = bkok->j_dual.view_device();

429:     /* Copy Ba, Bj to abuf, jbuf with change col ids from local to global */
430:     Kokkos::parallel_for(
431:       Kokkos::TeamPolicy<>(rows.extent(0), Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
432:         PetscInt i    = t.league_rank(); /* rows[i] is r-th row of B */
433:         PetscInt r    = rows(i);
434:         PetscInt base = rowoffset(i); /* Copy r-th row of B to this offset in abuf[] */
435:         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, Bi(r + 1) - Bi(r)), [&](PetscInt k) {
436:           abuf(base + k) = Ba(Bi(r) + k);
437:           jbuf(base + k) = l2g(Bj(Bi(r) + k));
438:         });
439:       });

441:     /* Send abuf & jbuf to fill Ca, Cj */
442:     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
443:     PetscCall(PetscSFBcastBegin(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
444:     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_INT, jbuf.data(), Cj.view_device().data(), MPI_REPLACE));
445:     PetscCall(PetscSFBcastEnd(bcastSF, MPIU_SCALAR, abuf.data(), Ca.view_device().data(), MPI_REPLACE));
446:     Cj.modify_device(); /* Mark Cj, Ca modified on device, but only sync Cj since we might not need Ca on host at all */
447:     Cj.sync_host();
448:     Ca.modify_device();

450:     /* Construct C with Ca, Ci, Cj */
451:     auto ckok = new Mat_SeqAIJKokkos(Cm, Cn, Cnnz, Ci, Cj, Ca);
452:     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, &C));
453:     PetscCall(PetscFree3(sdisp, rdisp, reqs));
454:     PetscCall(PetscFree(Browlens));
455:   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
456:   PetscFunctionReturn(PETSC_SUCCESS);
457: }

459: /* MatSeqAIJKokkosReduce - Reduce rows of a SEQAIJKOKKOS matrix (A) to form a Kokkos Csr matrix (C)

461:   It is the reverse of MatSeqAIJKokkosBcast in some sense.

463:   Think each row of A as a leaf, then the given ownerSF specifies roots of the leaves. Roots may connect to multiple leaves.
464:   In this routine, we reduce (i.e., concatenate) leaves (rows) at their roots to form potentially longer rows in C. Such rows might
465:   contain repeats, which does not matter since they will be summed up by other routines. C's row size will be nroots of ownerSF.

467:   Input Parameters:
468: +  A        - the SEQAIJKOKKOS matrix to be reduced
469: .  reuse    - either MAT_INITIAL_MATRIX or MAT_REUSE_MATRIX
470: .  local    - true if A uses local col ids; false if A is already in global col ids.
471: .  N        - if local, N is A's global col size
472: .  l2g      - if local, a map mapping A's local col ids to global ones, which are in range of [0,N).
473: -  ownerSF  - the SF specifies ownership (root) of rows in A

475:   Output Parameters:
476: +  reduceSF    - the SF to reduce A's rows to contiguous buffers at the receiver side
477: .  abuf         - a contiguous buffer to receive A's rows sent to this proc. Suppose there are 'nrows' such rows.
478: .  srcrowoffset - offset array of size nrows+1. Each entry is the corresponding row's offset in abuf[]. srcrowoffset[i+1]-srcrowoffset[i] is row i's len.
479: .  dstrowoffset - offset array of size nrows. Each entry is the corresponding row's offset in Ca[], i.e., C's 'a' array. Row i, i+1 in abuf[] may go to
480:                   unrelated places in Ca, so dstrowoffset is not in CSR-like format as srcrowoffset.
481: -  C            - the matrix made up by rows sent to me from other ranks, using global col ids

483:    TODO: we can even have MatSeqAIJKokkosReduceBegin/End to provide opportunity for callers to overlap comp./comm. when reuse = MAT_REUSE_MATRIX.
484:  */
485: static PetscErrorCode MatSeqAIJKokkosReduce(Mat A, MatReuse reuse, PetscBool local, PetscInt N, const ConstMatColIdxKokkosView &l2g, PetscSF ownerSF, PetscSF &reduceSF, MatScalarKokkosView &abuf, MatRowMapKokkosView &srcrowoffset, MatRowMapKokkosView &dstrowoffset, KokkosCsrMatrix &C)
486: {
487:   PetscInt          i, r, Am, An, Annz, Cnnz, nrows;
488:   const PetscInt   *Ai;
489:   Mat_SeqAIJKokkos *akok;

491:   PetscFunctionBegin;
492:   PetscCall(MatSeqAIJKokkosSyncDevice(A)); /* So that A's latest data is on device */
493:   PetscCall(MatGetSize(A, &Am, &An));
494:   Ai   = static_cast<Mat_SeqAIJ *>(A->data)->i;
495:   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
496:   Annz = Ai[Am];

498:   if (reuse == MAT_REUSE_MATRIX) {
499:     /* Send Aa to abuf */
500:     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
501:     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));

503:     /* Copy abuf to Ca */
504:     const MatScalarKokkosView &Ca = C.values;
505:     nrows                         = dstrowoffset.extent(0); /* Not srcrowoffset[] since it has an extra entry for CSR */
506:     Kokkos::parallel_for(
507:       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
508:         PetscInt i   = t.league_rank();
509:         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
510:         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
511:         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) { Ca(dst + k) = abuf(src + k); });
512:       });
513:   } else if (reuse == MAT_INITIAL_MATRIX) {
514:     MPI_Comm     comm;
515:     MPI_Request *reqs;
516:     PetscMPIInt  tag;
517:     PetscInt     Cm;

519:     PetscCall(PetscObjectGetComm((PetscObject)ownerSF, &comm));
520:     PetscCall(PetscCommGetNewTag(comm, &tag));

522:     PetscInt           niranks, nranks, nroots, nleaves;
523:     const PetscMPIInt *iranks, *ranks;
524:     const PetscInt    *ioffset, *rows, *roffset; /* rows[] contains local indices of rows scattered to me from others. ioffset[] is a CSR on rows[] */
525:     PetscCall(PetscSFSetUp(ownerSF));
526:     PetscCall(PetscSFGetLeafRanks(ownerSF, &niranks, &iranks, &ioffset, &rows));                          /* recv info: iranks[] will send rows to me */
527:     PetscCall(PetscSFGetRootRanks(ownerSF, &nranks, &ranks, &roffset, NULL /*rmine*/, NULL /*rremote*/)); /* send info */
528:     PetscCall(PetscSFGetGraph(ownerSF, &nroots, &nleaves, NULL, NULL));
529:     PetscCheck(nleaves == Am, PETSC_COMM_SELF, PETSC_ERR_PLIB, "ownerSF's nleaves(%" PetscInt_FMT ") != row size of A(%" PetscInt_FMT ")", nleaves, Am);
530:     Cm    = nroots;
531:     nrows = ioffset[niranks]; /* # of rows to be received. Might receive same row (each is partial) from different senders */

533:     /* Tell owners how long each row I will send */
534:     PetscInt               *srowlens;                              /* send buf of row lens */
535:     MatRowMapKokkosViewHost rrowlens_h("rrowoffset_h", nrows + 1); /* recv buf of row lens. +1 to make CSR later. Memory might be passed to other views */
536:     PetscInt               *rrowlens = rrowlens_h.data();

538:     PetscCall(PetscMalloc2(Am, &srowlens, niranks + nranks, &reqs));
539:     for (i = 0; i < Am; i++) srowlens[i] = Ai[i + 1] - Ai[i];
540:     rrowlens[0] = 0;
541:     rrowlens++; /* shift the pointer to make the following expression more readable */
542:     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Irecv(&rrowlens[ioffset[i]], ioffset[i + 1] - ioffset[i], MPIU_INT, iranks[i], tag, comm, &reqs[i]));
543:     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Isend(&srowlens[roffset[i]], roffset[i + 1] - roffset[i], MPIU_INT, ranks[i], tag, comm, &reqs[niranks + i]));
544:     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));

546:     /* Owner builds Ci on host by histogramming rrowlens[] */
547:     MatRowMapKokkosViewHost Ci_h("i", Cm + 1);
548:     Kokkos::deep_copy(Ci_h, 0); /* Zero Ci */
549:     MatRowMapType *Ci_ptr = Ci_h.data();

551:     for (i = 0; i < nrows; i++) {
552:       r = rows[i]; /* local row id of i-th received row */
553: #if defined(PETSC_USE_DEBUG)
554:       PetscCheck(r >= 0 && r < Cm, PETSC_COMM_SELF, PETSC_ERR_PLIB, "local row id (%" PetscInt_FMT ") is out of range [0,%" PetscInt_FMT ")", r, Cm);
555: #endif
556:       Ci_ptr[r + 1] += rrowlens[i]; /* add to length of row r in C */
557:     }
558:     for (i = 0; i < Cm; i++) Ci_ptr[i + 1] += Ci_ptr[i]; /* to CSR format */
559:     Cnnz = Ci_ptr[Cm];

561:     /* For each received row, compute src & dst offsets in memory copying (from recv bufs abuf, jbuf to Ca, Cj) */
562:     MatRowMapKokkosViewHost dstrowoffset_h("dstrowoffset_h", nrows);
563:     PetscInt               *dstrowoffset_hptr = dstrowoffset_h.data();
564:     PetscInt               *currowlens; /* Current row lens. They are temp accumulators for row lens in C, to help build dstrowoffset */

566:     PetscCall(PetscCalloc1(Cm, &currowlens));           /* Init with zero, to be added to */
567:     for (i = 0; i < nrows; i++) {                       /* for each row I receive */
568:       r                    = rows[i];                   /* row id in C */
569:       dstrowoffset_hptr[i] = Ci_ptr[r] + currowlens[r]; /* dst offset of the new place for each recv'ed row in Ca/Cj */
570:       currowlens[r] += rrowlens[i];                     /* accumulate to length of row r in C */
571:     }
572:     PetscCall(PetscFree(currowlens));

574:     rrowlens--;
575:     for (i = 0; i < nrows; i++) rrowlens[i + 1] += rrowlens[i]; /* Change rrowlens[] to CSR format */
576:     dstrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), dstrowoffset_h);
577:     srcrowoffset = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), rrowlens_h); /* src offset of each recv'ed row in abuf/jbuf */

579:     /* Build the reduceSF, which performs buffer to buffer send/recv */
580:     PetscInt *sdisp, *rdisp; /* buffer to send offsets of roots, and buffer to recv them */
581:     PetscCall(PetscMalloc2(niranks, &sdisp, nranks, &rdisp));
582:     for (i = 0; i < niranks; i++) sdisp[i] = rrowlens[ioffset[i]];
583:     for (i = 0; i < nranks; i++) PetscCallMPI(MPI_Irecv(&rdisp[i], 1, MPIU_INT, ranks[i], tag, comm, &reqs[i]));
584:     for (i = 0; i < niranks; i++) PetscCallMPI(MPI_Isend(&sdisp[i], 1, MPIU_INT, iranks[i], tag, comm, &reqs[nranks + i]));
585:     PetscCallMPI(MPI_Waitall(niranks + nranks, reqs, MPI_STATUSES_IGNORE));

587:     /* Nonzeros in abuf/jbuf are roots and those in A are leaves */
588:     PetscInt     nroots2 = Cnnz, nleaves2 = Annz;
589:     PetscSFNode *iremote;
590:     PetscCall(PetscMalloc1(nleaves2, &iremote)); /* no free, since memory will be given to reduceSF */
591:     for (i = 0; i < nranks; i++) {
592:       PetscInt rootbase = rdisp[i];                      /* root offset at this root rank */
593:       PetscInt leafbase = Ai[roffset[i]];                /* leaf base */
594:       PetscInt nz       = Ai[roffset[i + 1]] - leafbase; /* I will send nz nonzeros to this root rank */
595:       for (PetscInt k = 0; k < nz; k++) {
596:         iremote[leafbase + k].rank  = ranks[i];
597:         iremote[leafbase + k].index = rootbase + k;
598:       }
599:     }
600:     PetscCall(PetscSFCreate(comm, &reduceSF));
601:     PetscCall(PetscSFSetGraph(reduceSF, nroots2, nleaves2, NULL, PETSC_OWN_POINTER, iremote, PETSC_OWN_POINTER));
602:     PetscCall(PetscFree2(sdisp, rdisp));

604:     /* Reduce Aa, Ajg to abuf and jbuf */

606:     /* If A uses local col ids, convert them to global ones before sending */
607:     MatColIdxKokkosView Ajg;
608:     if (local) {
609:       Ajg                           = MatColIdxKokkosView("j", Annz);
610:       const MatColIdxKokkosView &Aj = akok->j_dual.view_device();
611:       Kokkos::parallel_for(
612:         Annz, KOKKOS_LAMBDA(const PetscInt i) { Ajg(i) = l2g(Aj(i)); });
613:     } else {
614:       Ajg = akok->j_dual.view_device(); /* no data copy, just take a reference */
615:     }

617:     MatColIdxKokkosView jbuf("jbuf", Cnnz);
618:     abuf = MatScalarKokkosView("abuf", Cnnz);
619:     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
620:     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_INT, Ajg.data(), jbuf.data(), MPI_REPLACE));
621:     PetscCallMPI(PetscSFReduceBegin(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));
622:     PetscCallMPI(PetscSFReduceEnd(reduceSF, MPIU_SCALAR, akok->a_device_data(), abuf.data(), MPI_REPLACE));

624:     /* Copy data from abuf, jbuf to Ca, Cj */
625:     MatRowMapKokkosView Ci = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), Ci_h); /* Ci is an alias of Ci_h if no device */
626:     MatColIdxKokkosView Cj("j", Cnnz);
627:     MatScalarKokkosView Ca("a", Cnnz);

629:     Kokkos::parallel_for(
630:       Kokkos::TeamPolicy<>(nrows, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
631:         PetscInt i   = t.league_rank();
632:         PetscInt src = srcrowoffset(i), dst = dstrowoffset(i);
633:         PetscInt len = srcrowoffset(i + 1) - srcrowoffset(i);
634:         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, len), [&](PetscInt k) {
635:           Ca(dst + k) = abuf(src + k);
636:           Cj(dst + k) = jbuf(src + k);
637:         });
638:       });

640:     /* Build C with Ca, Ci, Cj */
641:     C = KokkosCsrMatrix("csrmat", Cm, N, Cnnz, Ca, Ci, Cj);
642:     PetscCall(PetscFree2(srowlens, reqs));
643:   } else SETERRQ(PETSC_COMM_SELF, PETSC_ERR_PLIB, "Unsupported MatReuse enum %d", reuse);
644:   PetscFunctionReturn(PETSC_SUCCESS);
645: }

647: /* MatSetMPIAIJKokkosWithGlobalCSRMatrix - Set the diag and offdiag parts of a `MATMPIAIJKOKKOS` matrix by splitting a KokkosCsrMatrix

649:   Input Parameters:
650: +  C        - the `MATMPIAIJKOKKOS` matrix, of size m,n,M,N
651: .  reuse    - indicate whether the matrix has called this function before
652: .  csrmat   - the KokkosCsrMatrix, of size m,N
653: -  Cdstart  - when reuse == `MAT_REUSE_MATRIX`, it is an input parameter. For each row in csrmat, it stores the start of the first
654:               entry of the diag block of C in csrmat's j array. E.g, if row i has col ids = {0, 3, 4, 5, 7, 9} and the first diag
655:               entry is 5, then Cdstart[i] = 3.

657:   Output Parameters:
658: +  C        - the updated `MATMPIAIJKOKKOS` matrix
659: -  Cdstart - when reuse == `MAT_INITIAL_MATRIX`, it is an output parameter

661:   Note:
662:    Between calls with `MAT_INITIAL_MATRIX` or `MAT_REUSE_MATRIX`, csrmat must have the same nonzero pattern

664: .seealso: [](chapter_matrices), `Mat`, `MATMPIAIJKOKKOS`
665:  */
666: static PetscErrorCode MatSetMPIAIJKokkosWithGlobalCSRMatrix(Mat C, MatReuse reuse, const KokkosCsrMatrix &csrmat, MatRowMapKokkosView &Cdstart)
667: {
668:   const MatScalarKokkosView      &Ca = csrmat.values;
669:   const ConstMatRowMapKokkosView &Ci = csrmat.graph.row_map;
670:   PetscInt                        m, n, N;

672:   PetscFunctionBegin;
673:   PetscCall(MatGetLocalSize(C, &m, &n));
674:   PetscCall(MatGetSize(C, NULL, &N));

676:   if (reuse == MAT_REUSE_MATRIX) {
677:     Mat_MPIAIJ                *mpiaij = static_cast<Mat_MPIAIJ *>(C->data);
678:     Mat_SeqAIJKokkos          *akok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->A->spptr);
679:     Mat_SeqAIJKokkos          *bkok   = static_cast<Mat_SeqAIJKokkos *>(mpiaij->B->spptr);
680:     const MatScalarKokkosView &Cda = akok->a_dual.view_device(), Coa = bkok->a_dual.view_device();
681:     const MatRowMapKokkosView &Cdi = akok->i_dual.view_device(), Coi = bkok->i_dual.view_device();

683:     /* Fill 'a' of Cd and Co on device */
684:     Kokkos::parallel_for(
685:       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
686:         PetscInt i       = t.league_rank();     /* row i */
687:         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
688:         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
689:         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
690:         PetscInt cdend   = cdstart + cdlen;
691:         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
692:         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
693:           if (k < cdstart) { /* k in [0, cdstart) */
694:             Coa(Coi(i) + k) = Ca(Ci(i) + k);
695:           } else if (k < cdend) { /* k in [cdstart, cdend) */
696:             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
697:           } else { /* k in [cdend, clen) */
698:             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
699:           }
700:         });
701:       });

703:     akok->a_dual.modify_device();
704:     bkok->a_dual.modify_device();
705:   } else if (reuse == MAT_INITIAL_MATRIX) {
706:     Mat                        Cd, Co;
707:     const MatColIdxKokkosView &Cj = csrmat.graph.entries;
708:     MatRowMapKokkosDualView    Cdi_dual("i", m + 1), Coi_dual("i", m + 1);
709:     MatRowMapKokkosView        Cdi = Cdi_dual.view_device(), Coi = Coi_dual.view_device();
710:     PetscInt                   cstart, cend;

712:     /* Note that each row of C is sorted by col ids. We want to find out how to cut each row into three blocks:
713:        left to the diag block, diag block, right to the diag block. The diag block have col ids in [cstart,cend).
714:        Suppose a row of C has len nonzeros, indexed by [0, len). We want to know two indices: cdstart and cdend,
715:        such that the three blocks are [0,cdstart), [cdstart,cdend), [cdend,len). The following code equivalentaly
716:        stores values of cdstart and cdend-cstart (aka Cdi[]) instead.
717:      */
718:     Cdstart = MatRowMapKokkosView("Cdstart", m);
719:     PetscCall(PetscLayoutGetRange(C->cmap, &cstart, &cend)); /* Not MatGetOwnershipRangeColumn() since C has not been preallocated yet */

721:     /* I could use RangePolicy and one thread per row. But since each thread essentially does binary search, threads in a
722:       CUDA warp would completely diverge. So I use TeamPolicy with a team size 1.
723:      */
724:     Kokkos::parallel_for(
725:       Kokkos::TeamPolicy<>(m, 1), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
726:         Kokkos::single(Kokkos::PerTeam(t), [=]() {                               /* Only one thread works in a team */
727:                                                    PetscInt i = t.league_rank(); /* row i */
728:                                                    PetscInt j, first, count, step;

730:                                                    if (i == 0) { /* Set the first entry of the i arrays to zero on device, to be used in CSR */
731:                                                      Cdi(0) = 0;
732:                                                      Coi(0) = 0;
733:                                                    }

735:                                                    /* Do std::lower_bound(Ci(i),Ci(i+1),cstart) on Cj[]. We use j as the iterator. lower_bound() returns
736:           in 'first' the first iterator with a value >= cstart, or last iterator if no such element is found.
737:         */
738:                                                    count = Ci(i + 1) - Ci(i);
739:                                                    first = Ci(i);
740:                                                    while (count > 0) {
741:                                                      j    = first;
742:                                                      step = count / 2;
743:                                                      j += step;
744:                                                      if (Cj(j) < cstart) {
745:                                                        first = ++j;
746:                                                        count -= step + 1;
747:                                                      } else count = step;
748:                                                    }
749:                                                    Cdstart(i) = first - Ci(i); /* 'first' is the while-loop's output */

751:                                                    /* Do std::lower_bound(first,Ci(i+1),cend) on Cj[] */
752:                                                    count = Ci(i + 1) - first;
753:                                                    while (count > 0) {
754:                                                      j    = first;
755:                                                      step = count / 2;
756:                                                      j += step;
757:                                                      if (Cj(j) < cend) {
758:                                                        first = ++j;
759:                                                        count -= step + 1;
760:                                                      } else count = step;
761:                                                    }
762:                                                    Cdi(i + 1) = first - (Ci(i) + Cdstart(i));     /* 'first' is the while-loop's output */
763:                                                    Coi(i + 1) = (Ci(i + 1) - Ci(i)) - Cdi(i + 1); /* Co's row len = C's row len - Cd's row len */
764:         });
765:       });

767:     /* Convert row lens in Cdi[], Coi[] to CSR format using inclusive scan, e.g., changing [0,1,2,3] into [0,1,3,6] */
768:     Kokkos::parallel_scan(
769:       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
770:         update += Cdi(i);
771:         if (final) Cdi(i) = update;
772:       });
773:     Kokkos::parallel_scan(
774:       m + 1, KOKKOS_LAMBDA(const PetscInt i, PetscInt &update, const bool final) {
775:         update += Coi(i);
776:         if (final) Coi(i) = update;
777:       });

779:     /* Get Cdi, Coi on host (it is not a waste, since we do need them on host in
780:        MatCreateSeqAIJKokkosWithCSRMatrix() below), then get nnz of Cd and Co.
781:     */
782:     Cdi_dual.modify_device();
783:     Coi_dual.modify_device();
784:     Cdi_dual.sync_host();
785:     Coi_dual.sync_host();
786:     PetscInt Cd_nnz = Cdi_dual.view_host().data()[m];
787:     PetscInt Co_nnz = Coi_dual.view_host().data()[m];

789:     /* With nnz, allocate a, j for Cd and Co */
790:     MatColIdxKokkosDualView Cdj_dual("j", Cd_nnz), Coj_dual("j", Co_nnz);
791:     MatScalarKokkosDualView Cda_dual("a", Cd_nnz), Coa_dual("a", Co_nnz);

793:     /* Fill a, j of Cd and Co on device */
794:     MatColIdxKokkosView Cdj = Cdj_dual.view_device(), Coj = Coj_dual.view_device();
795:     MatScalarKokkosView Cda = Cda_dual.view_device(), Coa = Coa_dual.view_device();

797:     Kokkos::parallel_for(
798:       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
799:         PetscInt i       = t.league_rank();     /* row i */
800:         PetscInt clen    = Ci(i + 1) - Ci(i);   /* len of row i of C */
801:         PetscInt cdlen   = Cdi(i + 1) - Cdi(i); /* len of row i of Cd */
802:         PetscInt cdstart = Cdstart(i);          /* [start, end) of row i of Cd in C */
803:         PetscInt cdend   = cdstart + cdlen;
804:         /* [0, clen) is cut into three blocks: [0, cdstart), [cdstart, cdend), [cdend, clen) */
805:         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, clen), [&](PetscInt k) {
806:           if (k < cdstart) { /* k in [0, cdstart) */
807:             Coa(Coi(i) + k) = Ca(Ci(i) + k);
808:             Coj(Coi(i) + k) = Cj(Ci(i) + k);
809:           } else if (k < cdend) { /* k in [cdstart, cdend) */
810:             Cda(Cdi(i) + (k - cdstart)) = Ca(Ci(i) + k);
811:             Cdj(Cdi(i) + (k - cdstart)) = Cj(Ci(i) + k) - cstart; /* Use local col ids in Cdj */
812:           } else {                                                /* k in [cdend, clen) */
813:             Coa(Coi(i) + k - cdlen) = Ca(Ci(i) + k);
814:             Coj(Coi(i) + k - cdlen) = Cj(Ci(i) + k);
815:           }
816:         });
817:       });

819:     Cdj_dual.modify_device();
820:     Cda_dual.modify_device();
821:     Coj_dual.modify_device();
822:     Coa_dual.modify_device();
823:     /* With a, i, j for Cd and Co, finally build Cd, Co and then C. Their offloadmask will be set in each's MatAssemblyEnd */
824:     auto cdkok = new Mat_SeqAIJKokkos(m, n, Cd_nnz, Cdi_dual, Cdj_dual, Cda_dual);
825:     auto cokok = new Mat_SeqAIJKokkos(m, N, Co_nnz, Coi_dual, Coj_dual, Coa_dual);
826:     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cdkok, &Cd));
827:     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, cokok, &Co));
828:     PetscCall(MatSetMPIAIJKokkosWithSplitSeqAIJKokkosMatrices(C, Cd, Co)); /* Coj will be converted to local ids within */
829:   }
830:   PetscFunctionReturn(PETSC_SUCCESS);
831: }

833: /* MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos - Compact a SEQAIJKOKKS matrix's global col ids.

835:   It is similar to MatSeqAIJCompactOutExtraColumns_SeqAIJ, but it applies to SEQAIJKOKKOS and returns the l2g map in Kokkos view.

837:   Input Parameters:
838: +  C        - the MATMPIAIJKOKKOS matrix, of size m,n,M,N
839: .  reuse    - indicate whether the matrix has called this function before
840: .  csrmat   - the KokkosCsrMatrix, of size m,N
841: -  Cdoffset - when reuse == MAT_REUSE_MATRIX, it is an input parameter. For each row in csrmat, it stores the offset of the first
842:               entry of the diag block of C in csrmat's j array.

844:   Output Parameters:
845: +  C        - the updated MATMPIAIJKOKKOS matrix
846: -  Cdoffset - when reuse == MAT_INITIAL_MATRIX, it is an output parameter

848:   Note:
849:   the input matrix's col ids and col size will be changed.
850: */
851: static PetscErrorCode MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(Mat C, MatColIdxKokkosView &l2g)
852: {
853:   Mat_SeqAIJKokkos      *ckok;
854:   ISLocalToGlobalMapping l2gmap;
855:   const PetscInt        *garray;
856:   PetscInt               sz;

858:   PetscFunctionBegin;
859:   /* Compact P_other's global col ids and col size. We do it since we guess with local ids KK might be more memory scalable */
860:   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJ(C, &l2gmap));
861:   ckok = static_cast<Mat_SeqAIJKokkos *>(C->spptr);
862:   ckok->j_dual.modify_host(); /* P_other's j is modified on host; we need to sync it on device */
863:   ckok->j_dual.sync_device();
864:   ckok->SetColSize(C->cmap->n); /* Update col size of the csrmat in spptr */

866:   /* Build l2g -- the local to global mapping of C's cols */
867:   PetscCall(ISLocalToGlobalMappingGetIndices(l2gmap, &garray));
868:   PetscCall(ISLocalToGlobalMappingGetSize(l2gmap, &sz));
869:   PetscCheck(C->cmap->n == sz, PETSC_COMM_SELF, PETSC_ERR_PLIB, "matrix column size(%" PetscInt_FMT ") != l2g mapping size(%" PetscInt_FMT ")", C->cmap->n, sz);

871:   ConstMatColIdxKokkosViewHost tmp(garray, sz);
872:   l2g = MatColIdxKokkosView("l2g", sz);
873:   Kokkos::deep_copy(l2g, tmp);

875:   PetscCall(ISLocalToGlobalMappingRestoreIndices(l2gmap, &garray));
876:   PetscCall(ISLocalToGlobalMappingDestroy(&l2gmap));
877:   PetscFunctionReturn(PETSC_SUCCESS);
878: }

880: #if PETSC_PKG_KOKKOS_KERNELS_VERSION_GE(3, 7, 99)
881: static PetscErrorCode MatMPIAIJGetLocalMat_MPIAIJKokkos(Mat mat, MatReuse reuse, MatMatStruct_AB *mm, Mat *C)
882: {
883:   Mat                 A, B;
884:   const PetscInt     *garray;
885:   Mat_SeqAIJ         *aseq, *bseq;
886:   Mat_SeqAIJKokkos   *akok, *bkok, *ckok;
887:   MatScalarKokkosView aa, ba, ca;
888:   MatRowMapKokkosView ai, bi, ci;
889:   MatColIdxKokkosView aj, bj, cj;
890:   PetscInt            m, nnz;

892:   PetscFunctionBegin;
893:   PetscCall(MatMPIAIJGetSeqAIJ(mat, &A, &B, &garray));
894:   PetscCheckTypeName(A, MATSEQAIJKOKKOS);
895:   PetscCheckTypeName(B, MATSEQAIJKOKKOS);
896:   PetscCheck(reuse != MAT_INPLACE_MATRIX, PETSC_COMM_SELF, PETSC_ERR_SUP, "MAT_INPLACE_MATRIX not supported");
897:   PetscCall(MatSeqAIJKokkosSyncDevice(A));
898:   PetscCall(MatSeqAIJKokkosSyncDevice(B));
899:   aseq = static_cast<Mat_SeqAIJ *>(A->data);
900:   bseq = static_cast<Mat_SeqAIJ *>(B->data);
901:   akok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
902:   bkok = static_cast<Mat_SeqAIJKokkos *>(B->spptr);
903:   aa   = akok->a_dual.view_device();
904:   ai   = akok->i_dual.view_device();
905:   ba   = bkok->a_dual.view_device();
906:   bi   = bkok->i_dual.view_device();
907:   m    = A->rmap->n; /* M and nnz of C */
908:   nnz  = aseq->nz + bseq->nz;
909:   if (reuse == MAT_INITIAL_MATRIX) {
910:     aj           = akok->j_dual.view_device();
911:     bj           = bkok->j_dual.view_device();
912:     auto ca_dual = MatScalarKokkosDualView("a", nnz);
913:     auto ci_dual = MatRowMapKokkosDualView("i", m + 1);
914:     auto cj_dual = MatColIdxKokkosDualView("j", nnz);
915:     ca           = ca_dual.view_device();
916:     ci           = ci_dual.view_device();
917:     cj           = cj_dual.view_device();

919:     // For each row of B, find number of nonzeros on the left of the diagonal block (i.e., A).
920:     // The result is stored in mm->B_NzDiagLeft for reuse in the numeric phase
921:     MatColIdxKokkosViewHost NzLeft("NzLeft", m);
922:     const MatRowMapType    *rowptr = bkok->i_host_data();
923:     const MatColIdxType    *colidx = bkok->j_host_data();
924:     MatColIdxType          *nzleft = NzLeft.data();
925:     const MatColIdxType     cstart = mat->cmap->rstart; // start of global column indices of A; used to split B
926:     for (PetscInt i = 0; i < m; i++) {
927:       const MatColIdxType *first, *last, *it;
928:       PetscInt             count, step;

930:       // Basically, std::lower_bound(first,last,cstart), but need to map columns from local to global with garray[]
931:       first = colidx + rowptr[i];
932:       last  = colidx + rowptr[i + 1];
933:       count = last - first;
934:       while (count > 0) {
935:         it   = first;
936:         step = count / 2;
937:         it += step;
938:         if (garray[*it] < cstart) {
939:           first = ++it;
940:           count -= step + 1;
941:         } else count = step;
942:       }
943:       nzleft[i] = first - (colidx + rowptr[i]);
944:     }
945:     auto B_NzDiagLeft = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), NzLeft); // copy to device

947:     auto tmp = MatColIdxKokkosViewHost(const_cast<MatColIdxType *>(garray), B->cmap->n);
948:     auto l2g = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), tmp); // copy garray to device

950:     // Shuffle A and B in parallel using Kokkos hierarchical parallelism
951:     Kokkos::parallel_for(
952:       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
953:         PetscInt i    = t.league_rank(); /* row i */
954:         PetscInt disp = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
955:         PetscInt nzleft = B_NzDiagLeft(i);

957:         Kokkos::single(Kokkos::PerTeam(t), [=]() {
958:           ci(i) = disp;
959:           if (i == m - 1) ci(m) = ai(m) + bi(m);
960:         });

962:         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
963:           if (k < nzleft) { // portion of B that is on left of A
964:             ca(disp + k) = ba(bi(i) + k);
965:             cj(disp + k) = l2g(bj(bi(i) + k));
966:           } else if (k < nzleft + alen) { // diag A
967:             ca(disp + k) = aa(ai(i) + k - nzleft);
968:             cj(disp + k) = aj(ai(i) + k - nzleft) + cstart; // add the shift to convert local to global.
969:           } else {                                          // portion of B that is on right of A
970:             ca(disp + k) = ba(bi(i) + k - alen);
971:             cj(disp + k) = l2g(bj(bi(i) + k - alen));
972:           }
973:         });
974:       });
975:     ca_dual.modify_device();
976:     ci_dual.modify_device();
977:     cj_dual.modify_device();
978:     PetscCallCXX(ckok = new Mat_SeqAIJKokkos(m, mat->cmap->N, nnz, ci_dual, cj_dual, ca_dual));
979:     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, ckok, C));
980:     mm->B_NzDiagLeft = B_NzDiagLeft;
981:   } else if (reuse == MAT_REUSE_MATRIX) {
983:     PetscCheckTypeName(*C, MATSEQAIJKOKKOS);
984:     ckok               = static_cast<Mat_SeqAIJKokkos *>((*C)->spptr);
985:     ca                 = ckok->a_dual.view_device();
986:     auto &B_NzDiagLeft = mm->B_NzDiagLeft;

988:     Kokkos::parallel_for(
989:       Kokkos::TeamPolicy<>(m, Kokkos::AUTO()), KOKKOS_LAMBDA(const KokkosTeamMemberType &t) {
990:         PetscInt i    = t.league_rank(); // row i
991:         PetscInt disp = ai(i) + bi(i), alen = ai(i + 1) - ai(i), blen = bi(i + 1) - bi(i);
992:         PetscInt nzleft = B_NzDiagLeft(i);

994:         Kokkos::parallel_for(Kokkos::TeamThreadRange(t, alen + blen), [&](PetscInt k) {
995:           if (k < nzleft) { // portion of B that is on left of A
996:             ca(disp + k) = ba(bi(i) + k);
997:           } else if (k < nzleft + alen) { // diag A
998:             ca(disp + k) = aa(ai(i) + k - nzleft);
999:           } else { // portion of B that is on right of A
1000:             ca(disp + k) = ba(bi(i) + k - alen);
1001:           }
1002:         });
1003:       });

1005:     PetscCall(MatSeqAIJKokkosModifyDevice(*C));
1006:   }
1007:   PetscFunctionReturn(PETSC_SUCCESS);
1008: }
1009: #endif

1011: /* MatProductSymbolic_MPIAIJKokkos_AB - AB flavor of MatProductSymbolic_MPIAIJKokkos

1013:   Input Parameters:
1014: +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1015: .  A        - an MPIAIJKOKKOS matrix
1016: .  B        - an MPIAIJKOKKOS matrix
1017: -  mm       - a struct used to stash intermediate data when computing AB. Persist from symbolic to numeric operations.

1019:   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
1020: */
1021: static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AB(Mat_Product *product, Mat A, Mat B, MatMatStruct_AB *mm)
1022: {
1023:   Mat_MPIAIJ              *a  = static_cast<Mat_MPIAIJ *>(A->data);
1024:   Mat                      Ad = a->A, Ao = a->B; /* diag and offdiag of A */
1025:   IS                       glob = NULL;
1026:   const PetscInt          *garray;
1027:   PetscInt                 N = B->cmap->N, sz;
1028:   ConstMatColIdxKokkosView l2g1; /* two temp maps mapping local col ids to global ones */
1029:   MatColIdxKokkosView      l2g2;
1030:   Mat                      C1, C2; /* intermediate matrices */

1032:   PetscFunctionBegin;
1033: #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(3, 7, 99)
1034:   /* C1 = Ad * B_local. B_local is a matrix got by merging Bd and Bo, and uses local col ids */
1035:   PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &mm->B_local));
1036: #else
1037:   PetscCall(MatMPIAIJGetLocalMat_MPIAIJKokkos(B, MAT_INITIAL_MATRIX, mm, &mm->B_local));
1038:   PetscCall(ISCreateStride(MPI_COMM_SELF, N, 0, 1, &glob));
1039: #endif

1041:   PetscCall(MatProductCreate(Ad, mm->B_local, NULL, &C1));
1042:   PetscCall(MatProductSetType(C1, MATPRODUCT_AB));
1043:   PetscCall(MatProductSetFill(C1, product->fill));
1044:   C1->product->api_user = product->api_user;
1045:   PetscCall(MatProductSetFromOptions(C1));
1046:   PetscUseTypeMethod(C1, productsymbolic);

1048:   PetscCall(ISGetIndices(glob, &garray));
1049:   PetscCall(ISGetSize(glob, &sz));
1050:   const auto &tmp = ConstMatColIdxKokkosViewHost(garray, sz);                       /* wrap garray as a view */
1051:   l2g1            = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), tmp); /* maybe just an alias to tmp, so we restore garray at the very end */
1052:   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g1, mm->C1_global));

1054:   /* C2 = Ao * B_other. B_other is a matrix consisting of needed rows of B gathered from other procs */
1055:   PetscCall(MatSeqAIJKokkosBcast(mm->B_local, MAT_INITIAL_MATRIX, N, l2g1, a->Mvctx, mm->sf, mm->abuf, mm->rows, mm->rowoffset, mm->B_other));

1057:   /* Compact B_other to use local ids as we guess KK spgemm is more memory scalable with that; We could skip the compaction to simplify code */
1058:   PetscCall(MatSeqAIJCompactOutExtraColumns_SeqAIJKokkos(mm->B_other, l2g2));
1059:   PetscCall(MatProductCreate(Ao, mm->B_other, NULL, &C2));
1060:   PetscCall(MatProductSetType(C2, MATPRODUCT_AB));
1061:   PetscCall(MatProductSetFill(C2, product->fill));
1062:   C2->product->api_user = product->api_user;
1063:   PetscCall(MatProductSetFromOptions(C2));
1064:   PetscUseTypeMethod(C2, productsymbolic);
1065:   PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C2, N, l2g2, mm->C2_global));

1067:   /* C = C1 + C2.  We actually use their global col ids versions in adding */
1068:   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B_local, B_other are not */
1069:   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
1070:   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
1071:   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);

1073:   mm->C1 = C1;
1074:   mm->C2 = C2;
1075:   PetscCall(ISRestoreIndices(glob, &garray));
1076:   PetscCall(ISDestroy(&glob));
1077:   PetscFunctionReturn(PETSC_SUCCESS);
1078: }

1080: /* MatProductSymbolic_MPIAIJKokkos_AtB - A^tB flavor of MatProductSymbolic_MPIAIJKokkos

1082:   Input Parameters:
1083: +  product  - Mat_Product which carried out the computation. Passed in to access info about this mat product.
1084: .  A        - an MPIAIJKOKKOS matrix
1085: .  B        - a SEQAIJKOKKOS matrix. It works as if A^t is multiplied by a parallel matrix made up of Bs on each rank.
1086: .  localB   - Does B use local col ids? If false, then B is already in global col ids.
1087: .  N        - col size of the "parallel B matrix". It implies B's global col ids are in range of [0,N) and N is the same across the communicator.
1088: .  l2g      - If localB, then l2g maps B's local col ids to global ones.
1089: -  mm       - a struct used to stash intermediate data in AtB

1091:   Note: The local part of the result C is stored as mm->C_global, which is of type KokkosCsrMatrix and uses global col ids.
1092: */
1093: static PetscErrorCode MatProductSymbolic_MPIAIJKokkos_AtB(Mat_Product *product, Mat A, Mat B, PetscBool localB, PetscInt N, const ConstMatColIdxKokkosView &l2g, MatMatStruct_AtB *mm)
1094: {
1095:   Mat_MPIAIJ *a  = static_cast<Mat_MPIAIJ *>(A->data);
1096:   Mat         Ad = a->A, Ao = a->B; /* diag and offdiag of A */
1097:   Mat         C1, C2;               /* intermediate matrices */

1099:   PetscFunctionBegin;
1100:   /* C1 = Ad^t * B */
1101:   PetscCall(MatProductCreate(Ad, B, NULL, &C1));
1102:   PetscCall(MatProductSetType(C1, MATPRODUCT_AtB));
1103:   PetscCall(MatProductSetFill(C1, product->fill));
1104:   C1->product->api_user = product->api_user;
1105:   PetscCall(MatProductSetFromOptions(C1));
1106:   PetscUseTypeMethod(C1, productsymbolic);

1108:   if (localB) PetscCall(MatSeqAIJKokkosGetCSRMatrixWithGlobalColumnIds(C1, N, l2g, mm->C1_global));
1109:   else mm->C1_global = static_cast<Mat_SeqAIJKokkos *>(C1->spptr)->csrmat; /* the csrmat already uses global col ids */

1111:   /* C2 = Ao^t * B */
1112:   PetscCall(MatProductCreate(Ao, B, NULL, &C2));
1113:   PetscCall(MatProductSetType(C2, MATPRODUCT_AtB));
1114:   PetscCall(MatProductSetFill(C2, product->fill));
1115:   C2->product->api_user = product->api_user;
1116:   PetscCall(MatProductSetFromOptions(C2));
1117:   PetscUseTypeMethod(C2, productsymbolic);

1119:   PetscCall(MatSeqAIJKokkosReduce(C2, MAT_INITIAL_MATRIX, localB, N, l2g, a->Mvctx, mm->sf, mm->abuf, mm->srcrowoffset, mm->dstrowoffset, mm->C2_global));

1121:   mm->kh.create_spadd_handle(false); /* Input C1, C2 are NOT sorted, since B may be not */
1122:   KokkosSparse::spadd_symbolic(&mm->kh, mm->C1_global, mm->C2_global, mm->C_global);
1123:   /* Have to do numeric since spadd_symbolic does not really populate column indices of the result matrix */
1124:   KokkosSparse::spadd_numeric(&mm->kh, (MatScalarType)1.0, mm->C1_global, (MatScalarType)1.0, mm->C2_global, mm->C_global);
1125:   mm->C1 = C1;
1126:   mm->C2 = C2;
1127:   PetscFunctionReturn(PETSC_SUCCESS);
1128: }

1130: PetscErrorCode MatProductNumeric_MPIAIJKokkos(Mat C)
1131: {
1132:   Mat_Product                 *product = C->product;
1133:   MatProductType               ptype;
1134:   MatProductData_MPIAIJKokkos *mmdata;
1135:   MatMatStruct                *mm = NULL;
1136:   MatMatStruct_AB             *ab;
1137:   MatMatStruct_AtB            *atb;
1138:   Mat                          A, B, Ad, Ao, Bd, Bo;
1139:   const MatScalarType          one = 1.0; /* Not use literal 1.0 directly, to avoid wrong template instantiation in KokkosSparse::spadd_numeric */

1141:   PetscFunctionBegin;
1142:   MatCheckProduct(C, 1);
1143:   mmdata = static_cast<MatProductData_MPIAIJKokkos *>(product->data);
1144:   ptype  = product->type;
1145:   A      = product->A;
1146:   B      = product->B;
1147:   Ad     = static_cast<Mat_MPIAIJ *>(A->data)->A;
1148:   Ao     = static_cast<Mat_MPIAIJ *>(A->data)->B;
1149:   Bd     = static_cast<Mat_MPIAIJ *>(B->data)->A;
1150:   Bo     = static_cast<Mat_MPIAIJ *>(B->data)->B;

1152:   if (mmdata->reusesym) {           /* We reached here through e.g., MatMatMult(A,B,MAT_INITIAL_MATRIX,..,C), where symbolic/numeric are combined */
1153:     mmdata->reusesym = PETSC_FALSE; /* So that next time when user calls MatMatMult(E,F,MAT_REUSE_MATRIX,..,C), we still do numeric  */
1154:     ab               = mmdata->mmAB;
1155:     atb              = mmdata->mmAtB;
1156:     if (ab) {
1157:       static_cast<MatProductData_SeqAIJKokkos *>(ab->C1->product->data)->reusesym = PETSC_FALSE;
1158:       static_cast<MatProductData_SeqAIJKokkos *>(ab->C2->product->data)->reusesym = PETSC_FALSE;
1159:     }
1160:     if (atb) {
1161:       static_cast<MatProductData_SeqAIJKokkos *>(atb->C1->product->data)->reusesym = PETSC_FALSE;
1162:       static_cast<MatProductData_SeqAIJKokkos *>(atb->C2->product->data)->reusesym = PETSC_FALSE;
1163:     }
1164:     PetscFunctionReturn(PETSC_SUCCESS);
1165:   }

1167:   if (ptype == MATPRODUCT_AB) {
1168:     ab = mmdata->mmAB;
1169:     /* C1 = Ad * B_local */
1170:     PetscCheck(ab->C1->ops->productnumeric && ab->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AB");
1171: #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(3, 7, 99)
1172:     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1173: #else
1174:     PetscCall(MatMPIAIJGetLocalMat_MPIAIJKokkos(B, MAT_REUSE_MATRIX, ab, &ab->B_local));
1175: #endif

1177:     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C1->B has unexpectedly changed");
1178:     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
1179:     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1180:     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1181:     /* C2 = Ao * B_other */
1182:     PetscCheck(ab->C2->product->B == ab->B_other, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AB, internal mat product matrix C2->B has unexpectedly changed");
1183:     if (ab->C1->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
1184:     PetscCall((*ab->C2->ops->productnumeric)(ab->C2));
1185:     /* C = C1_global + C2_global */
1186:     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);
1187:     mm = static_cast<MatMatStruct *>(ab);
1188:   } else if (ptype == MATPRODUCT_AtB) {
1189:     atb = mmdata->mmAtB;
1190:     PetscCheck(atb->C1->ops->productnumeric && atb->C2->ops->productnumeric, PETSC_COMM_SELF, PETSC_ERR_PLIB, "Missing numeric op for MATPRODUCT_AtB");
1191:     /* C1 = Ad^t * B_local */
1192:     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &atb->B_local));
1193:     PetscCheck(atb->C1->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C1->B has unexpectedly changed");
1194:     if (atb->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, atb->C1));
1195:     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));

1197:     /* C2 = Ao^t * B_local */
1198:     PetscCheck(atb->C2->product->B == atb->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_AtB, internal mat product matrix C2->B has unexpectedly changed");
1199:     if (atb->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, atb->C2));
1200:     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1201:     /* Form C2_global */
1202:     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_TRUE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1203:     /* C = C1_global + C2_global */
1204:     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1205:     mm = static_cast<MatMatStruct *>(atb);
1206:   } else if (ptype == MATPRODUCT_PtAP) { /* BtAB */
1207:     ab = mmdata->mmAB;
1208: #if PETSC_PKG_KOKKOS_KERNELS_VERSION_LT(3, 7, 99)
1209:     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_REUSE_MATRIX, NULL /* glob */, &ab->B_local));
1210: #else
1211:     PetscCall(MatMPIAIJGetLocalMat_MPIAIJKokkos(B, MAT_REUSE_MATRIX, ab, &ab->B_local));
1212: #endif
1213:     /* ab->C1 = Ad * B_local */
1214:     PetscCheck(ab->C1->product->B == ab->B_local, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix ab->C1->B has unexpectedly changed");
1215:     if (ab->C1->product->A != Ad) PetscCall(MatProductReplaceMats(Ad, NULL, NULL, ab->C1));
1216:     PetscCall((*ab->C1->ops->productnumeric)(ab->C1));
1217:     PetscCall(MatSeqAIJKokkosBcast(ab->B_local, MAT_REUSE_MATRIX, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /*ownerSF*/, ab->sf, ab->abuf, ab->rows, ab->rowoffset, ab->B_other));
1218:     /* ab->C2 = Ao * B_other */
1219:     if (ab->C2->product->A != Ao) PetscCall(MatProductReplaceMats(Ao, NULL, NULL, ab->C2));
1220:     PetscCall((*ab->C2->ops->productnumeric)(ab->C2)); /* C2 = Ao * B_other */
1221:     KokkosSparse::spadd_numeric(&ab->kh, one, ab->C1_global, one, ab->C2_global, ab->C_global);

1223:     /* atb->C1 = Bd^t * ab->C_petsc */
1224:     atb = mmdata->mmAtB;
1225:     PetscCheck(atb->C1->product->B == ab->C_petsc, PETSC_COMM_SELF, PETSC_ERR_PLIB, "In MATPRODUCT_PtAP, internal mat product matrix atb->C1->B has unexpectedly changed");
1226:     if (atb->C1->product->A != Bd) PetscCall(MatProductReplaceMats(Bd, NULL, NULL, atb->C1));
1227:     PetscCall((*atb->C1->ops->productnumeric)(atb->C1));
1228:     /* atb->C2 = Bo^t * ab->C_petsc */
1229:     if (atb->C2->product->A != Bo) PetscCall(MatProductReplaceMats(Bo, NULL, NULL, atb->C2));
1230:     PetscCall((*atb->C2->ops->productnumeric)(atb->C2));
1231:     PetscCall(MatSeqAIJKokkosReduce(atb->C2, MAT_REUSE_MATRIX, PETSC_FALSE, 0 /* N */, MatColIdxKokkosView() /*l2g*/, NULL /* ownerSF */, atb->sf, atb->abuf, atb->srcrowoffset, atb->dstrowoffset, atb->C2_global));
1232:     KokkosSparse::spadd_numeric(&atb->kh, one, atb->C1_global, one, atb->C2_global, atb->C_global);
1233:     mm = static_cast<MatMatStruct *>(atb);
1234:   }
1235:   /* Split C_global to form C */
1236:   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_REUSE_MATRIX, mm->C_global, mm->Cdstart));
1237:   PetscFunctionReturn(PETSC_SUCCESS);
1238: }

1240: PetscErrorCode MatProductSymbolic_MPIAIJKokkos(Mat C)
1241: {
1242:   Mat                          A, B;
1243:   Mat_Product                 *product = C->product;
1244:   MatProductType               ptype;
1245:   MatProductData_MPIAIJKokkos *mmdata;
1246:   MatMatStruct                *mm   = NULL;
1247:   IS                           glob = NULL;
1248:   const PetscInt              *garray;
1249:   PetscInt                     m, n, M, N, sz;
1250:   ConstMatColIdxKokkosView     l2g; /* map local col ids to global ones */

1252:   PetscFunctionBegin;
1253:   MatCheckProduct(C, 1);
1254:   PetscCheck(!product->data, PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Product data not empty");
1255:   ptype = product->type;
1256:   A     = product->A;
1257:   B     = product->B;

1259:   switch (ptype) {
1260:   case MATPRODUCT_AB:
1261:     m = A->rmap->n;
1262:     n = B->cmap->n;
1263:     M = A->rmap->N;
1264:     N = B->cmap->N;
1265:     break;
1266:   case MATPRODUCT_AtB:
1267:     m = A->cmap->n;
1268:     n = B->cmap->n;
1269:     M = A->cmap->N;
1270:     N = B->cmap->N;
1271:     break;
1272:   case MATPRODUCT_PtAP:
1273:     m = B->cmap->n;
1274:     n = B->cmap->n;
1275:     M = B->cmap->N;
1276:     N = B->cmap->N;
1277:     break; /* BtAB */
1278:   default:
1279:     SETERRQ(PetscObjectComm((PetscObject)C), PETSC_ERR_PLIB, "Not for product type %s", MatProductTypes[ptype]);
1280:   }

1282:   PetscCall(MatSetSizes(C, m, n, M, N));
1283:   PetscCall(PetscLayoutSetUp(C->rmap));
1284:   PetscCall(PetscLayoutSetUp(C->cmap));
1285:   PetscCall(MatSetType(C, ((PetscObject)A)->type_name));

1287:   mmdata           = new MatProductData_MPIAIJKokkos();
1288:   mmdata->reusesym = product->api_user;

1290:   if (ptype == MATPRODUCT_AB) {
1291:     mmdata->mmAB = new MatMatStruct_AB();
1292:     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, mmdata->mmAB));
1293:     mm = static_cast<MatMatStruct *>(mmdata->mmAB);
1294:   } else if (ptype == MATPRODUCT_AtB) {
1295:     mmdata->mmAtB = new MatMatStruct_AtB();
1296:     auto atb      = mmdata->mmAtB;
1297:     PetscCall(MatMPIAIJGetLocalMatMerge(B, MAT_INITIAL_MATRIX, &glob, &atb->B_local));
1298:     PetscCall(ISGetIndices(glob, &garray));
1299:     PetscCall(ISGetSize(glob, &sz));
1300:     l2g = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), ConstMatColIdxKokkosViewHost(garray, sz));
1301:     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, A, atb->B_local, PETSC_TRUE, N, l2g, atb));
1302:     PetscCall(ISRestoreIndices(glob, &garray));
1303:     PetscCall(ISDestroy(&glob));
1304:     mm = static_cast<MatMatStruct *>(atb);
1305:   } else if (ptype == MATPRODUCT_PtAP) {    /* BtAB */
1306:     mmdata->mmAB  = new MatMatStruct_AB();  /* tmp=A*B */
1307:     mmdata->mmAtB = new MatMatStruct_AtB(); /* C=B^t*tmp */
1308:     auto ab       = mmdata->mmAB;
1309:     auto atb      = mmdata->mmAtB;
1310:     PetscCall(MatProductSymbolic_MPIAIJKokkos_AB(product, A, B, ab));
1311:     auto tmp = new Mat_SeqAIJKokkos(ab->C_global); /* Memory will be owned by ab->C_petsc */
1312:     PetscCall(MatCreateSeqAIJKokkosWithCSRMatrix(PETSC_COMM_SELF, tmp, &ab->C_petsc));
1313:     PetscCall(MatProductSymbolic_MPIAIJKokkos_AtB(product, B, ab->C_petsc, PETSC_FALSE, N, l2g /*not used*/, atb));
1314:     mm = static_cast<MatMatStruct *>(atb);
1315:   }
1316:   /* Split the C_global into petsc A, B format */
1317:   PetscCall(MatSetMPIAIJKokkosWithGlobalCSRMatrix(C, MAT_INITIAL_MATRIX, mm->C_global, mm->Cdstart));
1318:   C->product->data       = mmdata;
1319:   C->product->destroy    = MatProductDataDestroy_MPIAIJKokkos;
1320:   C->ops->productnumeric = MatProductNumeric_MPIAIJKokkos;
1321:   PetscFunctionReturn(PETSC_SUCCESS);
1322: }

1324: PETSC_INTERN PetscErrorCode MatProductSetFromOptions_MPIAIJKokkos(Mat mat)
1325: {
1326:   Mat_Product *product = mat->product;
1327:   PetscBool    match   = PETSC_FALSE;
1328:   PetscBool    usecpu  = PETSC_FALSE;

1330:   PetscFunctionBegin;
1331:   MatCheckProduct(mat, 1);
1332:   if (!product->A->boundtocpu && !product->B->boundtocpu) PetscCall(PetscObjectTypeCompare((PetscObject)product->B, ((PetscObject)product->A)->type_name, &match));
1333:   if (match) { /* we can always fallback to the CPU if requested */
1334:     switch (product->type) {
1335:     case MATPRODUCT_AB:
1336:       if (product->api_user) {
1337:         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatMatMult", "Mat");
1338:         PetscCall(PetscOptionsBool("-matmatmult_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1339:         PetscOptionsEnd();
1340:       } else {
1341:         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AB", "Mat");
1342:         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatMatMult", usecpu, &usecpu, NULL));
1343:         PetscOptionsEnd();
1344:       }
1345:       break;
1346:     case MATPRODUCT_AtB:
1347:       if (product->api_user) {
1348:         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatTransposeMatMult", "Mat");
1349:         PetscCall(PetscOptionsBool("-mattransposematmult_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1350:         PetscOptionsEnd();
1351:       } else {
1352:         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_AtB", "Mat");
1353:         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatTransposeMatMult", usecpu, &usecpu, NULL));
1354:         PetscOptionsEnd();
1355:       }
1356:       break;
1357:     case MATPRODUCT_PtAP:
1358:       if (product->api_user) {
1359:         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatPtAP", "Mat");
1360:         PetscCall(PetscOptionsBool("-matptap_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1361:         PetscOptionsEnd();
1362:       } else {
1363:         PetscOptionsBegin(PetscObjectComm((PetscObject)mat), ((PetscObject)mat)->prefix, "MatProduct_PtAP", "Mat");
1364:         PetscCall(PetscOptionsBool("-mat_product_algorithm_backend_cpu", "Use CPU code", "MatPtAP", usecpu, &usecpu, NULL));
1365:         PetscOptionsEnd();
1366:       }
1367:       break;
1368:     default:
1369:       break;
1370:     }
1371:     match = (PetscBool)!usecpu;
1372:   }
1373:   if (match) {
1374:     switch (product->type) {
1375:     case MATPRODUCT_AB:
1376:     case MATPRODUCT_AtB:
1377:     case MATPRODUCT_PtAP:
1378:       mat->ops->productsymbolic = MatProductSymbolic_MPIAIJKokkos;
1379:       break;
1380:     default:
1381:       break;
1382:     }
1383:   }
1384:   /* fallback to MPIAIJ ops */
1385:   if (!mat->ops->productsymbolic) PetscCall(MatProductSetFromOptions_MPIAIJ(mat));
1386:   PetscFunctionReturn(PETSC_SUCCESS);
1387: }

1389: static PetscErrorCode MatSetPreallocationCOO_MPIAIJKokkos(Mat mat, PetscCount coo_n, PetscInt coo_i[], PetscInt coo_j[])
1390: {
1391:   Mat_MPIAIJ       *mpiaij = (Mat_MPIAIJ *)mat->data;
1392:   Mat_MPIAIJKokkos *mpikok;

1394:   PetscFunctionBegin;
1395:   PetscCall(MatSetPreallocationCOO_MPIAIJ(mat, coo_n, coo_i, coo_j)); /* mpiaij->A,B's type is set to seqaijkokkos */
1396:   mat->preallocated = PETSC_TRUE;
1397:   PetscCall(MatAssemblyBegin(mat, MAT_FINAL_ASSEMBLY));
1398:   PetscCall(MatAssemblyEnd(mat, MAT_FINAL_ASSEMBLY));
1399:   PetscCall(MatZeroEntries(mat));
1400:   mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1401:   delete mpikok;
1402:   mpiaij->spptr = new Mat_MPIAIJKokkos(mpiaij);
1403:   PetscFunctionReturn(PETSC_SUCCESS);
1404: }

1406: static PetscErrorCode MatSetValuesCOO_MPIAIJKokkos(Mat mat, const PetscScalar v[], InsertMode imode)
1407: {
1408:   Mat_MPIAIJ                 *mpiaij = static_cast<Mat_MPIAIJ *>(mat->data);
1409:   Mat_MPIAIJKokkos           *mpikok = static_cast<Mat_MPIAIJKokkos *>(mpiaij->spptr);
1410:   Mat                         A = mpiaij->A, B = mpiaij->B;
1411:   PetscCount                  Annz = mpiaij->Annz, Annz2 = mpiaij->Annz2, Bnnz = mpiaij->Bnnz, Bnnz2 = mpiaij->Bnnz2;
1412:   MatScalarKokkosView         Aa, Ba;
1413:   MatScalarKokkosView         v1;
1414:   MatScalarKokkosView        &vsend  = mpikok->sendbuf_d;
1415:   const MatScalarKokkosView  &v2     = mpikok->recvbuf_d;
1416:   const PetscCountKokkosView &Ajmap1 = mpikok->Ajmap1_d, Ajmap2 = mpikok->Ajmap2_d, Aimap2 = mpikok->Aimap2_d;
1417:   const PetscCountKokkosView &Bjmap1 = mpikok->Bjmap1_d, Bjmap2 = mpikok->Bjmap2_d, Bimap2 = mpikok->Bimap2_d;
1418:   const PetscCountKokkosView &Aperm1 = mpikok->Aperm1_d, Aperm2 = mpikok->Aperm2_d, Bperm1 = mpikok->Bperm1_d, Bperm2 = mpikok->Bperm2_d;
1419:   const PetscCountKokkosView &Cperm1 = mpikok->Cperm1_d;
1420:   PetscMemType                memtype;

1422:   PetscFunctionBegin;
1423:   PetscCall(PetscGetMemType(v, &memtype)); /* Return PETSC_MEMTYPE_HOST when v is NULL */
1424:   if (PetscMemTypeHost(memtype)) {         /* If user gave v[] in host, we need to copy it to device if any */
1425:     v1 = Kokkos::create_mirror_view_and_copy(DefaultMemorySpace(), MatScalarKokkosViewHost((PetscScalar *)v, mpiaij->coo_n));
1426:   } else {
1427:     v1 = MatScalarKokkosView((PetscScalar *)v, mpiaij->coo_n); /* Directly use v[]'s memory */
1428:   }

1430:   if (imode == INSERT_VALUES) {
1431:     PetscCall(MatSeqAIJGetKokkosViewWrite(A, &Aa)); /* write matrix values */
1432:     PetscCall(MatSeqAIJGetKokkosViewWrite(B, &Ba));
1433:   } else {
1434:     PetscCall(MatSeqAIJGetKokkosView(A, &Aa)); /* read & write matrix values */
1435:     PetscCall(MatSeqAIJGetKokkosView(B, &Ba));
1436:   }

1438:   /* Pack entries to be sent to remote */
1439:   Kokkos::parallel_for(
1440:     vsend.extent(0), KOKKOS_LAMBDA(const PetscCount i) { vsend(i) = v1(Cperm1(i)); });

1442:   /* Send remote entries to their owner and overlap the communication with local computation */
1443:   PetscCall(PetscSFReduceWithMemTypeBegin(mpiaij->coo_sf, MPIU_SCALAR, PETSC_MEMTYPE_KOKKOS, vsend.data(), PETSC_MEMTYPE_KOKKOS, v2.data(), MPI_REPLACE));
1444:   /* Add local entries to A and B in one kernel */
1445:   Kokkos::parallel_for(
1446:     Annz + Bnnz, KOKKOS_LAMBDA(PetscCount i) {
1447:       PetscScalar sum = 0.0;
1448:       if (i < Annz) {
1449:         for (PetscCount k = Ajmap1(i); k < Ajmap1(i + 1); k++) sum += v1(Aperm1(k));
1450:         Aa(i) = (imode == INSERT_VALUES ? 0.0 : Aa(i)) + sum;
1451:       } else {
1452:         i -= Annz;
1453:         for (PetscCount k = Bjmap1(i); k < Bjmap1(i + 1); k++) sum += v1(Bperm1(k));
1454:         Ba(i) = (imode == INSERT_VALUES ? 0.0 : Ba(i)) + sum;
1455:       }
1456:     });
1457:   PetscCall(PetscSFReduceEnd(mpiaij->coo_sf, MPIU_SCALAR, vsend.data(), v2.data(), MPI_REPLACE));

1459:   /* Add received remote entries to A and B in one kernel */
1460:   Kokkos::parallel_for(
1461:     Annz2 + Bnnz2, KOKKOS_LAMBDA(PetscCount i) {
1462:       if (i < Annz2) {
1463:         for (PetscCount k = Ajmap2(i); k < Ajmap2(i + 1); k++) Aa(Aimap2(i)) += v2(Aperm2(k));
1464:       } else {
1465:         i -= Annz2;
1466:         for (PetscCount k = Bjmap2(i); k < Bjmap2(i + 1); k++) Ba(Bimap2(i)) += v2(Bperm2(k));
1467:       }
1468:     });

1470:   if (imode == INSERT_VALUES) {
1471:     PetscCall(MatSeqAIJRestoreKokkosViewWrite(A, &Aa)); /* Increase A & B's state etc. */
1472:     PetscCall(MatSeqAIJRestoreKokkosViewWrite(B, &Ba));
1473:   } else {
1474:     PetscCall(MatSeqAIJRestoreKokkosView(A, &Aa));
1475:     PetscCall(MatSeqAIJRestoreKokkosView(B, &Ba));
1476:   }
1477:   PetscFunctionReturn(PETSC_SUCCESS);
1478: }

1480: PetscErrorCode MatDestroy_MPIAIJKokkos(Mat A)
1481: {
1482:   Mat_MPIAIJ *mpiaij = (Mat_MPIAIJ *)A->data;

1484:   PetscFunctionBegin;
1485:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJSetPreallocation_C", NULL));
1486:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatMPIAIJGetLocalMatMerge_C", NULL));
1487:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetPreallocationCOO_C", NULL));
1488:   PetscCall(PetscObjectComposeFunction((PetscObject)A, "MatSetValuesCOO_C", NULL));
1489:   delete (Mat_MPIAIJKokkos *)mpiaij->spptr;
1490:   PetscCall(MatDestroy_MPIAIJ(A));
1491:   PetscFunctionReturn(PETSC_SUCCESS);
1492: }

1494: PETSC_INTERN PetscErrorCode MatConvert_MPIAIJ_MPIAIJKokkos(Mat A, MatType mtype, MatReuse reuse, Mat *newmat)
1495: {
1496:   Mat         B;
1497:   Mat_MPIAIJ *a;

1499:   PetscFunctionBegin;
1500:   if (reuse == MAT_INITIAL_MATRIX) {
1501:     PetscCall(MatDuplicate(A, MAT_COPY_VALUES, newmat));
1502:   } else if (reuse == MAT_REUSE_MATRIX) {
1503:     PetscCall(MatCopy(A, *newmat, SAME_NONZERO_PATTERN));
1504:   }
1505:   B = *newmat;

1507:   B->boundtocpu = PETSC_FALSE;
1508:   PetscCall(PetscFree(B->defaultvectype));
1509:   PetscCall(PetscStrallocpy(VECKOKKOS, &B->defaultvectype));
1510:   PetscCall(PetscObjectChangeTypeName((PetscObject)B, MATMPIAIJKOKKOS));

1512:   a = static_cast<Mat_MPIAIJ *>(A->data);
1513:   if (a->A) PetscCall(MatSetType(a->A, MATSEQAIJKOKKOS));
1514:   if (a->B) PetscCall(MatSetType(a->B, MATSEQAIJKOKKOS));
1515:   if (a->lvec) PetscCall(VecSetType(a->lvec, VECSEQKOKKOS));

1517:   B->ops->assemblyend           = MatAssemblyEnd_MPIAIJKokkos;
1518:   B->ops->mult                  = MatMult_MPIAIJKokkos;
1519:   B->ops->multadd               = MatMultAdd_MPIAIJKokkos;
1520:   B->ops->multtranspose         = MatMultTranspose_MPIAIJKokkos;
1521:   B->ops->productsetfromoptions = MatProductSetFromOptions_MPIAIJKokkos;
1522:   B->ops->destroy               = MatDestroy_MPIAIJKokkos;

1524:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJSetPreallocation_C", MatMPIAIJSetPreallocation_MPIAIJKokkos));
1525:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatMPIAIJGetLocalMatMerge_C", MatMPIAIJGetLocalMatMerge_MPIAIJKokkos));
1526:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetPreallocationCOO_C", MatSetPreallocationCOO_MPIAIJKokkos));
1527:   PetscCall(PetscObjectComposeFunction((PetscObject)B, "MatSetValuesCOO_C", MatSetValuesCOO_MPIAIJKokkos));
1528:   PetscFunctionReturn(PETSC_SUCCESS);
1529: }
1530: /*MC
1531:    MATAIJKOKKOS - "mpiaijkokkos", a matrix type to be used for CSR sparse matrices with Kokkos

1533:    A matrix type type using Kokkos-Kernels CrsMatrix type for portability across different device types

1535:    Options Database Key:
1536: .  -mat_type aijkokkos - sets the matrix type to `MATAIJKOKKOS`

1538:   Level: beginner

1540: .seealso: [](chapter_matrices), `Mat`, `MatCreateAIJKokkos()`, `MATSEQAIJKOKKOS`, `MATSEQAIJ`, `MATMPIAIJ`
1541: M*/
1542: PETSC_EXTERN PetscErrorCode MatCreate_MPIAIJKokkos(Mat A)
1543: {
1544:   PetscFunctionBegin;
1545:   PetscCall(PetscKokkosInitializeCheck());
1546:   PetscCall(MatCreate_MPIAIJ(A));
1547:   PetscCall(MatConvert_MPIAIJ_MPIAIJKokkos(A, MATMPIAIJKOKKOS, MAT_INPLACE_MATRIX, &A));
1548:   PetscFunctionReturn(PETSC_SUCCESS);
1549: }

1551: /*@C
1552:    MatCreateAIJKokkos - Creates a sparse matrix in `MATAIJKOKOS` (compressed row) format
1553:    (the default parallel PETSc format).  This matrix will ultimately pushed down
1554:    to Kokkos for calculations. For good matrix
1555:    assembly performance the user should preallocate the matrix storage by setting
1556:    the parameter `nz` (or the array `nnz`).

1558:    Collective

1560:    Input Parameters:
1561: +  comm - MPI communicator, set to `PETSC_COMM_SELF`
1562: .  m - number of rows
1563: .  n - number of columns
1564: .  nz - number of nonzeros per row (same for all rows)
1565: -  nnz - array containing the number of nonzeros in the various rows
1566:          (possibly different for each row) or `NULL`

1568:    Output Parameter:
1569: .  A - the matrix

1571:    Level: intermediate

1573:    Notes:
1574:    It is recommended that one use the `MatCreate()`, `MatSetType()` and/or `MatSetFromOptions()`,
1575:    MatXXXXSetPreallocation() paradigm instead of this routine directly.
1576:    [MatXXXXSetPreallocation() is, for example, `MatSeqAIJSetPreallocation()`]

1578:    If `nnz` is given then `nz` is ignored

1580:    The AIJ format, also called compressed row storage), is fully compatible with standard Fortran
1581:    storage.  That is, the stored row and column indices can begin at
1582:    either one (as in Fortran) or zero.

1584:    Specify the preallocated storage with either nz or nnz (not both).
1585:    Set `nz` = `PETSC_DEFAULT` and `nnz` = `NULL` for PETSc to control dynamic memory
1586:    allocation.  For large problems you MUST preallocate memory or you
1587:    will get TERRIBLE performance, see the users' manual chapter on matrices.

1589:    By default, this format uses inodes (identical nodes) when possible, to
1590:    improve numerical efficiency of matrix-vector products and solves. We
1591:    search for consecutive rows with the same nonzero structure, thereby
1592:    reusing matrix information to achieve increased efficiency.

1594:    Developer Note:
1595:    This manual page is for the sequential constructor, not the parallel constructor

1597: .seealso: [](chapter_matrices), `Mat`, `MATAIJKOKOS`, `MATSEQAIJKOKOS`, `MATMPIAIJKOKOS`, `MatCreate()`, `MatCreateAIJ()`, `MatSetValues()`,
1598:           `MatSeqAIJSetColumnIndices()`, `MatCreateSeqAIJWithArrays()`, `MatCreateAIJ()`, `MATMPIAIJKOKKOS`, `MATAIJKOKKOS`
1599: @*/
1600: PetscErrorCode MatCreateAIJKokkos(MPI_Comm comm, PetscInt m, PetscInt n, PetscInt M, PetscInt N, PetscInt d_nz, const PetscInt d_nnz[], PetscInt o_nz, const PetscInt o_nnz[], Mat *A)
1601: {
1602:   PetscMPIInt size;

1604:   PetscFunctionBegin;
1605:   PetscCall(MatCreate(comm, A));
1606:   PetscCall(MatSetSizes(*A, m, n, M, N));
1607:   PetscCallMPI(MPI_Comm_size(comm, &size));
1608:   if (size > 1) {
1609:     PetscCall(MatSetType(*A, MATMPIAIJKOKKOS));
1610:     PetscCall(MatMPIAIJSetPreallocation(*A, d_nz, d_nnz, o_nz, o_nnz));
1611:   } else {
1612:     PetscCall(MatSetType(*A, MATSEQAIJKOKKOS));
1613:     PetscCall(MatSeqAIJSetPreallocation(*A, d_nz, d_nnz));
1614:   }
1615:   PetscFunctionReturn(PETSC_SUCCESS);
1616: }

1618: // get GPU pointer to stripped down Mat. For both Seq and MPI Mat.
1619: PetscErrorCode MatKokkosGetDeviceMatWrite(Mat A, PetscSplitCSRDataStructure *B)
1620: {
1621:   PetscMPIInt                size, rank;
1622:   MPI_Comm                   comm;
1623:   PetscSplitCSRDataStructure d_mat = NULL;

1625:   PetscFunctionBegin;
1626:   PetscCall(PetscObjectGetComm((PetscObject)A, &comm));
1627:   PetscCallMPI(MPI_Comm_size(comm, &size));
1628:   PetscCallMPI(MPI_Comm_rank(comm, &rank));
1629:   if (size == 1) {
1630:     PetscCall(MatSeqAIJKokkosGetDeviceMat(A, &d_mat));
1631:     PetscCall(MatSeqAIJKokkosModifyDevice(A)); /* Since we are going to modify matrix values on device */
1632:   } else {
1633:     Mat_MPIAIJ *aij = (Mat_MPIAIJ *)A->data;
1634:     PetscCall(MatSeqAIJKokkosGetDeviceMat(aij->A, &d_mat));
1635:     PetscCall(MatSeqAIJKokkosModifyDevice(aij->A));
1636:     PetscCall(MatSeqAIJKokkosModifyDevice(aij->B));
1637:     PetscCheck(A->nooffprocentries || aij->donotstash, PetscObjectComm((PetscObject)A), PETSC_ERR_SUP, "Device assembly does not currently support offproc values insertion. Use MatSetOption(A,MAT_NO_OFF_PROC_ENTRIES,PETSC_TRUE) or MatSetOption(A,MAT_IGNORE_OFF_PROC_ENTRIES,PETSC_TRUE)");
1638:   }
1639:   // act like MatSetValues because not called on host
1640:   if (A->assembled) {
1641:     if (A->was_assembled) PetscCall(PetscInfo(A, "Assemble more than once already\n"));
1642:     A->was_assembled = PETSC_TRUE; // this is done (lazy) in MatAssemble but we are not calling it anymore - done in AIJ AssemblyEnd, need here?
1643:   } else {
1644:     PetscCall(PetscInfo(A, "Warning !assemble ??? assembled=%" PetscInt_FMT "\n", A->assembled));
1645:   }
1646:   if (!d_mat) {
1647:     struct _n_SplitCSRMat h_mat; /* host container */
1648:     Mat_SeqAIJKokkos     *aijkokA;
1649:     Mat_SeqAIJ           *jaca;
1650:     PetscInt              n = A->rmap->n, nnz;
1651:     Mat                   Amat;
1652:     PetscInt             *colmap;

1654:     /* create and copy h_mat */
1655:     h_mat.M = A->cmap->N; // use for debug build
1656:     PetscCall(PetscInfo(A, "Create device matrix in Kokkos\n"));
1657:     if (size == 1) {
1658:       Amat            = A;
1659:       jaca            = (Mat_SeqAIJ *)A->data;
1660:       h_mat.rstart    = 0;
1661:       h_mat.rend      = A->rmap->n;
1662:       h_mat.cstart    = 0;
1663:       h_mat.cend      = A->cmap->n;
1664:       h_mat.offdiag.i = h_mat.offdiag.j = NULL;
1665:       h_mat.offdiag.a                   = NULL;
1666:       aijkokA                           = static_cast<Mat_SeqAIJKokkos *>(A->spptr);
1667:     } else {
1668:       Mat_MPIAIJ       *aij  = (Mat_MPIAIJ *)A->data;
1669:       Mat_SeqAIJ       *jacb = (Mat_SeqAIJ *)aij->B->data;
1670:       PetscInt          ii;
1671:       Mat_SeqAIJKokkos *aijkokB;

1673:       Amat    = aij->A;
1674:       aijkokA = static_cast<Mat_SeqAIJKokkos *>(aij->A->spptr);
1675:       aijkokB = static_cast<Mat_SeqAIJKokkos *>(aij->B->spptr);
1676:       jaca    = (Mat_SeqAIJ *)aij->A->data;
1677:       PetscCheck(!aij->B->cmap->n || aij->garray, comm, PETSC_ERR_PLIB, "MPIAIJ Matrix was assembled but is missing garray");
1678:       PetscCheck(aij->B->rmap->n == aij->A->rmap->n, comm, PETSC_ERR_SUP, "Only support aij->B->rmap->n == aij->A->rmap->n");
1679:       aij->donotstash          = PETSC_TRUE;
1680:       aij->A->nooffprocentries = aij->B->nooffprocentries = A->nooffprocentries = PETSC_TRUE;
1681:       jaca->nonew = jacb->nonew = PETSC_TRUE; // no more disassembly
1682:       PetscCall(PetscCalloc1(A->cmap->N, &colmap));
1683:       for (ii = 0; ii < aij->B->cmap->n; ii++) colmap[aij->garray[ii]] = ii + 1;
1684:       // allocate B copy data
1685:       h_mat.rstart = A->rmap->rstart;
1686:       h_mat.rend   = A->rmap->rend;
1687:       h_mat.cstart = A->cmap->rstart;
1688:       h_mat.cend   = A->cmap->rend;
1689:       nnz          = jacb->i[n];
1690:       if (jacb->compressedrow.use) {
1691:         const Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_i_k(jacb->i, n + 1);
1692:         aijkokB->i_uncompressed_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_i_k));
1693:         Kokkos::deep_copy(aijkokB->i_uncompressed_d, h_i_k);
1694:         h_mat.offdiag.i = aijkokB->i_uncompressed_d.data();
1695:       } else {
1696:         h_mat.offdiag.i = aijkokB->i_device_data();
1697:       }
1698:       h_mat.offdiag.j = aijkokB->j_device_data();
1699:       h_mat.offdiag.a = aijkokB->a_device_data();
1700:       {
1701:         Kokkos::View<PetscInt *, Kokkos::HostSpace, Kokkos::MemoryTraits<Kokkos::Unmanaged>> h_colmap_k(colmap, A->cmap->N);
1702:         aijkokB->colmap_d = Kokkos::View<PetscInt *>(Kokkos::create_mirror(DefaultMemorySpace(), h_colmap_k));
1703:         Kokkos::deep_copy(aijkokB->colmap_d, h_colmap_k);
1704:         h_mat.colmap = aijkokB->colmap_d.data();
1705:         PetscCall(PetscFree(colmap));
1706:       }
1707:       h_mat.offdiag.ignorezeroentries = jacb->ignorezeroentries;
1708:       h_mat.offdiag.n                 = n;
1709:     }
1710:     // allocate A copy data
1711:     nnz                          = jaca->i[n];
1712:     h_mat.diag.n                 = n;
1713:     h_mat.diag.ignorezeroentries = jaca->ignorezeroentries;
1714:     PetscCallMPI(MPI_Comm_rank(comm, &h_mat.rank));
1715:     PetscCheck(!jaca->compressedrow.use, PETSC_COMM_SELF, PETSC_ERR_PLIB, "A does not support compressed row (todo)");
1716:     h_mat.diag.i = aijkokA->i_device_data();
1717:     h_mat.diag.j = aijkokA->j_device_data();
1718:     h_mat.diag.a = aijkokA->a_device_data();
1719:     // copy pointers and metadata to device
1720:     PetscCall(MatSeqAIJKokkosSetDeviceMat(Amat, &h_mat));
1721:     PetscCall(MatSeqAIJKokkosGetDeviceMat(Amat, &d_mat));
1722:     PetscCall(PetscInfo(A, "Create device Mat n=%" PetscInt_FMT " nnz=%" PetscInt_FMT "\n", h_mat.diag.n, nnz));
1723:   }
1724:   *B           = d_mat;       // return it, set it in Mat, and set it up
1725:   A->assembled = PETSC_FALSE; // ready to write with matsetvalues - this done (lazy) in normal MatSetValues
1726:   PetscFunctionReturn(PETSC_SUCCESS);
1727: }

1729: PETSC_INTERN PetscErrorCode MatSeqAIJKokkosGetOffloadMask(Mat A, const char **mask)
1730: {
1731:   Mat_SeqAIJKokkos *aijkok = static_cast<Mat_SeqAIJKokkos *>(A->spptr);

1733:   PetscFunctionBegin;
1734:   if (!aijkok) *mask = "AIJKOK_UNALLOCATED";
1735:   else if (aijkok->a_dual.need_sync_host()) *mask = "PETSC_OFFLOAD_GPU";
1736:   else if (aijkok->a_dual.need_sync_device()) *mask = "PETSC_OFFLOAD_CPU";
1737:   else *mask = "PETSC_OFFLOAD_BOTH";
1738:   PetscFunctionReturn(PETSC_SUCCESS);
1739: }

1741: PETSC_INTERN PetscErrorCode MatAIJKokkosPrintOffloadMask(Mat A)
1742: {
1743:   PetscMPIInt size;
1744:   Mat         Ad, Ao;
1745:   const char *amask, *bmask;

1747:   PetscFunctionBegin;
1748:   PetscCallMPI(MPI_Comm_size(PetscObjectComm((PetscObject)A), &size));

1750:   if (size == 1) {
1751:     PetscCall(MatSeqAIJKokkosGetOffloadMask(A, &amask));
1752:     PetscCall(PetscPrintf(PETSC_COMM_SELF, "%s\n", amask));
1753:   } else {
1754:     Ad = ((Mat_MPIAIJ *)A->data)->A;
1755:     Ao = ((Mat_MPIAIJ *)A->data)->B;
1756:     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ad, &amask));
1757:     PetscCall(MatSeqAIJKokkosGetOffloadMask(Ao, &bmask));
1758:     PetscCall(PetscPrintf(PETSC_COMM_SELF, "Diag : Off-diag = %s : %s\n", amask, bmask));
1759:   }
1760:   PetscFunctionReturn(PETSC_SUCCESS);
1761: }