blob: ff521649ebee47630c05c0bd6473d41dea5c3a89 [file] [log] [blame]
// Copyright 2021 Google LLC
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.
$import math
$assert IN_PTRS in ["MULTI", "REUSE"]
$assert OUT_PTRS in ["MULTI", "SWITCH", "MOV"]
$assert SIZE in [8, 16, 32, 64]
$TILE_SIZE = int(128/SIZE)
$NUM_ITERS = int(math.log2(TILE_SIZE))
#include <immintrin.h>
#include <assert.h>
#include <xnnpack/common.h>
#include <xnnpack/math.h>
#include <xnnpack/transpose.h>
void xnn_x${SIZE}_transpose_ukernel__${TILE_SIZE}x${TILE_SIZE}_${IN_PTRS.lower()}_${OUT_PTRS.lower()}_sse2(
const uint${SIZE}_t* input,
uint${SIZE}_t* output,
size_t input_stride,
size_t output_stride,
size_t block_width,
size_t block_height)
{
assert(output_stride >= block_height * sizeof(uint${SIZE}_t));
assert(input_stride >= block_width * sizeof(uint${SIZE}_t));
const size_t tile_height = ${TILE_SIZE};
const size_t tile_width = ${TILE_SIZE};
const size_t tile_hbytes = tile_height * sizeof(uint${SIZE}_t);
const size_t tile_wbytes = tile_width * sizeof(uint${SIZE}_t);
const size_t input_reset = tile_wbytes - round_down_po2(block_height, tile_height) * input_stride;
$if IN_PTRS == "MULTI":
const size_t input_offset = tile_height * input_stride;
const size_t output_reset = tile_width * output_stride - round_down_po2(block_height, 2) * sizeof(uint${SIZE}_t);
$if IN_PTRS == "MULTI":
const uint${SIZE}_t* i0 = input;
$for N in range(1, TILE_SIZE):
const uint${SIZE}_t* i${N} = (const uint${SIZE}_t*) ((uintptr_t) i${N-1} + input_stride);
$else:
const uint${SIZE}_t* i0 = input;
$if OUT_PTRS == "MULTI":
uint${SIZE}_t* o0 = (uint${SIZE}_t*) output;
$for N in range(1, TILE_SIZE):
uint${SIZE}_t* o${N} = (uint${SIZE}_t*) ((uintptr_t) o${N-1} + output_stride);
$else:
uint${SIZE}_t* o = (uint${SIZE}_t*) output;
do {
$if OUT_PTRS == "MULTI":
if XNN_UNPREDICTABLE(block_width < 2) {
o1 = o0;
}
$for N in range(2, TILE_SIZE, 2):
if XNN_UNPREDICTABLE(block_width <= ${N}) {
o${N} = o0;
}
if XNN_UNPREDICTABLE(block_width < ${N+2}) {
o${N+1} = o0;
}
$if OUT_PTRS in ["SWITCH", "MOV"]:
const size_t rem = min(block_width - 1, ${TILE_SIZE-1});
const size_t oN_stride = rem * output_stride;
size_t bh = block_height;
for (; bh >= ${TILE_SIZE}; bh -= ${TILE_SIZE}) {
$for N in range(TILE_SIZE):
$if IN_PTRS == "REUSE":
const __m128i v${NUM_ITERS}_${N} = _mm_loadu_si128((const __m128i*) i0);
i0 = (uint${SIZE}_t*) ((uintptr_t) i0 + input_stride);
$else:
const __m128i v${NUM_ITERS}_${N} = _mm_loadu_si128((const __m128i*) i${N});
i${N} = (uint${SIZE}_t*) ((uintptr_t) i${N} + input_offset);
$for N in range(TILE_SIZE >> 1):
const __m128i v${NUM_ITERS-1}_${N*2} = _mm_unpacklo_epi${SIZE}(v${NUM_ITERS}_${N*2}, v${NUM_ITERS}_${N*2+1});
const __m128i v${NUM_ITERS-1}_${N*2+1} = _mm_unpackhi_epi${SIZE}(v${NUM_ITERS}_${N*2}, v${NUM_ITERS}_${N*2+1});
$if NUM_ITERS>=2:
$for N in range(0, TILE_SIZE, 4):
const __m128i v${NUM_ITERS-2}_${N} = _mm_unpacklo_epi${SIZE*2}(v${NUM_ITERS-1}_${N}, v${NUM_ITERS-1}_${N+2});
const __m128i v${NUM_ITERS-2}_${N+1} = _mm_unpackhi_epi${SIZE*2}(v${NUM_ITERS-1}_${N}, v${NUM_ITERS-1}_${N+2});
const __m128i v${NUM_ITERS-2}_${N+2} = _mm_unpacklo_epi${SIZE*2}(v${NUM_ITERS-1}_${N+1}, v${NUM_ITERS-1}_${N+3});
const __m128i v${NUM_ITERS-2}_${N+3} = _mm_unpackhi_epi${SIZE*2}(v${NUM_ITERS-1}_${N+1}, v${NUM_ITERS-1}_${N+3});
$if NUM_ITERS>=3:
$for M in range(0, TILE_SIZE, 8):
$for N in range(0, 4):
const __m128i v${NUM_ITERS-3}_${M+2*N} = _mm_unpacklo_epi${SIZE*4}(v${NUM_ITERS-2}_${M+N}, v${NUM_ITERS-2}_${M+N+4});
const __m128i v${NUM_ITERS-3}_${M+2*N+1} = _mm_unpackhi_epi${SIZE*4}(v${NUM_ITERS-2}_${M+N}, v${NUM_ITERS-2}_${M+N+4});
$if NUM_ITERS>=4:
$for N in range(TILE_SIZE >> 1):
const __m128i v0_${N*2} = _mm_unpacklo_epi64(v1_${N}, v1_${N+8});
const __m128i v0_${N*2+1} = _mm_unpackhi_epi64(v1_${N}, v1_${N+8});
$if OUT_PTRS == "SWITCH":
uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
switch (rem) {
$for N in reversed(range(2, TILE_SIZE)):
case ${N}:
_mm_storeu_si128((__m128i*) oN, v0_${N});
oN = (uint${SIZE}_t*) ((uintptr_t) oN - output_stride);
case 1:
_mm_storeu_si128((__m128i*) oN, v0_1);
case 0:
_mm_storeu_si128((__m128i*) o, v0_0);
o = (uint${SIZE}_t*) ((uintptr_t) o + tile_hbytes);
break;
default:
XNN_UNREACHABLE;
}
$elif OUT_PTRS == "MOV":
uint${SIZE}_t* o${TILE_SIZE-1} = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
_mm_storeu_si128((__m128i*) o${TILE_SIZE-1}, v0_${TILE_SIZE-1});
$for N in reversed(range(2, TILE_SIZE, 2)):
uint${SIZE}_t *o${N} = (uint${SIZE}_t*) ((uintptr_t) o${N+1} - output_stride);
if XNN_UNPREDICTABLE(block_width <= ${N+1}) {
o${N} = o${TILE_SIZE-1};
}
_mm_storeu_si128((__m128i*) o${N}, v0_${N});
uint${SIZE}_t *o${N-1} = (uint${SIZE}_t*) ((uintptr_t) o${N} - output_stride);
if XNN_UNPREDICTABLE(block_width < ${N+1}) {
o${N-1} = o${TILE_SIZE-1};
}
_mm_storeu_si128((__m128i*) o${N-1}, v0_${N-1});
_mm_storeu_si128((__m128i*) o, v0_0);
o = (uint${SIZE}_t*) ((uintptr_t) o + tile_hbytes);
$else:
$for N in reversed(range(TILE_SIZE)):
_mm_storeu_si128((__m128i*) o${N}, v0_${N});
o${N} = (uint${SIZE}_t*) ((uintptr_t) o${N} + tile_hbytes);
}
if (bh != 0) {
$if IN_PTRS == "REUSE":
const __m128i v${NUM_ITERS}_0 = _mm_loadu_si128((const __m128i*) i0);
$for N in range(1, TILE_SIZE - 1, 2):
const uint${SIZE}_t *i${N} = (const uint${SIZE}_t*) ((uintptr_t) i${N-1} + input_stride);
if XNN_UNPREDICTABLE(bh < ${N+1}) {
i${N} = i0;
}
const __m128i v${NUM_ITERS}_${N} = _mm_loadu_si128((const __m128i*) i${N});
const uint${SIZE}_t *i${N+1} = (const uint${SIZE}_t*) ((uintptr_t) i${N} + input_stride);
if XNN_UNPREDICTABLE(bh <= ${N+1}) {
i${N+1} = i0;
}
const __m128i v${NUM_ITERS}_${N+1} = _mm_loadu_si128((const __m128i*) i${N+1});
$else:
const __m128i v${NUM_ITERS}_0 = _mm_loadu_si128((const __m128i*) i0);
$for N in range(1, TILE_SIZE - 1, 2):
if XNN_UNPREDICTABLE(bh < ${N+1}) {
i${N} = i0;
}
const __m128i v${NUM_ITERS}_${N} = _mm_loadu_si128((const __m128i*) i${N});
if XNN_UNPREDICTABLE(bh <= ${N+1}) {
i${N+1} = i0;
}
const __m128i v${NUM_ITERS}_${N+1} = _mm_loadu_si128((const __m128i*) i${N+1});
const __m128i v${NUM_ITERS}_${TILE_SIZE-1} = _mm_undefined_si128();
$CONST = "const "
$if NUM_ITERS == 1:
$CONST = ""
$for N in range(TILE_SIZE >> 1):
${CONST}__m128i v${NUM_ITERS-1}_${N*2} = _mm_unpacklo_epi${SIZE}(v${NUM_ITERS}_${N*2}, v${NUM_ITERS}_${N*2+1});
${CONST}__m128i v${NUM_ITERS-1}_${N*2+1} = _mm_unpackhi_epi${SIZE}(v${NUM_ITERS}_${N*2}, v${NUM_ITERS}_${N*2+1});
$if NUM_ITERS == 2:
$CONST = ""
$if NUM_ITERS>=2:
$for N in range(0, TILE_SIZE, 4):
${CONST}__m128i v${NUM_ITERS-2}_${N} = _mm_unpacklo_epi${SIZE*2}(v${NUM_ITERS-1}_${N}, v${NUM_ITERS-1}_${N+2});
${CONST}__m128i v${NUM_ITERS-2}_${N+1} = _mm_unpackhi_epi${SIZE*2}(v${NUM_ITERS-1}_${N}, v${NUM_ITERS-1}_${N+2});
${CONST}__m128i v${NUM_ITERS-2}_${N+2} = _mm_unpacklo_epi${SIZE*2}(v${NUM_ITERS-1}_${N+1}, v${NUM_ITERS-1}_${N+3});
${CONST}__m128i v${NUM_ITERS-2}_${N+3} = _mm_unpackhi_epi${SIZE*2}(v${NUM_ITERS-1}_${N+1}, v${NUM_ITERS-1}_${N+3});
$if NUM_ITERS == 3:
$CONST = ""
$if NUM_ITERS>=3:
$for M in range(0, TILE_SIZE, 8):
$for N in range(0, 4):
${CONST}__m128i v${NUM_ITERS-3}_${M+2*N} = _mm_unpacklo_epi${SIZE*4}(v${NUM_ITERS-2}_${M+N}, v${NUM_ITERS-2}_${M+N+4});
${CONST}__m128i v${NUM_ITERS-3}_${M+2*N+1} = _mm_unpackhi_epi${SIZE*4}(v${NUM_ITERS-2}_${M+N}, v${NUM_ITERS-2}_${M+N+4});
$if NUM_ITERS>=4:
$for N in range(TILE_SIZE >> 1):
__m128i v0_${N*2} = _mm_unpacklo_epi64(v1_${N}, v1_${N+8});
__m128i v0_${N*2+1} = _mm_unpackhi_epi64(v1_${N}, v1_${N+8});
if (bh & ${TILE_SIZE>>1}) {
$if OUT_PTRS == "SWITCH":
uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
switch (rem) {
$for N in reversed(range(2, TILE_SIZE)):
case ${N}:
_mm_storel_epi64((__m128i*) oN, v0_${N});
oN = (uint${SIZE}_t*) ((uintptr_t) oN - output_stride);
case 1:
_mm_storel_epi64((__m128i*) oN, v0_1);
case 0:
_mm_storel_epi64((__m128i*) o, v0_0);
break;
default:
XNN_UNREACHABLE;
}
$if NUM_ITERS > 1:
o += ${TILE_SIZE>>1};
$elif OUT_PTRS == "MOV":
uint${SIZE}_t* o${TILE_SIZE-1} = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
_mm_storel_epi64((__m128i*) o${TILE_SIZE-1}, v0_${TILE_SIZE-1});
$for N in reversed(range(2, TILE_SIZE, 2)):
uint${SIZE}_t *o${N} = (uint${SIZE}_t*) ((uintptr_t) o${N+1} - output_stride);
if XNN_UNPREDICTABLE(block_width <= ${N+1}) {
o${N} = o${TILE_SIZE-1};
}
_mm_storel_epi64((__m128i*) o${N}, v0_${N});
uint${SIZE}_t *o${N-1} = (uint${SIZE}_t*) ((uintptr_t) o${N} - output_stride);
if XNN_UNPREDICTABLE(block_width < ${N+1}) {
o${N-1} = o${TILE_SIZE-1};
}
_mm_storel_epi64((__m128i*) o${N-1}, v0_${N-1});
_mm_storel_epi64((__m128i*) o, v0_0);
$if NUM_ITERS > 1:
o += ${TILE_SIZE>>1};
$else:
$for N in reversed(range(TILE_SIZE)):
_mm_storel_epi64((__m128i*) o${N}, v0_${N});
$if NUM_ITERS>1:
o${N} += ${TILE_SIZE>>1};
$if NUM_ITERS > 1:
$for N in range(TILE_SIZE):
v0_${N} = _mm_unpackhi_epi64(v0_${N}, v0_${N});
}
$if NUM_ITERS>1:
if (bh & ${TILE_SIZE>>2}) {
$if OUT_PTRS == "SWITCH":
uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
switch (rem) {
$for N in reversed(range(2, TILE_SIZE)):
case ${N}:
*((int*) oN) = _mm_cvtsi128_si32(v0_${N});
oN = (uint${SIZE}_t*) ((uintptr_t) oN - output_stride);
case 1:
*((int*) oN) = _mm_cvtsi128_si32(v0_1);
case 0:
*((int*) o) = _mm_cvtsi128_si32(v0_0);
break;
default:
XNN_UNREACHABLE;
}
$if NUM_ITERS > 2:
o += ${TILE_SIZE>>2};
$elif OUT_PTRS == "MOV":
uint${SIZE}_t* o${TILE_SIZE-1} = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
*((int*) o${TILE_SIZE-1}) = _mm_cvtsi128_si32(v0_${TILE_SIZE-1});
$for N in reversed(range(2, TILE_SIZE, 2)):
uint${SIZE}_t *o${N} = (uint${SIZE}_t*) ((uintptr_t) o${N+1} - output_stride);
if XNN_UNPREDICTABLE(block_width <= ${N+1}) {
o${N} = o${TILE_SIZE-1};
}
*((int*) o${N}) = _mm_cvtsi128_si32(v0_${N});
uint${SIZE}_t *o${N-1} = (uint${SIZE}_t*) ((uintptr_t) o${N} - output_stride);
if XNN_UNPREDICTABLE(block_width < ${N+1}) {
o${N-1} = o${TILE_SIZE-1};
}
*((int*) o${N-1}) = _mm_cvtsi128_si32(v0_${N-1});
*((int*) o) = _mm_cvtsi128_si32(v0_0);
$if NUM_ITERS > 2:
o += ${TILE_SIZE>>2};
$else:
$for N in reversed(range(TILE_SIZE)):
*((int*) o${N}) = _mm_cvtsi128_si32(v0_${N});
$if NUM_ITERS>2:
o${N} += ${TILE_SIZE>>2};
$if NUM_ITERS > 2:
$for N in range(TILE_SIZE):
v0_${N} = _mm_srli_epi64(v0_${N}, 32);
}
$if NUM_ITERS>2:
if (bh & ${TILE_SIZE>>3}) {
$if OUT_PTRS == "SWITCH":
uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
switch (rem) {
$for N in reversed(range(2, TILE_SIZE)):
case ${N}:
*((uint16_t*) oN) = (uint16_t) _mm_cvtsi128_si32(v0_${N});
oN = (uint${SIZE}_t*) ((uintptr_t) oN - output_stride);
case 1:
*((uint16_t*) oN) = (uint16_t) _mm_cvtsi128_si32(v0_1);
case 0:
*((uint16_t*) o) = (uint16_t) _mm_cvtsi128_si32(v0_0);
break;
default:
XNN_UNREACHABLE;
}
$if NUM_ITERS>3:
o += ${TILE_SIZE>>3};
$elif OUT_PTRS == "MOV":
uint${SIZE}_t* o${TILE_SIZE-1} = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
*((uint16_t*) o${TILE_SIZE-1}) = (uint16_t) _mm_cvtsi128_si32(v0_${TILE_SIZE-1});
$for N in reversed(range(2, TILE_SIZE, 2)):
uint${SIZE}_t *o${N} = (uint${SIZE}_t*) ((uintptr_t) o${N+1} - output_stride);
if XNN_UNPREDICTABLE(block_width <= ${N+1}) {
o${N} = o${TILE_SIZE-1};
}
*((uint16_t*) o${N}) = (uint16_t) _mm_cvtsi128_si32(v0_${N});
uint${SIZE}_t *o${N-1} = (uint${SIZE}_t*) ((uintptr_t) o${N} - output_stride);
if XNN_UNPREDICTABLE(block_width < ${N+1}) {
o${N-1} = o${TILE_SIZE-1};
}
*((uint16_t*) o${N-1}) = (uint16_t) _mm_cvtsi128_si32(v0_${N-1});
*((uint16_t*) o) = (uint16_t) _mm_cvtsi128_si32(v0_0);
$if NUM_ITERS > 3:
o += ${TILE_SIZE>>3};
$else:
$for N in reversed(range(TILE_SIZE)):
*((uint16_t*) o${N}) = (uint16_t) _mm_cvtsi128_si32(v0_${N});
$if NUM_ITERS>3:
o${N} += ${TILE_SIZE>>3};
$if NUM_ITERS>3:
$for N in range(TILE_SIZE):
v0_${N} = _mm_srli_epi32(v0_${N}, 16);
}
$if SIZE == 8:
if (bh & 1) {
$if OUT_PTRS == "SWITCH":
uint${SIZE}_t* oN = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
switch (rem) {
$for N in reversed(range(2, TILE_SIZE)):
case ${N}:
*oN = (uint8_t) _mm_cvtsi128_si32(v0_${N});
oN = (uint${SIZE}_t*) ((uintptr_t) oN - output_stride);
case 1:
*oN = (uint8_t) _mm_cvtsi128_si32(v0_1);
case 0:
*o = (uint8_t) _mm_cvtsi128_si32(v0_0);
break;
default:
XNN_UNREACHABLE;
}
$elif OUT_PTRS == "MOV":
uint${SIZE}_t* o${TILE_SIZE-1} = (uint${SIZE}_t*) ((uintptr_t) o + oN_stride);
*((uint8_t*) o${TILE_SIZE-1}) = (uint8_t) _mm_cvtsi128_si32(v0_${TILE_SIZE-1});
$for N in reversed(range(2, TILE_SIZE, 2)):
uint${SIZE}_t *o${N} = (uint${SIZE}_t*) ((uintptr_t) o${N+1} - output_stride);
if XNN_UNPREDICTABLE(block_width <= ${N+1}) {
o${N} = o${TILE_SIZE-1};
}
*((uint8_t*) o${N}) = (uint8_t) _mm_cvtsi128_si32(v0_${N});
uint${SIZE}_t *o${N-1} = (uint${SIZE}_t*) ((uintptr_t) o${N} - output_stride);
if XNN_UNPREDICTABLE(block_width < ${N+1}) {
o${N-1} = o${TILE_SIZE-1};
}
*((uint8_t*) o${N-1}) = (uint8_t) _mm_cvtsi128_si32(v0_${N-1});
*((uint8_t*) o) = (uint8_t) _mm_cvtsi128_si32(v0_0);
$else:
$for N in reversed(range(TILE_SIZE)):
*o${N} = (uint8_t) _mm_cvtsi128_si32(v0_${N});
}
}
$if IN_PTRS == "MULTI":
i0 = (const uint${SIZE}_t*) ((uintptr_t) i0 + input_reset);
$for N in range(1, TILE_SIZE):
i${N} = (const uint${SIZE}_t*) ((uintptr_t) i${N-1} + input_stride);
$else:
i0 = (const uint${SIZE}_t*) ((uintptr_t) i0 + input_reset);
$if OUT_PTRS == "MULTI":
o0 = (uint${SIZE}_t*) ((uintptr_t) o0 + output_reset);
$for N in range(1, TILE_SIZE):
o${N} = (uint${SIZE}_t*) ((uintptr_t) o${N} + output_reset);
$else:
o = (uint${SIZE}_t*) ((uintptr_t) o + output_reset);
block_width = doz(block_width, tile_width);
} while (block_width != 0);
}