SQLAlchemy 分表实践

去年年底利用工作之余开发了一个进销存相关的 SAAS 项目,ORM 用的 SQLAlchemy,并且进行了一些分表操作,这里来做个简单的记录(也只能是简单记录了,我是小半年前进行的分表调研)。

我没有直接使用 继承自 db.Model 的 ORM 类来操作数据库,而是在其之上又封装了一层,将更具体的一些数据库操作进行了封装。

举个例子,我有一个继承自 db.ModelItem 类,同时还有一个自己的 Item 类,然后在自己的 Item 类中引用继承自 db.Model 的类,为了防止名称冲突,在引用时我会将继承自 db.Model 的类叫做 SqlaItem

分表是如何实现的呢,我通过 SQLAlchemyautomap_base() 将数据库中所有表进行了映射,然后自己实现了分表函数,通过分表函数得到分表的名称,然后动态的拿到那个表所对应的 ORM。

直接看代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import
from sqlalchemy import create_engine
from sqlalchemy.ext.automap import automap_base
from fuxi.config import SQLALCHEMY_DATABASE_URI
engine = create_engine(SQLALCHEMY_DATABASE_URI)
def ab_cls():
ab_cls = automap_base()
ab_cls.prepare(engine, reflect=True)
return ab_cls
def _get_ab_cls():
if not getattr(_get_ab_cls, '_ab_cls', None):
_ab_cls = ab_cls()
_get_ab_cls._ab_cls = _ab_cls
return _get_ab_cls._ab_cls
ab_cls = _get_ab_cls()

所有分了表的类,都要通过 ab_cls 来获取表映射出来的对象。

还是以 Item 为例,看一下我的相关代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
@classmethod
def _get_db_branch_name_by_user(cls, user_id):
branch = str(user_id)[-1]
return branch
@classmethod
def _get_item_dao(cls, branch_name):
tablename = 'item_%s' % branch_name
dao = getattr(ab_cls.classes, tablename)
return dao
@classmethod
def get_dao(cls, user_id):
branch = cls._get_db_branch_name_by_user(user_id)
return cls._get_item_dao(branch)

我通过这几个方法实现了获取分表映射的功能,在具体使用时,可以直接用 get_dao(user_id) 获取表映射(我是通过用户ID的规则进行的分表)。

随便看一个操作:

1
2
3
4
5
6
7
8
9
10
@classmethod
def gets_by_name(cls, user_id, name):
SqlaItem = cls.get_dao(user_id)
cond = (SqlaItem.user_id == user_id)
cond &= (SqlaItem.status != ItemStatus.DELETE.value)
if name:
cond &= (SqlaItem.name == name)
query = db.session.query(SqlaItem).filter(cond)
return [cls.init_from_sqla(x) for x in query]

我先通过 get_dao(user_id) 获取到这个用户数据所在的表的映射,然后就可进行各种 CURD 操作了。

也就是说,在我的项目中其实有两种获取 SqlaXxxx 的方法,如果没有分表,那么直接用继承自 db.Model 的类即可,如果是分了表的,就用动态映射出来的,所以后者实际上是不需要写继承自 db.Model 的类的,但是为了在初始化时生成所有表结构,我还是写了这些类,只不过这些类所对应的表都是分表中的第一张表,以 Item 为例

1
2
3
4
5
6
7
class Item(db.Model):
"""
商品
分表策略: 用户ID最后 1 位
"""
__tablename__ = 'item_1'
...

这样的话,我在执行 create_db 时所有需要分表的第一个表都会被建好,这个时候,我只需要再写个简单的脚本,就可以帮我把剩余的表建出来了,因为我所有分表结尾都是以下划线1或者01组成的,意思是,如果表需要分成 10 个,那么对应的第一个表的名称就是 xxx_1,如果需要分成 100 个,那么对应的第一个表的名称就是 xxx_01,所以我根据这个规则写了生成剩余表的脚本,如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# -*- coding: utf-8 -*-
from __future__ import unicode_literals, absolute_import
import os
from urlparse import urlparse
import MySQLdb
from envcfg.json.fuxi import SQLALCHEMY_DATABASE_URI
p = urlparse(SQLALCHEMY_DATABASE_URI)
DATABASE_HOST = p.hostname
DATABASE_USER = p.username
DATABASE_PASSWD = p.password if p.password else ""
def _get_table_name(table):
table_name_start = table.find('`') + 1
table_name_end = table.find('`', table_name_start)
return table[table_name_start:table_name_end]
def create_table(table):
"""
首先根据表名确定当前表是否需要分表
如果表名是已1结尾,那么就创建0-9结尾的表
如果表名是已01结尾,那么就创建00-99结尾的表
如果以上情况都不是,直接创建
"""
db = MySQLdb.connect(DATABASE_HOST, DATABASE_USER, DATABASE_PASSWD, "fuxi")
cursor = db.cursor()
assert table.startswith('CREATE TABLE')
table_name = _get_table_name(table)
if table_name.endswith('01'):
replace_table = table.replace(table_name, '%s')
for i in range(100):
sharding_table_name = '%s_%02d' % (table_name[:-3], i)
try:
cursor.execute(replace_table % sharding_table_name)
print '%s create success' % sharding_table_name
except Exception as e:
if e.args[0] != 1050:
raise e
print '%s exsit' % sharding_table_name
elif table_name.endswith('1'):
replace_table = table.replace(table_name, '%s')
for i in range(10):
sharding_table_name = '%s_%s' % (table_name[:-2], i)
try:
cursor.execute(replace_table % sharding_table_name)
print '%s create success' % sharding_table_name
except Exception as e:
if e.args[0] != 1050:
raise e
print '%s exsit' % sharding_table_name
else:
try:
cursor.execute(table)
print '%s create success' % table_name
except Exception as e:
if e.args[0] != 1050:
raise e
print '%s exsit' % table_name
db.close()
def gets_all_tables():
with open(os.path.dirname(os.path.realpath(__file__)) + '/fuxi.sql', 'r') as f:
content = f.read()
tables = content.split('\n\n')
for table in tables:
create_table(table)
if __name__ == '__main__':
gets_all_tables()

这种方式有个弊端,在项目启动时就需要将所有表结构读入到内存中,直接的表现是启动比较慢,占用内存比较多。

我不觉得这是个最佳方案,所以如果有更好的方案或者有任何疑问请通过邮件(jiapan.china#gmail.com)的方式告诉我,谢谢。