用C++模板创造任意维度任意类型动态数组




前言

最近在写一个矩阵计算的库(很low),在写自己的二维数组的时候突然想到能不能自己用C++中类和模板来造一个任意维度的动态数组呢?于是趁着有空就写了一下,期间查了不少资料,记录下来以备忘

目录

  1. 表示
  2. 重载[]运算符
  3. 实现任意参数数目的函数
  4. 处理一些显而易见的异常
  5. 尾声
  6. 附录

1. 表示

这个是没有什么技术含量的,因此这里也不会展开,基本就是动态管理一个一维的线形数组来储存数据,然后再存一个数组来装下每一个维度的长度的信息即可。

2. 重载 [] 运算符

为了让访问更方便,重载 [] 运算符是非常有必要的。但是这里有一个问题就是 [] 是可以重载的运算符,但是[][]以及[][][]以及更多的 [] 并不是一个单独的运算符,也就是说我们只能重载出一维的运算符。这里我们可以借鉴一下C/C++中原生的多维数组的样子,每一次进行[]后都返回一个低一个纬度的数组。

但是这样做有个问题,由于每一次[]都要创建一个数组的实例, 都会至少首先分配一个长度为dimension(维度数)的数组。这带来的资源浪费和性能消耗都是不能容忍的。 为了解决这个问题,我们在这个数组类里的默认的构造函数除了初始化值之外不能进行任何诸如内存分配的操作,让直接用默认的没有任何参数构造出来的类的实例里面只包含一个int* (储存维度的长度信息)和一个T*(储存数据),而这两个指针的值均为nullptr才可以。然后我们在创造这个将要返回的类的时候只需要把这两个值进行赋值,这样就几乎不会产生额外的消耗了。

上面的做法看起来很好,但是还是存在一个问题的,那就是析构的问题。注意到这里因为[]产生的临时对象用的数据的实际存储区是原来的对象的一部分,因此这里创建的临时的对象在析构的时候不能执行任何delete[]的操作,否则将会破坏原来的那个对象,因此需要加一个标志位来判断一下这个类是不是像上面的那一种方式创建的临时的对象,如果是的话析构的时候仅仅把两个指针赋值为nullptr或者是什么都不干。

现在重载的运算符都可以工作了,但是还是存在一个问题, 经过[]后返回的值是一个class mat的类型,而不是实际上我们当dimension走到1 或者是0 的时候想要的T* 或者是 T的类型,因此要针对这俩类型写一个类型转换的函数,然后在实际使用的时候利用C++的隐式类型转换来实现自如的应用。代码很简单,长这个样子:


operator T*()
{
    return data;
}

operator T()
{
    return data[0];
}

3. 实现任意参数函数

我们在创建的时候肯定不希望采用先定义一个数组,在数组里面写好每一个维度的信息,这样太麻烦了。为此,我们可以利用C++的可变参数模板来实现构造函数接受任意个数的参数,这样就可以直接把每一个维度的长度信息直接写在参数列表里面了,使用起来会更加方便。

具体做法也很简单,废话不多说,这里直接上代码:


template<typename ...Args>
Mat(Args... args)
{
    int list[] = {args...};
    init();

    size = new int[dimension];
    for(int i=0;i<dimension;i++)
    {
        size[i] = list[i];
    }

    int totalsize = 1;
    for(int i=0;i<dimension;i++)
    {
        totalsize *= size[i];
    }

    data = new T[totalsize];

    Created = 1;
}

上面的…有两个作用,当…位于左边的时候是用来创建一个参数包,位于右边的时候是用来把这个参数包展开,至于如何展开这个参数包并使用你可以参照维基百科的这个页面,在这里采用的是最简便的初始化列表展开的方法。这时候args的参数都会在list这个列表中,从而后面就可以很方便的进行下一步的处理

4. 处理一些显而易见的异常

我们的目标是尽量让编译器发现更多的错误,因此不应该万事依靠运行时的检测来检查异常。为了做到这个,C++提供的静态断言(static_assert)很好的给我们提供了一种方法,关于详细的资料你可以参照cppreference上的这个页面(http://zh.cppreference.com/w/cpp/language/static_assert),这里只是简单的说一下

定义差不多张这个样子:

static_assert ( bool_constexpr , message )

如果bool_constexpr的值是TRUE,没有任何问题,继续编译,如果是FALSE,就会抛出一个内容为message的编译错误

于是我们就可以在operator []里面针对dimension做一些判断,同时针对我们前面定义的类型转换符里面的dimension做一些检测,防止出现错误。

5. 尾声

现在基本所有的东西都搞定了,剩下的就是考虑一些其他的不正常的地方,以及规范一下代码风格之类的东西了。这里因为我实际上是完全用不到这个东西的,就这样算了吧。另外现在这个类用起来就是这样的:


int main(void)
{
    Mat<int, 3> test(4, 5, 6);
    test.at(1, 2, 3) = 5;
    int out = test[1][2][3];
    cout<<out<<endl;

    return 0;
}

看起来还是挺方便的hhh

6. 附录:

如果你想看一下具体的代码,这里有一份非常非常low的代码(求轻喷)可以参考一下


#include <iostream>
#include <windows.h>

using namespace std;

template<typename T, int dimension>
class Mat
{
public:
    int* size;
public:
    Mat(void)
    {
        init();
    }
    template<typename ...Args>
    Mat(Args... args)
    {
        int list[] = {args...};
        init();

        size = new int[dimension];
        for(int i=0;i<dimension;i++)
        {
            size[i] = list[i];
        }

        int totalsize = 1;
        for(int i=0;i<dimension;i++)
        {
            totalsize *= size[i];
        }

        data = new T[totalsize];

        Created = 1;
    }

    ~Mat()
    {
        if(Created)
        {
            if(size != nullptr)
            {
                delete[] size;
            }
            if(data != nullptr)
            {
                delete[] data;
            }
        }
    }

    Mat<T, dimension-1> operator[] (int num)
    {
        static_assert(dimension > 0, "operator [] for 0 dimension mat is not allowed!");
        Mat<T, dimension-1> temp;

        temp.size = size+1;
        int position=1;

        for(int i=1;i<dimension;i++)
        {
            position*=size[i];
        }
        temp.SetDataPointer(data + position*num);
        return temp;
    }

    template<typename... Args>
    T& at(Args... args)
    {
        int position[] = {args...};
        int LinePos = 0;
        for(int i = 1; i <= dimension; i++)
        {
            int CurrentDimensionSize = 1;
            for(int j = i; j < dimension; j++)
            {
                CurrentDimensionSize *= size[j];
            }

            LinePos += CurrentDimensionSize*position[i-1];
        }
        return *(data + LinePos);
    }

    void SetDataPointer(T* _data)
    {
        data = _data;
    }

    operator T*()
    {
        static_assert(dimension == 1, "Convert Mat to int* is illegal since dimension is not 1!");
        return data;
    }

    operator T()
    {
        static_assert(dimension == 0, "Convert Mat to int is illeagal since dimension is not 0!");
        return data[0];
    }

    void operator= (T i)
    {
        static_assert(dimension == 0, "Assign int value to not 0 dimension mat is not allowed!");
        data[0] = i;
    }
private:
    int init(void)
    {
        size = nullptr;
        data = nullptr;
        Created = 0;
        return 0;
    }
private:
    T* data;
    int Created;
};

int main(void)
{
    Mat<int, 3> test(4, 5, 6);
    test.at(1, 2, 3) = 5;
    int out = test[1][2][3];
    cout<<out<<endl;

    return 0;
}

评论

发表评论

电子邮件地址不会被公开。 必填项已用*标注