这里只是给出一个思路,或许对于未来解决问题有一些参考意义。
仿 jap 的写法
这种写法很像 java 环境中的 jpa,如果引入模版引擎,则可以大幅增强实用性。
但是,在 python 环境中,这不符合主流的 orm 框架。
潜在风险:代码检测的时候,可能会被误判,因为我们定义了一大堆空的函数。
# 注解式事务 start --------------------------------------------- @update(sql='update `t_temp` set `desc`= :desc where (`id`= :id) limit 1') def modify(params: dict = none) -> int: pass @query(sql='select * from `t_temp` where (`id`= :id) limit 1', result_type=dict) def querybyid(params: dict = none) -> list: pass @query(sql='select * from `t_temp` where (`id`= :id) limit 1', result_type=dict) def querybyid2(id: int) -> list: pass @transactional() def test_annotation(): ret = modify({'id': 18, 'desc': 'or 1=1'}) print(ret) result = querybyid2(18) print(result)
代码封装
import inspect import logger_factory import typing from sqlalchemy import create_engine, text from sqlalchemy.orm import sessionmaker, scoped_session from sqlalchemy.engine import result, cursorresult logger = logger_factory.get_logger() # 定义数据库连接字符串 database_uri = 'mysql+pymysql://{username}:{password}@{host}:{port}/{dbname}?charset=utf8mb4' # 替换为你的数据库用户名、密码、主机、端口和数据库名 username = 'root' password = 'root' host = 'localhost' port = '3306' dbname = 'med' # 创建数据库引擎,使用连接池 engine = create_engine( database_uri.format( username=username, password=password, host=host, port=port, dbname=dbname ), echo=false, # 如果设置为true,sqlalchemy将打印所有执行的sql语句,通常用于调试 pool_size=10, # 连接池大小 max_overflow=20, # 超过连接池大小外最多创建的连接数 pool_timeout=30, # 连接池中没有线程可用时,在抛出异常前等待的时间 pool_recycle=3600 # 多少秒之后对连接进行一次回收(重置) ) # do a test with engine.connect() as con: rs = con.execute(text('select 1')) rs.fetchone() logger.debug('create engine succeed!') # session-maker session = sessionmaker(bind=engine) # thread safe session-maker dbsession = scoped_session(session) # with session() as session: # # 获取数据库连接 # connection = session.connection() # savepoint = connection.begin_nested() # print(savepoint) def geteffectrows(result: result) -> int: r""" 获取受影响行数 这里有点问题:源码部分 rowcount 是一个 callable,但实际应该是 int; 这里绕一点,确保不会出问题,如果返回 -1,说明出现了意料之外的情况 :param result: 结果集 :return: 受影响行数 """ if isinstance(result, cursorresult): effect_row = result.rowcount if isinstance(effect_row, int): return effect_row if callable(effect_row): return effect_row() return -1 def resultasdict(result: result) -> list: r""" 将查询结果转换为 dict-list :param result: 结果集 :return: dict 列表 """ keys = result.keys() ret = list() for item in result.fetchall(): ret.append(dict(zip(keys, item))) return ret def execute(sql: str, params: dict = none) -> result: r""" 执行一条查询语句 :param sql: 查询语句 :param params: 参数 :return: 结果集 """ if sql is none: raise valueerror('sql cannot be none') logger.debug('execute sql: ' + sql) logger.debug('parameter : ' + str(params)) return dbsession().execute(text(sql), params) def executequery(sql: str, params: dict = none, result_type: type = tuple) -> typing.sequence: r""" 执行一个查询 :param sql: sql :param params: dict :param result_type: 结果集类型,可选:tuple、dict :return: 序列 """ result = execute(sql, params) if result_type == dict: return resultasdict(result) pass # default return_type tuple-list return result.fetchall() def executeupdate(sql: str, params: dict = none) -> int: r""" 执行一个查询 :param sql: sql 执行语句 :param params: dict 查询参数 :return: 受影响行数 """ result = execute(sql, params) return geteffectrows(result) def transactional(rollback: type = exception): r""" 注解式事务 用法类似于 spring 环境下的 @transactional 注解 注意: 事务控制在 session 级别,不能兼容事务嵌套的场景(理想状态下,应当通过 save-point 实现) 推荐: 如果遇到很复杂的事务嵌套,显式调用 session,手动控制事务 :param rollback: 指定触发回滚的异常类型 :return: 装饰器函数 """ def decorator(func): def call(*args, **kwargs): session = none try: session = dbsession() ret = func(*args, **kwargs) session.commit() return ret except rollback as e: if session: session.rollback() logger.exception(f'transaction exception, rollback: {str(e)}') raise finally: if session: session.close() return call return decorator pass def update(sql: str = none): r""" 注解式查询,e.g.:: @update(sql='update `t_temp` set `desc`= :desc where (`id`= :id) limit 1') def modify(params: dict = none) -> int: pass :param sql: 要执行的 sql :return: decorator """ def decorator(func): def call(*args, **kwargs): result = execute(sql, args[0]) return geteffectrows(result) return call return decorator pass def query(sql: str = none, result_type: type = tuple): r""" 注解式查询,e.g.:: e.g.:: @query(sql='select * from `t_temp` where (`id`= :id) limit 1', result_type=dict) def querybyid2(id: int) -> list: pass :param sql: 要执行的 sql :param result_type: 结果集类型,可选:tuple、dict :return: decorator """ def decorator(func): def call(*args, **kwargs): if sql is none: raise valueerror('sql cannot be none') first = args[0] if isinstance(first, dict): result = dbsession().execute(text(sql), args) else: names = inspect.signature(func).parameters.values() params = dict() for idx, name in enumerate(names): params[name.name] = args[idx] print(params) result = dbsession().execute(text(sql), params) if result_type == dict: keys = result.keys() ret = list() for item in result.fetchall(): ret.append(dict(zip(keys, item))) return ret # default return_type tuple pass return result.fetchall() return call return decorator pass @transactional() def test_transaction(): r""" 测试注解式事务 :return: none """ session = dbsession() session.execute(text("update `t_temp` set `desc`= :desc where (`id`= :id) limit 1"), {'id': 18, 'desc': 'or 1=3'}) session.execute(text("update `t_temp` set `desc`= :desc where (`id`= :id) limit 1"), {'id': 18, 'desc': 'or 1=4'}) # raise exception raise syntaxerror('syntax error') @transactional() def test_api(): r""" 测试封装过的函数 :return: none """ execute("update `t_temp` set `desc`= :desc where (`id`= :id) limit 1", {'id': 18, 'desc': 'or 1=1'}) execute("update `t_temp` set `desc`= :desc where (`id`= :id) limit 1", {'id': 18, 'desc': 'or 1=2'}) # raise exception raise syntaxerror('syntax error')
到此这篇关于python - sqlachemy另类用法的文章就介绍到这了,更多相关python sqlachemy另类用法内容请搜索代码网以前的文章或继续浏览下面的相关文章希望大家以后多多支持代码网!
发表评论