summary refs log tree commit diff
path: root/src/Math/Tensor.hpp
blob: 95c454801a2dfebb731f621df5ca5078dc66b4b8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
#pragma once

template<uint O, typename T, uint ...S>
struct Tensor {
    static constexpr USize element_size = (S*...);

    template<typename ...Args>
    using EnableArgs = std::enable_if_t<sizeof...(Args) == O, Bool>;
    template<typename ...Args>
    using EnableArgsPerElement = std::enable_if_t<sizeof...(Args) == element_size, Bool>;
    using Enable = std::enable_if_t<sizeof...(S) == O, Bool>;

    template<Enable = true>
    Tensor() : elements{} {}

    template<Enable = true>
    explicit Tensor(T scalar) {
        std::fill(elements, elements + element_size, scalar);
    }

    template<Enable = true, typename ...Args, EnableArgsPerElement<Args...> = true>
    Tensor(Args... args) : elements{ static_cast<T>(args)... } {}

    template<typename ...Args, EnableArgs<Args...> = true>
    auto& operator()(Args... args) {
        return elements[pos(args...)];
    }

    template<typename ...Args, EnableArgs<Args...> = true>
    static constexpr USize pos(Args... args) {
        USize positions[O] = {static_cast<USize>(args)...};
        USize dimensions[O] = {S...};

        USize p = 0;
        for (Int i = 0; i < O; i++) {
            p *= dimensions[i];
            p += positions[i];
        }
        return p;
    }

    T elements[element_size];
};