基于纯SQL训练机器学习模型

译者 | 朱先忠

审校 | 梁策 孙淑娟

在​​《用纯SQL在BigQuery上实现深层神经网络》​​一文中,作者声称使用纯SQL方式实现了一个深层神经网络模型。但在我打开他的​​GitHub代码仓库​​分析后发现,他是使用Python来实现迭代训练的,而这并不是真正的纯SQL方式。

在本文中,我将分享我是如何在​​开源分布式SQL数据库TiDB​​上用纯SQL方式训练机器学习模型的。主要步骤包括:

  1. 选择Iris数据集
  2. (https://scikit-learn.org/stable/auto_examples/datasets/plot_iris_dataset.html)。
  3. 选择softmax逻辑回归模型用于训练。
  4. 编写SQL语句来实现模型推理。
  5. 开始模型训练。

在测试中,我训练了一个softmax逻辑回归模型。在测试期间,我发现TiDB不允许在递归公共表表达式(CTE)中使用子查询和聚合函数。通过修改TiDB的代码,我绕过了这些限制,成功地训练了一个模型,并在Iris数据集上获得了98%的准确率。

为什么我选择TiDB来实现机器学习模型?

TiDB 5.1引入了许多新功能,包括ANSI SQL 99标准的通用表表达式(CTE)。我们可以使用CTE作为临时视图的语句来解耦复杂的SQL语句,并更高效地开发代码。此外,递归CTE可以引用自身,这对于改进SQL功能非常重要。CTE和窗口函数使SQL成为一种图灵完备的语言。

【说明】因为递归CTE可以“迭代”,所以我想尝试一下,看看是否可以使用纯SQL在TiDB上实现机器学习模型训练和推理。

鸢尾花(Iris)数据集

我选择使用scikit-learn的Iris数据集。该数据集包含3种类型,每种类型有50条记录,一共150条。每个记录有4个特征:萼片长度(SL)、萼片宽度(SW)、花瓣长度(PL)和花瓣宽度(PW)。我们可以利用这些特征来预测鸢尾花是否属于山鸢尾(Iris-setosa) 、 变色鸢尾(Iris-versicolor)和维吉尼亚鸢尾(Iris-virginica)。

以CSV格式下载数据后,我将其导入了TiDB数据库。使用的SQL脚本如下:

createtable iris(sl float, sw float, pl float, pw float, type varchar(16));
LOAD DATA LOCAL INFILE 'iris.csv'INTOTABLE iris FIELDS TERMINATED BY',' LINES TERMINATED BY'\n';
select*from iris limit10;
+------+------+------+------+------------------+
| sl | sw | pl | pw | type |
+-------+--------+-------+--------+------------+
|5.1|3.5|1.4|0.2| Iris-setosa |
|4.9|3|1.4|0.2| Iris-setosa |
|4.7|3.2|1.3|0.2| Iris-setosa |
|4.6|3.1|1.5|0.2| Iris-setosa |
|5|3.6|1.4|0.2| Iris-setosa |
|5.4|3.9|1.7|0.4| Iris-setosa |
|4.6|3.4|1.4|0.3| Iris-setosa |
|5|3.4|1.5|0.2| Iris-setosa |
|4.4|2.9|1.4|0.2| Iris-setosa |
|4.9|3.1|1.5|0.1| Iris-setosa |
+------+-------+-------+-------+--------------+
10 rows inset(0.00 sec)
select type,count(*)from iris groupby type;
+-------------------+------------------+
| type |count(*)|
+-------------------+-----------------+
| Iris-versicolor |50|
| Iris-setosa |50|
| Iris-virginica |50|
+-------------------+----------------+
3 rows inset(0.00 sec)

Softmax逻辑回归

我选择了一个简单的机器学习模型:用于多类分类的Softmax逻辑回归。在Softmax回归中:

成本函数是:

梯度是:

因此,我们可以使用梯度下降来升级梯度:

模型推理

我编写了一条SQL语句来实现推理。基于上面定义的模型和数据,输入数据x有五个维度(SL、SW、PL、PW和一个常数1.0),输出使用了一种热编码。SQL脚本如下:

createtable data(
x0 decimal(35,30), x1 decimal(35,30), x2 decimal(35,30), x3 decimal(35,30), x4 decimal(35,30),
y0 decimal(35,30), y1 decimal(35,30), y2 decimal(35,30)
);
insertinto data
select
sl, sw, pl, pw,1.0,
case when type='Iris-setosa'then 1 else 0 end,
case when type='Iris-versicolor'then 1 else 0 end,
case when type='Iris-virginica'then 1 else 0 end
from iris;

共有15个参数(3种类型*5个维度)。SQL脚本如下:

createtable weight(
w00 decimal(35,30), w01 decimal(35,30), w02 decimal(35,30), w03 decimal(35,30), w04 decimal(35,30),
w10 decimal(35,30), w11 decimal(35,30), w12 decimal(35,30), w13 decimal(35,30), w14 decimal(35,30),
w20 decimal(35,30), w21 decimal(35,30), w22 decimal(35,30), w23 decimal(35,30), w24 decimal(35,30));

我将输入数据初始化为0.1、0.2、0.3。为了便于演示,我使用了不同的数字。将它们全部初始化为0.1是可以的。SQL脚本如下:

insertinto weight values(
0.1,0.1,0.1,0.1,0.1,
0.2,0.2,0.2,0.2,0.2,
0.3,0.3,0.3,0.3,0.3);

接下来,我编写了一条SQL语句来计算数据推断结果的准确性。为了更好地理解,我使用伪代码来描述这个过程:

weight =(
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24
)
for data(x0, x1, x2, x3, x4, y0, y1, y2)in all Data:
exp0 = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04)
exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14)
exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24)
sum_exp = exp0 + exp1 + exp2

// softmax
p0 = exp0 / sum_exp
p1 = exp1 / sum_exp
p2 = exp2 / sum_exp

//推理结果
r0 = p0 > p1 and p0 > p2
r1 = p1 > p0 and p1 > p2
r2 = p2 > p0 and p2 > p1

data.correct=(y0 == r0 and y1 == r1 and y2 == r2)
return sum(Data.correct)/count(Data)

在上面的代码中,我计算了每行数据中的元素。为了对样本进行推断:

  1. 我计算出加权向量的EXP。
  2. 并且计算出softmax值。
  3. 然后,选择p0、p1和p2中最大的一个作为1,并将其余的设置为0。

如果样本的推断结果与其原始分类一致,则预测正确。然后,我将所有样本的正确数量相加,得到最终的准确率。

下面的代码显示了SQL语句的实现。我将每一行数据加上一个权重(只有一行数据),计算每一行的推断结果,并将正确的样本数相加:

select sum(y0 = r0 and y1 = r1 and y2 = r2)/count(*)
from
(select
y0, y1, y2,
p0 > p1 and p0 > p2 as r0, p1 > p0 and p1 > p2 as r1, p2 > p0 and p2 > p1 as r2
from
(select
y0, y1, y2,
e0/(e0+e1+e2)as p0, e1/(e0+e1+e2)as p1, e2/(e0+e1+e2)as p2
from
(select
y0, y1, y2,
exp(
w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4
)as e0,
exp(
w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4
)as e1,
exp(
w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4
)as e2
from data, weight) t1
)t2
    )t3;

上面的SQL语句几乎一步一步地实现了伪代码的计算过程。我得到了如下结果:

+-------------------------------------------------------------+
| sum(y0 = r0 and y1 = r1 and y2 = r2)/count(*)|
+-------------------------------------------------------------+
|0.3333|
+-------------------------------------------------------------+
1 row inset(0.01 sec)

接下来,我开始学习模型参数。

模型训练

注意:为了简化问题,我没有考虑“训练集”和“验证集”问题,而是把所有的数据仅用于进行训练。

我编写了伪代码,然后在此基础上编写了一条SQL语句:

weight =(
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24
)
for iter in iterations:
sum00 =0
sum01 =0
...
sum23 =0
sum24 =0
for data(x0, x1, x2, x3, x4, y0, y1, y2)in all Data:
exp0 = exp(x0 * w00, x1 * w01, x2 * w02, x3 * w03, x4 * w04)
exp1 = exp(x0 * w10, x1 * w11, x2 * w12, x3 * w13, x4 * w14)
exp2 = exp(x0 * w20, x1 * w21, x2 * w22, x3 * w23, x4 * w24)
sum_exp = exp0 + exp1 + exp2
// softmax
p0 = y0 - exp0 / sum_exp
p1 = y1 - exp1 / sum_exp
p2 = y2 - exp2 / sum_exp
sum00 += p0 * x0
sum01 += p0 * x1
sum02 += p0 * x2
...
sum23 += p2 * x3
sum24 += p2 * x4
w00 = w00 + learning_rate * sum00 / Data.size
w01 = w01 + learning_rate * sum01 / Data.size
...
w23 = w23 + learning_rate * sum23 / Data.size
    w24 = w24 + learning_rate * sum24 / Data.size

因为我手动扩展了sum和w向量,所以这段代码看起来有点麻烦。然后,我开始编写SQL训练代码。首先,我编写了一条只用一次迭代的SQL语句。

我设置了如下所示的学习速率和样本数:

set @lr =0.1;
Query OK,0 rows affected (0.00 sec)
set @dsize =150;
Query OK,0 rows affected (0.00 sec)

代码迭代了一次:

select
w00 + @lr * sum(d00)/ @dsize as w00, w01 + @lr * sum(d01)/ @dsize as w01, w02 + @lr * sum(d02)/ @dsize as w02, w03 + @lr * sum(d03)/ @dsize as w03, w04 + @lr * sum(d04)/ @dsize as w04 ,
w10 + @lr * sum(d10)/ @dsize as w10, w11 + @lr * sum(d11)/ @dsize as w11, w12 + @lr * sum(d12)/ @dsize as w12, w13 + @lr * sum(d13)/ @dsize as w13, w14 + @lr * sum(d14)/ @dsize as w14,
w20 + @lr * sum(d20)/ @dsize as w20, w21 + @lr * sum(d21)/ @dsize as w21, w22 + @lr * sum(d22)/ @dsize as w22, w23 + @lr * sum(d23)/ @dsize as w23, w24 + @lr * sum(d24)/ @dsize as w24
from
(select
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04,
p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14,
p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24
from
(select
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
x0, x1, x2, x3, x4,
y0 - e0/(e0+e1+e2)as p0, y1 - e1/(e0+e1+e2)as p1, y2 - e2/(e0+e1+e2)as p2
from
(select
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
x0, x1, x2, x3, x4, y0, y1, y2,
exp(
w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4
)as e0,
exp(
w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4
)as e1,
exp(
w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4
)as e2
from data, weight) t1
)t2
    )t3;

一次迭代后,输出结果是模型参数,如下所示:

以下是核心代码部分,我使用递归CTE进行迭代训练:

set @num_iterations =1000;
Query OK,0 rows affected (0.00 sec)

其核心思想是,每次迭代的输入都是前一次迭代的结果,此外我添加了一个增量迭代变量来控制迭代次数。总体框架代码是:

with recursive cte(iter, weight)as
(
select1, init_weight
union all
select iter+1, new_weight
from cte
where ites < @num_iterations

接下来,我将迭代的SQL语句与这个迭代框架结合在一起。为了提高计算精度,我在中间结果中添加了类型转换:

with recursive weight( iter,
w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24)as
(
select1,
cast(0.1asDECIMAL(35,30)), cast(0.1asDECIMAL(35,30)), cast (0.1asDECIMAL(35,30)), cast(0.1asDECIMAL(35,30)), cast(0.1asDECIMAL(35,30)),
cast(0.1asDECIMAL(35,30)), cast(0.1asDECIMAL(35,30)), cast(0.1asDECIMAL(35,30)), cast(0.1asDECIMAL(35,30)), cast(0.1asDECIMAL(35,30)),
cast(0.1asDECIMAL(35,30)), cast(0.1asDECIMAL(35,30)), cast(0.1asDECIMAL(35,30)), cast(0.1asDECIMAL(35,30)), cast(0.1asDECIMAL(35,30))
union all
select
iter +1,
w00 + @lr * cast(sum(d00)asDECIMAL(35,30))/ @dsize as w00, w01 + @lr * cast(sum(d01)asDECIMAL(35,30))/ @dsize as w01, w02 + @lr * cast(sum(d02)asDECIMAL(35,30))/ @dsize as w02, w03 + @lr * cast(sum(d03)asDECIMAL(35,30))/ @dsize as w03, w04 + @lr * cast(sum(d04)asDECIMAL(35,30))/ @dsize as w04 ,
w10 + @lr * cast(sum(d10)asDECIMAL(35,30))/ @dsize as w10, w11 + @lr * cast(sum(d11)asDECIMAL(35,30))/ @dsize as w11, w12 + @lr * cast(sum(d12)asDECIMAL(35,30))/ @dsize as w12, w13 + @lr * cast(sum(d13)asDECIMAL(35,30))/ @dsize as w13, w14 + @lr * cast(sum(d14)asDECIMAL(35,30))/ @dsize as w14,
w20 + @lr * cast(sum(d20)asDECIMAL(35,30))/ @dsize as w20, w21 + @lr * cast(sum(d21)asDECIMAL(35,30))/ @dsize as w21, w22 + @lr * cast(sum(d22)asDECIMAL(35,30))/ @dsize as w22, w23 + @lr * cast(sum(d23)asDECIMAL(35,30))/ @dsize as w23, w24 + @lr * cast(sum(d24)asDECIMAL(35,30))/ @dsize as w24
from
(select
iter, w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
p0 * x0 as d00, p0 * x1 as d01, p0 * x2 as d02, p0 * x3 as d03, p0 * x4 as d04,
p1 * x0 as d10, p1 * x1 as d11, p1 * x2 as d12, p1 * x3 as d13, p1 * x4 as d14,
p2 * x0 as d20, p2 * x1 as d21, p2 * x2 as d22, p2 * x3 as d23, p2 * x4 as d24
from
(select
iter, w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
x0, x1, x2, x3, x4,
y0 - e0/(e0+e1+e2)as p0, y1 - e1/(e0+e1+e2)as p1, y2 - e2/(e0+e1+e2)as p2
from
(select
iter, w00, w01, w02, w03, w04,
w10, w11, w12, w13, w14,
w20, w21, w22, w23, w24,
x0, x1, x2, x3, x4, y0, y1, y2,
exp(
w00 * x0 + w01 * x1 + w02 * x2 + w03 * x3 + w04 * x4
)as e0,
exp(
w10 * x0 + w11 * x1 + w12 * x2 + w13 * x3 + w14 * x4
)as e1,
exp(
w20 * x0 + w21 * x1 + w22 * x2 + w23 * x3 + w24 * x4
)as e2
from data, weight where iter < @num_iterations) t1
)t2
)t3
havingcount(*)>0
)
select*from weight where iter = @num_iterations;

这个代码块和上面一次迭代的代码块之间有两个区别。在此代码块中:

  • 在 data join weight后面,我添加了where iter <@num_iterations以便控制迭代次数和要输出的iter + 1 as iter列。
  • 添加了count(*)>0,以防止聚合在最后没有输入数据时输出数据。此错误可能会导致迭代失败。

上述代码运行结果是:

ERROR 3577(HY000):In recursive query block of Recursive Common Table Expression 'weight', the recursive table must be referenced only once,andnotin any subquery

这表明递归CTE不允许在递归部分使用子查询。不过,我可以合并上面所有的子查询。但是,即使在我手动合并它们之后还是得到了以下错误提示:

ERROR 3575(HY000): Recursive Common Table Expression 'cte' can contain neither aggregation nor window functions in recursive query block

这表明不允许使用聚合函数。然后,我决定改变TiDB的实现代码。

根据​​提案​​​中的介绍,递归CTE的实现遵循了TiDB的基本执行框架。在咨询​​PingCAP​​的研发人员黄文军(Wenjun Huang)之后,我了解到子查询和聚合函数不被允许的原因有两个:

  • MySQL不允许这样做。
  • 如果允许,会有很多复杂的特殊情形需要克服。

但我只是想测试一下这些功能。为此,我暂时删除了​​diff​​中对子查询和聚合函数的检查。

最后,我再次执行修改后的代码,输出结果如下:

成功了!经过1000次迭代,我得到了参数。

接下来,我使用新参数重新计算正确的速率:

+--------------------------------------------------------------+
| sum(y0 = r0 and y1 = r1 and y2 = r2)/count(*)|
+--------------------------------------------------------------+
|0.9867|
+--------------------------------------------------------------+
1 row inset(0.02 sec)

这一次,准确率达到了98%。

结论

通过使用TiDB 5.1中的递归CTE,我成功地使用纯SQL在TiDB上训练了softmax逻辑回归模型。

在测试期间,我发现TiDB的递归CTE不允许子查询和聚合函数,所以我修改了TiDB的代码以绕过这些限制。最后,我成功地训练了一个模型,并在Iris数据集上获得了98%的准确率。

最后,作为补充,在我的上述工作中还总结了下面几个想法:

  • 在做了一些测试之后,我发现PostgreSQL和MySQL都不支持递归CTE中的聚合函数,可能是因为有一些棘手的情形难以处理吧。
  • 在这次测试中,我手动扩展了向量的所有维度。事实上,我还编写了一个不需要扩展所有维度的实现。例如,数据表的模式是(idx,dim,value),但在这个实现中,权重表需要连接两次。这意味着需要在CTE中访问两次,为此还需要修改TiDB执行器的实现代码。由于这一原因,我没有在本文中讨论这一实现。但事实上,这种实现更通用,可以用它来处理MNIST数据集等更多维度的模型。

原文标题:​​I Trained a Machine Learning Model in Pure SQL​​,作者:Mingcong Han

文章来源网络,作者:运维,如若转载,请注明出处:https://shuyeidc.com/wp/279801.html<

(0)
运维的头像运维
上一篇2025-05-12 05:42
下一篇 2025-05-12 05:43

相关推荐

  • 个人主题怎么制作?

    制作个人主题是一个将个人风格、兴趣或专业领域转化为视觉化或结构化内容的过程,无论是用于个人博客、作品集、社交媒体账号还是品牌形象,核心都是围绕“个人特色”展开,以下从定位、内容规划、视觉设计、技术实现四个维度,详细拆解制作个人主题的完整流程,明确主题定位:找到个人特色的核心主题定位是所有工作的起点,需要先回答……

    2025-11-20
    0
  • 社群营销管理关键是什么?

    社群营销的核心在于通过建立有温度、有价值、有归属感的社群,实现用户留存、转化和品牌传播,其管理需贯穿“目标定位-内容运营-用户互动-数据驱动-风险控制”全流程,以下从五个维度展开详细说明:明确社群定位与目标社群管理的首要任务是精准定位,需明确社群的核心价值(如行业交流、产品使用指导、兴趣分享等)、目标用户画像……

    2025-11-20
    0
  • 香港公司网站备案需要什么材料?

    香港公司进行网站备案是一个涉及多部门协调、流程相对严谨的过程,尤其需兼顾中国内地与香港两地的监管要求,由于香港公司注册地与中国内地不同,其网站若主要服务内地用户或使用内地服务器,需根据服务器位置、网站内容性质等,选择对应的备案路径(如工信部ICP备案或公安备案),以下从备案主体资格、流程步骤、材料准备、注意事项……

    2025-11-20
    0
  • 如何企业上云推广

    企业上云已成为数字化转型的核心战略,但推广过程中需结合行业特性、企业痛点与市场需求,构建系统性、多维度的推广体系,以下从市场定位、策略设计、执行落地及效果优化四个维度,详细拆解企业上云推广的实践路径,精准定位:明确目标企业与核心价值企业上云并非“一刀切”的方案,需先锁定目标客户群体,提炼差异化价值主张,客户分层……

    2025-11-20
    0
  • PS设计搜索框的实用技巧有哪些?

    在PS中设计一个美观且功能性的搜索框需要结合创意构思、视觉设计和用户体验考量,以下从设计思路、制作步骤、细节优化及交互预览等方面详细说明,帮助打造符合需求的搜索框,设计前的规划明确使用场景:根据网站或APP的整体风格确定搜索框的调性,例如极简风适合细线条和纯色,科技感适合渐变和发光效果,电商类则可能需要突出搜索……

    2025-11-20
    0

发表回复

您的邮箱地址不会被公开。必填项已用 * 标注