#include #include #include #include #include #include #include #include #include #define unused(a) ((void)a) #define cast(t, p) ((t)(p)) #define casts(p) cast(short,p) #define casti(p) cast(int,p) #define castu(p) cast(unsigned,p) #define castl(p) cast(long,p) #define castll(p) cast(long long,p) #define castull(p) cast(unsigned long long,p) #define castf(p) cast(float,p) #define castd(p) cast(double,p) #define castsz(p) cast(size_t,p) #define castss(p) cast(ssize_t,p) #define recast(T,p) ((T)cast(void*,(p))) #define szof(a) ((int)sizeof(a)) #define cntof(a) ((int)(sizeof(a) / sizeof((a)[0]))) #define flag(n) ((1u) << (n)) #define min(a,b) ((a) < (b) ? (a) : (b)) #define max(a,b) ((a) > (b) ? (a) : (b)) #define clamp(a, v, b) (max(min(b, v), a)) #define iswap(x,y) do {((x) ^= (y), (y) ^= (x), (x) ^= (y)); }while(0) #define align_mask(a) ((a)-1) #define align_down_masked(n, m) ((n) & ~(m)) #define align_down(n, a) align_down_masked(n, align_mask(a)) #define align_up(n, a) align_down((n) + align_mask(a), (a)) #define xglue(x, y) x##y #define glue(x, y) xglue(x, y) #define uniqid(name) glue(name, __LINE__) static void die(const char *fmt, ...) { va_list args; va_start(args, fmt); vfprintf(stderr, fmt, args); fprintf(stderr, "\n"); va_end(args); exit(1); } /* --------------------------------------------------------------------------- * Platform * --------------------------------------------------------------------------- */ #ifdef _MSC_VER #define alignto(x) __declspec(align(x)) #define bit_cnt(u) __popcnt(u) #define bit_cnt64(u) __popcnt64(u) static int bit_ffs32(unsigned int u) {_BitScanForward(&u, u); return casti(u);} static int bit_ffs64(unsigned long long u) {_BitScanForward64(&u, u); return casti(u);} #else /* GCC, CLANG */ #define alignto(x) __attribute__((aligned(x))) #define bit_cnt(u) __builtin_popcount(u) #define bit_cnt64(u) __builtin_popcountll(u) #define bit_ffs32(u) __builtin_ctz(u) #define bit_ffs64(u) __builtin_ctzll(u) #endif #ifdef _WIN32 /* Windows */ #include "ntsecapi.h" static unsigned long long sys_rnd64(void) { unsigned long long rnd = 0; if (!RtlGenRandom(&rnd, sizeof(rnd))) { fprintf(stderr, "failed to generate system random number\n"); exit(1); } return rnd; } #else /* UNIX */ #include #include static unsigned long long sys_rnd64(void) { ssize_t res; unsigned long long rnd = 0; int fp = open("/dev/urandom", O_RDONLY); if (fp == -1) { fprintf(stderr, "failed to access system random number\n"); exit(1); } res = read(fp, cast(char*, &rnd), sizeof(rnd)); if (res < szof(rnd)) { fprintf(stderr, "failed to generate system random number\n"); exit(1); } close(fp); return rnd; } #endif #ifdef __x86_64__ /* SSE */ #define SSE_ALIGN_BYTES 16 #define SSE_ALIGN alignto(SSE_ALIGN_BYTES) #include static const char* str_chr(const char *s, const char *e, int chr) { static const char unsigned ovr_msk[32] = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }; const char *e = s + n; __m128i m = _mm_set1_epi8(chr & 0xff); for (; s < e; s += 16) { int r = (int)(e - s); r = r > 16 ? 16 : r; __m128i o = _mm_loadu_si128((const __m128i *)(ovr_msk + 16 - r)); __m128i d = _mm_loadu_si128((const __m128i *)(const void*)s); __m128i v = _mm_and_si128(d, o); unsigned msk = _mm_movemask_epi8(_mm_cmpeq_epi8(v,m)); if (msk) { return s + (31 - __builtin_clz(msk)); } } return e; } static int line_cnt(const char *s, int n) { static const char unsigned ovr_msk[32] = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }; int cnt = 0; const char *e = s + n; __m128i m = _mm_set1_epi8('\n'); for (; s < e; s += 16) { int r = casti(e - s); int l = r > 16 ? 16 : r; __m128i o = _mm_loadu_si128((const __m128i *)(ovr_msk + 16 - l)); __m128i d = _mm_loadu_si128((const __m128i *)(const void*)s); __m128i v = _mm_and_si128(d, o); unsigned msk = _mm_movemask_epi8(_mm_cmpeq_epi8(v,m)); cnt += bit_cnt(msk); } return cnt; } #elif defined(__arm__) || defined(__aarch64__) /* ARM NEON */ #include #define SSE_ALIGN_BYTES 16 #define SSE_ALIGN alignto(SSE_ALIGN_BYTES) static const char* str_chr(const char *s, const char *e, int chr) { static const char unsigned ovr_msk[32] = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }; uint8x16_t m = vdupq_n_u8(chr & 0xff); for (; s < e; s += 16) { int r = (int)(e - s); r = r > 16 ? 16 : r; uint8x16_t o = vld1q_u8(ovr_msk + 16 - r); uint8x16_t d = vld1q_u8((const unsigned char*)s); uint8x16_t v = vandq_u8(d, o); uint8x16_t c = vceqq_u8(v, m); uint64x2_t p = vreinterpretq_u64_u8(c); uint64_t vlo = vgetq_lane_u64(p, 0); if (vlo) { return s + ((bit_ffs64(vlo)) >> 3); } uint64_t vhi = vgetq_lane_u64(p, 1); if (vhi) { return s + 8 + ((bit_ffs64(vhi)) >> 3); } } return e; } static int line_cnt(const char *s, int n) { static const char unsigned ovr_msk[32] = { 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, }; int cnt = 0; const char *e = s + n; uint8x16_t m = vdupq_n_u8('\n'); for (; s < e; s += 16) { int r = casti(e - s); int l = r > 16 ? 16 : r; uint8x16_t o = vld1q_u8(ovr_msk + 16 - l); uint8x16_t d = vld1q_u8((const unsigned char*)s); uint8x16_t v = vandq_u8(d, o); uint8x16_t c = vceqq_u8(v, m); uint64x2_t p = vreinterpretq_u64_u8(c); cnt += bit_cnt64(vgetq_lane_u64(p, 0)) >> 3; cnt += bit_cnt64(vgetq_lane_u64(p, 1)) >> 3; } return cnt; } #else /* standard c */ static const char* str_chr(const char *s, const char *e, int c) { unsigned m = 0x01010101u * (unsigned)(c & 0xff); for (;s < e && ((uintptr_t)s & 0x03); ++s) { if (s[0] == c) { return s; } } for (;s < end; s += 4) { unsigned v = *(unsigned*)s; unsigned k = (~v) & 0x80808080; unsigned x = ((v ^ m) - 0x01010101) & k; if (x) return s + (bit_ffs32(x) >> 3u); } } static int line_cnt(const char *str, int len) { int cnt = 1; for (int i = 0; i < len; ++i) { if (str[i] == '\n') { cnt++; } } return cnt; } #endif /* --------------------------------------------------------------------------- * String * --------------------------------------------------------------------------- */ struct str { const char *str; const char *end; int len; }; #define str_rhs(s, n) str_sub(s, min((s).len, n), (s).len) #define str_lhs(s, n) str_sub(s, 0, min((s).len, n)) #define str_cut_lhs(s, n) *(s) = str_rhs(*(s), n) #define str_cut_rhs(s, n) *(s) = str_lhs(*(s), n) #define for_str_tok(it, rest, src, delim) \ for ((rest) = (src), (it) = str_split_cut(&(rest), (delim)); \ (it).len; (it) = str_split_cut(&(rest), (delim))) static struct str str(const char *p, int len) { struct str s = {0}; s.str = p; s.end = p + len; s.len = len; return s; } static struct str str_sub(struct str s, int from, int to) { int b = min(from, to); int e = max(from, to); struct str r = {0}; r.str = s.str + min(b, s.len); r.end = s.str + min(e, s.len); r.len = casti(r.end - r.str); return r; } static struct str str_split_cut(struct str *s, int delim) { const char *at = str_chr(s->str, s->end, delim); if (at < s->end) { int p = casti(at - s->str); struct str res = str_lhs(*s, p); str_cut_lhs(s, p + 1); return res; } else { struct str res = *s; memset(s, 0, sizeof(*s)); return res; } } /* --------------------------------------------------------------------------- * Command Arguments * --------------------------------------------------------------------------- */ #define CMD_ARGC() argc_ #define cmd_arg_opt_str(argv, x) ((argv[0][1] == '\0' && argv[1] == 0)?\ ((x), (char *)0) : (brk_ = 1, (argv[0][1] != '\0') ?\ (&argv[0][1]) : (argc--, argv++, argv[0]))) #define cmd_arg_opt_int(argv,x) cmd_arg_int(cmd_arg_opt_str(argv,x)) #define cmd_arg_opt_flt(argv,x) cmd_arg_flt(cmd_arg_opt_str(argv,x)) #define CMD_ARG_BEGIN(argv0, argc, argv) \ for (argv0 = *argv, argv++, argc--; argv[0] && argv[0][1] && argv[0][0] == '-'; argc--, argv++) {\ char argc_, **argv_; int brk_;\ if (argv[0][1] == '-' && argv[0][2] == '\0') {argv++; argc--; break;}\ for (brk_ = 0, argv[0]++, argv_ = argv; argv[0][0] && !brk_; argv[0]++) {\ if (argv_ != argv) break;\ argc_ = argv[0][0];\ switch (argc_) #define CMD_ARG_END }} static int cmd_arg_int(const char *str) { char *ep = 0; long n = strtol(str, &ep, 10); if (*ep != '\0' || ep == str) { die("Invalid argument number: %s\n", str); } if (n < INT_MIN || n > INT_MAX) { die("Argument number: %ld is out of range\n", n); } return casti(n); } static float cmd_arg_flt(const char *str) { char *ep = 0; float n = strtof(str, &ep); if (*ep != '\0' || ep == str) { die("Invalid argument number: %s\n", str); } return n; } /* --------------------------------------------------------------------------- * Utility * --------------------------------------------------------------------------- */ #define swap(x,y) do { \ unsigned char uniqid(t)[szof(x) == szof(y) ? szof(x) : -1]; \ memcpy(uniqid(t),&y,sizeof(x)); \ memcpy(&y,&x,sizeof(x)); \ memcpy(&x,uniqid(t),sizeof(x)); \ } while(0) #define arr_shfl(a,n,p) do { \ for (int uniqid(i) = 0; uniqid(i) < n; ++uniqid(i)) { \ if (p[uniqid(i)] >= 0) { \ int uniqid(j) = uniqid(i); \ while (p[uniqid(j)] != uniqid(i)) { \ const int uniqid(d) = p[uniqid(j)]; \ swap(a[uniqid(j)], a[uniqid(d)]); \ p[uniqid(j)] = -1 - uniqid(d); \ uniqid(j) = uniqid(d); \ } p[uniqid(j)] = -1 - p[uniqid(j)]; \ } \ }} while (0) static void* xalloc(int siz) { void *mem = calloc(castsz(siz), 1); if (!mem) { die("Out of Memory"); } return mem; } static unsigned long long rnd_gen(unsigned long long x, int n) { return x + castull(n) * 0x9E3779B97F4A7C15llu; } static unsigned long long rnd_mix(unsigned long long z) { z = (z ^ (z >> 30)) * 0xBF58476D1CE4E5B9llu; z = (z ^ (z >> 27)) * 0x94D049BB133111EBllu; return z ^ (z >> 31llu); } static unsigned long long rnd_split_mix(unsigned long long *x, int i) { *x = rnd_gen(*x, i); return rnd_mix(*x); } static unsigned long long rnd(unsigned long long *x) { return rnd_split_mix(x, 1); } static unsigned rndu(unsigned long long *x) { unsigned long long z = rnd(x); return castu(z & 0xffffffffu); } static unsigned rnduu(unsigned long long *x, unsigned mini, unsigned maxi) { unsigned lo = min(mini, maxi); unsigned hi = max(mini, maxi); unsigned rng = castu(-1); unsigned n = hi - lo + 1u; if (n == 1u) { return mini; } else if(n == 0u) { return rndu(x); } else { unsigned v = 0; unsigned remainder = rng % n; do {v = rndu(x);} while(v >= rng - remainder); return mini + v % n; } } static float rndf01(unsigned long long *x) { unsigned u = rndu(x); double du = castd(u); double div = castd((unsigned)-1); return castf(du/div); } static float rnduf(unsigned long long *x, float mini, float maxi) { float lo = min(mini, maxi); float hi = max(mini, maxi); unsigned u = rndu(x); float rng = hi - lo; double du = castd(u); double div = castd((unsigned)-1); return lo + rng * castf(du/div); } static void seq_lin(int *seq, int cnt) { int i = 0; for (i = 0; i < cnt; ++i) { seq[i] = i; } } static void seq_rnd(int *seq, int n, unsigned long long *r) { int i = 0; for (i = n - 1; i > 0; --i) { unsigned at = rndu(r) % castu(i + 1); iswap(seq[i], seq[at]); } } static char* file_load(int *siz, const char *path) { FILE *fd = 0; char *mem = 0; size_t ret = 0; /* open file */ assert(path); fd = fopen(path, "r"); if (!fd) { die("Unable to open file: %s\n", path); } /* calculate file size */ fseek(fd, 0, SEEK_END); *siz = casti(ftell(fd)); if (*siz < 0) { die("Unable to access file: %s\n", path); } if (*siz == 0) { die("File is empty: %s\n", path); } fseek(fd, 0, SEEK_SET); /* read file into memory */ mem = xalloc(*siz + SSE_ALIGN_BYTES); assert(mem); if (!mem) { die("file calloc failed"); } ret = fread(mem, 1, castsz(*siz), fd); if (ret < castsz(*siz)) { die("file fread failed: %d\n", ret); } fclose(fd); return mem; } /* ---------------------------------------------------------------------------== * * * Neural Network * * * --------------------------------------------------------------------------- */ struct nn_ctx { /* in */ void *mem; int in_cnt; int out_cnt; int hid_cnt; /* out */ int n_in_cnt; int n_out_cnt; int n_hid_cnt; int w_in_hid_cnt; int w_hid_out_cnt; float *n_in; float *n_hid; float *n_out; int *outs; float *w_in_hid; float *w_hid_out; /* intern */ unsigned setup:1; }; static int nn_req_net_siz(struct nn_ctx *ctx) { int n_in_siz = szof(float) * (ctx->in_cnt + 1); int n_hid_siz = szof(float) * (ctx->hid_cnt + 1); int n_out_siz = szof(float) * ctx->out_cnt; int out_siz = szof(int) * ctx->out_cnt; int wih_siz = szof(float) * ((ctx->in_cnt + 1) * (ctx->hid_cnt + 1)); int woh_siz = szof(float) * ((ctx->hid_cnt + 1) * ctx->out_cnt); return n_in_siz + n_hid_siz + n_out_siz + out_siz + wih_siz + woh_siz; } static void nn__setup_net(struct nn_ctx *ctx) { ctx->setup = 1; ctx->n_in_cnt = ctx->in_cnt + 1; ctx->n_out_cnt = ctx->out_cnt; ctx->n_hid_cnt = ctx->hid_cnt + 1; ctx->w_in_hid_cnt = (ctx->n_in_cnt * ctx->n_hid_cnt); ctx->w_hid_out_cnt = (ctx->n_out_cnt * ctx->n_hid_cnt); /* setup memory */ ctx->n_in = ctx->mem; ctx->n_hid = ctx->n_in + ctx->n_in_cnt; ctx->n_out = ctx->n_hid + ctx->n_hid_cnt; ctx->w_in_hid = ctx->n_out + ctx->n_out_cnt; ctx->w_hid_out = ctx->w_in_hid + ctx->w_in_hid_cnt; ctx->outs = cast(int*, ctx->w_hid_out + ctx->w_hid_out_cnt); ctx->n_in[ctx->in_cnt] = -1.0; ctx->n_hid[ctx->hid_cnt] = -1.0; } static int nn__in_hide_weight_idx(const struct nn_ctx *ctx, int in_idx, int hid_idx) { return in_idx * ctx->hid_cnt + hid_idx; } static int nn__hide_out_weight_idx(const struct nn_ctx *ctx, int hid_idx, int out_idx) { return hid_idx * ctx->out_cnt + out_idx; } static void nn_init(struct nn_ctx *ctx) { int ii, hi, oi; unsigned long long rnd_gen = sys_rnd64(); float dist = (2.4f / ctx->in_cnt); if (!ctx->setup) { nn__setup_net(ctx); } for (ii = 0; ii <= ctx->n_in_cnt; ++ii) { for (hi = 0; hi <= ctx->hid_cnt; ++hi) { int wi = nn__in_hide_weight_idx(ctx, ii, hi); ctx->w_in_hid[wi] = (rndf01(&rnd_gen) * 2.0f - 1.0f) * dist; } } for (hi = 0; hi <= ctx->n_hid_cnt; ++hi) { for (oi = 0; oi < ctx->out_cnt; ++oi) { int wi = nn__hide_out_weight_idx(ctx, hi, oi); ctx->w_hid_out[wi] = (rndf01(&rnd_gen) * 2.0f - 1.0f) * dist; } } } static float nn__sigmoid(float x) { return 1.0f / ( 1.0f + expf(-x)); } static int nn__clamp(float x) { if (x < 0.1f) return 0; else if (x > 0.9f) return 1; else return -1.0f; } static int* nn_eval(struct nn_ctx *ctx, const float *in) { int oi, ii, hi = 0; memcpy(ctx->n_in, in, sizeof(float) * castsz(ctx->in_cnt)); for (hi = 0; hi < ctx->hid_cnt; ++hi) { /* calculate weightes sum of pattern and bias neuron */ ctx->n_hid[hi] = 0; for (ii = 0; ii <= ctx->in_cnt; ++ii) { int wi = nn__in_hide_weight_idx(ctx, ii, hi); ctx->n_hid[hi] += ctx->n_in[ii] * ctx->w_in_hid[wi]; } ctx->n_hid[hi] = nn__sigmoid(ctx->n_hid[hi]); } for (oi = 0; oi < ctx->out_cnt; ++oi) { /* calcule output values - include bias neuron */ ctx->n_out[oi] = 0; for (hi = 0; hi <= ctx->hid_cnt; ++hi) { int wi = nn__hide_out_weight_idx(ctx, hi, oi); ctx->n_out[oi] += ctx->n_hid[hi] * ctx->w_hid_out[wi]; } ctx->n_out[oi] = nn__sigmoid(ctx->n_out[oi]); ctx->outs[oi] = nn__clamp(ctx->n_out[oi]); } return ctx->outs; } /* --------------------------------------------------------------------------- * Data * --------------------------------------------------------------------------- */ struct nn_train_elm { float *in; int *out; }; struct nn_train_set { int begin; int end; int cnt; }; struct nn_train_data { /* in */ void *mem; int in_cnt; int hid_cnt; int out_cnt; int elm_cnt; /* out */ struct nn_train_elm *elms; struct nn_train_set train; struct nn_train_set gen; struct nn_train_set val; int *seq; /* intern */ unsigned setup:1; }; static int nn_req_train_data_siz(const struct nn_train_data *trn) { int es = szof(struct nn_train_elm) * trn->elm_cnt; int ins = szof(float) * trn->in_cnt * trn->elm_cnt; int outs = szof(int) * trn->out_cnt * trn->elm_cnt; int shfl = szof(int) * trn->elm_cnt; return es + ins + shfl + outs; } static void nn__setup_train_data(struct nn_train_data *ctx) { int i; float *ins; int *outs; ctx->elms = recast(struct nn_train_elm*, ctx->mem); ctx->seq = recast(int*, ctx->elms + ctx->elm_cnt); ins = recast(float*, (ctx->seq + ctx->elm_cnt)); outs = recast(int*, (ins + ctx->in_cnt * ctx->elm_cnt)); for (i = 0; i < ctx->elm_cnt; ++i, ins += ctx->in_cnt, outs += ctx->out_cnt) { ctx->elms[i].in = ins; ctx->elms[i].out = outs; } ctx->setup = 1; } static void nn__split_train_data_set(struct nn_train_data *dat, int cnt) { float elm_cnt = castf(cnt); int train_cnt = casti(0.6f * elm_cnt); int gen_cnt = casti(ceil(0.2f * elm_cnt)); dat->train.begin = 0; dat->train.end = train_cnt; dat->train.cnt = train_cnt; dat->gen.begin = train_cnt; dat->gen.end = dat->gen.begin + gen_cnt; dat->gen.cnt = gen_cnt; dat->val.begin = dat->gen.end; dat->val.end = cnt; dat->val.cnt = cnt - dat->val.begin; } static void nn__load_train_data_elm(struct nn_train_elm *elm, const struct str *tok, int in_cnt, int out_cnt) { int i = 0; struct str it, _; for_str_tok(it, _, *tok, ',') { char *ep = 0; float f = 0.0; if (i >= in_cnt + out_cnt) { break; } *(char*)it.end = '\0'; f = strtof(it.str, &ep); *(char*)it.end = ','; if (i < in_cnt) { elm->in[i] = f; } else { elm->out[i - in_cnt] = casti(f); } i++; } } static void nn_train_data_load(struct nn_train_data* dat, const char *file, int len) { int i = 0; unsigned long long rnd_gen = sys_rnd64(); struct str ln, _, in = str(file,len); if (!dat->setup) { nn__setup_train_data(dat); } for_str_tok(ln, _, in, '\n') { if (ln.len > 2) { struct nn_train_elm *elm = &dat->elms[i]; nn__load_train_data_elm(elm, &ln, dat->in_cnt, dat->out_cnt); i++; } } seq_rnd(dat->seq, dat->elm_cnt, &rnd_gen); arr_shfl(dat->elms, dat->elm_cnt, dat->seq); nn__split_train_data_set(dat, dat->elm_cnt); } /* --------------------------------------------------------------------------- * Trainer * --------------------------------------------------------------------------- */ struct nn_trn { /* in */ void *mem; /* in: settings */ unsigned use_batch:1; float learn_rate; float momentum; float tar_acc; float acc; /* out */ float trn_acc; float trn_mse; float gen_acc; float gen_mse; float val_acc; float val_mse; unsigned done:1; /* intern */ unsigned setup:1; float *dt_in_hid; float *dt_hid_out; float *err_hid; float *err_out; }; static int nn_req_trn_siz(struct nn_ctx *ctx) { int ih = szof(float) * ctx->w_in_hid_cnt; int oh = szof(float) * ctx->w_hid_out_cnt; int nh = szof(float) * ctx->n_hid_cnt; int no = szof(float) * ctx->n_out_cnt; return ih + oh + nh + no; } static void nn_trn_init(struct nn_trn *trn, struct nn_ctx *ctx) { trn->setup = 1; trn->learn_rate = (trn->learn_rate == 0.0f) ? 0.001f : trn->learn_rate; trn->momentum = (trn->momentum == 0.0f) ? 0.9f : trn->momentum; trn->acc = (trn->acc == 0.0f) ? 90.0f : trn->acc; trn->dt_in_hid = trn->mem; trn->dt_hid_out = trn->dt_in_hid + szof(float) * ctx->w_in_hid_cnt; trn->err_hid = trn->dt_hid_out + szof(float) * ctx->w_hid_out_cnt; trn->err_out = trn->err_hid + szof(float) * ctx->n_hid_cnt; } static void nn__trn_update_weights(struct nn_trn *trn, struct nn_ctx *ctx) { int ii, hi, oi; /* input -> hidden weights */ for (ii = 0; ii <= ctx->in_cnt; ++ii) { for (hi = 0; hi <= ctx->hid_cnt; ++hi) { int w_idx = nn__in_hide_weight_idx(ctx, ii, hi); ctx->w_in_hid[w_idx] += trn->dt_in_hid[w_idx]; if (trn->use_batch) { trn->dt_in_hid[w_idx] = 0; } } } /* output -> hidden weights */ for (hi = 0; hi <= ctx->hid_cnt; ++hi) { for (oi = 0; oi < ctx->out_cnt; ++oi) { int wi = nn__hide_out_weight_idx(ctx, hi, oi); ctx->w_hid_out[wi] += trn->dt_hid_out[wi]; if (trn->use_batch) { trn->dt_hid_out[wi] = 0; } } } } static float nn__trn_out_err_gradient(float tar_val, float out_val) { return out_val * (1.0f - out_val) * (tar_val - out_val); } static float nn__trn_hid_err_gradient(struct nn_trn *trn, struct nn_ctx *ctx, int hi) { int oi = 0; float w_sum = 0; for (oi = 0; oi < ctx->out_cnt; ++oi) { int w_idx = nn__hide_out_weight_idx(ctx, hi, oi); w_sum += ctx->w_hid_out[w_idx] * trn->err_out[oi]; } return ctx->n_hid[hi] * (1.0f - ctx->n_hid[hi]) * w_sum; } static void nn__trn_backpropergate(struct nn_trn *trn, struct nn_ctx *ctx, const int *exp_out) { int ii = 0, oi = 0, hi = 0; /* modify deltas between hidden and output layers */ for (oi = 0; oi < ctx->out_cnt; ++oi) { /* get error gradient for output node */ trn->err_out[oi] = nn__trn_out_err_gradient(castf(exp_out[oi]), ctx->n_out[oi]); /* for all nodes in hidden layer and bias neuron calculate change in weight */ for (hi = 0; hi <= ctx->hid_cnt; ++hi) { int wi = nn__hide_out_weight_idx(ctx, hi, oi); if (trn->use_batch) { trn->dt_hid_out[wi] += trn->learn_rate * ctx->n_hid[hi] * trn->err_out[oi]; } else { trn->dt_hid_out[wi] = trn->learn_rate * ctx->n_hid[hi] * trn->err_out[oi] + trn->momentum * trn->dt_hid_out[wi]; } } } /* modify deltas between input and hidden layers */ for (hi = 0; hi <= ctx->hid_cnt; ++hi) { trn->err_hid[hi] = nn__trn_hid_err_gradient(trn, ctx, hi); for (ii = 0; ii <= ctx->in_cnt; ++ii) { int wi = nn__in_hide_weight_idx(ctx, ii, hi); if (trn->use_batch) { trn->dt_in_hid[wi] += trn->learn_rate * ctx->n_in[ii] * trn->err_hid[hi]; } else { trn->dt_in_hid[wi] = trn->learn_rate * ctx->n_in[ii] * trn->err_hid[hi] + trn->momentum * trn->dt_in_hid[wi]; } } } /* If using stochastic learning update the weights immediately */ if (!trn->use_batch) { nn__trn_update_weights(trn, ctx); } } static void nn__trn_tst(float *mse, float *acc, struct nn_trn *trn, struct nn_ctx *ctx, const struct nn_train_data *dat, const struct nn_train_set *set) { int oi, i = 0; float false_cnt = 0; *mse = *acc = 0; for (i = set->begin; i < set->end; ++i) { int is_correct = 1; struct nn_train_elm *elm = dat->elms + i; nn_eval(ctx, elm->in); for (oi = 0; oi < ctx->out_cnt; ++oi) { if (ctx->outs[oi] != elm->out[oi]) { is_correct = 0; } *mse += castf(pow((ctx->n_out[oi] - elm->out[oi]), 2)); } if (!is_correct) { false_cnt += 1.0f; } } *acc = 100.0f - (false_cnt / castf(set->cnt)) * 100.0f; *mse = *mse / (ctx->out_cnt * castf(set->cnt)); } static int nn_trn(struct nn_trn *trn, struct nn_ctx *ctx, struct nn_train_data *dat) { int i = 0, oi = 0; float mse = 0; float incorrect_entries = 0; if (trn->trn_acc >= trn->tar_acc && trn->gen_acc >= trn->tar_acc) { trn->done = 1; return 0; } for (i = dat->train.begin; i < dat->train.end; ++i) { int is_correct = 1; struct nn_train_elm *elm = dat->elms + i; nn_eval(ctx, elm->in); nn__trn_backpropergate(trn, ctx, elm->out); for (oi = 0; oi < ctx->out_cnt; ++oi) { if (ctx->outs[oi] != elm->out[oi]) { is_correct = 0; } mse += castf(pow((ctx->n_out[oi] - elm->out[oi]), 2)); } if (!is_correct) { incorrect_entries += 1.0f; } } if (trn->use_batch) { nn__trn_update_weights(trn, ctx); } trn->trn_acc = 100.0f - (incorrect_entries / castf(dat->train.cnt)) * 100.0f; trn->trn_mse = mse / (ctx->out_cnt * castf(dat->train.cnt)); nn__trn_tst(&trn->gen_mse, &trn->gen_acc, trn, ctx, dat, &dat->gen); return 1; } static void nn_trn_tst(struct nn_trn *trn, struct nn_ctx *ctx, struct nn_train_data *dat) { nn__trn_tst(&trn->val_mse, &trn->val_acc, trn, ctx, dat, &dat->val); } /* --------------------------------------------------------------------------- * * * App * * * --------------------------------------------------------------------------- */ static void usage(const char *app) { die("\n" "usage: %s [options] data in hidden out\n" "\n" " arguments:\n" "\n" " data, Path to training data file\n" " in, Number of input neurons\n" " hidden, Number of hidden neurons\n" " out, Number of output neurons\n" "\n" " options:\n" " -n , epoch count (150))\n" " -l , Learning rate (0.001)\n" " -a , Desired Accuracy (90)\n" " -m , Momentum (0.9)\n" " -o , Code out file path\n" " -b enable batch learning\n" " -h help message\n" "\n", app ); exit(1); } static void write_result(FILE *fp, const struct nn_ctx *ctx) { int hi, ii, oi, n = 0; static const char code[] = " for (hi = 0; hi < hid_cnt; ++hi) {\n" " float32x4_t sum = vdupq_n_f32(0.0f);\n" " for (ii = 0; ii <= in_cnt; i += 4) {\n" " float32x4_t in = vld1q_f32(n_in + ii);\n" " float32x4_t w = vld1q_f32(w_ih + hi * in_cnt + ii);\n" " float32x4_t m = vmulq_f32(in, w);\n" " sum = vaddq_f32(sum, m);\n" " }\n" " sum = vpaddq_f32(sum,sum);\n" " sum = vpaddq_f32(sum,sum);\n" " n_hid[hi] = vgetq_lane_f32(sum,0);\n" " n_hid[hi] = 1.0f / (1.0f + expf(-n_hid[hi]));\n" " }\n" " for (oi = 0; oi < out_cnt; ++oi) {\n" " float32x4_t sum = vdupq_n_f32(0.0f);\n" " for (hi = 0; hi <= hid_cnt; hi += 4) {\n" " float32x4_t in = vld1q_f32(n_hid + hi);\n" " float32x4_t w = vld1q_f32(w_ho + oi * hid_cnt + hi);\n" " float32x4_t m = vmulq_f32(in, w);\n" " sum = vaddq_f32(sum, m);\n" " }\n" " sum = vpaddq_f32(sum,sum);\n" " sum = vpaddq_f32(sum,sum);\n" " n_out[oi] = vgetq_lane_f32(sum,0);\n" " n_out[oi] = 1.0f / (1.0f + expf(-n_out[oi]));\n" " if (n_out[oi] < 0.1f) {\n" " out[oi] = 0.0f;\n" " } else if (x > 0.9f) {\n" " out[oi] = 1.0f;\n" " } else {\n" " out[oi] = n_out[oi];\n" " }\n" " }\n"; fprintf(fp, "static void nn_eval(float *out, const float *in){\n"); fprintf(fp, " static const int in_cnt = %d;\n", ctx->in_cnt); fprintf(fp, " static const int hid_cnt = %d;\n", ctx->hid_cnt); fprintf(fp, " static const int out_cnt = %d;\n", ctx->out_cnt); fprintf(fp, " static const float w_ih[] = {\n"); fprintf(fp, " "); for (hi = 0; hi < ctx->hid_cnt; ++hi) { for (ii = 0; ii <= ctx->in_cnt; ++ii) { int wi = nn__in_hide_weight_idx(ctx, ii, hi); fprintf(fp, "%f,", ctx->w_in_hid[wi]); if ((++n % 7 == 0)) { fprintf(fp, "\n "); } } } n = 0; fprintf(fp, "\n };\n"); fprintf(fp, " static const float w_ho[] = {\n"); fprintf(fp, " "); for (oi = 0; oi < ctx->out_cnt; ++oi) { for (hi = 0; hi <= ctx->hid_cnt; ++hi) { int wi = nn__hide_out_weight_idx(ctx, hi, oi); fprintf(fp, "%f,", ctx->w_hid_out[wi]); if ((++n % 7 == 0)) { fprintf(fp, "\n "); } } } fprintf(fp, "\n };\n"); fprintf(fp, " float n_in[%d] = {0};\n",align_up(ctx->n_in_cnt, 4)); fprintf(fp, " float n_hid[%d] = {0};\n", align_up(ctx->n_hid_cnt, 4)); fprintf(fp, " float n_out[%d] = {0};\n",ctx->out_cnt); fprintf(fp, " int hi, ii, oi;\n\n"); fprintf(fp, " memcpy(n_in, in, (size_t)in_cnt * sizeof(float));\n"); fprintf(fp, " n_in[%d] = n_hid[%d] = -1;\n", ctx->in_cnt, ctx->hid_cnt); fprintf(fp, "%s", code); fprintf(fp, "}\n"); } extern int main(int argc, char **argv) { const char *app = 0; const char *out_file = 0; int epochs = 150, epoch = 0; struct nn_train_data dat = {0}; struct nn_ctx ctx = {0}; struct nn_trn trn = {0}; trn.use_batch = 0; trn.tar_acc = 90.0f; trn.momentum = 0.9f; trn.learn_rate = 0.001f; /* Command Arguments */ CMD_ARG_BEGIN(app, argc, argv){ case 'h': default: usage(app); break; case 'b': trn.use_batch = 1; break; case 'n': epochs = cmd_arg_opt_int(argv,usage(app)); break; case 'l': trn.learn_rate = cmd_arg_opt_flt(argv,usage(app)); break; case 'm': trn.momentum = cmd_arg_opt_flt(argv,usage(app)); break; case 'o': out_file = cmd_arg_opt_str(argv,usage(app)); break; case 'a': { float v = cmd_arg_opt_flt(argv,usage(app)); trn.tar_acc = clamp(0.0f, v, 100.0f); } break; } CMD_ARG_END; if (argc < 4) { usage(app); } { /* Training Data */ int file_siz = 0; char *file = file_load(&file_siz, argv[0]); dat.in_cnt = cmd_arg_int(argv[1]); dat.hid_cnt = cmd_arg_int(argv[2]); dat.out_cnt = cmd_arg_int(argv[3]); dat.elm_cnt = line_cnt(file, file_siz); dat.mem = xalloc(nn_req_train_data_siz(&dat)); nn_train_data_load(&dat, file, file_siz); free(file); } /* Network */ ctx.in_cnt = dat.in_cnt; ctx.out_cnt = dat.out_cnt; ctx.hid_cnt = dat.hid_cnt; ctx.mem = xalloc(nn_req_net_siz(&ctx)); nn_init(&ctx); /* Trainer */ trn.mem = xalloc(nn_req_trn_siz(&ctx)); nn_trn_init(&trn, &ctx); printf("\nNeural Network Training Starting:\n"); printf("=================================================================\n"); printf("\tLearning Rate: %f\n", trn.learn_rate); printf("\tMomentum: %f\n", trn.momentum); printf("\tEpochs: %d\n", 150); printf("\tInput Neurons: %d\n", ctx.in_cnt); printf("\tHidden Neurons: %d\n", ctx.hid_cnt); printf("\tOutput Neurons: %d\n", ctx.out_cnt); printf("=================================================================\n"); for (epoch = 0; epoch < epochs; ++epoch) { nn_trn(&trn, &ctx, &dat); if (trn.done) break; printf("\tEpoch: %d Train Accuracy: %f Train MSE: %f General Accuracy: %f General MSE: %f\n", epoch, trn.trn_acc, trn.trn_mse, trn.gen_acc, trn.gen_mse); } nn_trn_tst(&trn, &ctx, &dat); printf("\nNeural Network Training Complete!!!\n"); printf("\tEllapsed Epochs: %d\n", epoch); printf("\tAccuracy: %f\n", trn.val_acc); printf("\tMSE: %f\n\n", trn.val_mse); if (out_file) { FILE *fp = fopen(out_file, "w"); if (fp) { write_result(fp, &ctx); fclose(fp); } else { printf("[error] failed to open output file!\n"); } } /* Cleanup */ free(trn.mem); free(ctx.mem); free(dat.mem); return 0; }