Actual source code: xyt.c
2: /*************************************xyt.c************************************
3: Module Name: xyt
4: Module Info:
6: author: Henry M. Tufo III
7: e-mail: hmt@asci.uchicago.edu
8: contact:
9: +--------------------------------+--------------------------------+
10: |MCS Division - Building 221 |Department of Computer Science |
11: |Argonne National Laboratory |Ryerson 152 |
12: |9700 S. Cass Avenue |The University of Chicago |
13: |Argonne, IL 60439 |Chicago, IL 60637 |
14: |(630) 252-5354/5986 ph/fx |(773) 702-6019/8487 ph/fx |
15: +--------------------------------+--------------------------------+
17: Last Modification: 3.20.01
18: **************************************xyt.c***********************************/
21: /*************************************xyt.c************************************
22: NOTES ON USAGE:
24: **************************************xyt.c***********************************/
25: #include <stdio.h>
26: #include <stdlib.h>
27: #include <limits.h>
28: #include <float.h>
29: #include <math.h>
31: #include petsc.h
33: #include const.h
34: #include types.h
35: #include comm.h
36: #include error.h
37: #include ivec.h
38: #include "bss_malloc.h"
39: #include queue.h
40: #include gs.h
41: #ifdef MLSRC
42: #include "ml_include.h"
43: #endif
44: #include blas.h
45: #include xyt.h
47: #define LEFT -1
48: #define RIGHT 1
49: #define BOTH 0
50: #define MAX_FORTRAN_HANDLES 10
52: typedef struct xyt_solver_info {
53: int n, m, n_global, m_global;
54: int nnz, max_nnz, msg_buf_sz;
55: int *nsep, *lnsep, *fo, nfo, *stages;
56: int *xcol_sz, *xcol_indices;
57: REAL **xcol_vals, *x, *solve_uu, *solve_w;
58: int *ycol_sz, *ycol_indices;
59: REAL **ycol_vals, *y;
60: int nsolves;
61: REAL tot_solve_time;
62: } xyt_info;
64:
65: typedef struct matvec_info {
66: int n, m, n_global, m_global;
67: int *local2global;
68: gs_ADT gs_handle;
69: PetscErrorCode (*matvec)(struct matvec_info*,REAL*,REAL*);
70: void *grid_data;
71: } mv_info;
73: struct xyt_CDT{
74: int id;
75: int ns;
76: int level;
77: xyt_info *info;
78: mv_info *mvi;
79: };
81: static int n_xyt=0;
82: static int n_xyt_handles=0;
84: /* prototypes */
85: static void do_xyt_solve(xyt_ADT xyt_handle, REAL *rhs);
86: static void check_init(void);
87: static void check_handle(xyt_ADT xyt_handle);
88: static void det_separators(xyt_ADT xyt_handle);
89: static void do_matvec(mv_info *A, REAL *v, REAL *u);
90: static int xyt_generate(xyt_ADT xyt_handle);
91: static int do_xyt_factor(xyt_ADT xyt_handle);
92: static mv_info *set_mvi(int *local2global, int n, int m, void *matvec, void *grid_data);
93: #ifdef MLSRC
94: void ML_XYT_solve(xyt_ADT xyt_handle, int lx, double *x, int lb, double *b);
95: PetscErrorCode ML_XYT_factor(xyt_ADT xyt_handle, int *local2global, int n, int m,
96: void *matvec, void *grid_data, int grid_tag, ML *my_ml);
97: #endif
100: /*************************************xyt.c************************************
101: Function: XYT_new()
103: Input :
104: Output:
105: Return:
106: Description:
107: **************************************xyt.c***********************************/
108: xyt_ADT
109: XYT_new(void)
110: {
111: xyt_ADT xyt_handle;
114: #ifdef DEBUG
115: error_msg_warning("XYT_new() :: start %d\n",n_xyt_handles);
116: #endif
118: /* rolling count on n_xyt ... pot. problem here */
119: n_xyt_handles++;
120: xyt_handle = (xyt_ADT)bss_malloc(sizeof(struct xyt_CDT));
121: xyt_handle->id = ++n_xyt;
122: xyt_handle->info = NULL;
123: xyt_handle->mvi = NULL;
125: #ifdef DEBUG
126: error_msg_warning("XYT_new() :: end %d\n",n_xyt_handles);
127: #endif
129: return(xyt_handle);
130: }
133: /*************************************xyt.c************************************
134: Function: XYT_factor()
136: Input :
137: Output:
138: Return:
139: Description:
140: **************************************xyt.c***********************************/
141: int
142: XYT_factor(xyt_ADT xyt_handle, /* prev. allocated xyt handle */
143: int *local2global, /* global column mapping */
144: int n, /* local num rows */
145: int m, /* local num cols */
146: void *matvec, /* b_loc=A_local.x_loc */
147: void *grid_data /* grid data for matvec */
148: )
149: {
150: #ifdef DEBUG
151: int flag;
154: error_msg_warning("XYT_factor() :: start %d\n",n_xyt_handles);
155: #endif
157: check_init();
158: check_handle(xyt_handle);
160: /* only 2^k for now and all nodes participating */
161: if ((1<<(xyt_handle->level=i_log2_num_nodes))!=num_nodes)
162: {error_msg_fatal("only 2^k for now and MPI_COMM_WORLD!!! %d != %d\n",1<<i_log2_num_nodes,num_nodes);}
164: /* space for X info */
165: xyt_handle->info = (xyt_info*)bss_malloc(sizeof(xyt_info));
167: /* set up matvec handles */
168: xyt_handle->mvi = set_mvi(local2global, n, m, matvec, grid_data);
170: /* matrix is assumed to be of full rank */
171: /* LATER we can reset to indicate rank def. */
172: xyt_handle->ns=0;
174: /* determine separators and generate firing order - NB xyt info set here */
175: det_separators(xyt_handle);
177: #ifdef DEBUG
178: flag = do_xyt_factor(xyt_handle);
179: error_msg_warning("XYT_factor() :: end %d (flag=%d)\n",n_xyt_handles,flag);
180: return(flag);
181: #else
182: return(do_xyt_factor(xyt_handle));
183: #endif
184: }
187: /*************************************xyt.c************************************
188: Function: XYT_solve
190: Input :
191: Output:
192: Return:
193: Description:
194: **************************************xyt.c***********************************/
195: int
196: XYT_solve(xyt_ADT xyt_handle, double *x, double *b)
197: {
198: #if defined( NXSRC) && defined(TIMING)
199: double dclock(), time=0.0;
200: #elif defined(MPISRC) && defined(TIMING)
201: double MPI_Wtime(), time=0.0;
202: #endif
203: #ifdef INFO
204: REAL vals[3], work[3];
205: int op[] = {NON_UNIFORM,GL_MIN,GL_MAX,GL_ADD};
206: #endif
209: #ifdef DEBUG
210: error_msg_warning("XYT_solve() :: start %d\n",n_xyt_handles);
211: #endif
213: check_init();
214: check_handle(xyt_handle);
216: /* need to copy b into x? */
217: if (b)
218: {rvec_copy(x,b,xyt_handle->mvi->n);}
219: do_xyt_solve(xyt_handle,x);
221: #ifdef DEBUG
222: error_msg_warning("XYT_solve() :: end %d\n",n_xyt_handles);
223: #endif
225: return(0);
226: }
229: /*************************************xyt.c************************************
230: Function: XYT_free()
232: Input :
233: Output:
234: Return:
235: Description:
236: **************************************xyt.c***********************************/
237: int
238: XYT_free(xyt_ADT xyt_handle)
239: {
240: #ifdef DEBUG
241: error_msg_warning("XYT_free() :: start %d\n",n_xyt_handles);
242: #endif
244: check_init();
245: check_handle(xyt_handle);
246: n_xyt_handles--;
248: bss_free(xyt_handle->info->nsep);
249: bss_free(xyt_handle->info->lnsep);
250: bss_free(xyt_handle->info->fo);
251: bss_free(xyt_handle->info->stages);
252: bss_free(xyt_handle->info->solve_uu);
253: bss_free(xyt_handle->info->solve_w);
254: bss_free(xyt_handle->info->x);
255: bss_free(xyt_handle->info->xcol_vals);
256: bss_free(xyt_handle->info->xcol_sz);
257: bss_free(xyt_handle->info->xcol_indices);
258: bss_free(xyt_handle->info->y);
259: bss_free(xyt_handle->info->ycol_vals);
260: bss_free(xyt_handle->info->ycol_sz);
261: bss_free(xyt_handle->info->ycol_indices);
262: bss_free(xyt_handle->info);
263: bss_free(xyt_handle->mvi->local2global);
264: gs_free(xyt_handle->mvi->gs_handle);
265: bss_free(xyt_handle->mvi);
266: bss_free(xyt_handle);
268:
269: #ifdef DEBUG
270: error_msg_warning("perm frees = %d\n",perm_frees());
271: error_msg_warning("perm calls = %d\n",perm_calls());
272: error_msg_warning("bss frees = %d\n",bss_frees());
273: error_msg_warning("bss calls = %d\n",bss_calls());
274: error_msg_warning("XYT_free() :: end %d\n",n_xyt_handles);
275: #endif
277: /* if the check fails we nuke */
278: /* if NULL pointer passed to bss_free we nuke */
279: /* if the calls to free fail that's not my problem */
280: return(0);
281: }
284: #ifdef MLSRC
285: /*************************************xyt.c************************************
286: Function: ML_XYT_factor()
288: Input :
289: Output:
290: Return:
291: Description:
293: ML requires that the solver call be checked in
294: **************************************xyt.c***********************************/
295: PetscErrorCode
296: ML_XYT_factor(xyt_ADT xyt_handle, /* prev. allocated xyt handle */
297: int *local2global, /* global column mapping */
298: int n, /* local num rows */
299: int m, /* local num cols */
300: void *matvec, /* b_loc=A_local.x_loc */
301: void *grid_data, /* grid data for matvec */
302: int grid_tag, /* grid tag for ML_Set_CSolve */
303: ML *my_ml /* ML handle */
304: )
305: {
306: #ifdef DEBUG
307: int flag;
308: #endif
311: #ifdef DEBUG
312: error_msg_warning("ML_XYT_factor() :: start %d\n",n_xyt_handles);
313: #endif
315: check_init();
316: check_handle(xyt_handle);
317: if (my_ml->comm->ML_mypid!=my_id)
318: {error_msg_fatal("ML_XYT_factor bad my_id %d\t%d\n",
319: my_ml->comm->ML_mypid,my_id);}
320: if (my_ml->comm->ML_nprocs!=num_nodes)
321: {error_msg_fatal("ML_XYT_factor bad np %d\t%d\n",
322: my_ml->comm->ML_nprocs,num_nodes);}
324: my_ml->SingleLevel[grid_tag].csolve->func->external = ML_XYT_solve;
325: my_ml->SingleLevel[grid_tag].csolve->func->ML_id = ML_EXTERNAL;
326: my_ml->SingleLevel[grid_tag].csolve->data = xyt_handle;
328: /* done ML specific stuff ... back to reg sched pgm */
329: #ifdef DEBUG
330: flag = XYT_factor(xyt_handle, local2global, n, m, matvec, grid_data);
331: error_msg_warning("ML_XYT_factor() :: end %d (flag=%d)\n",n_xyt_handles,flag);
332: return(flag);
333: #else
334: return(XYT_factor(xyt_handle, local2global, n, m, matvec, grid_data));
335: #endif
336: }
339: /*************************************xyt.c************************************
340: Function: ML_XYT_solve
342: Input :
343: Output:
344: Return:
345: Description:
346: **************************************xyt.c***********************************/
347: void
348: ML_XYT_solve(xyt_ADT xyt_handle, int lx, double *sol, int lb, double *rhs)
349: {
350: XYT_solve(xyt_handle, sol, rhs);
351: }
352: #endif
355: /*************************************xyt.c************************************
356: Function:
358: Input :
359: Output:
360: Return:
361: Description:
362: **************************************xyt.c***********************************/
363: int
364: XYT_stats(xyt_ADT xyt_handle)
365: {
366: int op[] = {NON_UNIFORM,GL_MIN,GL_MAX,GL_ADD,GL_MIN,GL_MAX,GL_ADD,GL_MIN,GL_MAX,GL_ADD};
367: int fop[] = {NON_UNIFORM,GL_MIN,GL_MAX,GL_ADD};
368: int vals[9], work[9];
369: REAL fvals[3], fwork[3];
372: #ifdef DEBUG
373: error_msg_warning("xyt_stats() :: begin\n");
374: #endif
376: check_init();
377: check_handle(xyt_handle);
379: /* if factorization not done there are no stats */
380: if (!xyt_handle->info||!xyt_handle->mvi)
381: {
382: if (!my_id)
383: {printf("XYT_stats() :: no stats available!\n");}
384: return 1;
385: }
387: vals[0]=vals[1]=vals[2]=xyt_handle->info->nnz;
388: vals[3]=vals[4]=vals[5]=xyt_handle->mvi->n;
389: vals[6]=vals[7]=vals[8]=xyt_handle->info->msg_buf_sz;
390: giop(vals,work,sizeof(op)/sizeof(op[0])-1,op);
392: fvals[0]=fvals[1]=fvals[2]
393: =xyt_handle->info->tot_solve_time/xyt_handle->info->nsolves++;
394: grop(fvals,fwork,sizeof(fop)/sizeof(fop[0])-1,fop);
396: if (!my_id)
397: {
398: printf("%d :: min xyt_nnz=%d\n",my_id,vals[0]);
399: printf("%d :: max xyt_nnz=%d\n",my_id,vals[1]);
400: printf("%d :: avg xyt_nnz=%g\n",my_id,1.0*vals[2]/num_nodes);
401: printf("%d :: tot xyt_nnz=%d\n",my_id,vals[2]);
402: printf("%d :: xyt C(2d) =%g\n",my_id,vals[2]/(pow(1.0*vals[5],1.5)));
403: printf("%d :: xyt C(3d) =%g\n",my_id,vals[2]/(pow(1.0*vals[5],1.6667)));
404: printf("%d :: min xyt_n =%d\n",my_id,vals[3]);
405: printf("%d :: max xyt_n =%d\n",my_id,vals[4]);
406: printf("%d :: avg xyt_n =%g\n",my_id,1.0*vals[5]/num_nodes);
407: printf("%d :: tot xyt_n =%d\n",my_id,vals[5]);
408: printf("%d :: min xyt_buf=%d\n",my_id,vals[6]);
409: printf("%d :: max xyt_buf=%d\n",my_id,vals[7]);
410: printf("%d :: avg xyt_buf=%g\n",my_id,1.0*vals[8]/num_nodes);
411: printf("%d :: min xyt_slv=%g\n",my_id,fvals[0]);
412: printf("%d :: max xyt_slv=%g\n",my_id,fvals[1]);
413: printf("%d :: avg xyt_slv=%g\n",my_id,fvals[2]/num_nodes);
414: }
416: #ifdef DEBUG
417: error_msg_warning("xyt_stats() :: end\n");
418: #endif
420: return(0);
421: }
424: /*************************************xyt.c************************************
425: Function: do_xyt_factor
427: Input :
428: Output:
429: Return:
430: Description: get A_local, local portion of global coarse matrix which
431: is a row dist. nxm matrix w/ n<m.
432: o my_ml holds address of ML struct associated w/A_local and coarse grid
433: o local2global holds global number of column i (i=0,...,m-1)
434: o local2global holds global number of row i (i=0,...,n-1)
435: o mylocmatvec performs A_local . vec_local (note that gs is performed using
436: gs_init/gop).
438: mylocmatvec = my_ml->Amat[grid_tag].matvec->external;
439: mylocmatvec (void :: void *data, double *in, double *out)
440: **************************************xyt.c***********************************/
441: static
442: int
443: do_xyt_factor(xyt_ADT xyt_handle)
444: {
445: int flag;
448: #ifdef DEBUG
449: error_msg_warning("do_xyt_factor() :: begin\n");
450: #endif
452: flag=xyt_generate(xyt_handle);
454: #ifdef INFO
455: XYT_stats(xyt_handle);
456: bss_stats();
457: perm_stats();
458: #endif
460: #ifdef DEBUG
461: error_msg_warning("do_xyt_factor() :: end\n");
462: #endif
464: return(flag);
465: }
468: /*************************************xyt.c************************************
469: Function:
471: Input :
472: Output:
473: Return:
474: Description:
475: **************************************xyt.c***********************************/
476: static
477: int
478: xyt_generate(xyt_ADT xyt_handle)
479: {
480: int i,j,k,idx;
481: int dim, col;
482: REAL *u, *uu, *v, *z, *w, alpha, alpha_w;
483: int *segs;
484: int op[] = {GL_ADD,0};
485: int off, len;
486: REAL *x_ptr, *y_ptr;
487: int *iptr, flag;
488: int start=0, end, work;
489: int op2[] = {GL_MIN,0};
490: gs_ADT gs_handle;
491: int *nsep, *lnsep, *fo;
492: int a_n=xyt_handle->mvi->n;
493: int a_m=xyt_handle->mvi->m;
494: int *a_local2global=xyt_handle->mvi->local2global;
495: int level;
496: int n, m;
497: int *xcol_sz, *xcol_indices, *stages;
498: REAL **xcol_vals, *x;
499: int *ycol_sz, *ycol_indices;
500: REAL **ycol_vals, *y;
501: int n_global;
502: int xt_nnz=0, xt_max_nnz=0;
503: int yt_nnz=0, yt_max_nnz=0;
504: int xt_zero_nnz =0;
505: int xt_zero_nnz_0=0;
506: int yt_zero_nnz =0;
507: int yt_zero_nnz_0=0;
510: #ifdef DEBUG
511: error_msg_warning("xyt_generate() :: begin\n");
512: #endif
514: n=xyt_handle->mvi->n;
515: nsep=xyt_handle->info->nsep;
516: lnsep=xyt_handle->info->lnsep;
517: fo=xyt_handle->info->fo;
518: end=lnsep[0];
519: level=xyt_handle->level;
520: gs_handle=xyt_handle->mvi->gs_handle;
522: /* is there a null space? */
523: /* LATER add in ability to detect null space by checking alpha */
524: for (i=0, j=0; i<=level; i++)
525: {j+=nsep[i];}
527: m = j-xyt_handle->ns;
528: if (m!=j)
529: {printf("xyt_generate() :: null space exists %d %d %d\n",m,j,xyt_handle->ns);}
531: error_msg_warning("xyt_generate() :: X(%d,%d)\n",n,m);
533: /* get and initialize storage for x local */
534: /* note that x local is nxm and stored by columns */
535: xcol_sz = (int*) bss_malloc(m*INT_LEN);
536: xcol_indices = (int*) bss_malloc((2*m+1)*sizeof(int));
537: xcol_vals = (REAL **) bss_malloc(m*sizeof(REAL *));
538: for (i=j=0; i<m; i++, j+=2)
539: {
540: xcol_indices[j]=xcol_indices[j+1]=xcol_sz[i]=-1;
541: xcol_vals[i] = NULL;
542: }
543: xcol_indices[j]=-1;
545: /* get and initialize storage for y local */
546: /* note that y local is nxm and stored by columns */
547: ycol_sz = (int*) bss_malloc(m*INT_LEN);
548: ycol_indices = (int*) bss_malloc((2*m+1)*sizeof(int));
549: ycol_vals = (REAL **) bss_malloc(m*sizeof(REAL *));
550: for (i=j=0; i<m; i++, j+=2)
551: {
552: ycol_indices[j]=ycol_indices[j+1]=ycol_sz[i]=-1;
553: ycol_vals[i] = NULL;
554: }
555: ycol_indices[j]=-1;
557: /* size of separators for each sub-hc working from bottom of tree to top */
558: /* this looks like nsep[]=segments */
559: stages = (int*) bss_malloc((level+1)*INT_LEN);
560: segs = (int*) bss_malloc((level+1)*INT_LEN);
561: ivec_zero(stages,level+1);
562: ivec_copy(segs,nsep,level+1);
563: for (i=0; i<level; i++)
564: {segs[i+1] += segs[i];}
565: stages[0] = segs[0];
567: /* temporary vectors */
568: u = (REAL *) bss_malloc(n*sizeof(REAL));
569: z = (REAL *) bss_malloc(n*sizeof(REAL));
570: v = (REAL *) bss_malloc(a_m*sizeof(REAL));
571: uu = (REAL *) bss_malloc(m*sizeof(REAL));
572: w = (REAL *) bss_malloc(m*sizeof(REAL));
574: /* extra nnz due to replication of vertices across separators */
575: for (i=1, j=0; i<=level; i++)
576: {j+=nsep[i];}
578: /* storage for sparse x values */
579: n_global = xyt_handle->info->n_global;
580: xt_max_nnz = yt_max_nnz = (int)(2.5*pow(1.0*n_global,1.6667) + j*n/2)/num_nodes;
581: x = (REAL *) bss_malloc(xt_max_nnz*sizeof(REAL));
582: y = (REAL *) bss_malloc(yt_max_nnz*sizeof(REAL));
584: /* LATER - can embed next sep to fire in gs */
585: /* time to make the donuts - generate X factor */
586: for (dim=i=j=0;i<m;i++)
587: {
588: /* time to move to the next level? */
589: while (i==segs[dim])
590: {
591: #ifdef SAFE
592: if (dim==level)
593: {error_msg_fatal("dim about to exceed level\n"); break;}
594: #endif
596: stages[dim++]=i;
597: end+=lnsep[dim];
598: }
599: stages[dim]=i;
601: /* which column are we firing? */
602: /* i.e. set v_l */
603: /* use new seps and do global min across hc to determine which one to fire */
604: (start<end) ? (col=fo[start]) : (col=INT_MAX);
605: giop_hc(&col,&work,1,op2,dim);
607: /* shouldn't need this */
608: if (col==INT_MAX)
609: {
610: error_msg_warning("hey ... col==INT_MAX??\n");
611: continue;
612: }
614: /* do I own it? I should */
615: rvec_zero(v ,a_m);
616: if (col==fo[start])
617: {
618: start++;
619: idx=ivec_linear_search(col, a_local2global, a_n);
620: if (idx!=-1)
621: {v[idx] = 1.0; j++;}
622: else
623: {error_msg_fatal("NOT FOUND!\n");}
624: }
625: else
626: {
627: idx=ivec_linear_search(col, a_local2global, a_m);
628: if (idx!=-1)
629: {v[idx] = 1.0;}
630: }
632: /* perform u = A.v_l */
633: rvec_zero(u,n);
634: do_matvec(xyt_handle->mvi,v,u);
636: /* uu = X^T.u_l (local portion) */
637: /* technically only need to zero out first i entries */
638: /* later turn this into an XYT_solve call ? */
639: rvec_zero(uu,m);
640: y_ptr=y;
641: iptr = ycol_indices;
642: for (k=0; k<i; k++)
643: {
644: off = *iptr++;
645: len = *iptr++;
647: #if BLAS||CBLAS
648: uu[k] = dot(len,u+off,1,y_ptr,1);
649: #else
650: uu[k] = rvec_dot(u+off,y_ptr,len);
651: #endif
652: y_ptr+=len;
653: }
655: /* uu = X^T.u_l (comm portion) */
656: ssgl_radd (uu, w, dim, stages);
658: /* z = X.uu */
659: rvec_zero(z,n);
660: x_ptr=x;
661: iptr = xcol_indices;
662: for (k=0; k<i; k++)
663: {
664: off = *iptr++;
665: len = *iptr++;
667: #if BLAS||CBLAS
668: axpy(len,uu[k],x_ptr,1,z+off,1);
669: #else
670: rvec_axpy(z+off,x_ptr,uu[k],len);
671: #endif
672: x_ptr+=len;
673: }
675: /* compute v_l = v_l - z */
676: rvec_zero(v+a_n,a_m-a_n);
677: #if BLAS||CBLAS
678: axpy(n,-1.0,z,1,v,1);
679: #else
680: rvec_axpy(v,z,-1.0,n);
681: #endif
683: /* compute u_l = A.v_l */
684: if (a_n!=a_m)
685: {gs_gop_hc(gs_handle,v,"+\0",dim);}
686: rvec_zero(u,n);
687: do_matvec(xyt_handle->mvi,v,u);
689: /* compute sqrt(alpha) = sqrt(u_l^T.u_l) - local portion */
690: #if BLAS||CBLAS
691: alpha = ddot(n,u,1,u,1);
692: #else
693: alpha = rvec_dot(u,u,n);
694: #endif
695: /* compute sqrt(alpha) = sqrt(u_l^T.u_l) - comm portion */
696: grop_hc(&alpha, &alpha_w, 1, op, dim);
698: alpha = (REAL) sqrt((double)alpha);
700: /* check for small alpha */
701: /* LATER use this to detect and determine null space */
702: #ifdef tmpr8
703: if (fabs(alpha)<1.0e-14)
704: {error_msg_fatal("bad alpha! %g\n",alpha);}
705: #else
706: if (fabs((double) alpha) < 1.0e-6)
707: {error_msg_fatal("bad alpha! %g\n",alpha);}
708: #endif
710: /* compute v_l = v_l/sqrt(alpha) */
711: rvec_scale(v,1.0/alpha,n);
712: rvec_scale(u,1.0/alpha,n);
714: /* add newly generated column, v_l, to X */
715: flag = 1;
716: off=len=0;
717: for (k=0; k<n; k++)
718: {
719: if (v[k]!=0.0)
720: {
721: len=k;
722: if (flag)
723: {off=k; flag=0;}
724: }
725: }
727: len -= (off-1);
729: if (len>0)
730: {
731: if ((xt_nnz+len)>xt_max_nnz)
732: {
733: error_msg_warning("increasing space for X by 2x!\n");
734: xt_max_nnz *= 2;
735: x_ptr = (REAL *) bss_malloc(xt_max_nnz*sizeof(REAL));
736: rvec_copy(x_ptr,x,xt_nnz);
737: bss_free(x);
738: x = x_ptr;
739: x_ptr+=xt_nnz;
740: }
741: xt_nnz += len;
742: rvec_copy(x_ptr,v+off,len);
744: /* keep track of number of zeros */
745: if (dim)
746: {
747: for (k=0; k<len; k++)
748: {
749: if (x_ptr[k]==0.0)
750: {xt_zero_nnz++;}
751: }
752: }
753: else
754: {
755: for (k=0; k<len; k++)
756: {
757: if (x_ptr[k]==0.0)
758: {xt_zero_nnz_0++;}
759: }
760: }
761: xcol_indices[2*i] = off;
762: xcol_sz[i] = xcol_indices[2*i+1] = len;
763: xcol_vals[i] = x_ptr;
764: }
765: else
766: {
767: xcol_indices[2*i] = 0;
768: xcol_sz[i] = xcol_indices[2*i+1] = 0;
769: xcol_vals[i] = x_ptr;
770: }
773: /* add newly generated column, u_l, to Y */
774: flag = 1;
775: off=len=0;
776: for (k=0; k<n; k++)
777: {
778: if (u[k]!=0.0)
779: {
780: len=k;
781: if (flag)
782: {off=k; flag=0;}
783: }
784: }
786: len -= (off-1);
788: if (len>0)
789: {
790: if ((yt_nnz+len)>yt_max_nnz)
791: {
792: error_msg_warning("increasing space for Y by 2x!\n");
793: yt_max_nnz *= 2;
794: y_ptr = (REAL *) bss_malloc(yt_max_nnz*sizeof(REAL));
795: rvec_copy(y_ptr,y,yt_nnz);
796: bss_free(y);
797: y = y_ptr;
798: y_ptr+=yt_nnz;
799: }
800: yt_nnz += len;
801: rvec_copy(y_ptr,u+off,len);
803: /* keep track of number of zeros */
804: if (dim)
805: {
806: for (k=0; k<len; k++)
807: {
808: if (y_ptr[k]==0.0)
809: {yt_zero_nnz++;}
810: }
811: }
812: else
813: {
814: for (k=0; k<len; k++)
815: {
816: if (y_ptr[k]==0.0)
817: {yt_zero_nnz_0++;}
818: }
819: }
820: ycol_indices[2*i] = off;
821: ycol_sz[i] = ycol_indices[2*i+1] = len;
822: ycol_vals[i] = y_ptr;
823: }
824: else
825: {
826: ycol_indices[2*i] = 0;
827: ycol_sz[i] = ycol_indices[2*i+1] = 0;
828: ycol_vals[i] = y_ptr;
829: }
830: }
832: /* close off stages for execution phase */
833: while (dim!=level)
834: {
835: stages[dim++]=i;
836: error_msg_warning("disconnected!!! dim(%d)!=level(%d)\n",dim,level);
837: }
838: stages[dim]=i;
840: xyt_handle->info->n=xyt_handle->mvi->n;
841: xyt_handle->info->m=m;
842: xyt_handle->info->nnz=xt_nnz + yt_nnz;
843: xyt_handle->info->max_nnz=xt_max_nnz + yt_max_nnz;
844: xyt_handle->info->msg_buf_sz=stages[level]-stages[0];
845: xyt_handle->info->solve_uu = (REAL *) bss_malloc(m*sizeof(REAL));
846: xyt_handle->info->solve_w = (REAL *) bss_malloc(m*sizeof(REAL));
847: xyt_handle->info->x=x;
848: xyt_handle->info->xcol_vals=xcol_vals;
849: xyt_handle->info->xcol_sz=xcol_sz;
850: xyt_handle->info->xcol_indices=xcol_indices;
851: xyt_handle->info->stages=stages;
852: xyt_handle->info->y=y;
853: xyt_handle->info->ycol_vals=ycol_vals;
854: xyt_handle->info->ycol_sz=ycol_sz;
855: xyt_handle->info->ycol_indices=ycol_indices;
857: bss_free(segs);
858: bss_free(u);
859: bss_free(v);
860: bss_free(uu);
861: bss_free(z);
862: bss_free(w);
864: #ifdef DEBUG
865: error_msg_warning("xyt_generate() :: end\n");
866: #endif
868: return(0);
869: }
872: /*************************************xyt.c************************************
873: Function:
875: Input :
876: Output:
877: Return:
878: Description:
879: **************************************xyt.c***********************************/
880: static
881: void
882: do_xyt_solve(xyt_ADT xyt_handle, register REAL *uc)
883: {
884: register int off, len, *iptr;
885: int level =xyt_handle->level;
886: int n =xyt_handle->info->n;
887: int m =xyt_handle->info->m;
888: int *stages =xyt_handle->info->stages;
889: int *xcol_indices=xyt_handle->info->xcol_indices;
890: int *ycol_indices=xyt_handle->info->ycol_indices;
891: register REAL *x_ptr, *y_ptr, *uu_ptr;
892: #if BLAS||CBLAS
893: REAL zero=0.0;
894: #endif
895: REAL *solve_uu=xyt_handle->info->solve_uu;
896: REAL *solve_w =xyt_handle->info->solve_w;
897: REAL *x =xyt_handle->info->x;
898: REAL *y =xyt_handle->info->y;
900: #ifdef DEBUG
901: error_msg_warning("do_xyt_solve() :: begin\n");
902: #endif
904: uu_ptr=solve_uu;
905: #if BLAS||CBLAS
906: copy(m,&zero,0,uu_ptr,1);
907: #else
908: rvec_zero(uu_ptr,m);
909: #endif
911: /* x = X.Y^T.b */
912: /* uu = Y^T.b */
913: for (y_ptr=y,iptr=ycol_indices; *iptr!=-1; y_ptr+=len)
914: {
915: off=*iptr++; len=*iptr++;
916: #if BLAS||CBLAS
917: *uu_ptr++ = dot(len,uc+off,1,y_ptr,1);
918: #else
919: *uu_ptr++ = rvec_dot(uc+off,y_ptr,len);
920: #endif
921: }
923: /* comunication of beta */
924: uu_ptr=solve_uu;
925: if (level) {ssgl_radd(uu_ptr, solve_w, level, stages);}
927: #if BLAS&&CBLAS
928: copy(n,&zero,0,uc,1);
929: #else
930: rvec_zero(uc,n);
931: #endif
933: /* x = X.uu */
934: for (x_ptr=x,iptr=xcol_indices; *iptr!=-1; x_ptr+=len)
935: {
936: off=*iptr++; len=*iptr++;
937: #if BLAS&&CBLAS
938: axpy(len,*uu_ptr++,x_ptr,1,uc+off,1);
939: #else
940: rvec_axpy(uc+off,x_ptr,*uu_ptr++,len);
941: #endif
942: }
944: #ifdef DEBUG
945: error_msg_warning("do_xyt_solve() :: end\n");
946: #endif
947: }
950: /*************************************Xyt.c************************************
951: Function: check_init
953: Input :
954: Output:
955: Return:
956: Description:
957: **************************************xyt.c***********************************/
958: static
959: void
960: check_init(void)
961: {
962: #ifdef DEBUG
963: error_msg_warning("check_init() :: start %d\n",n_xyt_handles);
964: #endif
966: comm_init();
967: /*
968: perm_init();
969: bss_init();
970: */
972: #ifdef DEBUG
973: error_msg_warning("check_init() :: end %d\n",n_xyt_handles);
974: #endif
975: }
978: /*************************************xyt.c************************************
979: Function: check_handle()
981: Input :
982: Output:
983: Return:
984: Description:
985: **************************************xyt.c***********************************/
986: static
987: void
988: check_handle(xyt_ADT xyt_handle)
989: {
990: #ifdef SAFE
991: int vals[2], work[2], op[] = {NON_UNIFORM,GL_MIN,GL_MAX};
992: #endif
995: #ifdef DEBUG
996: error_msg_warning("check_handle() :: start %d\n",n_xyt_handles);
997: #endif
999: if (xyt_handle==NULL)
1000: {error_msg_fatal("check_handle() :: bad handle :: NULL %d\n",xyt_handle);}
1002: #ifdef SAFE
1003: vals[0]=vals[1]=xyt_handle->id;
1004: giop(vals,work,sizeof(op)/sizeof(op[0])-1,op);
1005: if ((vals[0]!=vals[1])||(xyt_handle->id<=0))
1006: {error_msg_fatal("check_handle() :: bad handle :: id mismatch min/max %d/%d %d\n",
1007: vals[0],vals[1], xyt_handle->id);}
1008: #endif
1010: #ifdef DEBUG
1011: error_msg_warning("check_handle() :: end %d\n",n_xyt_handles);
1012: #endif
1013: }
1016: /*************************************xyt.c************************************
1017: Function: det_separators
1019: Input :
1020: Output:
1021: Return:
1022: Description:
1023: det_separators(xyt_handle, local2global, n, m, mylocmatvec, grid_data);
1024: **************************************xyt.c***********************************/
1025: static
1026: void
1027: det_separators(xyt_ADT xyt_handle)
1028: {
1029: int i, ct, id;
1030: int mask, edge, *iptr;
1031: int *dir, *used;
1032: int sum[4], w[4];
1033: REAL rsum[4], rw[4];
1034: int op[] = {GL_ADD,0};
1035: REAL *lhs, *rhs;
1036: int *nsep, *lnsep, *fo, nfo=0;
1037: gs_ADT gs_handle=xyt_handle->mvi->gs_handle;
1038: int *local2global=xyt_handle->mvi->local2global;
1039: int n=xyt_handle->mvi->n;
1040: int m=xyt_handle->mvi->m;
1041: int level=xyt_handle->level;
1042: int shared=FALSE;
1044: #ifdef DEBUG
1045: error_msg_warning("det_separators() :: start %d %d %d\n",level,n,m);
1046: #endif
1047:
1048: dir = (int*)bss_malloc(INT_LEN*(level+1));
1049: nsep = (int*)bss_malloc(INT_LEN*(level+1));
1050: lnsep= (int*)bss_malloc(INT_LEN*(level+1));
1051: fo = (int*)bss_malloc(INT_LEN*(n+1));
1052: used = (int*)bss_malloc(INT_LEN*n);
1054: ivec_zero(dir ,level+1);
1055: ivec_zero(nsep ,level+1);
1056: ivec_zero(lnsep,level+1);
1057: ivec_set (fo ,-1,n+1);
1058: ivec_zero(used,n);
1060: lhs = (double*)bss_malloc(REAL_LEN*m);
1061: rhs = (double*)bss_malloc(REAL_LEN*m);
1063: /* determine the # of unique dof */
1064: rvec_zero(lhs,m);
1065: rvec_set(lhs,1.0,n);
1066: gs_gop_hc(gs_handle,lhs,"+\0",level);
1067: error_msg_warning("done first gs_gop_hc\n");
1068: rvec_zero(rsum,2);
1069: for (ct=i=0;i<n;i++)
1070: {
1071: if (lhs[i]!=0.0)
1072: {rsum[0]+=1.0/lhs[i]; rsum[1]+=lhs[i];}
1074: if (lhs[i]!=1.0)
1075: {
1076: shared=TRUE;
1077: }
1078: }
1080: grop_hc(rsum,rw,2,op,level);
1081: rsum[0]+=0.1;
1082: rsum[1]+=0.1;
1084: /*
1085: if (!my_id)
1086: {
1087: printf("xyt n unique = %d (%g)\n",(int) rsum[0], rsum[0]);
1088: printf("xyt n shared = %d (%g)\n",(int) rsum[1], rsum[1]);
1089: }
1090: */
1092: xyt_handle->info->n_global=xyt_handle->info->m_global=(int) rsum[0];
1093: xyt_handle->mvi->n_global =xyt_handle->mvi->m_global =(int) rsum[0];
1095: /* determine separator sets top down */
1096: if (shared)
1097: {
1098: /* solution is to do as in the symmetric shared case but then */
1099: /* pick the sub-hc with the most free dofs and do a mat-vec */
1100: /* and pick up the responses on the other sub-hc from the */
1101: /* initial separator set obtained from the symm. shared case */
1102: error_msg_fatal("shared dof separator determination not ready ... see hmt!!!\n");
1103: for (iptr=fo+n,id=my_id,mask=num_nodes>>1,edge=level;edge>0;edge--,mask>>=1)
1104: {
1105: /* set rsh of hc, fire, and collect lhs responses */
1106: (id<mask) ? rvec_zero(lhs,m) : rvec_set(lhs,1.0,m);
1107: gs_gop_hc(gs_handle,lhs,"+\0",edge);
1108:
1109: /* set lsh of hc, fire, and collect rhs responses */
1110: (id<mask) ? rvec_set(rhs,1.0,m) : rvec_zero(rhs,m);
1111: gs_gop_hc(gs_handle,rhs,"+\0",edge);
1112:
1113: for (i=0;i<n;i++)
1114: {
1115: if (id< mask)
1116: {
1117: if (lhs[i]!=0.0)
1118: {lhs[i]=1.0;}
1119: }
1120: if (id>=mask)
1121: {
1122: if (rhs[i]!=0.0)
1123: {rhs[i]=1.0;}
1124: }
1125: }
1127: if (id< mask)
1128: {gs_gop_hc(gs_handle,lhs,"+\0",edge-1);}
1129: else
1130: {gs_gop_hc(gs_handle,rhs,"+\0",edge-1);}
1132: /* count number of dofs I own that have signal and not in sep set */
1133: rvec_zero(rsum,4);
1134: for (ivec_zero(sum,4),ct=i=0;i<n;i++)
1135: {
1136: if (!used[i])
1137: {
1138: /* number of unmarked dofs on node */
1139: ct++;
1140: /* number of dofs to be marked on lhs hc */
1141: if (id< mask)
1142: {
1143: if (lhs[i]!=0.0)
1144: {sum[0]++; rsum[0]+=1.0/lhs[i];}
1145: }
1146: /* number of dofs to be marked on rhs hc */
1147: if (id>=mask)
1148: {
1149: if (rhs[i]!=0.0)
1150: {sum[1]++; rsum[1]+=1.0/rhs[i];}
1151: }
1152: }
1153: }
1155: /* go for load balance - choose half with most unmarked dofs, bias LHS */
1156: (id<mask) ? (sum[2]=ct) : (sum[3]=ct);
1157: (id<mask) ? (rsum[2]=ct) : (rsum[3]=ct);
1158: giop_hc(sum,w,4,op,edge);
1159: grop_hc(rsum,rw,4,op,edge);
1160: rsum[0]+=0.1; rsum[1]+=0.1; rsum[2]+=0.1; rsum[3]+=0.1;
1162: if (id<mask)
1163: {
1164: /* mark dofs I own that have signal and not in sep set */
1165: for (ct=i=0;i<n;i++)
1166: {
1167: if ((!used[i])&&(lhs[i]!=0.0))
1168: {
1169: ct++; nfo++;
1171: if (nfo>n)
1172: {error_msg_fatal("nfo about to exceed n\n");}
1174: *--iptr = local2global[i];
1175: used[i]=edge;
1176: }
1177: }
1178: if (ct>1) {ivec_sort(iptr,ct);}
1180: lnsep[edge]=ct;
1181: nsep[edge]=(int) rsum[0];
1182: dir [edge]=LEFT;
1183: }
1185: if (id>=mask)
1186: {
1187: /* mark dofs I own that have signal and not in sep set */
1188: for (ct=i=0;i<n;i++)
1189: {
1190: if ((!used[i])&&(rhs[i]!=0.0))
1191: {
1192: ct++; nfo++;
1194: if (nfo>n)
1195: {error_msg_fatal("nfo about to exceed n\n");}
1197: *--iptr = local2global[i];
1198: used[i]=edge;
1199: }
1200: }
1201: if (ct>1) {ivec_sort(iptr,ct);}
1203: lnsep[edge]=ct;
1204: nsep[edge]= (int) rsum[1];
1205: dir [edge]=RIGHT;
1206: }
1208: /* LATER or we can recur on these to order seps at this level */
1209: /* do we need full set of separators for this? */
1211: /* fold rhs hc into lower */
1212: if (id>=mask)
1213: {id-=mask;}
1214: }
1215: }
1216: else
1217: {
1218: for (iptr=fo+n,id=my_id,mask=num_nodes>>1,edge=level;edge>0;edge--,mask>>=1)
1219: {
1220: /* set rsh of hc, fire, and collect lhs responses */
1221: (id<mask) ? rvec_zero(lhs,m) : rvec_set(lhs,1.0,m);
1222: gs_gop_hc(gs_handle,lhs,"+\0",edge);
1224: /* set lsh of hc, fire, and collect rhs responses */
1225: (id<mask) ? rvec_set(rhs,1.0,m) : rvec_zero(rhs,m);
1226: gs_gop_hc(gs_handle,rhs,"+\0",edge);
1228: /* count number of dofs I own that have signal and not in sep set */
1229: for (ivec_zero(sum,4),ct=i=0;i<n;i++)
1230: {
1231: if (!used[i])
1232: {
1233: /* number of unmarked dofs on node */
1234: ct++;
1235: /* number of dofs to be marked on lhs hc */
1236: if ((id< mask)&&(lhs[i]!=0.0)) {sum[0]++;}
1237: /* number of dofs to be marked on rhs hc */
1238: if ((id>=mask)&&(rhs[i]!=0.0)) {sum[1]++;}
1239: }
1240: }
1242: /* for the non-symmetric case we need separators of width 2 */
1243: /* so take both sides */
1244: (id<mask) ? (sum[2]=ct) : (sum[3]=ct);
1245: giop_hc(sum,w,4,op,edge);
1247: ct=0;
1248: if (id<mask)
1249: {
1250: /* mark dofs I own that have signal and not in sep set */
1251: for (i=0;i<n;i++)
1252: {
1253: if ((!used[i])&&(lhs[i]!=0.0))
1254: {
1255: ct++; nfo++;
1256: *--iptr = local2global[i];
1257: used[i]=edge;
1258: }
1259: }
1260: /* LSH hc summation of ct should be sum[0] */
1261: }
1262: else
1263: {
1264: /* mark dofs I own that have signal and not in sep set */
1265: for (i=0;i<n;i++)
1266: {
1267: if ((!used[i])&&(rhs[i]!=0.0))
1268: {
1269: ct++; nfo++;
1270: *--iptr = local2global[i];
1271: used[i]=edge;
1272: }
1273: }
1274: /* RSH hc summation of ct should be sum[1] */
1275: }
1277: if (ct>1) {ivec_sort(iptr,ct);}
1278: lnsep[edge]=ct;
1279: nsep[edge]=sum[0]+sum[1];
1280: dir [edge]=BOTH;
1282: /* LATER or we can recur on these to order seps at this level */
1283: /* do we need full set of separators for this? */
1285: /* fold rhs hc into lower */
1286: if (id>=mask)
1287: {id-=mask;}
1288: }
1289: }
1291: /* level 0 is on processor case - so mark the remainder */
1292: for (ct=i=0;i<n;i++)
1293: {
1294: if (!used[i])
1295: {
1296: ct++; nfo++;
1297: *--iptr = local2global[i];
1298: used[i]=edge;
1299: }
1300: }
1301: if (ct>1) {ivec_sort(iptr,ct);}
1302: lnsep[edge]=ct;
1303: nsep [edge]=ct;
1304: dir [edge]=BOTH;
1306: xyt_handle->info->nsep=nsep;
1307: xyt_handle->info->lnsep=lnsep;
1308: xyt_handle->info->fo=fo;
1309: xyt_handle->info->nfo=nfo;
1311: bss_free(dir);
1312: bss_free(lhs);
1313: bss_free(rhs);
1314: bss_free(used);
1316: #ifdef DEBUG
1317: error_msg_warning("det_separators() :: end\n");
1318: #endif
1319: }
1322: /*************************************xyt.c************************************
1323: Function: set_mvi
1325: Input :
1326: Output:
1327: Return:
1328: Description:
1329: **************************************xyt.c***********************************/
1330: static
1331: mv_info *set_mvi(int *local2global, int n, int m, void *matvec, void *grid_data)
1332: {
1333: mv_info *mvi;
1336: #ifdef DEBUG
1337: error_msg_warning("set_mvi() :: start\n");
1338: #endif
1340: mvi = (mv_info*)bss_malloc(sizeof(mv_info));
1341: mvi->n=n;
1342: mvi->m=m;
1343: mvi->n_global=-1;
1344: mvi->m_global=-1;
1345: mvi->local2global=(int*)bss_malloc((m+1)*INT_LEN);
1346: ivec_copy(mvi->local2global,local2global,m);
1347: mvi->local2global[m] = INT_MAX;
1348: mvi->matvec=(PetscErrorCode (*)(mv_info*,REAL*,REAL*))matvec;
1349: mvi->grid_data=grid_data;
1351: /* set xyt communication handle to perform restricted matvec */
1352: mvi->gs_handle = gs_init(local2global, m, num_nodes);
1354: #ifdef DEBUG
1355: error_msg_warning("set_mvi() :: end \n");
1356: #endif
1357:
1358: return(mvi);
1359: }
1362: /*************************************xyt.c************************************
1363: Function: set_mvi
1365: Input :
1366: Output:
1367: Return:
1368: Description:
1370: computes u = A.v
1371: do_matvec(xyt_handle->mvi,v,u);
1372: **************************************xyt.c***********************************/
1373: static
1374: void do_matvec(mv_info *A, REAL *v, REAL *u)
1375: {
1376: A->matvec((mv_info*)A->grid_data,v,u);
1377: }