Skip to content

Instantly share code, notes, and snippets.

@clausecker
Created August 18, 2020 23:51
Show Gist options
  • Select an option

  • Save clausecker/c6df8b6e10593b2b7034185061ca1b90 to your computer and use it in GitHub Desktop.

Select an option

Save clausecker/c6df8b6e10593b2b7034185061ca1b90 to your computer and use it in GitHub Desktop.

Revisions

  1. clausecker created this gist Aug 18, 2020.
    230 changes: 230 additions & 0 deletions count8asm15.s
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,230 @@
    // b:a = a+b+c, v31.16b used for scratch space
    .macro csa, a, b, c
    eor v31.16b, \a\().16b, \b\().16b
    eor \a\().16b, v31.16b, \c\().16b
    bit \b\().16b, \c\().16b, v31.16b
    .endm

    // d:a = a+b+c
    .macro csac a, b, c, d
    eor \d\().16b, \a\().16b, \b\().16b
    eor \a\().16b, \d\().16b, \c\().16b
    bsl \d\().16b, \c\().16b, \b\().16b
    .endm

    .type count8asm15, @function
    .globl count8asm15
    // X0: counts
    // X1: buf
    // X2: len
    count8asm15:
    ld1 {V16.2D-V19.2D}, [X0] // load counts into V16-19

    ldr d29, .Lmask // bit mask into v29.8b
    // ldr d29, =0x8040201008040201 // (unfortunately unsupported by LLVM)
    movi v30.8b, #0 // scalar counter vector v30.8b

    cmp x2, #16 // enough data for at least one vector iteration?
    blt 2f // if not, go straight to scalar tail

    // scalar head to reach 16 byte alignment
    and x3, x2, #15 // how far we are off the alignment
    cbz x3, 3f // skip scalar head if already aligned
    sub x2, x2, x3 // apply alignment to x2
    tst x3, #1 // unroll loop once (duff style)
    bne 1f

    // scalar loop: process one byte at a time
    0: ld1r {v2.8b}, [x1], #1 // load same byte to each elem of v2
    cmtst v2.8b, v29.8b, v2.8b // set counter bytes to 0 or -1 according to [x0] bits
    sub v30.8b, v30.8b, v2.8b // and increment counters

    1: ld1r {v2.8b}, [x1], #1 // load same byte to each elem of v2
    cmtst v2.8b, v29.8b, v2.8b // set counter bytes to 0 or -1 according to [x0] bits
    sub v30.8b, v30.8b, v2.8b // and increment counters

    subs x3, x3, #2
    bgt 0b

    3: movi v20.16b, #0x11 // bit masks for getting out the values
    movi v21.16b, #0x22
    movi v22.16b, #0x44
    movi v23.16b, #0x88
    movi v24.16b, #0x0f

    cmp x2, #15*16 // enough data to process 240 bytes?
    blt 1f

    // 15-fold CSA reduction
    0: ld1 {v0.16b-v2.16b}, [x1], #3*16
    ld1 {v3.16b-v6.16b}, [x1], #4*16
    ld1 {v25.16b-v28.16b}, [x1], #4*16
    csa v0, v1, v2
    csac v0, v3, v4, v2
    csac v0, v5, v6, v3
    ld1 {v4.16b-v7.16b}, [x1], #4*16
    csa v1, v2, v3
    csa v0, v25, v26
    csa v0, v27, v28
    csac v1, v25, v27, v3
    csa v0, v4, v5
    csa v0, v6, v7
    csa v1, v4, v6
    csa v2, v3, v4

    // V3:V2:V1:V0 = SUM([x1,#0*16]...[x1,#14*16])
    // approach: first fold V3...V0 into four registers, each
    // acumulating two pairs of V3:V2:V1:V0 (one per nibble)
    // then, unravel the pairs and sum up

    // partial counts register: V4--V7
    // group counts into nibbles
    and v4.16b, v0.16b, v20.16b
    ushr v0.16b, v0.16b, #1
    shl v25.16b, v1.16b, #1
    bit v4.16b, v25.16b, v21.16b
    shl v25.16b, v2.16b, #2
    bit v4.16b, v25.16b, v22.16b
    shl v25.16b, v3.16b, #3
    bit v4.16b, v25.16b, v23.16b

    and v5.16b, v0.16b, v20.16b
    ushr v0.16b, v0.16b, #1
    bit v5.16b, v1.16b, v21.16b
    ushr v1.16b, v1.16b, #1
    shl v25.16b, v2.16b, #1
    bit v5.16b, v25.16b, v22.16b
    shl v25.16b, v3.16b, #2
    bit v5.16b, v25.16b, v23.16b

    and v6.16b, v0.16b, v20.16b
    ushr v0.16b, v0.16b, #1
    bit v6.16b, v1.16b, v21.16b
    ushr v1.16b, v1.16b, #1
    bit v6.16b, v2.16b, v22.16b
    ushr v2.16b, v2.16b, #1
    shl v25.16b, v3.16b, #1
    bit v6.16b, v25.16b, v23.16b

    and v7.16b, v0.16b, v20.16b
    bit v7.16b, v1.16b, v21.16b
    bit v7.16b, v2.16b, v22.16b
    bit v7.16b, v3.16b, v23.16b

    // extra nibbles and add horizontally
    and v0.16b, v4.16b, v24.16b
    ushr v4.16b, v4.16b, #4
    uaddlv h0, v0.16b
    uaddlv h4, v4.16b

    and v1.16b, v5.16b, v24.16b
    ushr v5.16b, v5.16b, #4
    uaddlv h1, v1.16b
    uaddlv h5, v5.16b

    and v2.16b, v6.16b, v24.16b
    ushr v6.16b, v6.16b, #4
    uaddlv h2, v2.16b
    uaddlv h6, v6.16b

    and v3.16b, v7.16b, v24.16b
    ushr v7.16b, v7.16b, #4
    uaddlv h3, v3.16b
    uaddlv h7, v7.16b

    // add sums to counters
    ins v0.d[1], v1.d[0]
    ins v2.d[1], v3.d[0]
    ins v4.d[1], v5.d[0]
    ins v6.d[1], v7.d[0]

    add v16.2d, v16.2d, v0.2d
    add v17.2d, v17.2d, v2.2d
    add v18.2d, v18.2d, v4.2d
    add v19.2d, v19.2d, v6.2d

    subs x2, x2, #15*16
    bgt 0b

    1: cmp x2, #16 // enough data to process 16 bytes?
    blt 2f

    // single vector
    0: ld1 {v0.16b}, [x1], #16 // load 16 byte

    and v1.16b, v0.16b, v20.16b
    ushr v0.16b, v0.16b, #1
    uaddlv h2, v1.16b
    and v1.16b, v0.16b, v20.16b
    ushr v0.16b, v0.16b, #1
    uaddlv h1, v1.16b
    ins v2.d[1], v1.d[0]
    add v16.2d, v16.2d, v2.2d

    and v1.16b, v0.16b, v20.16b
    ushr v0.16b, v0.16b, #1
    uaddlv h2, v1.16b
    and v1.16b, v0.16b, v20.16b
    ushr v0.16b, v0.16b, #1
    uaddlv h1, v1.16b
    ins v2.d[1], v1.d[0]
    add v17.2d, v17.2d, v2.2d

    and v1.16b, v0.16b, v20.16b
    ushr v0.16b, v0.16b, #1
    uaddlv h2, v1.16b
    and v1.16b, v0.16b, v20.16b
    ushr v0.16b, v0.16b, #1
    uaddlv h1, v1.16b
    ins v2.d[1], v1.d[0]
    add v18.2d, v18.2d, v2.2d

    and v1.16b, v0.16b, v20.16b
    ushr v0.16b, v0.16b, #1
    uaddlv h2, v1.16b
    and v1.16b, v0.16b, v20.16b
    uaddlv h1, v1.16b
    ins v2.d[1], v1.d[0]
    add v19.2d, v19.2d, v2.2d

    subs x2, x2, #16
    bgt 0b

    // scalar tail
    2: cbz x2, 1f // any bytes left to process?
    tst x2, #1 // unroll loop once (duff style)
    bne 2f

    // scalar loop: process one byte at a time
    0: ld1r {v2.8b}, [x1], #1 // load same byte to each elem of v2
    cmtst v2.8b, v29.8b, v2.8b // set counter bytes to 0 or -1 according to [x0] bits
    sub v30.8b, v30.8b, v2.8b // and increment counters

    2: ld1r {v2.8b}, [x1], #1 // load same byte to each elem of v2
    cmtst v2.8b, v29.8b, v2.8b // set counter bytes to 0 or -1 according to [x0] bits
    sub v30.8b, v30.8b, v2.8b // and increment counters

    subs x2, x2, #2
    bgt 0b

    // unpack temp vector and add to counters
    uxtl v30.8h, v30.8b
    uxtl v29.4s, v30.4h
    uxtl2 v30.4s, v30.8h // v29:v30 holds 4 S counters

    uxtl v2.2d, v29.2s
    add v16.2d, v16.2d, v2.2d
    uxtl2 v2.2d, v29.4s
    add v17.2d, v17.2d, v2.2d

    uxtl v2.2d, v30.2s
    add v18.2d, v18.2d, v2.2d
    uxtl2 v2.2d, v30.4s
    add v19.2d, v19.2d, v2.2d

    1: st1 {v16.2D-v19.2D}, [x0] // write counters back
    ret

    // kludge for LLVM compatibility
    .balign 8
    .Lmask: .8byte 0x8040201008040201
    97 changes: 97 additions & 0 deletions harness.c
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,97 @@
    #define _XOPEN_SOURCE 700
    #include <stdio.h>
    #include <stdlib.h>
    #include <string.h>
    #include <time.h>

    extern void
    count8reference(long long counts[restrict 8], const unsigned char *restrict buf, size_t len)
    {
    size_t i;
    int j;

    for (i = 0; i < len; i++)
    for (j = 0; j < 8; j++)
    counts[j] += buf[i] >> j & 1;
    }

    extern void count8asm15(long long [restrict 8], const unsigned char *restrict, size_t);

    /*
    * Compute the difference of two struct timespec.
    */
    static struct timespec
    tsdiff(struct timespec a, struct timespec b)
    {
    a.tv_sec -= b.tv_sec;
    a.tv_nsec -= b.tv_nsec;
    if (a.tv_nsec < 0) {
    a.tv_sec -= 1;
    a.tv_nsec += 1000000000;
    }

    return (a);
    }

    /* perform a benchmark */
    static void benchmark(const unsigned char *buf, size_t len, const char *name,
    void (*pospopcnt)(long long[restrict 8], const unsigned char *restrict, size_t))
    {
    struct timespec diff, start, end;
    double dur;
    int i, n = 1;
    long long naive_accum[8], asm_accum[8];

    memset(naive_accum, 0, sizeof naive_accum);
    memset(asm_accum, 0, sizeof asm_accum);

    count8reference(asm_accum, buf, len);
    pospopcnt(naive_accum, buf, len);

    if (memcmp(asm_accum, naive_accum, sizeof asm_accum) != 0)
    printf("%s\tmismatch\n", name);

    do {

    clock_gettime(CLOCK_REALTIME, &start);
    for (i = 0; i < n; i++)
    pospopcnt(asm_accum, buf, len);
    clock_gettime(CLOCK_REALTIME, &end);
    diff = tsdiff(end, start);
    n <<= 1;
    } while (diff.tv_sec == 0);

    n >>= 1;
    dur = diff.tv_sec + diff.tv_nsec / 1000000000.0;
    dur /= n;
    printf("%s\t%g B/s\n", name, len / dur);
    }

    extern int
    main(int argc, char *argv[])
    {
    size_t len = 8192;
    FILE *random;

    unsigned char *buf;

    if (argc > 1)
    len = atoll(argv[1]) + 31 & ~31LL;

    buf = malloc(len);
    if (buf == NULL) {
    perror("malloc");
    return (EXIT_FAILURE);
    }

    random = fopen("/dev/urandom", "rb");
    if (random == NULL) {
    perror("/dev/urandom");
    return (EXIT_FAILURE);
    }

    fread(buf, 1, len, random);

    benchmark(buf, len, "naive", count8reference);
    benchmark(buf, len, "asm15", count8asm15);
    }