Actual source code: sfkok.kokkos.cxx

  1: #include <../src/vec/is/sf/impls/basic/sfpack.h>

  3: #include <Kokkos_Core.hpp>

  5: using DeviceExecutionSpace = Kokkos::DefaultExecutionSpace;
  6: using DeviceMemorySpace    = typename DeviceExecutionSpace::memory_space;
  7: using HostMemorySpace      = Kokkos::HostSpace;

  9: typedef Kokkos::View<char*,DeviceMemorySpace>       deviceBuffer_t;
 10: typedef Kokkos::View<char*,HostMemorySpace>         HostBuffer_t;

 12: typedef Kokkos::View<const char*,DeviceMemorySpace> deviceConstBuffer_t;
 13: typedef Kokkos::View<const char*,HostMemorySpace>   HostConstBuffer_t;

 15: /*====================================================================================*/
 16: /*                             Regular operations                           */
 17: /*====================================================================================*/
 18: template<typename Type> struct Insert{KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = y;             return old;}};
 19: template<typename Type> struct Add   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x += y;             return old;}};
 20: template<typename Type> struct Mult  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x *= y;             return old;}};
 21: template<typename Type> struct Min   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = PetscMin(x,y); return old;}};
 22: template<typename Type> struct Max   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = PetscMax(x,y); return old;}};
 23: template<typename Type> struct LAND  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x && y;        return old;}};
 24: template<typename Type> struct LOR   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x || y;        return old;}};
 25: template<typename Type> struct LXOR  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = !x != !y;      return old;}};
 26: template<typename Type> struct BAND  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x & y;         return old;}};
 27: template<typename Type> struct BOR   {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x | y;         return old;}};
 28: template<typename Type> struct BXOR  {KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {Type old = x; x  = x ^ y;         return old;}};
 29: template<typename PairType> struct Minloc {
 30:   KOKKOS_INLINE_FUNCTION PairType operator()(PairType& x,PairType y) const {
 31:     PairType old = x;
 32:     if (y.first < x.first) x = y;
 33:     else if (y.first == x.first) x.second = PetscMin(x.second,y.second);
 34:     return old;
 35:   }
 36: };
 37: template<typename PairType> struct Maxloc {
 38:   KOKKOS_INLINE_FUNCTION PairType operator()(PairType& x,PairType y) const {
 39:     PairType old = x;
 40:     if (y.first > x.first) x = y;
 41:     else if (y.first == x.first) x.second = PetscMin(x.second,y.second); /* See MPI MAXLOC */
 42:     return old;
 43:   }
 44: };

 46: /*====================================================================================*/
 47: /*                             Atomic operations                            */
 48: /*====================================================================================*/
 49: template<typename Type> struct AtomicInsert  {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_assign(&x,y);}};
 50: template<typename Type> struct AtomicAdd     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_add(&x,y);}};
 51: template<typename Type> struct AtomicBAND    {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_and(&x,y);}};
 52: template<typename Type> struct AtomicBOR     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_or (&x,y);}};
 53: template<typename Type> struct AtomicBXOR    {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_xor(&x,y);}};
 54: template<typename Type> struct AtomicLAND    {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {const Type zero=0,one=~0; Kokkos::atomic_and(&x,y?one:zero);}};
 55: template<typename Type> struct AtomicLOR     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {const Type zero=0,one=1;  Kokkos::atomic_or (&x,y?one:zero);}};
 56: template<typename Type> struct AtomicMult    {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_mul(&x,y);}};
 57: template<typename Type> struct AtomicMin     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_min(&x,y);}};
 58: template<typename Type> struct AtomicMax     {KOKKOS_INLINE_FUNCTION void operator()(Type& x,Type y) const {Kokkos::atomic_fetch_max(&x,y);}};
 59: /* TODO: struct AtomicLXOR  */
 60: template<typename Type> struct AtomicFetchAdd{KOKKOS_INLINE_FUNCTION Type operator()(Type& x,Type y) const {return Kokkos::atomic_fetch_add(&x,y);}};

 62: /* Map a thread id to an index in root/leaf space through a series of 3D subdomains. See PetscSFPackOpt. */
 63: static KOKKOS_INLINE_FUNCTION PetscInt MapTidToIndex(const PetscInt *opt,PetscInt tid)
 64: {
 65:   PetscInt        i,j,k,m,n,r;
 66:   const PetscInt  *offset,*start,*dx,*dy,*X,*Y;

 68:   n      = opt[0];
 69:   offset = opt + 1;
 70:   start  = opt + n + 2;
 71:   dx     = opt + 2*n + 2;
 72:   dy     = opt + 3*n + 2;
 73:   X      = opt + 5*n + 2;
 74:   Y      = opt + 6*n + 2;
 75:   for (r=0; r<n; r++) {if (tid < offset[r+1]) break;}
 76:   m = (tid - offset[r]);
 77:   k = m/(dx[r]*dy[r]);
 78:   j = (m - k*dx[r]*dy[r])/dx[r];
 79:   i = m - k*dx[r]*dy[r] - j*dx[r];

 81:   return (start[r] + k*X[r]*Y[r] + j*X[r] + i);
 82: }

 84: /*====================================================================================*/
 85: /*  Wrappers for Pack/Unpack/Scatter kernels. Function pointers are stored in 'link'         */
 86: /*====================================================================================*/

 88: /* Suppose user calls PetscSFReduce(sf,unit,...) and <unit> is an MPI data type made of 16 PetscReals, then
 89:    <Type> is PetscReal, which is the primitive type we operate on.
 90:    <bs>   is 16, which says <unit> contains 16 primitive types.
 91:    <BS>   is 8, which is the maximal SIMD width we will try to vectorize operations on <unit>.
 92:    <EQ>   is 0, which is (bs == BS ? 1 : 0)

 94:   If instead, <unit> has 8 PetscReals, then bs=8, BS=8, EQ=1, rendering MBS below to a compile time constant.
 95:   For the common case in VecScatter, bs=1, BS=1, EQ=1, MBS=1, the inner for-loops below will be totally unrolled.
 96: */
 97: template<typename Type,PetscInt BS,PetscInt EQ>
 98: static PetscErrorCode Pack(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,const void *data_,void *buf_)
 99: {
100:   const PetscInt          *iopt = opt ? opt->array : NULL;
101:   const PetscInt          M = EQ ? 1 : link->bs/BS, MBS=M*BS; /* If EQ, then MBS will be a compile-time const */
102:   const Type              *data = static_cast<const Type*>(data_);
103:   Type                    *buf = static_cast<Type*>(buf_);
104:   DeviceExecutionSpace    exec;

106:   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
107:     /* iopt != NULL ==> idx == NULL, i.e., the indices have patterns but not contiguous;
108:        iopt == NULL && idx == NULL ==> the indices are contiguous;
109:      */
110:     PetscInt t = (iopt? MapTidToIndex(iopt,tid) : (idx? idx[tid] : start+tid))*MBS;
111:     PetscInt s = tid*MBS;
112:     for (int i=0; i<MBS; i++) buf[s+i] = data[t+i];
113:   });
114:   return 0;
115: }

117: template<typename Type,class Op,PetscInt BS,PetscInt EQ>
118: static PetscErrorCode UnpackAndOp(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,void *data_,const void *buf_)
119: {
120:   Op                      op;
121:   const PetscInt          *iopt = opt ? opt->array : NULL;
122:   const PetscInt          M = EQ ? 1 : link->bs/BS, MBS=M*BS;
123:   Type                    *data = static_cast<Type*>(data_);
124:   const Type              *buf = static_cast<const Type*>(buf_);
125:   DeviceExecutionSpace    exec;

127:   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
128:     PetscInt t = (iopt? MapTidToIndex(iopt,tid) : (idx? idx[tid] : start+tid))*MBS;
129:     PetscInt s = tid*MBS;
130:     for (int i=0; i<MBS; i++) op(data[t+i],buf[s+i]);
131:   });
132:   return 0;
133: }

135: template<typename Type,class Op,PetscInt BS,PetscInt EQ>
136: static PetscErrorCode FetchAndOp(PetscSFLink link,PetscInt count,PetscInt start,PetscSFPackOpt opt,const PetscInt *idx,void *data,void *buf)
137: {
138:   Op                      op;
139:   const PetscInt          *ropt = opt ? opt->array : NULL;
140:   const PetscInt          M = EQ ? 1 : link->bs/BS, MBS=M*BS;
141:   Type                    *rootdata = static_cast<Type*>(data),*leafbuf=static_cast<Type*>(buf);
142:   DeviceExecutionSpace    exec;

144:   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
145:     PetscInt r = (ropt? MapTidToIndex(ropt,tid) : (idx? idx[tid] : start+tid))*MBS;
146:     PetscInt l = tid*MBS;
147:     for (int i=0; i<MBS; i++) leafbuf[l+i] = op(rootdata[r+i],leafbuf[l+i]);
148:   });
149:   return 0;
150: }

152: template<typename Type,class Op,PetscInt BS,PetscInt EQ>
153: static PetscErrorCode ScatterAndOp(PetscSFLink link,PetscInt count,PetscInt srcStart,PetscSFPackOpt srcOpt,const PetscInt *srcIdx,const void *src_,PetscInt dstStart,PetscSFPackOpt dstOpt,const PetscInt *dstIdx,void *dst_)
154: {
155:   PetscInt                srcx=0,srcy=0,srcX=0,srcY=0,dstx=0,dsty=0,dstX=0,dstY=0;
156:   const PetscInt          M = (EQ) ? 1 : link->bs/BS, MBS=M*BS;
157:   const Type              *src = static_cast<const Type*>(src_);
158:   Type                    *dst = static_cast<Type*>(dst_);
159:   DeviceExecutionSpace    exec;

161:   /* The 3D shape of source subdomain may be different than that of the destination, which makes it difficult to use CUDA 3D grid and block */
162:   if (srcOpt)       {srcx = srcOpt->dx[0]; srcy = srcOpt->dy[0]; srcX = srcOpt->X[0]; srcY = srcOpt->Y[0]; srcStart = srcOpt->start[0]; srcIdx = NULL;}
163:   else if (!srcIdx) {srcx = srcX = count; srcy = srcY = 1;}

165:   if (dstOpt)       {dstx = dstOpt->dx[0]; dsty = dstOpt->dy[0]; dstX = dstOpt->X[0]; dstY = dstOpt->Y[0]; dstStart = dstOpt->start[0]; dstIdx = NULL;}
166:   else if (!dstIdx) {dstx = dstX = count; dsty = dstY = 1;}

168:   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
169:     PetscInt i,j,k,s,t;
170:     Op       op;
171:     if (!srcIdx) { /* src is in 3D */
172:       k = tid/(srcx*srcy);
173:       j = (tid - k*srcx*srcy)/srcx;
174:       i = tid - k*srcx*srcy - j*srcx;
175:       s = srcStart + k*srcX*srcY + j*srcX + i;
176:     } else { /* src is contiguous */
177:       s = srcIdx[tid];
178:     }

180:     if (!dstIdx) { /* 3D */
181:       k = tid/(dstx*dsty);
182:       j = (tid - k*dstx*dsty)/dstx;
183:       i = tid - k*dstx*dsty - j*dstx;
184:       t = dstStart + k*dstX*dstY + j*dstX + i;
185:     } else { /* contiguous */
186:       t = dstIdx[tid];
187:     }

189:     s *= MBS;
190:     t *= MBS;
191:     for (i=0; i<MBS; i++) op(dst[t+i],src[s+i]);
192:   });
193:   return 0;
194: }

196: /* Specialization for Insert since we may use memcpy */
197: template<typename Type,PetscInt BS,PetscInt EQ>
198: static PetscErrorCode ScatterAndInsert(PetscSFLink link,PetscInt count,PetscInt srcStart,PetscSFPackOpt srcOpt,const PetscInt *srcIdx,const void *src_,PetscInt dstStart,PetscSFPackOpt dstOpt,const PetscInt *dstIdx,void *dst_)
199: {
200:   const Type              *src = static_cast<const Type*>(src_);
201:   Type                    *dst = static_cast<Type*>(dst_);
202:   DeviceExecutionSpace    exec;

204:   if (!count) return 0;
205:   /*src and dst are contiguous */
206:   if ((!srcOpt && !srcIdx) && (!dstOpt && !dstIdx) && src != dst) {
207:     size_t sz = count*link->unitbytes;
208:     deviceBuffer_t      dbuf(reinterpret_cast<char*>(dst+dstStart*link->bs),sz);
209:     deviceConstBuffer_t sbuf(reinterpret_cast<const char*>(src+srcStart*link->bs),sz);
210:     Kokkos::deep_copy(exec,dbuf,sbuf);
211:   } else {
212:     ScatterAndOp<Type,Insert<Type>,BS,EQ>(link,count,srcStart,srcOpt,srcIdx,src,dstStart,dstOpt,dstIdx,dst);
213:   }
214:   return 0;
215: }

217: template<typename Type,class Op,PetscInt BS,PetscInt EQ>
218: static PetscErrorCode FetchAndOpLocal(PetscSFLink link,PetscInt count,PetscInt rootstart,PetscSFPackOpt rootopt,const PetscInt *rootidx,void *rootdata_,PetscInt leafstart,PetscSFPackOpt leafopt,const PetscInt *leafidx,const void *leafdata_,void *leafupdate_)
219: {
220:   Op                      op;
221:   const PetscInt          M = (EQ) ? 1 : link->bs/BS, MBS = M*BS;
222:   const PetscInt          *ropt = rootopt ? rootopt->array : NULL;
223:   const PetscInt          *lopt = leafopt ? leafopt->array : NULL;
224:   Type                    *rootdata = static_cast<Type*>(rootdata_),*leafupdate = static_cast<Type*>(leafupdate_);
225:   const Type              *leafdata = static_cast<const Type*>(leafdata_);
226:   DeviceExecutionSpace    exec;

228:   Kokkos::parallel_for(Kokkos::RangePolicy<DeviceExecutionSpace>(exec,0,count),KOKKOS_LAMBDA(PetscInt tid) {
229:     PetscInt r = (ropt? MapTidToIndex(ropt,tid) : (rootidx? rootidx[tid] : rootstart+tid))*MBS;
230:     PetscInt l = (lopt? MapTidToIndex(lopt,tid) : (leafidx? leafidx[tid] : leafstart+tid))*MBS;
231:     for (int i=0; i<MBS; i++) leafupdate[l+i] = op(rootdata[r+i],leafdata[l+i]);
232:   });
233:   return 0;
234: }

236: /*====================================================================================*/
237: /*  Init various types and instantiate pack/unpack function pointers                  */
238: /*====================================================================================*/
239: template<typename Type,PetscInt BS,PetscInt EQ>
240: static void PackInit_RealType(PetscSFLink link)
241: {
242:   /* Pack/unpack for remote communication */
243:   link->d_Pack              = Pack<Type,BS,EQ>;
244:   link->d_UnpackAndInsert   = UnpackAndOp<Type,Insert<Type>,BS,EQ>;
245:   link->d_UnpackAndAdd      = UnpackAndOp<Type,Add<Type>   ,BS,EQ>;
246:   link->d_UnpackAndMult     = UnpackAndOp<Type,Mult<Type>  ,BS,EQ>;
247:   link->d_UnpackAndMin      = UnpackAndOp<Type,Min<Type>   ,BS,EQ>;
248:   link->d_UnpackAndMax      = UnpackAndOp<Type,Max<Type>   ,BS,EQ>;
249:   link->d_FetchAndAdd       = FetchAndOp <Type,Add<Type>   ,BS,EQ>;
250:   /* Scatter for local communication */
251:   link->d_ScatterAndInsert  = ScatterAndInsert<Type,BS,EQ>; /* Has special optimizations */
252:   link->d_ScatterAndAdd     = ScatterAndOp<Type,Add<Type>    ,BS,EQ>;
253:   link->d_ScatterAndMult    = ScatterAndOp<Type,Mult<Type>   ,BS,EQ>;
254:   link->d_ScatterAndMin     = ScatterAndOp<Type,Min<Type>    ,BS,EQ>;
255:   link->d_ScatterAndMax     = ScatterAndOp<Type,Max<Type>    ,BS,EQ>;
256:   link->d_FetchAndAddLocal  = FetchAndOpLocal<Type,Add <Type>,BS,EQ>;
257:   /* Atomic versions when there are data-race possibilities */
258:   link->da_UnpackAndInsert  = UnpackAndOp<Type,AtomicInsert<Type>  ,BS,EQ>;
259:   link->da_UnpackAndAdd     = UnpackAndOp<Type,AtomicAdd<Type>     ,BS,EQ>;
260:   link->da_UnpackAndMult    = UnpackAndOp<Type,AtomicMult<Type>    ,BS,EQ>;
261:   link->da_UnpackAndMin     = UnpackAndOp<Type,AtomicMin<Type>     ,BS,EQ>;
262:   link->da_UnpackAndMax     = UnpackAndOp<Type,AtomicMax<Type>     ,BS,EQ>;
263:   link->da_FetchAndAdd      = FetchAndOp <Type,AtomicFetchAdd<Type>,BS,EQ>;

265:   link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>;
266:   link->da_ScatterAndAdd    = ScatterAndOp<Type,AtomicAdd<Type>   ,BS,EQ>;
267:   link->da_ScatterAndMult   = ScatterAndOp<Type,AtomicMult<Type>  ,BS,EQ>;
268:   link->da_ScatterAndMin    = ScatterAndOp<Type,AtomicMin<Type>   ,BS,EQ>;
269:   link->da_ScatterAndMax    = ScatterAndOp<Type,AtomicMax<Type>   ,BS,EQ>;
270:   link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>;
271: }

273: template<typename Type,PetscInt BS,PetscInt EQ>
274: static void PackInit_IntegerType(PetscSFLink link)
275: {
276:   link->d_Pack              = Pack<Type,BS,EQ>;
277:   link->d_UnpackAndInsert   = UnpackAndOp<Type,Insert<Type> ,BS,EQ>;
278:   link->d_UnpackAndAdd      = UnpackAndOp<Type,Add<Type>    ,BS,EQ>;
279:   link->d_UnpackAndMult     = UnpackAndOp<Type,Mult<Type>   ,BS,EQ>;
280:   link->d_UnpackAndMin      = UnpackAndOp<Type,Min<Type>    ,BS,EQ>;
281:   link->d_UnpackAndMax      = UnpackAndOp<Type,Max<Type>    ,BS,EQ>;
282:   link->d_UnpackAndLAND     = UnpackAndOp<Type,LAND<Type>   ,BS,EQ>;
283:   link->d_UnpackAndLOR      = UnpackAndOp<Type,LOR<Type>    ,BS,EQ>;
284:   link->d_UnpackAndLXOR     = UnpackAndOp<Type,LXOR<Type>   ,BS,EQ>;
285:   link->d_UnpackAndBAND     = UnpackAndOp<Type,BAND<Type>   ,BS,EQ>;
286:   link->d_UnpackAndBOR      = UnpackAndOp<Type,BOR<Type>    ,BS,EQ>;
287:   link->d_UnpackAndBXOR     = UnpackAndOp<Type,BXOR<Type>   ,BS,EQ>;
288:   link->d_FetchAndAdd       = FetchAndOp <Type,Add<Type>    ,BS,EQ>;

290:   link->d_ScatterAndInsert  = ScatterAndInsert<Type,BS,EQ>;
291:   link->d_ScatterAndAdd     = ScatterAndOp<Type,Add<Type>   ,BS,EQ>;
292:   link->d_ScatterAndMult    = ScatterAndOp<Type,Mult<Type>  ,BS,EQ>;
293:   link->d_ScatterAndMin     = ScatterAndOp<Type,Min<Type>   ,BS,EQ>;
294:   link->d_ScatterAndMax     = ScatterAndOp<Type,Max<Type>   ,BS,EQ>;
295:   link->d_ScatterAndLAND    = ScatterAndOp<Type,LAND<Type>  ,BS,EQ>;
296:   link->d_ScatterAndLOR     = ScatterAndOp<Type,LOR<Type>   ,BS,EQ>;
297:   link->d_ScatterAndLXOR    = ScatterAndOp<Type,LXOR<Type>  ,BS,EQ>;
298:   link->d_ScatterAndBAND    = ScatterAndOp<Type,BAND<Type>  ,BS,EQ>;
299:   link->d_ScatterAndBOR     = ScatterAndOp<Type,BOR<Type>   ,BS,EQ>;
300:   link->d_ScatterAndBXOR    = ScatterAndOp<Type,BXOR<Type>  ,BS,EQ>;
301:   link->d_FetchAndAddLocal  = FetchAndOpLocal<Type,Add<Type>,BS,EQ>;

303:   link->da_UnpackAndInsert  = UnpackAndOp<Type,AtomicInsert<Type>,BS,EQ>;
304:   link->da_UnpackAndAdd     = UnpackAndOp<Type,AtomicAdd<Type>   ,BS,EQ>;
305:   link->da_UnpackAndMult    = UnpackAndOp<Type,AtomicMult<Type>  ,BS,EQ>;
306:   link->da_UnpackAndMin     = UnpackAndOp<Type,AtomicMin<Type>   ,BS,EQ>;
307:   link->da_UnpackAndMax     = UnpackAndOp<Type,AtomicMax<Type>   ,BS,EQ>;
308:   link->da_UnpackAndLAND    = UnpackAndOp<Type,AtomicLAND<Type>  ,BS,EQ>;
309:   link->da_UnpackAndLOR     = UnpackAndOp<Type,AtomicLOR<Type>   ,BS,EQ>;
310:   link->da_UnpackAndBAND    = UnpackAndOp<Type,AtomicBAND<Type>  ,BS,EQ>;
311:   link->da_UnpackAndBOR     = UnpackAndOp<Type,AtomicBOR<Type>   ,BS,EQ>;
312:   link->da_UnpackAndBXOR    = UnpackAndOp<Type,AtomicBXOR<Type>  ,BS,EQ>;
313:   link->da_FetchAndAdd      = FetchAndOp <Type,AtomicFetchAdd<Type>,BS,EQ>;

315:   link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>;
316:   link->da_ScatterAndAdd    = ScatterAndOp<Type,AtomicAdd<Type>   ,BS,EQ>;
317:   link->da_ScatterAndMult   = ScatterAndOp<Type,AtomicMult<Type>  ,BS,EQ>;
318:   link->da_ScatterAndMin    = ScatterAndOp<Type,AtomicMin<Type>   ,BS,EQ>;
319:   link->da_ScatterAndMax    = ScatterAndOp<Type,AtomicMax<Type>   ,BS,EQ>;
320:   link->da_ScatterAndLAND   = ScatterAndOp<Type,AtomicLAND<Type>  ,BS,EQ>;
321:   link->da_ScatterAndLOR    = ScatterAndOp<Type,AtomicLOR<Type>   ,BS,EQ>;
322:   link->da_ScatterAndBAND   = ScatterAndOp<Type,AtomicBAND<Type>  ,BS,EQ>;
323:   link->da_ScatterAndBOR    = ScatterAndOp<Type,AtomicBOR<Type>   ,BS,EQ>;
324:   link->da_ScatterAndBXOR   = ScatterAndOp<Type,AtomicBXOR<Type>  ,BS,EQ>;
325:   link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>;
326: }

328: #if defined(PETSC_HAVE_COMPLEX)
329: template<typename Type,PetscInt BS,PetscInt EQ>
330: static void PackInit_ComplexType(PetscSFLink link)
331: {
332:   link->d_Pack             = Pack<Type,BS,EQ>;
333:   link->d_UnpackAndInsert  = UnpackAndOp<Type,Insert<Type>,BS,EQ>;
334:   link->d_UnpackAndAdd     = UnpackAndOp<Type,Add<Type>   ,BS,EQ>;
335:   link->d_UnpackAndMult    = UnpackAndOp<Type,Mult<Type>  ,BS,EQ>;
336:   link->d_FetchAndAdd      = FetchAndOp <Type,Add<Type>   ,BS,EQ>;

338:   link->d_ScatterAndInsert = ScatterAndInsert<Type,BS,EQ>;
339:   link->d_ScatterAndAdd    = ScatterAndOp<Type,Add<Type>   ,BS,EQ>;
340:   link->d_ScatterAndMult   = ScatterAndOp<Type,Mult<Type>  ,BS,EQ>;
341:   link->d_FetchAndAddLocal = FetchAndOpLocal<Type,Add<Type>,BS,EQ>;

343:   link->da_UnpackAndInsert = UnpackAndOp<Type,AtomicInsert<Type> ,BS,EQ>;
344:   link->da_UnpackAndAdd    = UnpackAndOp<Type,AtomicAdd<Type>    ,BS,EQ>;
345:   link->da_UnpackAndMult   = UnpackAndOp<Type,AtomicMult<Type>   ,BS,EQ>;
346:   link->da_FetchAndAdd     = FetchAndOp<Type,AtomicFetchAdd<Type>,BS,EQ>;

348:   link->da_ScatterAndInsert = ScatterAndOp<Type,AtomicInsert<Type>,BS,EQ>;
349:   link->da_ScatterAndAdd    = ScatterAndOp<Type,AtomicAdd<Type>   ,BS,EQ>;
350:   link->da_ScatterAndMult   = ScatterAndOp<Type,AtomicMult<Type>  ,BS,EQ>;
351:   link->da_FetchAndAddLocal = FetchAndOpLocal<Type,AtomicFetchAdd<Type>,BS,EQ>;
352: }
353: #endif

355: template<typename Type>
356: static void PackInit_PairType(PetscSFLink link)
357: {
358:   link->d_Pack             = Pack<Type,1,1>;
359:   link->d_UnpackAndInsert  = UnpackAndOp<Type,Insert<Type>,1,1>;
360:   link->d_UnpackAndMaxloc  = UnpackAndOp<Type,Maxloc<Type>,1,1>;
361:   link->d_UnpackAndMinloc  = UnpackAndOp<Type,Minloc<Type>,1,1>;

363:   link->d_ScatterAndInsert = ScatterAndOp<Type,Insert<Type>,1,1>;
364:   link->d_ScatterAndMaxloc = ScatterAndOp<Type,Maxloc<Type>,1,1>;
365:   link->d_ScatterAndMinloc = ScatterAndOp<Type,Minloc<Type>,1,1>;
366:   /* Atomics for pair types are not implemented yet */
367: }

369: template<typename Type,PetscInt BS,PetscInt EQ>
370: static void PackInit_DumbType(PetscSFLink link)
371: {
372:   link->d_Pack             = Pack<Type,BS,EQ>;
373:   link->d_UnpackAndInsert  = UnpackAndOp<Type,Insert<Type>,BS,EQ>;
374:   link->d_ScatterAndInsert = ScatterAndInsert<Type,BS,EQ>;
375:   /* Atomics for dumb types are not implemented yet */
376: }

378: /*
379:   Kokkos::DefaultExecutionSpace(stream) is a reference counted pointer object. It has a bug
380:   that one is not able to repeatedly create and destroy the object. SF's original design was each
381:   SFLink has a stream (NULL or not) and hence an execution space object. The bug prevents us from
382:   destroying multiple SFLinks with NULL stream and the default execution space object. To avoid
383:   memory leaks, SF_Kokkos only supports NULL stream, which is also petsc's default scheme. SF_Kokkos
384:   does not do its own new/delete. It just uses Kokkos::DefaultExecutionSpace(), which is a singliton
385:   object in Kokkos.
386: */
387: /*
388: static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSFLink link)
389: {
390:   return 0;
391: }
392: */

394: /* Some device-specific utilities */
395: static PetscErrorCode PetscSFLinkSyncDevice_Kokkos(PetscSFLink PETSC_UNUSED link)
396: {
397:   Kokkos::fence();
398:   return 0;
399: }

401: static PetscErrorCode PetscSFLinkSyncStream_Kokkos(PetscSFLink PETSC_UNUSED link)
402: {
403:   DeviceExecutionSpace    exec;
404:   exec.fence();
405:   return 0;
406: }

408: static PetscErrorCode PetscSFLinkMemcpy_Kokkos(PetscSFLink PETSC_UNUSED link,PetscMemType dstmtype,void* dst,PetscMemType srcmtype,const void*src,size_t n)
409: {
410:   DeviceExecutionSpace    exec;

412:   if (!n) return 0;
413:   if (PetscMemTypeHost(dstmtype) && PetscMemTypeHost(srcmtype)) {
414:     PetscMemcpy(dst,src,n);
415:   } else {
416:     if (PetscMemTypeDevice(dstmtype) && PetscMemTypeHost(srcmtype)) {
417:       deviceBuffer_t       dbuf(static_cast<char*>(dst),n);
418:       HostConstBuffer_t    sbuf(static_cast<const char*>(src),n);
419:       Kokkos::deep_copy(exec,dbuf,sbuf);
420:       PetscLogCpuToGpu(n);
421:     } else if (PetscMemTypeHost(dstmtype) && PetscMemTypeDevice(srcmtype)) {
422:       HostBuffer_t         dbuf(static_cast<char*>(dst),n);
423:       deviceConstBuffer_t  sbuf(static_cast<const char*>(src),n);
424:       Kokkos::deep_copy(exec,dbuf,sbuf);
425:       PetscLogGpuToCpu(n);
426:     } else if (PetscMemTypeDevice(dstmtype) && PetscMemTypeDevice(srcmtype)) {
427:       deviceBuffer_t       dbuf(static_cast<char*>(dst),n);
428:       deviceConstBuffer_t  sbuf(static_cast<const char*>(src),n);
429:       Kokkos::deep_copy(exec,dbuf,sbuf);
430:     }
431:   }
432:   return 0;
433: }

435: PetscErrorCode PetscSFMalloc_Kokkos(PetscMemType mtype,size_t size,void** ptr)
436: {
437:   if (PetscMemTypeHost(mtype)) PetscMalloc(size,ptr);
438:   else if (PetscMemTypeDevice(mtype)) {
439:     if (!PetscKokkosInitialized) PetscKokkosInitializeCheck();
440:     *ptr = Kokkos::kokkos_malloc<DeviceMemorySpace>(size);
441:   } else SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"Wrong PetscMemType %d", (int)mtype);
442:   return 0;
443: }

445: PetscErrorCode PetscSFFree_Kokkos(PetscMemType mtype,void* ptr)
446: {
447:   if (PetscMemTypeHost(mtype)) PetscFree(ptr);
448:   else if (PetscMemTypeDevice(mtype)) {Kokkos::kokkos_free<DeviceMemorySpace>(ptr);}
449:   else SETERRQ(PETSC_COMM_SELF,PETSC_ERR_ARG_WRONG,"Wrong PetscMemType %d",(int)mtype);
450:   return 0;
451: }

453: /* Destructor when the link uses MPI for communication */
454: static PetscErrorCode PetscSFLinkDestroy_Kokkos(PetscSF sf,PetscSFLink link)
455: {
456:   for (int i=PETSCSF_LOCAL; i<=PETSCSF_REMOTE; i++) {
457:     PetscSFFree(sf,PETSC_MEMTYPE_DEVICE,link->rootbuf_alloc[i][PETSC_MEMTYPE_DEVICE]);
458:     PetscSFFree(sf,PETSC_MEMTYPE_DEVICE,link->leafbuf_alloc[i][PETSC_MEMTYPE_DEVICE]);
459:   }
460:   return 0;
461: }

463: /* Some fields of link are initialized by PetscSFPackSetUp_Host. This routine only does what needed on device */
464: PetscErrorCode PetscSFLinkSetUp_Kokkos(PetscSF PETSC_UNUSED sf,PetscSFLink link,MPI_Datatype unit)
465: {
466:   PetscInt           nSignedChar=0,nUnsignedChar=0,nInt=0,nPetscInt=0,nPetscReal=0;
467:   PetscBool          is2Int,is2PetscInt;
468: #if defined(PETSC_HAVE_COMPLEX)
469:   PetscInt           nPetscComplex=0;
470: #endif

472:   if (link->deviceinited) return 0;
473:   PetscKokkosInitializeCheck();
474:   MPIPetsc_Type_compare_contig(unit,MPI_SIGNED_CHAR,  &nSignedChar);
475:   MPIPetsc_Type_compare_contig(unit,MPI_UNSIGNED_CHAR,&nUnsignedChar);
476:   /* MPI_CHAR is treated below as a dumb type that does not support reduction according to MPI standard */
477:   MPIPetsc_Type_compare_contig(unit,MPI_INT,  &nInt);
478:   MPIPetsc_Type_compare_contig(unit,MPIU_INT, &nPetscInt);
479:   MPIPetsc_Type_compare_contig(unit,MPIU_REAL,&nPetscReal);
480: #if defined(PETSC_HAVE_COMPLEX)
481:   MPIPetsc_Type_compare_contig(unit,MPIU_COMPLEX,&nPetscComplex);
482: #endif
483:   MPIPetsc_Type_compare(unit,MPI_2INT,&is2Int);
484:   MPIPetsc_Type_compare(unit,MPIU_2INT,&is2PetscInt);

486:   if (is2Int) {
487:     PackInit_PairType<Kokkos::pair<int,int>>(link);
488:   } else if (is2PetscInt) { /* TODO: when is2PetscInt and nPetscInt=2, we don't know which path to take. The two paths support different ops. */
489:     PackInit_PairType<Kokkos::pair<PetscInt,PetscInt>>(link);
490:   } else if (nPetscReal) {
491:    #if !defined(PETSC_HAVE_DEVICE)  /* Skip the unimportant stuff to speed up SF device compilation time */
492:     if      (nPetscReal == 8) PackInit_RealType<PetscReal,8,1>(link); else if (nPetscReal%8 == 0) PackInit_RealType<PetscReal,8,0>(link);
493:     else if (nPetscReal == 4) PackInit_RealType<PetscReal,4,1>(link); else if (nPetscReal%4 == 0) PackInit_RealType<PetscReal,4,0>(link);
494:     else if (nPetscReal == 2) PackInit_RealType<PetscReal,2,1>(link); else if (nPetscReal%2 == 0) PackInit_RealType<PetscReal,2,0>(link);
495:     else if (nPetscReal == 1) PackInit_RealType<PetscReal,1,1>(link); else if (nPetscReal%1 == 0)
496:    #endif
497:     PackInit_RealType<PetscReal,1,0>(link);
498:   } else if (nPetscInt && sizeof(PetscInt) == sizeof(llint)) {
499:    #if !defined(PETSC_HAVE_DEVICE)
500:     if      (nPetscInt == 8) PackInit_IntegerType<llint,8,1>(link); else if (nPetscInt%8 == 0) PackInit_IntegerType<llint,8,0>(link);
501:     else if (nPetscInt == 4) PackInit_IntegerType<llint,4,1>(link); else if (nPetscInt%4 == 0) PackInit_IntegerType<llint,4,0>(link);
502:     else if (nPetscInt == 2) PackInit_IntegerType<llint,2,1>(link); else if (nPetscInt%2 == 0) PackInit_IntegerType<llint,2,0>(link);
503:     else if (nPetscInt == 1) PackInit_IntegerType<llint,1,1>(link); else if (nPetscInt%1 == 0)
504:    #endif
505:     PackInit_IntegerType<llint,1,0>(link);
506:   } else if (nInt) {
507:    #if !defined(PETSC_HAVE_DEVICE)
508:     if      (nInt == 8) PackInit_IntegerType<int,8,1>(link); else if (nInt%8 == 0) PackInit_IntegerType<int,8,0>(link);
509:     else if (nInt == 4) PackInit_IntegerType<int,4,1>(link); else if (nInt%4 == 0) PackInit_IntegerType<int,4,0>(link);
510:     else if (nInt == 2) PackInit_IntegerType<int,2,1>(link); else if (nInt%2 == 0) PackInit_IntegerType<int,2,0>(link);
511:     else if (nInt == 1) PackInit_IntegerType<int,1,1>(link); else if (nInt%1 == 0)
512:    #endif
513:     PackInit_IntegerType<int,1,0>(link);
514:   } else if (nSignedChar) {
515:    #if !defined(PETSC_HAVE_DEVICE)
516:     if      (nSignedChar == 8) PackInit_IntegerType<char,8,1>(link); else if (nSignedChar%8 == 0) PackInit_IntegerType<char,8,0>(link);
517:     else if (nSignedChar == 4) PackInit_IntegerType<char,4,1>(link); else if (nSignedChar%4 == 0) PackInit_IntegerType<char,4,0>(link);
518:     else if (nSignedChar == 2) PackInit_IntegerType<char,2,1>(link); else if (nSignedChar%2 == 0) PackInit_IntegerType<char,2,0>(link);
519:     else if (nSignedChar == 1) PackInit_IntegerType<char,1,1>(link); else if (nSignedChar%1 == 0)
520:    #endif
521:     PackInit_IntegerType<char,1,0>(link);
522:   }  else if (nUnsignedChar) {
523:    #if !defined(PETSC_HAVE_DEVICE)
524:     if      (nUnsignedChar == 8) PackInit_IntegerType<unsigned char,8,1>(link); else if (nUnsignedChar%8 == 0) PackInit_IntegerType<unsigned char,8,0>(link);
525:     else if (nUnsignedChar == 4) PackInit_IntegerType<unsigned char,4,1>(link); else if (nUnsignedChar%4 == 0) PackInit_IntegerType<unsigned char,4,0>(link);
526:     else if (nUnsignedChar == 2) PackInit_IntegerType<unsigned char,2,1>(link); else if (nUnsignedChar%2 == 0) PackInit_IntegerType<unsigned char,2,0>(link);
527:     else if (nUnsignedChar == 1) PackInit_IntegerType<unsigned char,1,1>(link); else if (nUnsignedChar%1 == 0)
528:    #endif
529:     PackInit_IntegerType<unsigned char,1,0>(link);
530: #if defined(PETSC_HAVE_COMPLEX)
531:   } else if (nPetscComplex) {
532:    #if !defined(PETSC_HAVE_DEVICE)
533:     if      (nPetscComplex == 8) PackInit_ComplexType<Kokkos::complex<PetscReal>,8,1>(link); else if (nPetscComplex%8 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,8,0>(link);
534:     else if (nPetscComplex == 4) PackInit_ComplexType<Kokkos::complex<PetscReal>,4,1>(link); else if (nPetscComplex%4 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,4,0>(link);
535:     else if (nPetscComplex == 2) PackInit_ComplexType<Kokkos::complex<PetscReal>,2,1>(link); else if (nPetscComplex%2 == 0) PackInit_ComplexType<Kokkos::complex<PetscReal>,2,0>(link);
536:     else if (nPetscComplex == 1) PackInit_ComplexType<Kokkos::complex<PetscReal>,1,1>(link); else if (nPetscComplex%1 == 0)
537:    #endif
538:     PackInit_ComplexType<Kokkos::complex<PetscReal>,1,0>(link);
539: #endif
540:   } else {
541:     MPI_Aint lb,nbyte;
542:     MPI_Type_get_extent(unit,&lb,&nbyte);
544:     if (nbyte % sizeof(int)) { /* If the type size is not multiple of int */
545:      #if !defined(PETSC_HAVE_DEVICE)
546:       if      (nbyte == 4) PackInit_DumbType<char,4,1>(link); else if (nbyte%4 == 0) PackInit_DumbType<char,4,0>(link);
547:       else if (nbyte == 2) PackInit_DumbType<char,2,1>(link); else if (nbyte%2 == 0) PackInit_DumbType<char,2,0>(link);
548:       else if (nbyte == 1) PackInit_DumbType<char,1,1>(link); else if (nbyte%1 == 0)
549:      #endif
550:       PackInit_DumbType<char,1,0>(link);
551:     } else {
552:       nInt = nbyte / sizeof(int);
553:      #if !defined(PETSC_HAVE_DEVICE)
554:       if      (nInt == 8) PackInit_DumbType<int,8,1>(link); else if (nInt%8 == 0) PackInit_DumbType<int,8,0>(link);
555:       else if (nInt == 4) PackInit_DumbType<int,4,1>(link); else if (nInt%4 == 0) PackInit_DumbType<int,4,0>(link);
556:       else if (nInt == 2) PackInit_DumbType<int,2,1>(link); else if (nInt%2 == 0) PackInit_DumbType<int,2,0>(link);
557:       else if (nInt == 1) PackInit_DumbType<int,1,1>(link); else if (nInt%1 == 0)
558:      #endif
559:       PackInit_DumbType<int,1,0>(link);
560:     }
561:   }

563:   link->SyncDevice   = PetscSFLinkSyncDevice_Kokkos;
564:   link->SyncStream   = PetscSFLinkSyncStream_Kokkos;
565:   link->Memcpy       = PetscSFLinkMemcpy_Kokkos;
566:   link->Destroy      = PetscSFLinkDestroy_Kokkos;
567:   link->deviceinited = PETSC_TRUE;
568:   return 0;
569: }