diff options
Diffstat (limited to 'src/Math/Tensor.hpp')
| -rw-r--r-- | src/Math/Tensor.hpp | 43 |
1 files changed, 43 insertions, 0 deletions
diff --git a/src/Math/Tensor.hpp b/src/Math/Tensor.hpp new file mode 100644 index 0000000..e373c1f --- /dev/null +++ b/src/Math/Tensor.hpp @@ -0,0 +1,43 @@ +#pragma once + +template<size_t O, typename T, size_t ...S> +struct Tensor { + static constexpr size_t element_size = (S*...); + + template<typename ...Args> + using EnableArgs = std::enable_if_t<sizeof...(Args) == O, bool>; + template<typename ...Args> + using EnableArgsPerElement = std::enable_if_t<sizeof...(Args) == element_size, bool>; + using Enable = std::enable_if_t<sizeof...(S) == O, bool>; + + template<Enable = true> + Tensor() : elements{} {} + + template<Enable = true> + explicit Tensor(T scalar) { + std::fill(elements, elements + element_size, scalar); + } + + template<Enable = true, typename ...Args, EnableArgsPerElement<Args...> = true> + Tensor(Args... args) : elements{ args... } {} + + template<typename ...Args, EnableArgs<Args...> = true> + auto& operator()(Args... args) { + return elements[pos(args...)]; + } + + template<typename ...Args, EnableArgs<Args...> = true> + static constexpr size_t pos(Args... args) { + size_t positions[O] = {static_cast<size_t>(args)...}; + size_t dimensions[O] = {S...}; + + size_t p = 0; + for (int i = 0; i < O; i++) { + p *= dimensions[i]; + p += positions[i]; + } + return p; + } + + T elements[element_size]; +}; |
