Actual source code: xxt.c

  2: /*************************************xxt.c************************************
  3: Module Name: xxt
  4: Module Info:

  6: author:  Henry M. Tufo III
  7: e-mail:  hmt@asci.uchicago.edu
  8: contact:
  9: +--------------------------------+--------------------------------+
 10: |MCS Division - Building 221     |Department of Computer Science  |
 11: |Argonne National Laboratory     |Ryerson 152                     |
 12: |9700 S. Cass Avenue             |The University of Chicago       |
 13: |Argonne, IL  60439              |Chicago, IL  60637              |
 14: |(630) 252-5354/5986 ph/fx       |(773) 702-6019/8487 ph/fx       |
 15: +--------------------------------+--------------------------------+

 17: Last Modification: 3.20.01
 18: **************************************xxt.c***********************************/


 21: /*************************************xxt.c************************************
 22: NOTES ON USAGE: 

 24: **************************************xxt.c***********************************/
 25:  #include petsc.h
 26: #include <stdio.h>
 27: #include <stdlib.h>
 28: #include <limits.h>
 29: #include <float.h>
 30: #include <math.h>


 33:  #include const.h
 34:  #include types.h
 35:  #include comm.h
 36:  #include error.h
 37:  #include ivec.h
 38: #include "bss_malloc.h"
 39:  #include queue.h
 40:  #include gs.h
 41: #ifdef MLSRC
 42: #include "ml_include.h"
 43: #endif
 44:  #include blas.h
 45:  #include xxt.h

 47: #define LEFT  -1
 48: #define RIGHT  1
 49: #define BOTH   0
 50: #define MAX_FORTRAN_HANDLES  10

 52: typedef struct xxt_solver_info {
 53:   int n, m, n_global, m_global;
 54:   int nnz, max_nnz, msg_buf_sz;
 55:   int *nsep, *lnsep, *fo, nfo, *stages;
 56:   int *col_sz, *col_indices;
 57:   REAL **col_vals, *x, *solve_uu, *solve_w;
 58:   int nsolves;
 59:   REAL tot_solve_time;
 60: } xxt_info;

 62: typedef struct matvec_info {
 63:   int n, m, n_global, m_global;
 64:   int *local2global;
 65:   gs_ADT gs_handle;
 66:   PetscErrorCode (*matvec)(struct matvec_info*,REAL*,REAL*);
 67:   void *grid_data;
 68: } mv_info;

 70: struct xxt_CDT{
 71:   int id;
 72:   int ns;
 73:   int level;
 74:   xxt_info *info;
 75:   mv_info  *mvi;
 76: };

 78: static int n_xxt=0;
 79: static int n_xxt_handles=0;

 81: /* prototypes */
 82: static void do_xxt_solve(xxt_ADT xxt_handle, REAL *rhs);
 83: static void check_init(void);
 84: static void check_handle(xxt_ADT xxt_handle);
 85: static void det_separators(xxt_ADT xxt_handle);
 86: static void do_matvec(mv_info *A, REAL *v, REAL *u);
 87: static int xxt_generate(xxt_ADT xxt_handle);
 88: static int do_xxt_factor(xxt_ADT xxt_handle);
 89: static mv_info *set_mvi(int *local2global, int n, int m, void *matvec, void *grid_data);
 90: #ifdef MLSRC
 91: void ML_XXT_solve(xxt_ADT xxt_handle, int lx, double *x, int lb, double *b);
 92: int  ML_XXT_factor(xxt_ADT xxt_handle, int *local2global, int n, int m,
 93:                    void *matvec, void *grid_data, int grid_tag, ML *my_ml);
 94: #endif


 97: /*************************************xxt.c************************************
 98: Function: XXT_new()

100: Input :
101: Output:
102: Return:
103: Description:
104: **************************************xxt.c***********************************/
105: xxt_ADT 
106: XXT_new(void)
107: {
108:   xxt_ADT xxt_handle;


111: #ifdef DEBUG
112:   error_msg_warning("XXT_new() :: start %d\n",n_xxt_handles);
113: #endif

115:   /* rolling count on n_xxt ... pot. problem here */
116:   n_xxt_handles++;
117:   xxt_handle       = (xxt_ADT)bss_malloc(sizeof(struct xxt_CDT));
118:   xxt_handle->id   = ++n_xxt;
119:   xxt_handle->info = NULL; xxt_handle->mvi  = NULL;

121: #ifdef DEBUG
122:   error_msg_warning("XXT_new() :: end   %d\n",n_xxt_handles);
123: #endif

125:   return(xxt_handle);
126: }


129: /*************************************xxt.c************************************
130: Function: XXT_factor()

132: Input :
133: Output:
134: Return:
135: Description:
136: **************************************xxt.c***********************************/
137: int
138: XXT_factor(xxt_ADT xxt_handle, /* prev. allocated xxt  handle */
139:            int *local2global,  /* global column mapping       */
140:            int n,              /* local num rows              */
141:            int m,              /* local num cols              */
142:            void *matvec,       /* b_loc=A_local.x_loc         */
143:            void *grid_data     /* grid data for matvec        */
144:            )
145: {
146: #ifdef DEBUG
147:   int flag;


150:   error_msg_warning("XXT_factor() :: start %d\n",n_xxt_handles);
151: #endif

153:   check_init();
154:   check_handle(xxt_handle);

156:   /* only 2^k for now and all nodes participating */
157:   if ((1<<(xxt_handle->level=i_log2_num_nodes))!=num_nodes)
158:     {error_msg_fatal("only 2^k for now and MPI_COMM_WORLD!!! %d != %d\n",1<<i_log2_num_nodes,num_nodes);}

160:   /* space for X info */
161:   xxt_handle->info = (xxt_info*)bss_malloc(sizeof(xxt_info));

163:   /* set up matvec handles */
164:   xxt_handle->mvi  = set_mvi(local2global, n, m, matvec, grid_data);

166:   /* matrix is assumed to be of full rank */
167:   /* LATER we can reset to indicate rank def. */
168:   xxt_handle->ns=0;

170:   /* determine separators and generate firing order - NB xxt info set here */
171:   det_separators(xxt_handle);

173: #ifdef DEBUG
174:   flag = do_xxt_factor(xxt_handle);
175:   error_msg_warning("XXT_factor() :: end   %d (flag=%d)\n",n_xxt_handles,flag);
176:   return(flag);
177: #else
178:   return(do_xxt_factor(xxt_handle));
179: #endif
180: }


183: /*************************************xxt.c************************************
184: Function: XXT_solve

186: Input :
187: Output:
188: Return:
189: Description:
190: **************************************xxt.c***********************************/
191: int
192: XXT_solve(xxt_ADT xxt_handle, double *x, double *b)
193: {
194: #ifdef INFO
195:   REAL vals[3], work[3];
196:   int op[] = {NON_UNIFORM,GL_MIN,GL_MAX,GL_ADD};
197: #endif


200: #ifdef DEBUG
201:   error_msg_warning("XXT_solve() :: start %d\n",n_xxt_handles);
202: #endif

204:   check_init();
205:   check_handle(xxt_handle);

207:   /* need to copy b into x? */
208:   if (b)
209:     {rvec_copy(x,b,xxt_handle->mvi->n);}
210:   do_xxt_solve(xxt_handle,x);

212: #ifdef DEBUG
213:   error_msg_warning("XXT_solve() :: end   %d\n",n_xxt_handles);
214: #endif

216:   return(0);
217: }


220: /*************************************xxt.c************************************
221: Function: XXT_free()

223: Input :
224: Output:
225: Return:
226: Description:
227: **************************************xxt.c***********************************/
228: int
229: XXT_free(xxt_ADT xxt_handle)
230: {
231: #ifdef DEBUG
232:   error_msg_warning("XXT_free() :: start %d\n",n_xxt_handles);
233: #endif

235:   check_init();
236:   check_handle(xxt_handle);
237:   n_xxt_handles--;

239:   bss_free(xxt_handle->info->nsep);
240:   bss_free(xxt_handle->info->lnsep);
241:   bss_free(xxt_handle->info->fo);
242:   bss_free(xxt_handle->info->stages);
243:   bss_free(xxt_handle->info->solve_uu);
244:   bss_free(xxt_handle->info->solve_w);
245:   bss_free(xxt_handle->info->x);
246:   bss_free(xxt_handle->info->col_vals);
247:   bss_free(xxt_handle->info->col_sz);
248:   bss_free(xxt_handle->info->col_indices);
249:   bss_free(xxt_handle->info);
250:   bss_free(xxt_handle->mvi->local2global);
251:    gs_free(xxt_handle->mvi->gs_handle);
252:   bss_free(xxt_handle->mvi);
253:   bss_free(xxt_handle);

255: 
256: #ifdef DEBUG
257:   error_msg_warning("perm frees = %d\n",perm_frees());
258:   error_msg_warning("perm calls = %d\n",perm_calls());
259:   error_msg_warning("bss frees  = %d\n",bss_frees());
260:   error_msg_warning("bss calls  = %d\n",bss_calls());
261:   error_msg_warning("XXT_free() :: end   %d\n",n_xxt_handles);
262: #endif

264:   /* if the check fails we nuke */
265:   /* if NULL pointer passed to bss_free we nuke */
266:   /* if the calls to free fail that's not my problem */
267:   return(0);
268: }


271: #ifdef MLSRC
272: /*************************************xxt.c************************************
273: Function: ML_XXT_factor()

275: Input :
276: Output:
277: Return:
278: Description:

280: ML requires that the solver call be checked in
281: **************************************xxt.c***********************************/
282: PetscErrorCode
283: ML_XXT_factor(xxt_ADT xxt_handle,  /* prev. allocated xxt  handle */
284:                 int *local2global, /* global column mapping       */
285:                 int n,             /* local num rows              */
286:                 int m,             /* local num cols              */
287:                 void *matvec,      /* b_loc=A_local.x_loc         */
288:                 void *grid_data,   /* grid data for matvec        */
289:                 int grid_tag,      /* grid tag for ML_Set_CSolve  */
290:                 ML *my_ml          /* ML handle                   */
291:                 )
292: {
293: #ifdef DEBUG
294:   int flag;
295: #endif


298: #ifdef DEBUG
299:   error_msg_warning("ML_XXT_factor() :: start %d\n",n_xxt_handles);
300: #endif

302:   check_init();
303:   check_handle(xxt_handle);
304:   if (my_ml->comm->ML_mypid!=my_id)
305:     {error_msg_fatal("ML_XXT_factor bad my_id %d\t%d\n",
306:                      my_ml->comm->ML_mypid,my_id);}
307:   if (my_ml->comm->ML_nprocs!=num_nodes)
308:     {error_msg_fatal("ML_XXT_factor bad np %d\t%d\n",
309:                      my_ml->comm->ML_nprocs,num_nodes);}

311:   my_ml->SingleLevel[grid_tag].csolve->func->external = ML_XXT_solve;
312:   my_ml->SingleLevel[grid_tag].csolve->func->ML_id = ML_EXTERNAL;
313:   my_ml->SingleLevel[grid_tag].csolve->data = xxt_handle;

315:   /* done ML specific stuff ... back to reg sched pgm */
316: #ifdef DEBUG
317:   flag = XXT_factor(xxt_handle, local2global, n, m, matvec, grid_data);
318:   error_msg_warning("ML_XXT_factor() :: end   %d (flag=%d)\n",n_xxt_handles,flag);
319:   return(flag);
320: #else
321:   return(XXT_factor(xxt_handle, local2global, n, m, matvec, grid_data));
322: #endif
323: }


326: /*************************************xxt.c************************************
327: Function: ML_XXT_solve

329: Input :
330: Output:
331: Return:
332: Description:
333: **************************************xxt.c***********************************/
334: void 
335: ML_XXT_solve(xxt_ADT xxt_handle, int lx, double *sol, int lb, double *rhs)
336: {
337:   XXT_solve(xxt_handle, sol, rhs);
338: }
339: #endif


342: /*************************************xxt.c************************************
343: Function: 

345: Input : 
346: Output: 
347: Return: 
348: Description:  
349: **************************************xxt.c***********************************/
350: int
351: XXT_stats(xxt_ADT xxt_handle)
352: {
353:   int  op[] = {NON_UNIFORM,GL_MIN,GL_MAX,GL_ADD,GL_MIN,GL_MAX,GL_ADD,GL_MIN,GL_MAX,GL_ADD};
354:   int fop[] = {NON_UNIFORM,GL_MIN,GL_MAX,GL_ADD};
355:   int   vals[9],  work[9];
356:   REAL fvals[3], fwork[3];


359: #ifdef DEBUG
360:   error_msg_warning("xxt_stats() :: begin\n");
361: #endif

363:   check_init();
364:   check_handle(xxt_handle);

366:   /* if factorization not done there are no stats */
367:   if (!xxt_handle->info||!xxt_handle->mvi)
368:     {
369:       if (!my_id)
370:         {printf("XXT_stats() :: no stats available!\n");}
371:       return 1;
372:     }

374:   vals[0]=vals[1]=vals[2]=xxt_handle->info->nnz;
375:   vals[3]=vals[4]=vals[5]=xxt_handle->mvi->n;
376:   vals[6]=vals[7]=vals[8]=xxt_handle->info->msg_buf_sz;
377:   giop(vals,work,sizeof(op)/sizeof(op[0])-1,op);

379:   fvals[0]=fvals[1]=fvals[2]
380:     =xxt_handle->info->tot_solve_time/xxt_handle->info->nsolves++;
381:   grop(fvals,fwork,sizeof(fop)/sizeof(fop[0])-1,fop);

383:   if (!my_id)
384:     {
385:       printf("%d :: min   xxt_nnz=%d\n",my_id,vals[0]);
386:       printf("%d :: max   xxt_nnz=%d\n",my_id,vals[1]);
387:       printf("%d :: avg   xxt_nnz=%g\n",my_id,1.0*vals[2]/num_nodes);
388:       printf("%d :: tot   xxt_nnz=%d\n",my_id,vals[2]);
389:       printf("%d :: xxt   C(2d)  =%g\n",my_id,vals[2]/(pow(1.0*vals[5],1.5)));
390:       printf("%d :: xxt   C(3d)  =%g\n",my_id,vals[2]/(pow(1.0*vals[5],1.6667)));
391:       printf("%d :: min   xxt_n  =%d\n",my_id,vals[3]);
392:       printf("%d :: max   xxt_n  =%d\n",my_id,vals[4]);
393:       printf("%d :: avg   xxt_n  =%g\n",my_id,1.0*vals[5]/num_nodes);
394:       printf("%d :: tot   xxt_n  =%d\n",my_id,vals[5]);
395:       printf("%d :: min   xxt_buf=%d\n",my_id,vals[6]);
396:       printf("%d :: max   xxt_buf=%d\n",my_id,vals[7]);
397:       printf("%d :: avg   xxt_buf=%g\n",my_id,1.0*vals[8]/num_nodes);
398:       printf("%d :: min   xxt_slv=%g\n",my_id,fvals[0]);
399:       printf("%d :: max   xxt_slv=%g\n",my_id,fvals[1]);
400:       printf("%d :: avg   xxt_slv=%g\n",my_id,fvals[2]/num_nodes);
401:     }

403: #ifdef DEBUG
404:   error_msg_warning("xxt_stats() :: end\n");
405: #endif

407:   return(0);
408: }


411: /*************************************xxt.c************************************
412: Function: do_xxt_factor

414: Input : 
415: Output: 
416: Return: 
417: Description: get A_local, local portion of global coarse matrix which 
418: is a row dist. nxm matrix w/ n<m.
419:    o my_ml holds address of ML struct associated w/A_local and coarse grid
420:    o local2global holds global number of column i (i=0,...,m-1)
421:    o local2global holds global number of row    i (i=0,...,n-1)
422:    o mylocmatvec performs A_local . vec_local (note that gs is performed using 
423:    gs_init/gop).

425: mylocmatvec = my_ml->Amat[grid_tag].matvec->external;
426: mylocmatvec (void :: void *data, double *in, double *out)
427: **************************************xxt.c***********************************/
428: static
429: int
430: do_xxt_factor(xxt_ADT xxt_handle)
431: {
432:   int flag;


435: #ifdef DEBUG
436:   error_msg_warning("do_xxt_factor() :: begin\n");
437: #endif

439:   flag=xxt_generate(xxt_handle);

441: #ifdef INFO
442:   XXT_stats(xxt_handle);
443:   bss_stats();
444:   perm_stats();
445: #endif

447: #ifdef DEBUG
448:   error_msg_warning("do_xxt_factor() :: end\n");
449: #endif

451:   return(flag);
452: }


455: /*************************************xxt.c************************************
456: Function: 

458: Input : 
459: Output: 
460: Return: 
461: Description:  
462: **************************************xxt.c***********************************/
463: static
464: int
465: xxt_generate(xxt_ADT xxt_handle)
466: {
467:   int i,j,k,idex;
468:   int dim, col;
469:   REAL *u, *uu, *v, *z, *w, alpha, alpha_w;
470:   int *segs;
471:   int op[] = {GL_ADD,0};
472:   int off, len;
473:   REAL *x_ptr;
474:   int *iptr, flag;
475:   int start=0, end, work;
476:   int op2[] = {GL_MIN,0};
477:   gs_ADT gs_handle;
478:   int *nsep, *lnsep, *fo;
479:   int a_n=xxt_handle->mvi->n;
480:   int a_m=xxt_handle->mvi->m;
481:   int *a_local2global=xxt_handle->mvi->local2global;
482:   int level;
483:   int xxt_nnz=0, xxt_max_nnz=0;
484:   int n, m;
485:   int *col_sz, *col_indices, *stages;
486:   REAL **col_vals, *x;
487:   int n_global;
488:   int xxt_zero_nnz=0;
489:   int xxt_zero_nnz_0=0;


492: #ifdef DEBUG
493:   error_msg_warning("xxt_generate() :: begin\n");
494: #endif

496:   n=xxt_handle->mvi->n;
497:   nsep=xxt_handle->info->nsep;
498:   lnsep=xxt_handle->info->lnsep;
499:   fo=xxt_handle->info->fo;
500:   end=lnsep[0];
501:   level=xxt_handle->level;
502:   gs_handle=xxt_handle->mvi->gs_handle;

504:   /* is there a null space? */
505:   /* LATER add in ability to detect null space by checking alpha */
506:   for (i=0, j=0; i<=level; i++)
507:     {j+=nsep[i];}

509:   m = j-xxt_handle->ns;
510:   if (m!=j)
511:     {printf("xxt_generate() :: null space exists %d %d %d\n",m,j,xxt_handle->ns);}

513:   /* get and initialize storage for x local         */
514:   /* note that x local is nxm and stored by columns */
515:   col_sz = (int*) bss_malloc(m*INT_LEN);
516:   col_indices = (int*) bss_malloc((2*m+1)*sizeof(int));
517:   col_vals = (REAL **) bss_malloc(m*sizeof(REAL *));
518:   for (i=j=0; i<m; i++, j+=2)
519:     {
520:       col_indices[j]=col_indices[j+1]=col_sz[i]=-1;
521:       col_vals[i] = NULL;
522:     }
523:   col_indices[j]=-1;

525:   /* size of separators for each sub-hc working from bottom of tree to top */
526:   /* this looks like nsep[]=segments */
527:   stages = (int*) bss_malloc((level+1)*INT_LEN);
528:   segs   = (int*) bss_malloc((level+1)*INT_LEN);
529:   ivec_zero(stages,level+1);
530:   ivec_copy(segs,nsep,level+1);
531:   for (i=0; i<level; i++)
532:     {segs[i+1] += segs[i];}
533:   stages[0] = segs[0];

535:   /* temporary vectors  */
536:   u  = (REAL *) bss_malloc(n*sizeof(REAL));
537:   z  = (REAL *) bss_malloc(n*sizeof(REAL));
538:   v  = (REAL *) bss_malloc(a_m*sizeof(REAL));
539:   uu = (REAL *) bss_malloc(m*sizeof(REAL));
540:   w  = (REAL *) bss_malloc(m*sizeof(REAL));

542:   /* extra nnz due to replication of vertices across separators */
543:   for (i=1, j=0; i<=level; i++)
544:     {j+=nsep[i];}

546:   /* storage for sparse x values */
547:   n_global = xxt_handle->info->n_global;
548:   xxt_max_nnz = (int)(2.5*pow(1.0*n_global,1.6667) + j*n/2)/num_nodes;
549:   x = (REAL *) bss_malloc(xxt_max_nnz*sizeof(REAL));
550:   xxt_nnz = 0;

552:   /* LATER - can embed next sep to fire in gs */
553:   /* time to make the donuts - generate X factor */
554:   for (dim=i=j=0;i<m;i++)
555:     {
556:       /* time to move to the next level? */
557:       while (i==segs[dim])
558:         {
559: #ifdef SAFE          
560:           if (dim==level)
561:             {error_msg_fatal("dim about to exceed level\n"); break;}
562: #endif

564:           stages[dim++]=i;
565:           end+=lnsep[dim];
566:         }
567:       stages[dim]=i;

569:       /* which column are we firing? */
570:       /* i.e. set v_l */
571:       /* use new seps and do global min across hc to determine which one to fire */
572:       (start<end) ? (col=fo[start]) : (col=INT_MAX);
573:       giop_hc(&col,&work,1,op2,dim);

575:       /* shouldn't need this */
576:       if (col==INT_MAX)
577:         {
578:           error_msg_warning("hey ... col==INT_MAX??\n");
579:           continue;
580:         }

582:       /* do I own it? I should */
583:       rvec_zero(v ,a_m);
584:       if (col==fo[start])
585:         {
586:           start++;
587:           idex=ivec_linear_search(col, a_local2global, a_n);
588:           if (idex!=-1)
589:             {v[idex] = 1.0; j++;}
590:           else
591:             {error_msg_fatal("NOT FOUND!\n");}
592:         }
593:       else
594:         {
595:           idex=ivec_linear_search(col, a_local2global, a_m);
596:           if (idex!=-1)
597:             {v[idex] = 1.0;}
598:         }

600:       /* perform u = A.v_l */
601:       rvec_zero(u,n);
602:       do_matvec(xxt_handle->mvi,v,u);

604:       /* uu =  X^T.u_l (local portion) */
605:       /* technically only need to zero out first i entries */
606:       /* later turn this into an XXT_solve call ? */
607:       rvec_zero(uu,m);
608:       x_ptr=x;
609:       iptr = col_indices;
610:       for (k=0; k<i; k++)
611:         {
612:           off = *iptr++;
613:           len = *iptr++;

615: #if   BLAS||CBLAS
616:           uu[k] = dot(len,u+off,1,x_ptr,1);
617: #else
618:           uu[k] = rvec_dot(u+off,x_ptr,len);
619: #endif
620:           x_ptr+=len;
621:         }


624:       /* uu = X^T.u_l (comm portion) */
625:       ssgl_radd  (uu, w, dim, stages);

627:       /* z = X.uu */
628:       rvec_zero(z,n);
629:       x_ptr=x;
630:       iptr = col_indices;
631:       for (k=0; k<i; k++)
632:         {
633:           off = *iptr++;
634:           len = *iptr++;

636: #if   BLAS||CBLAS
637:           axpy(len,uu[k],x_ptr,1,z+off,1);
638: #else
639:           rvec_axpy(z+off,x_ptr,uu[k],len);
640: #endif
641:           x_ptr+=len;
642:         }

644:       /* compute v_l = v_l - z */
645:       rvec_zero(v+a_n,a_m-a_n);
646: #if   BLAS&&CBLAS
647:       axpy(n,-1.0,z,1,v,1);
648: #else
649:       rvec_axpy(v,z,-1.0,n);
650: #endif

652:       /* compute u_l = A.v_l */
653:       if (a_n!=a_m)
654:         {gs_gop_hc(gs_handle,v,"+\0",dim);}
655:       rvec_zero(u,n);
656:       do_matvec(xxt_handle->mvi,v,u);

658:       /* compute sqrt(alpha) = sqrt(v_l^T.u_l) - local portion */
659: #if   BLAS&&CBLAS
660:       alpha = dot(n,u,1,v,1);
661: #else
662:       alpha = rvec_dot(u,v,n);
663: #endif
664:       /* compute sqrt(alpha) = sqrt(v_l^T.u_l) - comm portion */
665:       grop_hc(&alpha, &alpha_w, 1, op, dim);

667:       alpha = (REAL) sqrt((double)alpha);

669:       /* check for small alpha                             */
670:       /* LATER use this to detect and determine null space */
671: #ifdef tmpr8
672:       if (fabs(alpha)<1.0e-14)
673:         {error_msg_fatal("bad alpha! %g\n",alpha);}
674: #else
675:       if (fabs((double) alpha) < 1.0e-6)
676:         {error_msg_fatal("bad alpha! %g\n",alpha);}
677: #endif

679:       /* compute v_l = v_l/sqrt(alpha) */
680:       rvec_scale(v,1.0/alpha,n);

682:       /* add newly generated column, v_l, to X */
683:       flag = 1;
684:       off=len=0;
685:       for (k=0; k<n; k++)
686:         {
687:           if (v[k]!=0.0)
688:             {
689:               len=k;
690:               if (flag)
691:                 {off=k; flag=0;}
692:             }
693:         }

695:       len -= (off-1);

697:       if (len>0)
698:         {
699:           if ((xxt_nnz+len)>xxt_max_nnz)
700:             {
701:               error_msg_warning("increasing space for X by 2x!\n");
702:               xxt_max_nnz *= 2;
703:               x_ptr = (REAL *) bss_malloc(xxt_max_nnz*sizeof(REAL));
704:               rvec_copy(x_ptr,x,xxt_nnz);
705:               bss_free(x);
706:               x = x_ptr;
707:               x_ptr+=xxt_nnz;
708:             }
709:           xxt_nnz += len;
710:           rvec_copy(x_ptr,v+off,len);

712:           /* keep track of number of zeros */
713:           if (dim)
714:             {
715:               for (k=0; k<len; k++)
716:                 {
717:                   if (x_ptr[k]==0.0)
718:                     {xxt_zero_nnz++;}
719:                 }
720:             }
721:           else
722:             {
723:               for (k=0; k<len; k++)
724:                 {
725:                   if (x_ptr[k]==0.0)
726:                     {xxt_zero_nnz_0++;}
727:                 }
728:             }
729:           col_indices[2*i] = off;
730:           col_sz[i] = col_indices[2*i+1] = len;
731:           col_vals[i] = x_ptr;
732:         }
733:       else
734:         {
735:           col_indices[2*i] = 0;
736:           col_sz[i] = col_indices[2*i+1] = 0;
737:           col_vals[i] = x_ptr;
738:         }
739:     }

741:   /* close off stages for execution phase */
742:   while (dim!=level)
743:     {
744:       stages[dim++]=i;
745:       error_msg_warning("disconnected!!! dim(%d)!=level(%d)\n",dim,level);
746:     }
747:   stages[dim]=i;

749:   xxt_handle->info->n=xxt_handle->mvi->n;
750:   xxt_handle->info->m=m;
751:   xxt_handle->info->nnz=xxt_nnz;
752:   xxt_handle->info->max_nnz=xxt_max_nnz;
753:   xxt_handle->info->msg_buf_sz=stages[level]-stages[0];
754:   xxt_handle->info->solve_uu = (REAL *) bss_malloc(m*sizeof(REAL));
755:   xxt_handle->info->solve_w  = (REAL *) bss_malloc(m*sizeof(REAL));
756:   xxt_handle->info->x=x;
757:   xxt_handle->info->col_vals=col_vals;
758:   xxt_handle->info->col_sz=col_sz;
759:   xxt_handle->info->col_indices=col_indices;
760:   xxt_handle->info->stages=stages;
761:   xxt_handle->info->nsolves=0;
762:   xxt_handle->info->tot_solve_time=0.0;

764:   bss_free(segs);
765:   bss_free(u);
766:   bss_free(v);
767:   bss_free(uu);
768:   bss_free(z);
769:   bss_free(w);

771: #ifdef DEBUG
772:   error_msg_warning("xxt_generate() :: end\n");
773: #endif

775:   return(0);
776: }


779: /*************************************xxt.c************************************
780: Function: 

782: Input : 
783: Output: 
784: Return: 
785: Description:  
786: **************************************xxt.c***********************************/
787: static
788: void
789: do_xxt_solve(xxt_ADT xxt_handle, register REAL *uc)
790: {
791:   register int off, len, *iptr;
792:   int level       =xxt_handle->level;
793:   int n           =xxt_handle->info->n;
794:   int m           =xxt_handle->info->m;
795:   int *stages     =xxt_handle->info->stages;
796:   int *col_indices=xxt_handle->info->col_indices;
797:   register REAL *x_ptr, *uu_ptr;
798: #if   BLAS||CBLAS
799:   REAL zero=0.0;
800: #endif
801:   REAL *solve_uu=xxt_handle->info->solve_uu;
802:   REAL *solve_w =xxt_handle->info->solve_w;
803:   REAL *x       =xxt_handle->info->x;

805: #ifdef DEBUG
806:   error_msg_warning("do_xxt_solve() :: begin\n");
807: #endif

809:   uu_ptr=solve_uu;
810: #if   BLAS||CBLAS
811:   copy(m,&zero,0,uu_ptr,1);
812: #else
813:   rvec_zero(uu_ptr,m);
814: #endif

816:   /* x  = X.Y^T.b */
817:   /* uu = Y^T.b */
818:   for (x_ptr=x,iptr=col_indices; *iptr!=-1; x_ptr+=len)
819:     {
820:       off=*iptr++; len=*iptr++;
821: #if   BLAS||CBLAS
822:       *uu_ptr++ = dot(len,uc+off,1,x_ptr,1);
823: #else
824:       *uu_ptr++ = rvec_dot(uc+off,x_ptr,len);
825: #endif
826:     }

828:   /* comunication of beta */
829:   uu_ptr=solve_uu;
830:   if (level) {ssgl_radd(uu_ptr, solve_w, level, stages);}

832: #if   BLAS||CBLAS
833:   copy(n,&zero,0,uc,1);
834: #else
835:   rvec_zero(uc,n);
836: #endif

838:   /* x = X.uu */
839:   for (x_ptr=x,iptr=col_indices; *iptr!=-1; x_ptr+=len)
840:     {
841:       off=*iptr++; len=*iptr++;
842: #if   BLAS||CBLAS
843:       axpy(len,*uu_ptr++,x_ptr,1,uc+off,1);
844: #else
845:       rvec_axpy(uc+off,x_ptr,*uu_ptr++,len);
846: #endif
847:     }

849: #ifdef DEBUG
850:   error_msg_warning("do_xxt_solve() :: end\n");
851: #endif
852: }


855: /*************************************Xxt.c************************************
856: Function: check_init

858: Input :
859: Output:
860: Return:
861: Description:
862: **************************************xxt.c***********************************/
863: static
864: void
865: check_init(void)
866: {
867: #ifdef DEBUG
868:   error_msg_warning("check_init() :: start %d\n",n_xxt_handles);
869: #endif

871:   comm_init();
872:   /*
873:   perm_init(); 
874:   bss_init();
875:   */

877: #ifdef DEBUG
878:   error_msg_warning("check_init() :: end   %d\n",n_xxt_handles);
879: #endif
880: }


883: /*************************************xxt.c************************************
884: Function: check_handle()

886: Input :
887: Output:
888: Return:
889: Description:
890: **************************************xxt.c***********************************/
891: static
892: void 
893: check_handle(xxt_ADT xxt_handle)
894: {
895: #ifdef SAFE
896:   int vals[2], work[2], op[] = {NON_UNIFORM,GL_MIN,GL_MAX};
897: #endif


900: #ifdef DEBUG
901:   error_msg_warning("check_handle() :: start %d\n",n_xxt_handles);
902: #endif

904:   if (xxt_handle==NULL)
905:     {error_msg_fatal("check_handle() :: bad handle :: NULL %d\n",xxt_handle);}

907: #ifdef SAFE
908:   vals[0]=vals[1]=xxt_handle->id;
909:   giop(vals,work,sizeof(op)/sizeof(op[0])-1,op);
910:   if ((vals[0]!=vals[1])||(xxt_handle->id<=0))
911:     {error_msg_fatal("check_handle() :: bad handle :: id mismatch min/max %d/%d %d\n",
912:                      vals[0],vals[1], xxt_handle->id);}
913: #endif

915: #ifdef DEBUG
916:   error_msg_warning("check_handle() :: end   %d\n",n_xxt_handles);
917: #endif
918: }


921: /*************************************xxt.c************************************
922: Function: det_separators

924: Input :
925: Output:
926: Return:
927: Description:
928:   det_separators(xxt_handle, local2global, n, m, mylocmatvec, grid_data);
929: **************************************xxt.c***********************************/
930: static 
931: void 
932: det_separators(xxt_ADT xxt_handle)
933: {
934:   int i, ct, id;
935:   int mask, edge, *iptr;
936:   int *dir, *used;
937:   int sum[4], w[4];
938:   REAL rsum[4], rw[4];
939:   int op[] = {GL_ADD,0};
940:   REAL *lhs, *rhs;
941:   int *nsep, *lnsep, *fo, nfo=0;
942:   gs_ADT gs_handle=xxt_handle->mvi->gs_handle;
943:   int *local2global=xxt_handle->mvi->local2global;
944:   int  n=xxt_handle->mvi->n;
945:   int  m=xxt_handle->mvi->m;
946:   int level=xxt_handle->level;
947:   int shared=FALSE;

949: #ifdef DEBUG
950:   error_msg_warning("det_separators() :: start %d %d %d\n",level,n,m);
951: #endif
952: 
953:   dir  = (int*)bss_malloc(INT_LEN*(level+1));
954:   nsep = (int*)bss_malloc(INT_LEN*(level+1));
955:   lnsep= (int*)bss_malloc(INT_LEN*(level+1));
956:   fo   = (int*)bss_malloc(INT_LEN*(n+1));
957:   used = (int*)bss_malloc(INT_LEN*n);

959:   ivec_zero(dir  ,level+1);
960:   ivec_zero(nsep ,level+1);
961:   ivec_zero(lnsep,level+1);
962:   ivec_set (fo   ,-1,n+1);
963:   ivec_zero(used,n);

965:   lhs  = (double*)bss_malloc(REAL_LEN*m);
966:   rhs  = (double*)bss_malloc(REAL_LEN*m);

968:   /* determine the # of unique dof */
969:   rvec_zero(lhs,m);
970:   rvec_set(lhs,1.0,n);
971:   gs_gop_hc(gs_handle,lhs,"+\0",level);
972:   rvec_zero(rsum,2);
973:   for (ct=i=0;i<n;i++)
974:     {
975:       if (lhs[i]!=0.0)
976:         {rsum[0]+=1.0/lhs[i]; rsum[1]+=lhs[i];}
977:     }
978:   grop_hc(rsum,rw,2,op,level);
979:   rsum[0]+=0.1;
980:   rsum[1]+=0.1;
981:   /*  if (!my_id)
982:     {
983:       printf("xxt n unique = %d (%g)\n",(int) rsum[0], rsum[0]);
984:       printf("xxt n shared = %d (%g)\n",(int) rsum[1], rsum[1]);
985:       }*/

987:   if (fabs(rsum[0]-rsum[1])>EPS)
988:     {shared=TRUE;}

990:   xxt_handle->info->n_global=xxt_handle->info->m_global=(int) rsum[0];
991:   xxt_handle->mvi->n_global =xxt_handle->mvi->m_global =(int) rsum[0];

993:   /* determine separator sets top down */
994:   if (shared)
995:     {
996:       for (iptr=fo+n,id=my_id,mask=num_nodes>>1,edge=level;edge>0;edge--,mask>>=1)
997:         {
998:           /* set rsh of hc, fire, and collect lhs responses */
999:           (id<mask) ? rvec_zero(lhs,m) : rvec_set(lhs,1.0,m);
1000:           gs_gop_hc(gs_handle,lhs,"+\0",edge);
1001: 
1002:           /* set lsh of hc, fire, and collect rhs responses */
1003:           (id<mask) ? rvec_set(rhs,1.0,m) : rvec_zero(rhs,m);
1004:           gs_gop_hc(gs_handle,rhs,"+\0",edge);
1005: 
1006:           for (i=0;i<n;i++)
1007:             {
1008:               if (id< mask)
1009:                 {
1010:                   if (lhs[i]!=0.0)
1011:                     {lhs[i]=1.0;}
1012:                 }
1013:               if (id>=mask)
1014:                 {
1015:                   if (rhs[i]!=0.0)
1016:                     {rhs[i]=1.0;}
1017:                 }
1018:             }

1020:           if (id< mask)
1021:             {gs_gop_hc(gs_handle,lhs,"+\0",edge-1);}
1022:           else
1023:             {gs_gop_hc(gs_handle,rhs,"+\0",edge-1);}

1025:           /* count number of dofs I own that have signal and not in sep set */
1026:           rvec_zero(rsum,4);
1027:           for (ivec_zero(sum,4),ct=i=0;i<n;i++)
1028:             {
1029:               if (!used[i])
1030:                 {
1031:                   /* number of unmarked dofs on node */
1032:                   ct++;
1033:                   /* number of dofs to be marked on lhs hc */
1034:                   if (id< mask)
1035:                     {
1036:                       if (lhs[i]!=0.0)
1037:                         {sum[0]++; rsum[0]+=1.0/lhs[i];}
1038:                     }
1039:                   /* number of dofs to be marked on rhs hc */
1040:                   if (id>=mask)
1041:                     {
1042:                       if (rhs[i]!=0.0)
1043:                         {sum[1]++; rsum[1]+=1.0/rhs[i];}
1044:                     }
1045:                 }
1046:             }

1048:           /* go for load balance - choose half with most unmarked dofs, bias LHS */
1049:           (id<mask) ? (sum[2]=ct) : (sum[3]=ct);
1050:           (id<mask) ? (rsum[2]=ct) : (rsum[3]=ct);
1051:           giop_hc(sum,w,4,op,edge);
1052:           grop_hc(rsum,rw,4,op,edge);
1053:           rsum[0]+=0.1; rsum[1]+=0.1; rsum[2]+=0.1; rsum[3]+=0.1;

1055:           if (id<mask)
1056:             {
1057:               /* mark dofs I own that have signal and not in sep set */
1058:               for (ct=i=0;i<n;i++)
1059:                 {
1060:                   if ((!used[i])&&(lhs[i]!=0.0))
1061:                     {
1062:                       ct++; nfo++;

1064:                       if (nfo>n)
1065:                         {error_msg_fatal("nfo about to exceed n\n");}

1067:                       *--iptr = local2global[i];
1068:                       used[i]=edge;
1069:                     }
1070:                 }
1071:               if (ct>1) {ivec_sort(iptr,ct);}

1073:               lnsep[edge]=ct;
1074:               nsep[edge]=(int) rsum[0];
1075:               dir [edge]=LEFT;
1076:             }

1078:           if (id>=mask)
1079:             {
1080:               /* mark dofs I own that have signal and not in sep set */
1081:               for (ct=i=0;i<n;i++)
1082:                 {
1083:                   if ((!used[i])&&(rhs[i]!=0.0))
1084:                     {
1085:                       ct++; nfo++;

1087:                       if (nfo>n)
1088:                         {error_msg_fatal("nfo about to exceed n\n");}

1090:                       *--iptr = local2global[i];
1091:                       used[i]=edge;
1092:                     }
1093:                 }
1094:               if (ct>1) {ivec_sort(iptr,ct);}

1096:               lnsep[edge]=ct;
1097:               nsep[edge]= (int) rsum[1];
1098:               dir [edge]=RIGHT;
1099:             }

1101:           /* LATER or we can recur on these to order seps at this level */
1102:           /* do we need full set of separators for this?                */

1104:           /* fold rhs hc into lower */
1105:           if (id>=mask)
1106:             {id-=mask;}
1107:         }
1108:     }
1109:   else
1110:     {
1111:       for (iptr=fo+n,id=my_id,mask=num_nodes>>1,edge=level;edge>0;edge--,mask>>=1)
1112:         {
1113:           /* set rsh of hc, fire, and collect lhs responses */
1114:           (id<mask) ? rvec_zero(lhs,m) : rvec_set(lhs,1.0,m);
1115:           gs_gop_hc(gs_handle,lhs,"+\0",edge);

1117:           /* set lsh of hc, fire, and collect rhs responses */
1118:           (id<mask) ? rvec_set(rhs,1.0,m) : rvec_zero(rhs,m);
1119:           gs_gop_hc(gs_handle,rhs,"+\0",edge);

1121:           /* count number of dofs I own that have signal and not in sep set */
1122:           for (ivec_zero(sum,4),ct=i=0;i<n;i++)
1123:             {
1124:               if (!used[i])
1125:                 {
1126:                   /* number of unmarked dofs on node */
1127:                   ct++;
1128:                   /* number of dofs to be marked on lhs hc */
1129:                   if ((id< mask)&&(lhs[i]!=0.0)) {sum[0]++;}
1130:                   /* number of dofs to be marked on rhs hc */
1131:                   if ((id>=mask)&&(rhs[i]!=0.0)) {sum[1]++;}
1132:                 }
1133:             }

1135:           /* go for load balance - choose half with most unmarked dofs, bias LHS */
1136:           (id<mask) ? (sum[2]=ct) : (sum[3]=ct);
1137:           giop_hc(sum,w,4,op,edge);

1139:           /* lhs hc wins */
1140:           if (sum[2]>=sum[3])
1141:             {
1142:               if (id<mask)
1143:                 {
1144:                   /* mark dofs I own that have signal and not in sep set */
1145:                   for (ct=i=0;i<n;i++)
1146:                     {
1147:                       if ((!used[i])&&(lhs[i]!=0.0))
1148:                         {
1149:                           ct++; nfo++;
1150:                           *--iptr = local2global[i];
1151:                           used[i]=edge;
1152:                         }
1153:                     }
1154:                   if (ct>1) {ivec_sort(iptr,ct);}
1155:                   lnsep[edge]=ct;
1156:                 }
1157:               nsep[edge]=sum[0];
1158:               dir [edge]=LEFT;
1159:             }
1160:           /* rhs hc wins */
1161:           else
1162:             {
1163:               if (id>=mask)
1164:                 {
1165:                   /* mark dofs I own that have signal and not in sep set */
1166:                   for (ct=i=0;i<n;i++)
1167:                     {
1168:                       if ((!used[i])&&(rhs[i]!=0.0))
1169:                         {
1170:                           ct++; nfo++;
1171:                           *--iptr = local2global[i];
1172:                           used[i]=edge;
1173:                         }
1174:                     }
1175:                   if (ct>1) {ivec_sort(iptr,ct);}
1176:                   lnsep[edge]=ct;
1177:                 }
1178:               nsep[edge]=sum[1];
1179:               dir [edge]=RIGHT;
1180:             }
1181:           /* LATER or we can recur on these to order seps at this level */
1182:           /* do we need full set of separators for this?                */

1184:           /* fold rhs hc into lower */
1185:           if (id>=mask)
1186:             {id-=mask;}
1187:         }
1188:     }


1191:   /* level 0 is on processor case - so mark the remainder */
1192:   for (ct=i=0;i<n;i++)
1193:     {
1194:       if (!used[i])
1195:         {
1196:           ct++; nfo++;
1197:           *--iptr = local2global[i];
1198:           used[i]=edge;
1199:         }
1200:     }
1201:   if (ct>1) {ivec_sort(iptr,ct);}
1202:   lnsep[edge]=ct;
1203:   nsep [edge]=ct;
1204:   dir  [edge]=LEFT;

1206:   xxt_handle->info->nsep=nsep;
1207:   xxt_handle->info->lnsep=lnsep;
1208:   xxt_handle->info->fo=fo;
1209:   xxt_handle->info->nfo=nfo;

1211:   bss_free(dir);
1212:   bss_free(lhs);
1213:   bss_free(rhs);
1214:   bss_free(used);

1216: #ifdef DEBUG  
1217:   error_msg_warning("det_separators() :: end\n");
1218: #endif
1219: }


1222: /*************************************xxt.c************************************
1223: Function: set_mvi

1225: Input :
1226: Output:
1227: Return:
1228: Description:
1229: **************************************xxt.c***********************************/
1230: static
1231: mv_info *set_mvi(int *local2global, int n, int m, void *matvec, void *grid_data)
1232: {
1233:   mv_info *mvi;


1236: #ifdef DEBUG
1237:   error_msg_warning("set_mvi() :: start\n");
1238: #endif

1240:   mvi = (mv_info*)bss_malloc(sizeof(mv_info));
1241:   mvi->n=n;
1242:   mvi->m=m;
1243:   mvi->n_global=-1;
1244:   mvi->m_global=-1;
1245:   mvi->local2global=(int*)bss_malloc((m+1)*INT_LEN);
1246:   ivec_copy(mvi->local2global,local2global,m);
1247:   mvi->local2global[m] = INT_MAX;
1248:   mvi->matvec=(PetscErrorCode (*)(mv_info*,REAL*,REAL*))matvec;
1249:   mvi->grid_data=grid_data;

1251:   /* set xxt communication handle to perform restricted matvec */
1252:   mvi->gs_handle = gs_init(local2global, m, num_nodes);

1254: #ifdef DEBUG
1255:   error_msg_warning("set_mvi() :: end   \n");
1256: #endif
1257: 
1258:   return(mvi);
1259: }


1262: /*************************************xxt.c************************************
1263: Function: set_mvi

1265: Input :
1266: Output:
1267: Return:
1268: Description:

1270:       computes u = A.v 
1271:       do_matvec(xxt_handle->mvi,v,u);
1272: **************************************xxt.c***********************************/
1273: static
1274: void do_matvec(mv_info *A, REAL *v, REAL *u)
1275: {
1276:   A->matvec((mv_info*)A->grid_data,v,u);
1277: }