From 43ecfce7d20360a5fdc53e5ced266eccc8723242 Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Fri, 29 May 2020 14:41:05 -0700 Subject: blas code update --- include/libmath.h | 17 +++++++++++++---- include/libmath/blas.h | 46 ++++++++++++++++++++++++++-------------------- include/libn.h | 38 +++++++++++++++++--------------------- include/u.h | 1 + 4 files changed, 57 insertions(+), 45 deletions(-) (limited to 'include') diff --git a/include/libmath.h b/include/libmath.h index 40ae4ee..c605d24 100644 --- a/include/libmath.h +++ b/include/libmath.h @@ -134,17 +134,26 @@ float math·truncf(float); typedef struct math·Vector { - double *data; + double *d; int len; + int inc; } math·Vector; -#define math·slicev(vec, lo, hi) (struct math·Vector){.len=((hi)-(lo)), .data=((vec).data + (lo))} +#define iota(x) 1 << (x) +enum +{ + mat·trans = iota(1), + mat·symm = iota(2), + mat·posdef = iota(3), + mat·negdef = iota(4), +}; typedef struct math·Matrix { - double *data; - uint32 kind; + double *d; + uint state; int dim[2]; + int inc; } math·Matrix; // TODO: tensor ala numpy diff --git a/include/libmath/blas.h b/include/libmath/blas.h index 83acb2c..b8930a8 100644 --- a/include/libmath/blas.h +++ b/include/libmath/blas.h @@ -10,8 +10,6 @@ enum blas·LeftSide = 16u, }; -typedef uint32 blas·Flag; - /* * Floats */ @@ -31,17 +29,21 @@ int blas·fargmax(int len, float *x, int inc); int blas·fargmin(int len, float *x, int inc); // level 2 -void blas·tpmvf(blas·Flag f, int n, float *m, float *x); -error blas·gemvf(int nrow, int ncol, float a, float *m, int incm, float *x, int incx, float b, float *y, int incy) ; -void blas·tpsvf(blas·Flag f, int n, float *m, float *x); -void blas·gerf(int nrow, int ncol, float a, float *x, int incx, float *y, int incy, float *m, int incm); -void blas·herf(int n, float a, float *x, float *m); -void blas·syrf(int nrow, int ncol, float a, float *x, float *m); +error blas·fgemv(uint f, int nrow, int ncol, float a, float *m, int incm, float *x, int incx, float b, float *y, int incy); +void blas·fsymv(uint f, int n, float a, float* m, int incm, float *x, int incx, float b, float *y, int incy); +void blas·fspmv(uint f, int n, float a, float* m, float *x, int incx, float b, float *y, int incm); +void blas·ftrmv(uint f, int n, float* m, int incm, float *x, int incx); +void blas·ftpmv(uint f, int n, float *m, float *x, int incx); +void blas·ftrsv(uint f, int n, float *m, int incm, float *x, int incx); +void blas·ftpsv(uint f, int n, float *m, float *x, int incx); +void blas·fger(int nrow, int ncol, float a, float *x, int incx, float *y, int incy, float *m, int incm); +void blas·fsyr(uint f, int n, float a, float *x, int incx, float *m, int incm); +void blas·fspr(uint f, int n, float a, float *x, int incx, float *m); // level 3 -void blas·gemmf(int n1, int n2, int n3, float a, float *m1, float *m2, float b, float *m3); -void blas·trmmf(blas·Flag f, int nrow, int ncol, float a, float *m1, float *m2); -void blas·trsmf(blas·Flag f, int nrow, int ncol, float a, float *m1, float *m2); +void blas·fgemm(uint tr1, uint tr2, int n1, int n2, int n3, float a, float *m1, int inc1, float *m2, int inc2, float b, float *m3, int inc3); +void blas·ftrmm(uint f, int nrow, int ncol, float a, float *m1, float *m2); +void blas·ftrsm(uint f, int nrow, int ncol, float a, float *m1, float *m2); /* * Doubles @@ -62,17 +64,21 @@ int blas·dargmax(int len, double *x, int inc); int blas·dargmin(int len, double *x, int inc); // level 2 -void blas·tpmvd(blas·Flag f, int n, double *m, double *x); -error blas·gemvd(int nrow, int ncol, double a, double *m, int incm, double *x, int incx, double b, double *y, int incy) ; -void blas·tpsvd(blas·Flag f, int n, double *m, double *x); -void blas·gerd(int nrow, int ncol, double a, double *x, int incx, double *y, int incy, double *m, int incm); -void blas·herd(int n, double a, double *x, double *m); -void blas·syrd(int nrow, int ncol, double a, double *x, double *m); +error blas·dgemv(uint f, int nrow, int ncol, double a, double *m, int incm, double *x, int incx, double b, double *y, int incy); +void blas·dsymv(uint f, int n, double a, double* m, int incm, double *x, int incx, double b, double *y, int incy); +void blas·dspmv(uint f, int n, double a, double* m, double *x, int incx, double b, double *y, int incy); +void blas·dtrmv(uint f, int n, double *m, int incm, double *x, int incx); +void blas·dtpmv(uint f, int n, double *m, double *x, int incx); +void blas·dtrsv(uint f, int n, double *m, int incm, double *x, int incx); +void blas·dtpsv(uint f, int n, double *m, double *x, int incx); +void blas·dger(int nrow, int ncol, double a, double *x, int incx, double *y, int incy, double *m, int incm); +void blas·dsyr(uint flag, int n, double a, double *x, int incx, double *m, int incm); +void blas·dspr(uint flag, int n, double a, double *x, int incx, double *m); // level 3 -void blas·gemmd(int n1, int n2, int n3, double a, double *m1, double *m2, double b, double *m3); -void blas·trmmd(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2); -void blas·trsmd(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2); +void blas·dgemm(uint tr1, uint tr2, int n1, int n2, int n3, double a, double *m1, int inc1, double *m2, int inc2, double b, double *m3, int inc3); +void blas·dtrmm(uint f, int nrow, int ncol, double a, double *m1, double *m2); +void blas·dtrsm(uint f, int nrow, int ncol, double a, double *m1, double *m2); /* * TODO: Complex diff --git a/include/libn.h b/include/libn.h index 835d2c1..721defd 100644 --- a/include/libn.h +++ b/include/libn.h @@ -65,21 +65,17 @@ typedef struct mem·Allocator { void (*free)(void *iface, void *ptr); } mem·Allocator; -/* system implementation */ -static -void ·free(void* _, void* ptr) { - return free(ptr); -} - -static -void *·alloc(void* _, uint n, ulong size) { - return malloc(n*size); -} +typedef struct mem·Reallocator { + void *(*alloc)(void *iface, uint n, ulong size); + void *(*realloc)(void *iface, void *ptr, uint n, ulong size); + void (*free)(void *iface, void *ptr); +} mem·Reallocator; -static -void *·calloc(void* _, uint n, ulong size) { - return calloc(n, size); -} +/* system implementation */ +extern void ·free(void* _, void *ptr); +extern void *·alloc(void* _, uint n, ulong size); +extern void *·calloc(void* _, uint n, ulong size); +extern void *·realloc(void* _, void *ptr, uint n, ulong size); // TODO(nnoll): Allow for nil iterfaces? static @@ -123,10 +119,10 @@ int str·cap(const string s); void str·clear(string *s); void str·grow(string *s, vlong delta); void str·fit(string *s); -void str·appendlen(string *s, vlong len, const byte *b); -void str·append(string *s, const byte* b); -void str·appendf(string *s, const byte* fmt, ...); -void str·appendbyte(string *s, const byte b); +int str·appendlen(string *s, vlong len, const byte *b); +int str·append(string *s, const byte* b); +int str·appendf(string *s, const byte* fmt, ...); +int str·appendbyte(string *s, const byte b); bool str·equals(const string s, const string t); int str·find(string s, const byte* substr); void str·lower(string s); @@ -344,9 +340,9 @@ vlong gz·seek(gz·Stream *s, long off, enum SeekPos whence); // ----------------------------------------------------------------------------- // error handling functions -void errorf(const byte* fmt, ...); - -#define panicf(...) (errorf(__VA_ARGS__), assert(0)) +void errorf(byte* fmt, ...); +void panicf(byte *fmt, ...); +void vpanicf(byte *fmt, va_list args); // ----------------------------------------------------------------------------- // sorting diff --git a/include/u.h b/include/u.h index 3fafcda..8044fe0 100644 --- a/include/u.h +++ b/include/u.h @@ -51,6 +51,7 @@ typedef int error; #endif #define arrlen(Array) (sizeof(Array) / sizeof((Array)[0])) +#define arrend(Array) ((Array) + arrlen(Array)) #define MAX(x, y) ((x) >= (y) ? (x) : (y)) #define MIN(x, y) ((x) < (y) ? (x) : (y)) -- cgit v1.2.1